# üéôÔ∏è Speech Denoising - Google Colab Training

This notebook allows you to train the Speech Denoising U-Net model on Google Colab with GPU acceleration.

## Overview
- **Model**: U-Net with Complex Ratio Mask (CRM)
- **Dataset**: VoiceBank + DEMAND
- **Training Time**: ~1-2 hours on Colab GPU (T4/P100)

---

## 1Ô∏è‚É£ Setup Environment

In [None]:
# Check if GPU is available
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è No GPU detected! Go to Runtime > Change runtime type > GPU")

In [None]:
# Clone the repository (uncomment if running from GitHub)
# !git clone https://github.com/YOUR_USERNAME/speech_denoising.git
# %cd speech_denoising

# Or upload files manually and set the working directory
import os
# Uncomment the next line if you uploaded the project to a specific folder
# os.chdir('/content/speech_denoising')

print(f"Working directory: {os.getcwd()}")

In [None]:
# Install dependencies
!pip install -q torch torchaudio --upgrade
!pip install -q librosa soundfile scipy numpy pandas
!pip install -q pystoi matplotlib seaborn tensorboard
!pip install -q tqdm pyyaml

# Optional: Install PESQ for additional metrics (may fail on some systems)
!pip install -q pesq || echo "PESQ installation failed - continuing without it"

print("\n‚úÖ Dependencies installed!")

## 2Ô∏è‚É£ Download Dataset

You have two options:
1. **Option A**: Download directly to Colab (faster, but lost when session ends)
2. **Option B**: Mount Google Drive and store dataset there (persistent)

In [None]:
# Option A: Download directly to Colab
# This is faster but data will be lost when the session ends

DOWNLOAD_DIRECTLY = True  # Set to False if using Google Drive

if DOWNLOAD_DIRECTLY:
    !mkdir -p data
    
    # Download VoiceBank + DEMAND dataset from Edinburgh DataShare
    # Note: These are large files (~3.3 GB total)
    
    print("üì• Downloading VoiceBank + DEMAND dataset...")
    print("This may take 10-20 minutes depending on connection speed.\n")
    
    # Dataset URLs from Edinburgh DataShare
    BASE_URL = "https://datashare.ed.ac.uk/bitstream/handle/10283/2791"
    
    files = [
        "clean_trainset_28spk_wav.zip",
        "noisy_trainset_28spk_wav.zip", 
        "clean_testset_wav.zip",
        "noisy_testset_wav.zip"
    ]
    
    for f in files:
        if not os.path.exists(f"data/{f.replace('.zip', '')}"):
            print(f"Downloading {f}...")
            !wget -q --show-progress -O data/{f} "{BASE_URL}/{f}?sequence=1&isAllowed=y"
            print(f"Extracting {f}...")
            !cd data && unzip -q {f} && rm {f}
        else:
            print(f"‚úì {f.replace('.zip', '')} already exists")
    
    print("\n‚úÖ Dataset downloaded and extracted!")

In [None]:
# Option B: Mount Google Drive (persistent storage)
# Uncomment this cell if you want to use Google Drive

USE_GOOGLE_DRIVE = False  # Set to True to use Google Drive
GDRIVE_DATA_PATH = "/content/drive/MyDrive/speech_denoising_data"  # Change this path

if USE_GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Create symbolic link to data directory
    !mkdir -p "{GDRIVE_DATA_PATH}"
    !ln -sf "{GDRIVE_DATA_PATH}" data
    
    print(f"\nüìÅ Data directory linked to: {GDRIVE_DATA_PATH}")
    print("If dataset is not there, download it using Option A first,")
    print("then copy the data folder to Google Drive for future sessions.")

In [None]:
# Verify dataset
import os
from pathlib import Path

data_dir = Path("data")
required_dirs = [
    "clean_trainset_28spk_wav",
    "noisy_trainset_28spk_wav",
    "clean_testset_wav",
    "noisy_testset_wav"
]

print("üìä Dataset verification:")
all_ok = True
for dir_name in required_dirs:
    dir_path = data_dir / dir_name
    if dir_path.exists():
        count = len(list(dir_path.glob("*.wav")))
        print(f"  ‚úÖ {dir_name}: {count} files")
    else:
        print(f"  ‚ùå {dir_name}: NOT FOUND")
        all_ok = False

if all_ok:
    print("\n‚úÖ Dataset is ready for training!")
else:
    print("\n‚ö†Ô∏è Some directories are missing. Please download the dataset.")

## 3Ô∏è‚É£ Configuration

Adjust training parameters here. The defaults are optimized for Colab GPU.

In [None]:
# Training configuration for Colab
# These settings are optimized for Colab's GPU (T4/P100)

CONFIG = {
    # Data paths
    'data': {
        'train_clean_dir': './data/clean_trainset_28spk_wav',
        'train_noisy_dir': './data/noisy_trainset_28spk_wav',
        'test_clean_dir': './data/clean_testset_wav',
        'test_noisy_dir': './data/noisy_testset_wav',
        'sample_rate': 16000,
        'segment_length': 32000,  # 2 seconds
    },
    
    # STFT parameters
    'stft': {
        'n_fft': 512,
        'hop_length': 128,
        'win_length': 512,
    },
    
    # Model parameters
    'model': {
        'name': 'UNetDenoiser',
        'encoder_channels': [32, 64, 128, 256, 512],
        'use_attention': True,
        'dropout': 0.1,
    },
    
    # Training parameters (optimized for Colab)
    'training': {
        'batch_size': 8,  # Reduced for Colab GPU memory
        'num_epochs': 50,  # Reduced for faster training
        'learning_rate': 0.0001,
        'weight_decay': 1e-5,
        'scheduler': {
            'patience': 5,
            'factor': 0.5,
            'min_lr': 1e-6,
        },
        'early_stopping_patience': 10,
        'grad_clip': 5.0,
        'num_workers': 2,  # Reduced for Colab
        'use_amp': True,  # Mixed precision for faster training
    },
    
    # Loss
    'loss': {
        'l1_weight': 1.0,
        'stft_weight': 1.0,
    },
    
    # Checkpoints
    'checkpoint': {
        'save_dir': './checkpoints',
        'save_every': 5,
        'keep_last': 3,
    },
    
    # Logging
    'logging': {
        'log_dir': './logs',
        'log_every': 50,
    },
    
    # Evaluation
    'eval': {
        'output_dir': './outputs',
        'compute_pesq': True,
        'compute_stoi': True,
    }
}

print("üìã Configuration loaded!")
print(f"  Batch size: {CONFIG['training']['batch_size']}")
print(f"  Epochs: {CONFIG['training']['num_epochs']}")
print(f"  Learning rate: {CONFIG['training']['learning_rate']}")

## 4Ô∏è‚É£ Initialize Model and Data

In [None]:
import sys
sys.path.insert(0, '.')

from data.dataset import VoiceBankDEMANDDataset, create_dataloaders
from models.unet import UNetDenoiser
from models.loss import DenoiserLoss
from utils.metrics import evaluate_batch, is_pesq_available
from utils.audio_utils import AudioProcessor

print("‚úÖ Modules imported successfully!")

In [None]:
# Create dataloaders
data_cfg = CONFIG['data']
stft_cfg = CONFIG['stft']
train_cfg = CONFIG['training']

print("üìÇ Loading dataset...")
train_loader, val_loader = create_dataloaders(
    train_clean_dir=data_cfg['train_clean_dir'],
    train_noisy_dir=data_cfg['train_noisy_dir'],
    test_clean_dir=data_cfg['test_clean_dir'],
    test_noisy_dir=data_cfg['test_noisy_dir'],
    sample_rate=data_cfg['sample_rate'],
    segment_length=data_cfg['segment_length'],
    batch_size=train_cfg['batch_size'],
    num_workers=train_cfg['num_workers'],
    n_fft=stft_cfg['n_fft'],
    hop_length=stft_cfg['hop_length'],
    win_length=stft_cfg['win_length']
)

print(f"\n‚úÖ Data loaded!")

In [None]:
# Create model
model_cfg = CONFIG['model']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = UNetDenoiser(
    in_channels=2,
    out_channels=2,
    encoder_channels=model_cfg['encoder_channels'],
    use_attention=model_cfg['use_attention'],
    dropout=model_cfg['dropout'],
    mask_type='CRM'
).to(device)

print(f"üß† Model: {model_cfg['name']}")
print(f"   Parameters: {model.count_parameters():,}")
print(f"   Device: {device}")

## 5Ô∏è‚É£ Training

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from tqdm.notebook import tqdm
from pathlib import Path
from datetime import datetime

# Create directories
ckpt_dir = Path(CONFIG['checkpoint']['save_dir'])
ckpt_dir.mkdir(parents=True, exist_ok=True)

# Loss function
loss_cfg = CONFIG['loss']
criterion = DenoiserLoss(
    complex_weight=loss_cfg['l1_weight'],
    magnitude_weight=1.0,
    stft_weight=loss_cfg['stft_weight'],
    use_mr_stft=True,
    n_fft=stft_cfg['n_fft'],
    hop_length=stft_cfg['hop_length'],
    win_length=stft_cfg['win_length']
).to(device)

# Audio processor
audio_processor = AudioProcessor(
    n_fft=stft_cfg['n_fft'],
    hop_length=stft_cfg['hop_length'],
    win_length=stft_cfg['win_length']
)

# Optimizer and scheduler
optimizer = optim.AdamW(
    model.parameters(),
    lr=train_cfg['learning_rate'],
    weight_decay=train_cfg['weight_decay']
)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=train_cfg['scheduler']['factor'],
    patience=train_cfg['scheduler']['patience'],
    min_lr=train_cfg['scheduler']['min_lr']
)

# Mixed precision scaler
scaler = GradScaler() if train_cfg['use_amp'] else None

print("‚úÖ Training components initialized!")

In [None]:
# Training function
def train_epoch(model, train_loader, optimizer, criterion, audio_processor, device, scaler=None):
    model.train()
    total_loss = 0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc="Training")
    for batch in pbar:
        noisy_stft = batch['noisy_stft'].to(device)
        clean_stft = batch['clean_stft'].to(device)
        clean_wav = batch['clean'].to(device)
        
        # Reshape: [batch, freq, time, 2] -> [batch, 2, freq, time]
        noisy_stft = noisy_stft.permute(0, 3, 1, 2)
        clean_stft = clean_stft.permute(0, 3, 1, 2)
        
        optimizer.zero_grad()
        
        if scaler is not None:
            with autocast():
                pred_stft = model(noisy_stft)
                
                # Reconstruct waveform
                pred_stft_istft = pred_stft.permute(0, 2, 3, 1)
                pred_wav = audio_processor.istft(pred_stft_istft)
                
                min_len = min(pred_wav.shape[-1], clean_wav.shape[-1])
                pred_wav = pred_wav[..., :min_len]
                clean_wav_trim = clean_wav[..., :min_len]
                
                losses = criterion(pred_stft, clean_stft, pred_wav, clean_wav_trim)
            
            scaler.scale(losses['total_loss']).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), train_cfg['grad_clip'])
            scaler.step(optimizer)
            scaler.update()
        else:
            pred_stft = model(noisy_stft)
            pred_stft_istft = pred_stft.permute(0, 2, 3, 1)
            pred_wav = audio_processor.istft(pred_stft_istft)
            
            min_len = min(pred_wav.shape[-1], clean_wav.shape[-1])
            pred_wav = pred_wav[..., :min_len]
            clean_wav_trim = clean_wav[..., :min_len]
            
            losses = criterion(pred_stft, clean_stft, pred_wav, clean_wav_trim)
            losses['total_loss'].backward()
            nn.utils.clip_grad_norm_(model.parameters(), train_cfg['grad_clip'])
            optimizer.step()
        
        total_loss += losses['total_loss'].item()
        num_batches += 1
        pbar.set_postfix({'loss': f"{losses['total_loss'].item():.4f}"})
    
    return total_loss / num_batches


@torch.no_grad()
def validate(model, val_loader, criterion, audio_processor, device):
    model.eval()
    total_loss = 0
    metrics = {'stoi': 0, 'si_sdr': 0}
    num_batches = 0
    
    for batch in tqdm(val_loader, desc="Validating"):
        noisy_stft = batch['noisy_stft'].to(device)
        clean_stft = batch['clean_stft'].to(device)
        clean_wav = batch['clean'].to(device)
        
        noisy_stft = noisy_stft.permute(0, 3, 1, 2)
        clean_stft = clean_stft.permute(0, 3, 1, 2)
        
        pred_stft = model(noisy_stft)
        pred_stft_istft = pred_stft.permute(0, 2, 3, 1)
        pred_wav = audio_processor.istft(pred_stft_istft)
        
        min_len = min(pred_wav.shape[-1], clean_wav.shape[-1])
        pred_wav = pred_wav[..., :min_len]
        clean_wav_trim = clean_wav[..., :min_len]
        
        losses = criterion(pred_stft, clean_stft)
        total_loss += losses['total_loss'].item()
        
        # Calculate metrics
        try:
            batch_metrics = evaluate_batch(
                clean_wav_trim, pred_wav,
                sample_rate=data_cfg['sample_rate'],
                compute_pesq=False,
                compute_stoi=True
            )
            for key in metrics:
                if key in batch_metrics:
                    metrics[key] += batch_metrics[key]
        except:
            pass
        
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    avg_metrics = {k: v / num_batches for k, v in metrics.items()}
    
    return avg_loss, avg_metrics

print("‚úÖ Training functions defined!")

In [None]:
# Main training loop
print("="*60)
print("üöÄ STARTING TRAINING")
print("="*60)
print(f"Epochs: {train_cfg['num_epochs']}")
print(f"Batch size: {train_cfg['batch_size']}")
print(f"Device: {device}")
print()

best_val_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'stoi': [], 'si_sdr': []}

for epoch in range(train_cfg['num_epochs']):
    print(f"\nEpoch {epoch + 1}/{train_cfg['num_epochs']}")
    print("-" * 40)
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, audio_processor, device, scaler)
    
    # Validate
    val_loss, val_metrics = validate(model, val_loader, criterion, audio_processor, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['stoi'].append(val_metrics.get('stoi', 0))
    history['si_sdr'].append(val_metrics.get('si_sdr', 0))
    
    # Print results
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")
    print(f"  STOI: {val_metrics.get('stoi', 0):.3f}")
    print(f"  SI-SDR: {val_metrics.get('si_sdr', 0):.2f} dB")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Check for best model
    is_best = val_loss < best_val_loss
    if is_best:
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'config': CONFIG
        }, ckpt_dir / 'best_model.pt')
        print("  ‚úÖ Saved best model!")
    else:
        patience_counter += 1
    
    # Save periodic checkpoint
    if (epoch + 1) % CONFIG['checkpoint']['save_every'] == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, ckpt_dir / f'checkpoint_epoch_{epoch+1}.pt')
    
    # Early stopping
    if patience_counter >= train_cfg['early_stopping_patience']:
        print(f"\n‚èπÔ∏è Early stopping at epoch {epoch + 1}")
        break

print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETED!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Model saved to: {ckpt_dir / 'best_model.pt'}")
print("="*60)

In [None]:
# Plot training history
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train')
axes[0, 0].plot(history['val_loss'], label='Validation')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# STOI
axes[0, 1].plot(history['stoi'])
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('STOI')
axes[0, 1].set_title('STOI (Speech Intelligibility)')
axes[0, 1].grid(True)

# SI-SDR
axes[1, 0].plot(history['si_sdr'])
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('SI-SDR (dB)')
axes[1, 0].set_title('SI-SDR (Signal Quality)')
axes[1, 0].grid(True)

# Hide empty subplot
axes[1, 1].axis('off')
axes[1, 1].text(0.5, 0.5, f'Best Val Loss: {best_val_loss:.4f}\n\n'
                f'Final STOI: {history["stoi"][-1]:.3f}\n'
                f'Final SI-SDR: {history["si_sdr"][-1]:.2f} dB',
                ha='center', va='center', fontsize=14,
                transform=axes[1, 1].transAxes)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150)
plt.show()

print("üìä Training history saved to training_history.png")

## 6Ô∏è‚É£ Test Inference

In [None]:
# Load best model and test on a sample
checkpoint = torch.load(ckpt_dir / 'best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"‚úÖ Loaded best model from epoch {checkpoint['epoch'] + 1}")
print(f"   Validation loss: {checkpoint['val_loss']:.4f}")

In [None]:
# Test on a sample from validation set
import torchaudio
from utils.audio_utils import load_audio, save_audio
import IPython.display as ipd

# Get a test sample
test_batch = next(iter(val_loader))
noisy_wav = test_batch['noisy'][0:1].to(device)
clean_wav = test_batch['clean'][0:1]
noisy_stft = test_batch['noisy_stft'][0:1].to(device)

# Denoise
with torch.no_grad():
    noisy_stft_input = noisy_stft.permute(0, 3, 1, 2)
    pred_stft = model(noisy_stft_input)
    pred_stft_out = pred_stft.permute(0, 2, 3, 1)
    denoised_wav = audio_processor.istft(pred_stft_out)

# Convert to numpy for playback
noisy_np = noisy_wav[0].cpu().numpy()
clean_np = clean_wav[0].numpy()
denoised_np = denoised_wav[0].cpu().numpy()

# Ensure same length
min_len = min(len(noisy_np), len(clean_np), len(denoised_np))
noisy_np = noisy_np[:min_len]
clean_np = clean_np[:min_len]
denoised_np = denoised_np[:min_len]

print("üéß Audio Comparison (play each to compare):")
print("\n1. Noisy Input:")
ipd.display(ipd.Audio(noisy_np, rate=data_cfg['sample_rate']))

print("\n2. Denoised Output:")
ipd.display(ipd.Audio(denoised_np, rate=data_cfg['sample_rate']))

print("\n3. Clean Reference:")
ipd.display(ipd.Audio(clean_np, rate=data_cfg['sample_rate']))

In [None]:
# Visualize spectrograms
import librosa
import librosa.display
import numpy as np

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, (audio, title) in zip(axes, [(noisy_np, 'Noisy'), (denoised_np, 'Denoised'), (clean_np, 'Clean')]):
    D = librosa.amplitude_to_db(np.abs(librosa.stft(audio)), ref=np.max)
    librosa.display.specshow(D, sr=data_cfg['sample_rate'], hop_length=128, 
                            x_axis='time', y_axis='hz', ax=ax)
    ax.set_title(title)
    ax.set_ylim(0, 8000)

plt.tight_layout()
plt.savefig('spectrogram_comparison.png', dpi=150)
plt.show()

print("üìä Spectrogram comparison saved!")

## 7Ô∏è‚É£ Download Model

Download your trained model to use locally.

In [None]:
# Download the best model
from google.colab import files

print("üì• Downloading trained model...")
files.download(str(ckpt_dir / 'best_model.pt'))
print("\n‚úÖ Download started! Check your browser downloads.")

In [None]:
# Optional: Copy to Google Drive for persistent storage
SAVE_TO_GDRIVE = False  # Set to True to save to Google Drive
GDRIVE_SAVE_PATH = "/content/drive/MyDrive/speech_denoising_models"

if SAVE_TO_GDRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    
    import shutil
    save_path = Path(GDRIVE_SAVE_PATH)
    save_path.mkdir(parents=True, exist_ok=True)
    
    # Copy model
    shutil.copy(ckpt_dir / 'best_model.pt', save_path / 'best_model.pt')
    
    # Copy training history plot
    if Path('training_history.png').exists():
        shutil.copy('training_history.png', save_path / 'training_history.png')
    
    print(f"‚úÖ Model saved to Google Drive: {save_path}")

---

## üìù Notes

- **Training time**: ~1-2 hours on Colab GPU (T4/P100) for 50 epochs
- **Memory**: Model uses ~4-6GB GPU memory with batch size 8
- **Best results**: Train for 100+ epochs with the full configuration
- **Tips**:
  - Increase batch size if you have more GPU memory
  - Use Google Drive to persist data between sessions
  - Check GPU usage with `!nvidia-smi`

## üîó Resources

- Dataset: [VoiceBank + DEMAND](https://datashare.ed.ac.uk/handle/10283/2791)
- Paper: [A Fully Convolutional Neural Network for Speech Enhancement](https://arxiv.org/abs/1609.07132)