# üéôÔ∏è Speech Denoising - Google Colab Training

Notebook n√†y cho ph√©p train model Speech Denoising U-Net tr√™n Google Colab v·ªõi GPU.

## Overview
- **Model**: U-Net v·ªõi Complex Ratio Mask (CRM)
- **Dataset**: VoiceBank + DEMAND (t·ª´ Google Drive)
- **Training Time**: ~1-2 gi·ªù tr√™n Colab GPU (T4/P100)

## C·∫•u tr√∫c Dataset tr√™n Google Drive
```
speech_denoising_data/
‚îú‚îÄ‚îÄ clean_trainset_28spk_wav/   (11,572 files)
‚îú‚îÄ‚îÄ noisy_trainset_28spk_wav/   (11,572 files)
‚îú‚îÄ‚îÄ clean_testset_wav/          (824 files)
‚îî‚îÄ‚îÄ noisy_testset_wav/          (824 files)
```

---

## 1Ô∏è‚É£ Setup Environment

In [None]:
# Check GPU availability
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("‚ö†Ô∏è No GPU detected! Go to Runtime > Change runtime type > GPU")

In [None]:
# Install dependencies
!pip install -q torch torchaudio --upgrade
!pip install -q librosa soundfile scipy numpy pandas
!pip install -q pystoi matplotlib seaborn tensorboard
!pip install -q tqdm pyyaml

# Optional: Install PESQ
!pip install -q pesq || echo "PESQ installation failed - continuing without it"

print("\n‚úÖ Dependencies installed!")

## 2Ô∏è‚É£ Mount Google Drive & Setup Dataset

In [None]:
# ============================================
# üìÅ C·∫§U H√åNH ƒê∆Ø·ªúNG D·∫™N GOOGLE DRIVE
# ============================================

# ƒê∆∞·ªùng d·∫´n t·ªõi folder ch·ª©a dataset tr√™n Google Drive
# Th∆∞ m·ª•c n√†y ch·ª©a: clean_trainset_28spk_wav, noisy_trainset_28spk_wav, etc.
GDRIVE_DATASET_FOLDER = "speech_denoising_data"

# Folder ID t·ª´ URL (backup n·∫øu c·∫ßn)
# URL: https://drive.google.com/drive/folders/1mDHfxtzvC-7kw0YXF0dFAcYlh7GAb2-
GDRIVE_FOLDER_ID = "1mDHfxtzvC-7kw0YXF0dFAcYlh7GAb2-"

print(f"üìÇ Dataset folder: {GDRIVE_DATASET_FOLDER}")

In [None]:
# Mount Google Drive
from google.colab import drive
import os
from pathlib import Path

# Mount drive
drive.mount('/content/drive')

# T·∫°o ƒë∆∞·ªùng d·∫´n ƒë·∫ßy ƒë·ªß
GDRIVE_DATASET_PATH = f"/content/drive/MyDrive/{GDRIVE_DATASET_FOLDER}"

# Verify dataset exists
if os.path.exists(GDRIVE_DATASET_PATH):
    print(f"‚úÖ Found dataset folder: {GDRIVE_DATASET_PATH}")
    print("\nüìÇ Contents:")
    for item in os.listdir(GDRIVE_DATASET_PATH):
        item_path = os.path.join(GDRIVE_DATASET_PATH, item)
        if os.path.isdir(item_path):
            count = len([f for f in os.listdir(item_path) if f.endswith('.wav')])
            print(f"   üìÅ {item}: {count} files")
else:
    print(f"‚ùå Dataset folder not found at: {GDRIVE_DATASET_PATH}")
    print("\nPlease check:")
    print(f"  1. Folder name is correct: {GDRIVE_DATASET_FOLDER}")
    print("  2. Dataset is in 'My Drive' root folder")

In [None]:
# Setup dataset paths
TRAIN_CLEAN_DIR = os.path.join(GDRIVE_DATASET_PATH, "clean_trainset_28spk_wav")
TRAIN_NOISY_DIR = os.path.join(GDRIVE_DATASET_PATH, "noisy_trainset_28spk_wav")
TEST_CLEAN_DIR = os.path.join(GDRIVE_DATASET_PATH, "clean_testset_wav")
TEST_NOISY_DIR = os.path.join(GDRIVE_DATASET_PATH, "noisy_testset_wav")

# Verify all directories
print("üìä Dataset verification:")
print("-" * 50)
for name, path in [("Train Clean", TRAIN_CLEAN_DIR), 
                   ("Train Noisy", TRAIN_NOISY_DIR),
                   ("Test Clean", TEST_CLEAN_DIR),
                   ("Test Noisy", TEST_NOISY_DIR)]:
    if os.path.exists(path):
        count = len([f for f in os.listdir(path) if f.endswith('.wav')])
        print(f"  ‚úÖ {name}: {count} files")
    else:
        print(f"  ‚ùå {name}: NOT FOUND")
print("-" * 50)

## 3Ô∏è‚É£ Define Dataset & Model Classes

In [None]:
# Audio Processor class
import torch
import torchaudio
import numpy as np
from typing import Tuple, Optional, Dict, List

class AudioProcessor:
    """Audio processing utilities: STFT, iSTFT"""
    
    def __init__(self, n_fft=512, hop_length=128, win_length=512, sample_rate=16000):
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.sample_rate = sample_rate
        self.window = torch.hann_window(win_length)
    
    def stft(self, waveform: torch.Tensor) -> torch.Tensor:
        """STFT: [batch, samples] -> [batch, freq, time, 2]"""
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        window = self.window.to(waveform.device)
        stft_out = torch.stft(
            waveform, n_fft=self.n_fft, hop_length=self.hop_length,
            win_length=self.win_length, window=window,
            return_complex=True, center=True, pad_mode='reflect'
        )
        return torch.stack([stft_out.real, stft_out.imag], dim=-1)
    
    def istft(self, stft_tensor: torch.Tensor) -> torch.Tensor:
        """iSTFT: [batch, freq, time, 2] -> [batch, samples]"""
        window = self.window.to(stft_tensor.device)
        stft_complex = torch.complex(stft_tensor[..., 0], stft_tensor[..., 1])
        return torch.istft(
            stft_complex, n_fft=self.n_fft, hop_length=self.hop_length,
            win_length=self.win_length, window=window,
            center=True, return_complex=False
        )

print("‚úÖ AudioProcessor defined!")

In [None]:
# Dataset class
from torch.utils.data import Dataset, DataLoader
import random

class VoiceBankDEMANDDataset(Dataset):
    """VoiceBank + DEMAND Dataset for Speech Denoising"""
    
    def __init__(
        self,
        clean_dir: str,
        noisy_dir: str,
        sample_rate: int = 16000,
        segment_length: int = 32000,
        n_fft: int = 512,
        hop_length: int = 128,
        win_length: int = 512,
        is_train: bool = True
    ):
        self.clean_dir = Path(clean_dir)
        self.noisy_dir = Path(noisy_dir)
        self.sample_rate = sample_rate
        self.segment_length = segment_length
        self.is_train = is_train
        
        self.audio_processor = AudioProcessor(
            n_fft=n_fft, hop_length=hop_length,
            win_length=win_length, sample_rate=sample_rate
        )
        
        # Get file list
        self.clean_files = sorted(list(self.clean_dir.glob("*.wav")))
        self.noisy_files = sorted(list(self.noisy_dir.glob("*.wav")))
        
        # Match files by name
        clean_names = {f.stem: f for f in self.clean_files}
        noisy_names = {f.stem: f for f in self.noisy_files}
        common_names = set(clean_names.keys()) & set(noisy_names.keys())
        
        self.file_pairs = [(clean_names[n], noisy_names[n]) for n in sorted(common_names)]
        print(f"  Found {len(self.file_pairs)} file pairs")
    
    def __len__(self):
        return len(self.file_pairs)
    
    def __getitem__(self, idx):
        clean_path, noisy_path = self.file_pairs[idx]
        
        # Load audio
        clean_wav, sr = torchaudio.load(clean_path)
        noisy_wav, _ = torchaudio.load(noisy_path)
        
        # Convert to mono and squeeze
        if clean_wav.shape[0] > 1:
            clean_wav = clean_wav.mean(dim=0)
        else:
            clean_wav = clean_wav.squeeze(0)
            
        if noisy_wav.shape[0] > 1:
            noisy_wav = noisy_wav.mean(dim=0)
        else:
            noisy_wav = noisy_wav.squeeze(0)
        
        # Resample if needed
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            clean_wav = resampler(clean_wav)
            noisy_wav = resampler(noisy_wav)
        
        # Random segment for training, full audio for validation
        if self.is_train and len(clean_wav) > self.segment_length:
            start = random.randint(0, len(clean_wav) - self.segment_length)
            clean_wav = clean_wav[start:start + self.segment_length]
            noisy_wav = noisy_wav[start:start + self.segment_length]
        else:
            # Pad or truncate
            if len(clean_wav) < self.segment_length:
                pad_len = self.segment_length - len(clean_wav)
                clean_wav = torch.nn.functional.pad(clean_wav, (0, pad_len))
                noisy_wav = torch.nn.functional.pad(noisy_wav, (0, pad_len))
            else:
                clean_wav = clean_wav[:self.segment_length]
                noisy_wav = noisy_wav[:self.segment_length]
        
        # Compute STFT
        clean_stft = self.audio_processor.stft(clean_wav).squeeze(0)
        noisy_stft = self.audio_processor.stft(noisy_wav).squeeze(0)
        
        return {
            'clean': clean_wav,
            'noisy': noisy_wav,
            'clean_stft': clean_stft,
            'noisy_stft': noisy_stft,
            'filename': clean_path.stem
        }

print("‚úÖ VoiceBankDEMANDDataset defined!")

In [None]:
# U-Net Model
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    """Conv block with BatchNorm and LeakyReLU"""
    def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, dropout=0.0):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
        )
    def forward(self, x):
        return self.conv(x)

class EncoderBlock(nn.Module):
    """Encoder block with downsampling"""
    def __init__(self, in_ch, out_ch, dropout=0.0):
        super().__init__()
        self.conv1 = ConvBlock(in_ch, out_ch, dropout=dropout)
        self.conv2 = ConvBlock(out_ch, out_ch, dropout=dropout)
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return self.pool(x), x

class DecoderBlock(nn.Module):
    """Decoder block with upsampling and skip connection"""
    def __init__(self, in_ch, out_ch, dropout=0.0):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv1 = ConvBlock(out_ch * 2, out_ch, dropout=dropout)
        self.conv2 = ConvBlock(out_ch, out_ch, dropout=dropout)
    
    def forward(self, x, skip):
        x = self.up(x)
        # Handle size mismatch
        if x.shape != skip.shape:
            x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        return self.conv2(x)

class AttentionBlock(nn.Module):
    """Self-attention block"""
    def __init__(self, channels):
        super().__init__()
        self.query = nn.Conv2d(channels, channels // 8, 1)
        self.key = nn.Conv2d(channels, channels // 8, 1)
        self.value = nn.Conv2d(channels, channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        B, C, H, W = x.shape
        q = self.query(x).view(B, -1, H * W).permute(0, 2, 1)
        k = self.key(x).view(B, -1, H * W)
        v = self.value(x).view(B, -1, H * W)
        
        attn = F.softmax(torch.bmm(q, k), dim=-1)
        out = torch.bmm(v, attn.permute(0, 2, 1)).view(B, C, H, W)
        return self.gamma * out + x

class UNetDenoiser(nn.Module):
    """U-Net for speech denoising with Complex Ratio Mask"""
    
    def __init__(
        self,
        in_channels=2,
        out_channels=2,
        encoder_channels=[32, 64, 128, 256, 512],
        use_attention=True,
        dropout=0.1,
        mask_type='CRM'
    ):
        super().__init__()
        self.mask_type = mask_type
        
        # Encoder
        self.encoders = nn.ModuleList()
        in_ch = in_channels
        for out_ch in encoder_channels:
            self.encoders.append(EncoderBlock(in_ch, out_ch, dropout))
            in_ch = out_ch
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            ConvBlock(encoder_channels[-1], encoder_channels[-1] * 2, dropout=dropout),
            AttentionBlock(encoder_channels[-1] * 2) if use_attention else nn.Identity(),
            ConvBlock(encoder_channels[-1] * 2, encoder_channels[-1] * 2, dropout=dropout)
        )
        
        # Decoder
        self.decoders = nn.ModuleList()
        decoder_channels = encoder_channels[::-1]
        in_ch = encoder_channels[-1] * 2
        for out_ch in decoder_channels:
            self.decoders.append(DecoderBlock(in_ch, out_ch, dropout))
            in_ch = out_ch
        
        # Output
        self.output = nn.Conv2d(encoder_channels[0], out_channels, 1)
    
    def forward(self, x):
        # Store input for mask application
        input_stft = x
        
        # Encoder
        skips = []
        for encoder in self.encoders:
            x, skip = encoder(x)
            skips.append(skip)
        
        # Bottleneck
        x = self.bottleneck(x)
        
        # Decoder
        for decoder, skip in zip(self.decoders, reversed(skips)):
            x = decoder(x, skip)
        
        # Output mask
        mask = self.output(x)
        
        # Apply Complex Ratio Mask
        if self.mask_type == 'CRM':
            mask = torch.tanh(mask)
            output = input_stft * mask
        else:
            output = mask
        
        return output
    
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

print("‚úÖ UNetDenoiser defined!")

In [None]:
# Loss functions
class MultiResolutionSTFTLoss(nn.Module):
    """Multi-resolution STFT loss"""
    def __init__(self, fft_sizes=[512, 1024, 2048], hop_sizes=[128, 256, 512], win_lengths=[512, 1024, 2048]):
        super().__init__()
        self.fft_sizes = fft_sizes
        self.hop_sizes = hop_sizes
        self.win_lengths = win_lengths
    
    def forward(self, pred, target):
        loss = 0
        for fft_size, hop_size, win_length in zip(self.fft_sizes, self.hop_sizes, self.win_lengths):
            window = torch.hann_window(win_length).to(pred.device)
            
            pred_stft = torch.stft(pred, fft_size, hop_size, win_length, window, return_complex=True)
            target_stft = torch.stft(target, fft_size, hop_size, win_length, window, return_complex=True)
            
            pred_mag = pred_stft.abs()
            target_mag = target_stft.abs()
            
            # Spectral convergence + Log magnitude loss
            loss += torch.norm(target_mag - pred_mag, p='fro') / (torch.norm(target_mag, p='fro') + 1e-8)
            loss += F.l1_loss(torch.log(pred_mag + 1e-8), torch.log(target_mag + 1e-8))
        
        return loss / len(self.fft_sizes)

class DenoiserLoss(nn.Module):
    """Combined loss for speech denoising"""
    def __init__(self, complex_weight=1.0, magnitude_weight=1.0, stft_weight=0.5, n_fft=512, hop_length=128, win_length=512, use_mr_stft=True):
        super().__init__()
        self.complex_weight = complex_weight
        self.magnitude_weight = magnitude_weight
        self.stft_weight = stft_weight
        self.use_mr_stft = use_mr_stft
        
        if use_mr_stft:
            self.mr_stft_loss = MultiResolutionSTFTLoss()
    
    def forward(self, pred_stft, target_stft, pred_wav=None, target_wav=None):
        losses = {}
        
        # Complex L1 loss
        complex_loss = F.l1_loss(pred_stft, target_stft)
        losses['complex_loss'] = complex_loss
        
        # Magnitude loss
        pred_mag = torch.sqrt(pred_stft[:, 0]**2 + pred_stft[:, 1]**2 + 1e-8)
        target_mag = torch.sqrt(target_stft[:, 0]**2 + target_stft[:, 1]**2 + 1e-8)
        magnitude_loss = F.l1_loss(pred_mag, target_mag)
        losses['magnitude_loss'] = magnitude_loss
        
        # MR-STFT loss
        stft_loss = torch.tensor(0.0, device=pred_stft.device)
        if self.use_mr_stft and pred_wav is not None and target_wav is not None:
            stft_loss = self.mr_stft_loss(pred_wav, target_wav)
            losses['stft_loss'] = stft_loss
        
        # Total loss
        total = self.complex_weight * complex_loss + self.magnitude_weight * magnitude_loss + self.stft_weight * stft_loss
        losses['total_loss'] = total
        
        return losses

print("‚úÖ Loss functions defined!")

In [None]:
# Metrics
from pystoi import stoi

def calculate_si_sdr(reference, estimation):
    """Calculate Scale-Invariant SDR"""
    reference = reference - reference.mean()
    estimation = estimation - estimation.mean()
    
    dot = (reference * estimation).sum()
    s_target = dot * reference / (reference ** 2).sum()
    e_noise = estimation - s_target
    
    si_sdr = 10 * torch.log10((s_target ** 2).sum() / ((e_noise ** 2).sum() + 1e-8) + 1e-8)
    return si_sdr.item()

def evaluate_batch(clean_wav, pred_wav, sample_rate=16000, compute_pesq=False, compute_stoi=True):
    """Evaluate batch of audio samples"""
    metrics = {'stoi': 0.0, 'si_sdr': 0.0}
    batch_size = clean_wav.shape[0]
    
    for i in range(batch_size):
        clean = clean_wav[i].cpu().numpy()
        pred = pred_wav[i].cpu().numpy()
        
        # STOI
        if compute_stoi:
            try:
                metrics['stoi'] += stoi(clean, pred, sample_rate, extended=False)
            except:
                pass
        
        # SI-SDR
        metrics['si_sdr'] += calculate_si_sdr(
            torch.from_numpy(clean), 
            torch.from_numpy(pred)
        )
    
    # Average
    for key in metrics:
        metrics[key] /= batch_size
    
    return metrics

print("‚úÖ Metrics defined!")

## 4Ô∏è‚É£ Configuration & Initialization

In [None]:
# Training configuration (optimized for Colab GPU)
CONFIG = {
    'data': {
        'sample_rate': 16000,
        'segment_length': 32000,  # 2 seconds
    },
    'stft': {
        'n_fft': 512,
        'hop_length': 128,
        'win_length': 512,
    },
    'model': {
        'encoder_channels': [32, 64, 128, 256, 512],
        'use_attention': True,
        'dropout': 0.1,
    },
    'training': {
        'batch_size': 8,        # Ph√π h·ª£p v·ªõi GPU memory
        'num_epochs': 50,       # TƒÉng l√™n 100 n·∫øu c√≥ th·ªùi gian
        'learning_rate': 0.0001,
        'weight_decay': 1e-5,
        'grad_clip': 5.0,
        'num_workers': 2,
        'use_amp': True,        # Mixed precision
        'early_stopping_patience': 10,
    },
    'scheduler': {
        'patience': 5,
        'factor': 0.5,
        'min_lr': 1e-6,
    },
    'loss': {
        'complex_weight': 1.0,
        'magnitude_weight': 1.0,
        'stft_weight': 0.5,
    },
}

print("üìã Configuration:")
print(f"  Batch size: {CONFIG['training']['batch_size']}")
print(f"  Epochs: {CONFIG['training']['num_epochs']}")
print(f"  Learning rate: {CONFIG['training']['learning_rate']}")
print(f"  Mixed precision: {CONFIG['training']['use_amp']}")

In [None]:
# Create datasets and dataloaders
print("üìÇ Loading dataset from Google Drive...")

stft_cfg = CONFIG['stft']
data_cfg = CONFIG['data']
train_cfg = CONFIG['training']

# Training dataset
print("\n  Loading training set...")
train_dataset = VoiceBankDEMANDDataset(
    clean_dir=TRAIN_CLEAN_DIR,
    noisy_dir=TRAIN_NOISY_DIR,
    sample_rate=data_cfg['sample_rate'],
    segment_length=data_cfg['segment_length'],
    n_fft=stft_cfg['n_fft'],
    hop_length=stft_cfg['hop_length'],
    win_length=stft_cfg['win_length'],
    is_train=True
)

# Validation dataset
print("  Loading validation set...")
val_dataset = VoiceBankDEMANDDataset(
    clean_dir=TEST_CLEAN_DIR,
    noisy_dir=TEST_NOISY_DIR,
    sample_rate=data_cfg['sample_rate'],
    segment_length=data_cfg['segment_length'],
    n_fft=stft_cfg['n_fft'],
    hop_length=stft_cfg['hop_length'],
    win_length=stft_cfg['win_length'],
    is_train=False
)

# Dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=train_cfg['batch_size'],
    shuffle=True,
    num_workers=train_cfg['num_workers'],
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=train_cfg['batch_size'],
    shuffle=False,
    num_workers=train_cfg['num_workers'],
    pin_memory=True
)

print(f"\n‚úÖ Data loaded!")
print(f"   Training samples: {len(train_dataset)}")
print(f"   Validation samples: {len(val_dataset)}")
print(f"   Training batches: {len(train_loader)}")
print(f"   Validation batches: {len(val_loader)}")

In [None]:
# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_cfg = CONFIG['model']

model = UNetDenoiser(
    in_channels=2,
    out_channels=2,
    encoder_channels=model_cfg['encoder_channels'],
    use_attention=model_cfg['use_attention'],
    dropout=model_cfg['dropout'],
    mask_type='CRM'
).to(device)

print(f"üß† Model: UNetDenoiser")
print(f"   Parameters: {model.count_parameters():,}")
print(f"   Device: {device}")

In [None]:
# Initialize training components
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from tqdm.notebook import tqdm

# Directories
ckpt_dir = Path('./checkpoints')
ckpt_dir.mkdir(parents=True, exist_ok=True)

# Loss function
loss_cfg = CONFIG['loss']
criterion = DenoiserLoss(
    complex_weight=loss_cfg['complex_weight'],
    magnitude_weight=loss_cfg['magnitude_weight'],
    stft_weight=loss_cfg['stft_weight'],
    use_mr_stft=True
).to(device)

# Audio processor for iSTFT
audio_processor = AudioProcessor(
    n_fft=stft_cfg['n_fft'],
    hop_length=stft_cfg['hop_length'],
    win_length=stft_cfg['win_length']
)

# Optimizer
optimizer = optim.AdamW(
    model.parameters(),
    lr=train_cfg['learning_rate'],
    weight_decay=train_cfg['weight_decay']
)

# Scheduler
scheduler_cfg = CONFIG['scheduler']
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=scheduler_cfg['factor'],
    patience=scheduler_cfg['patience'],
    min_lr=scheduler_cfg['min_lr']
)

# Mixed precision scaler
scaler = GradScaler() if train_cfg['use_amp'] else None

print("‚úÖ Training components initialized!")

## 5Ô∏è‚É£ Training

In [None]:
# Training functions
def train_epoch(model, train_loader, optimizer, criterion, audio_processor, device, scaler=None):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    num_batches = 0
    
    pbar = tqdm(train_loader, desc="Training")
    for batch in pbar:
        noisy_stft = batch['noisy_stft'].to(device)
        clean_stft = batch['clean_stft'].to(device)
        clean_wav = batch['clean'].to(device)
        
        # Reshape: [batch, freq, time, 2] -> [batch, 2, freq, time]
        noisy_stft = noisy_stft.permute(0, 3, 1, 2)
        clean_stft = clean_stft.permute(0, 3, 1, 2)
        
        optimizer.zero_grad()
        
        if scaler is not None:
            with autocast():
                pred_stft = model(noisy_stft)
                
                # Reconstruct waveform
                pred_stft_istft = pred_stft.permute(0, 2, 3, 1)
                pred_wav = audio_processor.istft(pred_stft_istft)
                
                # Ensure same length
                min_len = min(pred_wav.shape[-1], clean_wav.shape[-1])
                pred_wav = pred_wav[..., :min_len]
                clean_wav_trim = clean_wav[..., :min_len]
                
                losses = criterion(pred_stft, clean_stft, pred_wav, clean_wav_trim)
            
            scaler.scale(losses['total_loss']).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), train_cfg['grad_clip'])
            scaler.step(optimizer)
            scaler.update()
        else:
            pred_stft = model(noisy_stft)
            pred_stft_istft = pred_stft.permute(0, 2, 3, 1)
            pred_wav = audio_processor.istft(pred_stft_istft)
            
            min_len = min(pred_wav.shape[-1], clean_wav.shape[-1])
            pred_wav = pred_wav[..., :min_len]
            clean_wav_trim = clean_wav[..., :min_len]
            
            losses = criterion(pred_stft, clean_stft, pred_wav, clean_wav_trim)
            losses['total_loss'].backward()
            nn.utils.clip_grad_norm_(model.parameters(), train_cfg['grad_clip'])
            optimizer.step()
        
        total_loss += losses['total_loss'].item()
        num_batches += 1
        pbar.set_postfix({'loss': f"{losses['total_loss'].item():.4f}"})
    
    return total_loss / num_batches


@torch.no_grad()
def validate(model, val_loader, criterion, audio_processor, device):
    """Validate model"""
    model.eval()
    total_loss = 0
    metrics = {'stoi': 0, 'si_sdr': 0}
    num_batches = 0
    
    for batch in tqdm(val_loader, desc="Validating"):
        noisy_stft = batch['noisy_stft'].to(device)
        clean_stft = batch['clean_stft'].to(device)
        clean_wav = batch['clean'].to(device)
        
        noisy_stft = noisy_stft.permute(0, 3, 1, 2)
        clean_stft = clean_stft.permute(0, 3, 1, 2)
        
        pred_stft = model(noisy_stft)
        pred_stft_istft = pred_stft.permute(0, 2, 3, 1)
        pred_wav = audio_processor.istft(pred_stft_istft)
        
        min_len = min(pred_wav.shape[-1], clean_wav.shape[-1])
        pred_wav = pred_wav[..., :min_len]
        clean_wav_trim = clean_wav[..., :min_len]
        
        losses = criterion(pred_stft, clean_stft)
        total_loss += losses['total_loss'].item()
        
        # Metrics
        try:
            batch_metrics = evaluate_batch(
                clean_wav_trim, pred_wav,
                sample_rate=CONFIG['data']['sample_rate'],
                compute_stoi=True
            )
            for key in metrics:
                if key in batch_metrics:
                    metrics[key] += batch_metrics[key]
        except:
            pass
        
        num_batches += 1
    
    avg_loss = total_loss / num_batches
    avg_metrics = {k: v / num_batches for k, v in metrics.items()}
    
    return avg_loss, avg_metrics

print("‚úÖ Training functions defined!")

In [None]:
# Main training loop
print("="*60)
print("üöÄ STARTING TRAINING")
print("="*60)
print(f"Dataset: Google Drive - {GDRIVE_DATASET_FOLDER}")
print(f"Epochs: {train_cfg['num_epochs']}")
print(f"Batch size: {train_cfg['batch_size']}")
print(f"Device: {device}")
print()

best_val_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'stoi': [], 'si_sdr': []}

for epoch in range(train_cfg['num_epochs']):
    print(f"\nEpoch {epoch + 1}/{train_cfg['num_epochs']}")
    print("-" * 40)
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, audio_processor, device, scaler)
    
    # Validate
    val_loss, val_metrics = validate(model, val_loader, criterion, audio_processor, device)
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['stoi'].append(val_metrics.get('stoi', 0))
    history['si_sdr'].append(val_metrics.get('si_sdr', 0))
    
    # Print results
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")
    print(f"  STOI: {val_metrics.get('stoi', 0):.3f}")
    print(f"  SI-SDR: {val_metrics.get('si_sdr', 0):.2f} dB")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Check for best model
    is_best = val_loss < best_val_loss
    if is_best:
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'config': CONFIG
        }, ckpt_dir / 'best_model.pt')
        print("  ‚úÖ Saved best model!")
    else:
        patience_counter += 1
    
    # Save periodic checkpoint
    if (epoch + 1) % 5 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, ckpt_dir / f'checkpoint_epoch_{epoch+1}.pt')
    
    # Early stopping
    if patience_counter >= train_cfg['early_stopping_patience']:
        print(f"\n‚èπÔ∏è Early stopping at epoch {epoch + 1}")
        break

print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETED!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Model saved to: {ckpt_dir / 'best_model.pt'}")
print("="*60)

In [None]:
# Plot training history
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train')
axes[0, 0].plot(history['val_loss'], label='Validation')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# STOI
axes[0, 1].plot(history['stoi'])
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('STOI')
axes[0, 1].set_title('STOI (Speech Intelligibility)')
axes[0, 1].grid(True)

# SI-SDR
axes[1, 0].plot(history['si_sdr'])
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('SI-SDR (dB)')
axes[1, 0].set_title('SI-SDR (Signal Quality)')
axes[1, 0].grid(True)

# Summary
axes[1, 1].axis('off')
axes[1, 1].text(0.5, 0.5, f'Best Val Loss: {best_val_loss:.4f}\n\n'
                f'Final STOI: {history["stoi"][-1]:.3f}\n'
                f'Final SI-SDR: {history["si_sdr"][-1]:.2f} dB',
                ha='center', va='center', fontsize=14,
                transform=axes[1, 1].transAxes)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150)
plt.show()

print("üìä Training history saved!")

## 6Ô∏è‚É£ Test Inference

In [None]:
# Load best model
checkpoint = torch.load(ckpt_dir / 'best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"‚úÖ Loaded best model from epoch {checkpoint['epoch'] + 1}")
print(f"   Validation loss: {checkpoint['val_loss']:.4f}")

In [None]:
# Test on a sample
import IPython.display as ipd

# Get a test sample
test_batch = next(iter(val_loader))
noisy_wav = test_batch['noisy'][0:1].to(device)
clean_wav = test_batch['clean'][0:1]
noisy_stft = test_batch['noisy_stft'][0:1].to(device)

# Denoise
with torch.no_grad():
    noisy_stft_input = noisy_stft.permute(0, 3, 1, 2)
    pred_stft = model(noisy_stft_input)
    pred_stft_out = pred_stft.permute(0, 2, 3, 1)
    denoised_wav = audio_processor.istft(pred_stft_out)

# Convert to numpy
noisy_np = noisy_wav[0].cpu().numpy()
clean_np = clean_wav[0].numpy()
denoised_np = denoised_wav[0].cpu().numpy()

# Ensure same length
min_len = min(len(noisy_np), len(clean_np), len(denoised_np))
noisy_np = noisy_np[:min_len]
clean_np = clean_np[:min_len]
denoised_np = denoised_np[:min_len]

print("üéß Audio Comparison:")
print("\n1. Noisy Input:")
ipd.display(ipd.Audio(noisy_np, rate=CONFIG['data']['sample_rate']))

print("\n2. Denoised Output:")
ipd.display(ipd.Audio(denoised_np, rate=CONFIG['data']['sample_rate']))

print("\n3. Clean Reference:")
ipd.display(ipd.Audio(clean_np, rate=CONFIG['data']['sample_rate']))

In [None]:
# Visualize spectrograms
import librosa
import librosa.display

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, (audio, title) in zip(axes, [(noisy_np, 'Noisy'), (denoised_np, 'Denoised'), (clean_np, 'Clean')]):
    D = librosa.amplitude_to_db(np.abs(librosa.stft(audio)), ref=np.max)
    librosa.display.specshow(D, sr=CONFIG['data']['sample_rate'], hop_length=128, 
                            x_axis='time', y_axis='hz', ax=ax)
    ax.set_title(title)
    ax.set_ylim(0, 8000)

plt.tight_layout()
plt.savefig('spectrogram_comparison.png', dpi=150)
plt.show()

print("üìä Spectrogram comparison saved!")

## 7Ô∏è‚É£ Save Model to Google Drive

In [None]:
# Save model to Google Drive
import shutil

# T·∫°o th∆∞ m·ª•c l∆∞u model tr√™n Google Drive
GDRIVE_MODEL_SAVE_PATH = "/content/drive/MyDrive/speech_denoising_models"
save_path = Path(GDRIVE_MODEL_SAVE_PATH)
save_path.mkdir(parents=True, exist_ok=True)

# Copy best model
shutil.copy(ckpt_dir / 'best_model.pt', save_path / 'best_model.pt')

# Copy training history
if Path('training_history.png').exists():
    shutil.copy('training_history.png', save_path / 'training_history.png')
if Path('spectrogram_comparison.png').exists():
    shutil.copy('spectrogram_comparison.png', save_path / 'spectrogram_comparison.png')

print(f"‚úÖ Model saved to Google Drive: {save_path}")
print("   Files saved:")
for f in save_path.iterdir():
    print(f"   - {f.name}")

In [None]:
# Optional: Download to local machine
from google.colab import files

print("üì• Downloading trained model...")
files.download(str(ckpt_dir / 'best_model.pt'))
print("\n‚úÖ Download started! Check your browser downloads.")

---

## üìù Notes

### H∆∞·ªõng d·∫´n s·ª≠ d·ª•ng
1. Upload dataset l√™n Google Drive v·ªõi c·∫•u tr√∫c th∆∞ m·ª•c ƒë√∫ng
2. S·ª≠a `GDRIVE_DATASET_FOLDER` n·∫øu t√™n th∆∞ m·ª•c kh√°c
3. Ch·∫°y t·ª´ng cell t·ª´ ƒë·∫ßu ƒë·∫øn cu·ªëi
4. Model s·∫Ω ƒë∆∞·ª£c l∆∞u v√†o Google Drive sau khi train xong

### Training Tips
- **Th·ªùi gian**: ~1-2 gi·ªù tr√™n Colab GPU (T4) cho 50 epochs
- **Memory**: Model s·ª≠ d·ª•ng ~4-6GB GPU memory v·ªõi batch size 8
- **TƒÉng epochs**: ƒê·ªïi `num_epochs` th√†nh 100 ƒë·ªÉ c√≥ k·∫øt qu·∫£ t·ªët h∆°n
- Check GPU: `!nvidia-smi`

### Sau khi train
- Model ƒë∆∞·ª£c l∆∞u t·∫°i `./checkpoints/best_model.pt`
- Section 7 s·∫Ω copy model l√™n Google Drive ƒë·ªÉ l∆∞u tr·ªØ
- C√≥ th·ªÉ download model v·ªÅ m√°y local