In [4]:

!pip install torch torchaudio torchvision

import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import spectral_norm
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import copy

from google.colab import drive
drive.mount('/content/drive')

# 1. DATASET WITH NORMALIZATION

class ImprovedAudioDataset(Dataset):
    def __init__(self, root_dir, categories, max_frames=512, fraction=1.0):
        self.root_dir = root_dir
        self.categories = categories
        self.max_frames = max_frames
        self.file_list = []
        self.class_to_idx = {cat: i for i, cat in enumerate(categories)}

        for cat_name in self.categories:
            cat_dir = os.path.join(root_dir, cat_name)
            files_in_cat = [os.path.join(cat_dir, f) for f in os.listdir(cat_dir) if f.endswith(".wav")]
            num_to_sample = int(len(files_in_cat) * fraction)
            sampled_files = random.sample(files_in_cat, num_to_sample)
            label_idx = self.class_to_idx[cat_name]
            self.file_list.extend([(file_path, label_idx) for file_path in sampled_files])

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path, label = self.file_list[idx]
        wav, sr = torchaudio.load(path)
        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)

        mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sr, n_fft=1024, hop_length=256, n_mels=128
        )(wav)
        log_spec = torch.log1p(mel_spec)

        # Normalize to [-1, 1] for better training
        log_spec = (log_spec - log_spec.mean()) / (log_spec.std() + 1e-8)

        _, _, n_frames = log_spec.shape
        if n_frames < self.max_frames:
            pad = self.max_frames - n_frames
            log_spec = F.pad(log_spec, (0, pad))
        else:
            log_spec = log_spec[:, :, :self.max_frames]

        label_vec = F.one_hot(torch.tensor(label), num_classes=len(self.categories)).float()
        return log_spec, label_vec

# 2. IMPROVED GENERATOR

class ImprovedGenerator(nn.Module):
    def __init__(self, latent_dim, num_categories, spec_shape=(128, 512)):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_categories = num_categories
        self.spec_shape = spec_shape

        self.fc = nn.Sequential(
            nn.Linear(latent_dim + num_categories, 256 * 8 * 32),
            nn.BatchNorm1d(256 * 8 * 32),
            nn.ReLU(inplace=True)
        )
        self.unflatten_shape = (256, 8, 32)

        self.net = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 16x64
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 32x128
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1),    # 64x256
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, 1, 4, 2, 1),     # 128x512
            nn.Tanh()  # Output in [-1, 1] to match normalized input
        )

    def forward(self, z, y):
        h = torch.cat([z, y], dim=1)
        h = self.fc(h)
        h = h.view(-1, *self.unflatten_shape)
        return self.net(h)

# 3. IMPROVED DISCRIMINATOR WITH SPECTRAL NORMALIZATION


class ImprovedDiscriminator(nn.Module):
    def __init__(self, num_categories, spec_shape=(128, 512)):
        super().__init__()
        self.num_categories = num_categories
        self.spec_shape = spec_shape
        H, W = spec_shape

        self.label_embedding = nn.Linear(num_categories, H * W)

        # Spectral normalization for Lipschitz constraint
        self.conv1 = spectral_norm(nn.Conv2d(2, 64, 4, 2, 1))
        self.conv2 = spectral_norm(nn.Conv2d(64, 128, 4, 2, 1))
        self.conv3 = spectral_norm(nn.Conv2d(128, 256, 4, 2, 1))
        self.conv4 = spectral_norm(nn.Conv2d(256, 512, 4, 2, 1))
        self.conv5 = spectral_norm(nn.Conv2d(512, 1, (8, 32), 1, 0))

    def forward(self, spec, y):
        label_map = self.label_embedding(y).view(-1, 1, *self.spec_shape)
        h = torch.cat([spec, label_map], dim=1)

        h = F.leaky_relu(self.conv1(h), 0.2)
        h = F.leaky_relu(self.conv2(h), 0.2)
        h = F.leaky_relu(self.conv3(h), 0.2)
        h = F.leaky_relu(self.conv4(h), 0.2)
        logit = self.conv5(h)

        return logit.view(-1, 1)

# 4. WGAN-GP: GRADIENT PENALTY

def compute_gradient_penalty(discriminator, real_samples, fake_samples, labels, device):
    batch_size = real_samples.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)

    d_interpolates = discriminator(interpolates, labels)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# 5. EXPONENTIAL MOVING AVERAGE FOR GENERATOR

def update_ema(ema_model, model, decay=0.999):
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)

# 6. CHECKPOINT MANAGEMENT

def save_checkpoint(gen, disc, opt_g, opt_d, epoch, path):
    torch.save({
        'epoch': epoch,
        'gen_state': gen.state_dict(),
        'disc_state': disc.state_dict(),
        'opt_g_state': opt_g.state_dict(),
        'opt_d_state': opt_d.state_dict()
    }, path)
    print(f"âœ“ Checkpoint saved: {path}")

def load_checkpoint(gen, disc, opt_g, opt_d, path, device):
    checkpoint = torch.load(path, map_location=device)
    gen.load_state_dict(checkpoint['gen_state'])
    disc.load_state_dict(checkpoint['disc_state'])
    opt_g.load_state_dict(checkpoint['opt_g_state'])
    opt_d.load_state_dict(checkpoint['opt_d_state'])
    print(f"âœ“ Loaded checkpoint from epoch {checkpoint['epoch']}")
    return checkpoint['epoch']

# 7. IMPROVED AUDIO GENERATION

def generate_audio(generator, category_idx, num_samples, device, categories,
                   mean=0, std=1, sample_rate=22050):
    generator.eval()
    y = F.one_hot(torch.tensor([category_idx]), num_classes=len(categories)).float().to(device)
    z = torch.randn(num_samples, generator.latent_dim, device=device)

    with torch.no_grad():
        log_spec_gen = generator(z, y)

    # Denormalize
    log_spec_gen = log_spec_gen * std + mean
    spec_gen = torch.expm1(log_spec_gen).squeeze(1)

    inverse_mel = torchaudio.transforms.InverseMelScale(
        n_stft=513, n_mels=128, sample_rate=sample_rate
    ).to(device)
    linear_spec = inverse_mel(spec_gen)

    # More Griffin-Lim iterations for better quality
    griffin = torchaudio.transforms.GriffinLim(
        n_fft=1024, hop_length=256, win_length=1024, n_iter=64
    ).to(device)

    waveform = griffin(linear_spec)
    return waveform.cpu()

# 8. TRAINING FUNCTION WITH WGAN-GP


def train_improved_gan(generator, discriminator, dataloader, device, categories,
                       epochs, lr_g=1e-4, lr_d=4e-4, latent_dim=100,
                       lambda_gp=10, n_critic=5, ema_decay=0.999,
                       resume_path=None):

    # Separate learning rates (TTUR: Two Time-scale Update Rule)
    opt_g = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=(0.0, 0.9))
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=(0.0, 0.9))

    # EMA generator for better samples
    ema_generator = copy.deepcopy(generator).to(device)

    # Create directories
    os.makedirs("improved_audio", exist_ok=True)
    os.makedirs("improved_plots", exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)

    start_epoch = 1
    if resume_path and os.path.exists(resume_path):
        start_epoch = load_checkpoint(generator, discriminator, opt_g, opt_d, resume_path, device) + 1

    # Compute dataset statistics for denormalization
    all_specs = []
    for batch_specs, _ in dataloader:
        all_specs.append(batch_specs)
        if len(all_specs) >= 10:  # Sample from first 10 batches
            break
    all_specs = torch.cat(all_specs, dim=0)
    data_mean = all_specs.mean().item()
    data_std = all_specs.std().item()

    for epoch in range(start_epoch, epochs + 1):
        generator.train()
        discriminator.train()

        loop = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}")
        epoch_d_loss = 0
        epoch_g_loss = 0

        for batch_idx, (real_specs, labels) in enumerate(loop):
            real_specs = real_specs.to(device)
            labels = labels.to(device)
            batch_size = real_specs.size(0)

            # Train Discriminator (n_critic times)
            for _ in range(n_critic):
                opt_d.zero_grad()

                real_validity = discriminator(real_specs, labels)

                z = torch.randn(batch_size, latent_dim, device=device)
                fake_specs = generator(z, labels)
                fake_validity = discriminator(fake_specs.detach(), labels)

                gp = compute_gradient_penalty(discriminator, real_specs, fake_specs, labels, device)
                d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp

                d_loss.backward()
                opt_d.step()

            # Train Generator
            opt_g.zero_grad()

            z = torch.randn(batch_size, latent_dim, device=device)
            fake_specs = generator(z, labels)
            fake_validity = discriminator(fake_specs, labels)
            g_loss = -torch.mean(fake_validity)

            g_loss.backward()
            opt_g.step()

            # Update EMA generator
            update_ema(ema_generator, generator, ema_decay)

            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()

            loop.set_postfix(D_loss=d_loss.item(), G_loss=g_loss.item(), GP=gp.item())

        avg_d_loss = epoch_d_loss / len(dataloader)
        avg_g_loss = epoch_g_loss / len(dataloader)
        print(f"\nðŸ“Š Epoch {epoch}: D_Loss={avg_d_loss:.4f}, G_Loss={avg_g_loss:.4f}")

        # Generate samples every 10 epochs
        if epoch % 10 == 0 or epoch == 1:
            print(f"\nðŸŽ¨ Generating samples (Epoch {epoch})...")
            ema_generator.eval()

            fig, axes = plt.subplots(1, len(categories), figsize=(4 * len(categories), 4))
            if len(categories) == 1:
                axes = [axes]

            for cat_idx, cat_name in enumerate(categories):
                y_cond = F.one_hot(torch.tensor([cat_idx]), num_classes=len(categories)).float().to(device)
                z_sample = torch.randn(1, latent_dim).to(device)

                with torch.no_grad():
                    spec_gen = ema_generator(z_sample, y_cond)

                spec_np = spec_gen.squeeze().cpu().numpy()
                axes[cat_idx].imshow(spec_np, aspect='auto', origin='lower', cmap='viridis')
                axes[cat_idx].set_title(f'{cat_name}\n(Epoch {epoch})')
                axes[cat_idx].axis('off')

            plt.tight_layout()
            plt.savefig(f'improved_plots/epoch_{epoch:03d}.png')
            plt.show()
            plt.close(fig)

            # Generate audio
            for cat_idx, cat_name in enumerate(categories):
                wav = generate_audio(ema_generator, cat_idx, 1, device, categories,
                                    data_mean, data_std)
                fname = f"improved_audio/{cat_name}_ep{epoch:03d}.wav"
                torchaudio.save(fname, wav, sample_rate=22050)
                print(f"ðŸ’¾ Saved: {fname}")
                display(Audio(data=wav.numpy(), rate=22050))

        # Save checkpoint every 25 epochs
        if epoch % 25 == 0:
            save_checkpoint(generator, discriminator, opt_g, opt_d, epoch,
                          f'checkpoints/checkpoint_epoch_{epoch:03d}.pth')

# 9. MAIN EXECUTION

if __name__ == '__main__':
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    LATENT_DIM = 128
    EPOCHS = 200
    BATCH_SIZE = 32
    LR_G = 1e-4  # Generator learning rate
    LR_D = 4e-4  # Discriminator learning rate (4x generator for WGAN-GP)
    LAMBDA_GP = 10
    N_CRITIC = 5
    EMA_DECAY = 0.999

    BASE_PATH = 'drive/MyDrive/organized_dataset/'
    TRAIN_PATH = os.path.join(BASE_PATH, 'train')
    train_categories = sorted([d for d in os.listdir(TRAIN_PATH)
                              if os.path.isdir(os.path.join(TRAIN_PATH, d))])

    print(f" Device: {DEVICE}")
    print(f" Categories ({len(train_categories)}): {train_categories}")

    train_dataset = ImprovedAudioDataset(TRAIN_PATH, train_categories)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, num_workers=2, pin_memory=True)

    generator = ImprovedGenerator(LATENT_DIM, len(train_categories)).to(DEVICE)
    discriminator = ImprovedDiscriminator(len(train_categories)).to(DEVICE)

    print(f"\n Generator params: {sum(p.numel() for p in generator.parameters()):,}")
    print(f"Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")

    train_improved_gan(
        generator=generator,
        discriminator=discriminator,
        dataloader=train_loader,
        device=DEVICE,
        categories=train_categories,
        epochs=EPOCHS,
        lr_g=LR_G,
        lr_d=LR_D,
        latent_dim=LATENT_DIM,
        lambda_gp=LAMBDA_GP,
        n_critic=N_CRITIC,
        ema_decay=EMA_DECAY,
        resume_path=None  # Set to checkpoint path to resume
    )

    print("\nTraining complete!")


Output hidden; open in https://colab.research.google.com to view.