In [None]:
import os
import torch
import torchaudio
import librosa
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from google.colab import drive

# Mount Google Drive for saving checkpoints and precomputed data
drive.mount('/content/drive')
CHECKPOINT_DIR = '/content/drive/MyDrive/checkpoints'
PRECOMPUTE_DIR = '/content/drive/MyDrive/precomputed_spectrograms'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(PRECOMPUTE_DIR, exist_ok=True)

# Preprocessing function
def preprocess_audio(file_path, sr=16000, top_db=30, duration=5, n_mels=128, n_fft=1024, hop_length=256):
    """
    Preprocess audio: load, trim, pad/truncate, normalize, convert to mel spectrogram.

    Args:
        file_path (str): Path to audio file
        sr (int): Sampling rate (16000 Hz)
        top_db (int): Trimming threshold (30 dB)
        duration (int): Target duration (5 seconds)
        n_mels (int): Number of mel bins
        n_fft (int): FFT window size
        hop_length (int): Hop length for STFT

    Returns:
        np.ndarray: Mel spectrogram in dB, normalized to [0,1]
    """
    y, _ = librosa.load(file_path, sr=sr)
    y, _ = librosa.effects.trim(y, top_db=top_db)
    target_length = sr * duration
    if len(y) < target_length:
        y = np.pad(y, (0, target_length - len(y)))
    else:
        y = y[:target_length]
    y = y / np.max(np.abs(y)) if np.max(np.abs(y)) != 0 else y  # Normalize to [-1,1]
    S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
    S_dB = librosa.power_to_db(S, ref=np.max)
    S_dB = (S_dB - np.min(S_dB)) / (np.max(S_dB) - np.min(S_dB))  # Normalize to [0,1]
    return S_dB

# Function to precompute spectrograms
def precompute_spectrograms(clean_dir, noisy_dir, save_dir):
    clean_files = sorted(os.listdir(clean_dir))
    noisy_files = sorted(os.listdir(noisy_dir))
    for i, (clean_file, noisy_file) in enumerate(zip(clean_files, noisy_files)):
        clean_mel = preprocess_audio(os.path.join(clean_dir, clean_file))
        noisy_mel = preprocess_audio(os.path.join(noisy_dir, noisy_file))
        np.save(os.path.join(save_dir, f'clean_{i}.npy'), clean_mel)
        np.save(os.path.join(save_dir, f'noisy_{i}.npy'), noisy_mel)
    print(f"Precomputed spectrograms saved to {save_dir}")


Mounted at /content/drive


In [2]:
# Dataset class
class VoiceBankDataset(Dataset):
    def __init__(self, clean_dir, noisy_dir, precompute_dir=None):
        """
        Dataset for VoiceBank clean and noisy audio pairs.
        If precompute_dir is provided, loads precomputed spectrograms.
        """
        self.clean_files = sorted(os.listdir(clean_dir))
        self.noisy_files = sorted(os.listdir(noisy_dir))
        self.clean_dir = clean_dir
        self.noisy_dir = noisy_dir
        self.precompute_dir = precompute_dir
        if precompute_dir:
            self.clean_mels = [os.path.join(precompute_dir, f'clean_{i}.npy') for i in range(len(self.clean_files))]
            self.noisy_mels = [os.path.join(precompute_dir, f'noisy_{i}.npy') for i in range(len(self.noisy_files))]

    def __len__(self):
        return len(self.clean_files)

    def __getitem__(self, idx):
        if self.precompute_dir:
            clean_mel = np.load(self.clean_mels[idx])
            noisy_mel = np.load(self.noisy_mels[idx])
        else:
            clean_mel = preprocess_audio(os.path.join(self.clean_dir, self.clean_files[idx]))
            noisy_mel = preprocess_audio(os.path.join(self.noisy_dir, self.noisy_files[idx]))
        return torch.tensor(noisy_mel, dtype=torch.float32), torch.tensor(clean_mel, dtype=torch.float32)


In [None]:
# Generator model (BLSTM + CNN)
class Generator(nn.Module):
    def __init__(self, n_mels=128, time_frames=313):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.blstm = nn.LSTM(32 * n_mels, 256, bidirectional=True, batch_first=True)
        self.dense = nn.Linear(512, n_mels)  # 512 = 256 * 2 (bidirectional)

    def forward(self, x):
        x = nn.functional.leaky_relu(self.conv1(x), 0.2)
        x = nn.functional.leaky_relu(self.conv2(x), 0.2)
        batch, channels, height, time = x.shape
        x = x.permute(0, 3, 2, 1).reshape(batch, time, channels * height)
        x, _ = self.blstm(x)
        x = self.dense(x)
        x = x.permute(0, 2, 1).unsqueeze(1)
        return torch.sigmoid(x)  # output is in [0,1]

# Discriminator model (CNN)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
        self.flatten = nn.Flatten()
        self.dense = nn.Linear(512 * 8 * 19, 1)

    def forward(self, x):
        x = nn.functional.leaky_relu(self.conv1(x), 0.2)
        x = nn.functional.leaky_relu(self.conv2(x), 0.2)
        x = nn.functional.leaky_relu(self.conv3(x), 0.2)
        x = nn.functional.leaky_relu(self.conv4(x), 0.2)
        x = self.flatten(x)
        return torch.sigmoid(self.dense(x))

In [4]:
# Training function
def train(generator, discriminator, train_loader, num_epochs, device, lambda_l1=100):
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))  # Reduced LR for D
    criterion_bce = nn.BCELoss()
    criterion_l1 = nn.L1Loss()

    # Load checkpoint if exists
    start_epoch = 0
    checkpoint_path = os.path.join(CHECKPOINT_DIR, 'latest_checkpoint.pth')
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch}")

    for epoch in range(start_epoch, num_epochs):
        for batch_idx, (noisy_mel, clean_mel) in enumerate(train_loader):
            noisy_mel = noisy_mel.unsqueeze(1).to(device)  # Add channel dim: [batch, 1, 128, 313]
            clean_mel = clean_mel.unsqueeze(1).to(device)
            batch_size = noisy_mel.size(0)

            # Labels
            real_labels = torch.full((batch_size, 1), 0.9, device=device)  # Label smoothing
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # ---------------------
            # Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            output_real = discriminator(clean_mel)
            loss_D_real = criterion_bce(output_real, real_labels)

            enhanced_mel = generator(noisy_mel)
            output_fake = discriminator(enhanced_mel.detach())
            loss_D_fake = criterion_bce(output_fake, fake_labels)

            loss_D = loss_D_real + loss_D_fake
            loss_D.backward()
            optimizer_D.step()

            # -----------------
            # Train Generator
            # -----------------
            optimizer_G.zero_grad()
            output_fake = discriminator(enhanced_mel)
            loss_G_bce = criterion_bce(output_fake, real_labels)  # Fool discriminator
            loss_G_l1 = criterion_l1(enhanced_mel, clean_mel)     # Reconstruction loss
            loss_G = loss_G_bce + lambda_l1 * loss_G_l1
            loss_G.backward()
            optimizer_G.step()

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}/{num_epochs-1}, Batch {batch_idx}/{len(train_loader)-1}, "
                      f"Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

        # Save checkpoint after each epoch
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_G_state_dict': optimizer_G.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
        }, checkpoint_path)
        print(f"Checkpoint saved at epoch {epoch}")


In [None]:
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Dataset paths (replace with your VoiceBank paths)
clean_train_dir = '/content/dataset/clean_trainset_28spk_wav'
noisy_train_dir = '/content/dataset/noisy_trainset_28spk_wav'
clean_test_dir = '/content/dataset/clean_testset_wav'
noisy_test_dir = '/content/dataset/noisy_testset_wav'

# Precompute spectrograms to speed up training 
#precompute_spectrograms(clean_train_dir, noisy_train_dir, PRECOMPUTE_DIR)

# Load dataset
train_dataset = VoiceBankDataset(clean_train_dir, noisy_train_dir, precompute_dir=PRECOMPUTE_DIR)
test_dataset = VoiceBankDataset(clean_test_dir, noisy_test_dir, precompute_dir=PRECOMPUTE_DIR)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)

# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Train
train(generator, discriminator, train_loader, num_epochs=50, device=device)

Using device: cuda
Resuming training from epoch 1
Epoch 1/49, Batch 0/723, Loss D: 0.5114, Loss G: 12.1968
Epoch 1/49, Batch 100/723, Loss D: 0.7939, Loss G: 10.6354
Epoch 1/49, Batch 200/723, Loss D: 0.8932, Loss G: 12.2324
Epoch 1/49, Batch 300/723, Loss D: 0.9438, Loss G: 12.5707
Epoch 1/49, Batch 400/723, Loss D: 0.7437, Loss G: 11.6717
Epoch 1/49, Batch 500/723, Loss D: 1.3386, Loss G: 11.2023
Epoch 1/49, Batch 600/723, Loss D: 0.9567, Loss G: 11.6519
Epoch 1/49, Batch 700/723, Loss D: 0.8060, Loss G: 11.2174
Checkpoint saved at epoch 1
Epoch 2/49, Batch 0/723, Loss D: 1.0915, Loss G: 9.6956
Epoch 2/49, Batch 100/723, Loss D: 0.8416, Loss G: 10.0197
Epoch 2/49, Batch 200/723, Loss D: 0.6451, Loss G: 10.2824
Epoch 2/49, Batch 300/723, Loss D: 0.9079, Loss G: 10.0525
Epoch 2/49, Batch 400/723, Loss D: 0.8713, Loss G: 10.4867
Epoch 2/49, Batch 500/723, Loss D: 0.8588, Loss G: 10.7408
Epoch 2/49, Batch 600/723, Loss D: 0.8769, Loss G: 10.2894
Epoch 2/49, Batch 700/723, Loss D: 0.7795,