In [None]:
"""
Voice Packet Reconstruction - GAN VERSION
Sharp, realistic reconstruction using GAN + L1 loss
Architecture:
- Generator: Transformer (existing model)
- Discriminator: PatchGAN (judges local patches for realism)
- Loss: L1 + Adversarial + Feature Matching

This fixes MSE blurring and produces sharp, realistic spectrograms!
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from pathlib import Path
import json
from tqdm import tqdm
import urllib.request
from IPython.display import Audio
import tarfile
import warnings
warnings.filterwarnings('ignore')

In [None]:
CONFIG = {
    'data_dir': '/kaggle/working/audio_data',
    'ljspeech_url': 'https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2',
    'checkpoint_dir': '/kaggle/working/checkpoints',
    'output_dir': '/kaggle/working/outputs',
    'input_model_path': '/kaggle/input/holomodel/pytorch/default/3/my_folder/best_model_gan.pt',
    'batch_size': 16,
    'num_epochs': 10,
    'learning_rate_g': 2e-4,
    'learning_rate_d': 5e-5,  # MUCH slower discriminator learning
    'packet_loss_min': 0.1,
    'packet_loss_max': 0.5,
    'num_files_to_use': 1000,
    'weight_decay': 0.01,
    'early_stopping_patience': 10,
    'lambda_l1': 100.0,
    'lambda_adv': 0.1,    # REDUCED from 1.0 - discriminator too strong
    'lambda_fm': 10.0,
    # Label smoothing
    'real_label': 0.9,    # Smooth real labels (instead of 1.0)
    'fake_label': 0.0,    # Keep fake labels at 0
}

In [None]:
# Generator
class PacketLossReconstructor(nn.Module):
    """
    Generator: Transformer-based reconstruction model.
    Same as before i.e., U-Net style with skip connections.
    """
    def __init__(self, n_mels=128, hidden_dim=512, num_layers=6, dropout=0.3, max_len=500):
        super().__init__()
        
        self.n_mels = n_mels
        self.max_len = max_len
        
        self.pos_encoding = nn.Parameter(torch.randn(1, max_len, n_mels) * 0.02)
        
        # Encoder
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=n_mels,
                nhead=8,
                dim_feedforward=hidden_dim,
                dropout=dropout,
                batch_first=True,
                activation='gelu'
            ) for _ in range(num_layers)
        ])
        
        # Decoder with skip connections
        self.decoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=n_mels * 2,
                nhead=8,
                dim_feedforward=hidden_dim,
                dropout=dropout,
                batch_first=True,
                activation='gelu'
            ) for _ in range(num_layers)
        ])
        
        self.skip_projections = nn.ModuleList([
            nn.Linear(n_mels * 2, n_mels) for _ in range(num_layers)
        ])
        
        self.output = nn.Sequential(
            nn.Linear(n_mels, n_mels * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(n_mels * 2, n_mels)
        )
        
    def forward(self, x, mask=None):
        seq_len = x.size(1)
        
        if seq_len > self.max_len:
            x = x[:, :self.max_len, :]
            if mask is not None:
                mask = mask[:, :self.max_len]
            seq_len = self.max_len
        
        x = x + self.pos_encoding[:, :seq_len, :]
        
        # Encode with skip connections
        skip_connections = []
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x)
            skip_connections.append(x)
        
        # Decode with skip connections
        for decoder_layer, skip_proj, skip in zip(
            self.decoder_layers, self.skip_projections, reversed(skip_connections)
        ):
            x = torch.cat([x, skip], dim=-1)
            x = decoder_layer(x)
            x = skip_proj(x)
        
        x = self.output(x)
        return x


# Discriminator
class PatchGANDiscriminator(nn.Module):
    """
    PatchGAN discriminator for spectrograms.
    Judges whether local patches are real or fake.
    Also returns intermediate features for feature matching loss.
    """
    def __init__(self, n_mels=128, ndf=64):
        super().__init__()

        # Conv layers with spectral normalization for stability
        self.conv1 = nn.utils.spectral_norm(
            nn.Conv2d(2, ndf, kernel_size=4, stride=2, padding=1)  # 2 channels: corrupted + reconstruction
        )
        
        self.conv2 = nn.utils.spectral_norm(
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)
        )
        self.bn2 = nn.BatchNorm2d(ndf * 2)
        
        self.conv3 = nn.utils.spectral_norm(
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
        )
        self.bn3 = nn.BatchNorm2d(ndf * 4)
        
        self.conv4 = nn.utils.spectral_norm(
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=1, padding=1)
        )
        self.bn4 = nn.BatchNorm2d(ndf * 8)
        
        # Output: probability map (PatchGAN)
        self.conv5 = nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=1)
        
        self.leaky_relu = nn.LeakyReLU(0.2)
        
    def forward(self, corrupted, reconstruction):
        """
        corrupted: (batch, seq_len, n_mels)
        reconstruction: (batch, seq_len, n_mels)
        
        Returns:
        - output: (batch, h, w) probability map
        - features: list of intermediate features for feature matching
        """
        # Stack corrupted and reconstruction as channels
        # (batch, seq_len, n_mels) -> (batch, 1, seq_len, n_mels)
        corrupted = corrupted.unsqueeze(1)
        reconstruction = reconstruction.unsqueeze(1)
        
        # Concatenate along channel dimension
        x = torch.cat([corrupted, reconstruction], dim=1)  # (batch, 2, seq_len, n_mels)
        
        features = []
        
        # Layer 1
        x = self.conv1(x)
        x = self.leaky_relu(x)
        features.append(x)
        
        # Layer 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.leaky_relu(x)
        features.append(x)
        
        # Layer 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.leaky_relu(x)
        features.append(x)
        
        # Layer 4
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.leaky_relu(x)
        features.append(x)
        
        # Output
        x = self.conv5(x)
        
        return x, features

In [None]:
# Dataset
class PacketLossDataset(Dataset):
    """Simulates actual packet loss (dropped time frames)."""
    
    def __init__(self, audio_dir, sample_rate=16000, n_mels=128, 
                 segment_length=128, loss_min=0.1, loss_max=0.5):
        self.audio_dir = Path(audio_dir)
        self.audio_files = list(self.audio_dir.glob("*.wav"))
        
        if len(self.audio_files) == 0:
            raise ValueError(f"No audio files found in {audio_dir}")
        
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.segment_length = segment_length
        self.loss_min = loss_min
        self.loss_max = loss_max
        
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate, n_fft=1024, hop_length=256, n_mels=n_mels
        )
        
    def __len__(self):
        return len(self.audio_files) * 15
    
    def __getitem__(self, idx):
        file_idx = idx % len(self.audio_files)
        audio_file = self.audio_files[file_idx]
        
        waveform, sr = torchaudio.load(audio_file)
        
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            waveform = resampler(waveform)
        
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)
        
        mel_spec = self.mel_transform(waveform)
        mel_spec = mel_spec.squeeze(0).t()
        mel_spec = torch.log(mel_spec + 1e-9)
        mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-8)
        mel_spec = torch.clamp(mel_spec, -3, 3)
        
        if mel_spec.shape[0] > self.segment_length:
            start_idx = np.random.randint(0, mel_spec.shape[0] - self.segment_length)
            mel_spec = mel_spec[start_idx:start_idx + self.segment_length]
        else:
            pad_length = self.segment_length - mel_spec.shape[0]
            mel_spec = F.pad(mel_spec, (0, 0, 0, pad_length))
        
        clean = mel_spec.clone()
        
        loss_rate = np.random.uniform(self.loss_min, self.loss_max)
        mask = torch.rand(self.segment_length) < loss_rate
        
        corrupted = clean.clone()
        corrupted[mask] = 0
        
        return {
            'corrupted': corrupted,
            'clean': clean,
            'mask': mask,
            'loss_rate': loss_rate
        }


# GAN Loss Functions
def adversarial_loss_g(discriminator_output, real_label=0.9):
    """Generator wants discriminator to output real_label (smoothed)."""
    return F.binary_cross_entropy_with_logits(
        discriminator_output, 
        torch.full_like(discriminator_output, real_label)
    )


def adversarial_loss_d(real_output, fake_output, real_label=0.9, fake_label=0.0):
    """Discriminator with one-sided label smoothing."""
    real_loss = F.binary_cross_entropy_with_logits(
        real_output, 
        torch.full_like(real_output, real_label)  # Smoothed real labels
    )
    fake_loss = F.binary_cross_entropy_with_logits(
        fake_output, 
        torch.full_like(fake_output, fake_label)  # Hard fake labels
    )
    return (real_loss + fake_loss) / 2


def feature_matching_loss(real_features, fake_features):
    """Match intermediate discriminator features."""
    loss = 0
    for real_feat, fake_feat in zip(real_features, fake_features):
        loss += F.l1_loss(fake_feat, real_feat.detach())
    return loss / len(real_features)


# Training
def train_epoch_gan(generator, discriminator, dataloader, optimizer_g, optimizer_d, device, config):
    """Train for one epoch with GAN losses."""
    generator.train()
    discriminator.train()
    
    total_g_loss = 0
    total_d_loss = 0
    total_l1_loss = 0
    total_adv_loss_g = 0
    total_fm_loss = 0
    
    pbar = tqdm(dataloader, desc="Training GAN")
    for i, batch in enumerate(pbar):
        corrupted = batch['corrupted'].to(device)
        clean = batch['clean'].to(device)
        mask = batch['mask'].to(device)
        
        batch_size = corrupted.size(0)
        
        # Train Discriminator (every 2 iterations to give generator a chance)
        if i % 2 == 0:  # Train D every other iteration
            optimizer_d.zero_grad()
            
            # Generate fake
            with torch.no_grad():
                fake = generator(corrupted, mask)
            
            # Discriminator on real (with label smoothing)
            real_output, _ = discriminator(corrupted, clean)
            
            # Discriminator on fake
            fake_output, _ = discriminator(corrupted, fake.detach())
            
            # Discriminator loss with label smoothing
            d_loss = adversarial_loss_d(
                real_output, fake_output, 
                real_label=config['real_label'], 
                fake_label=config['fake_label']
            )
            
            d_loss.backward()
            optimizer_d.step()
            
            total_d_loss += d_loss.item()
        
        # Train Generator (every iteration)
        optimizer_g.zero_grad()
        
        # Generate fake
        fake = generator(corrupted, mask)
        
        # L1 loss
        l1_loss = F.l1_loss(fake, clean)
        
        # Adversarial loss
        fake_output, fake_features = discriminator(corrupted, fake)
        adv_loss_g = adversarial_loss_g(fake_output, real_label=config['real_label'])
        
        # Feature matching loss
        _, real_features = discriminator(corrupted, clean)
        fm_loss = feature_matching_loss(real_features, fake_features)
        
        # Total generator loss
        g_loss = (
            config['lambda_l1'] * l1_loss +
            config['lambda_adv'] * adv_loss_g +
            config['lambda_fm'] * fm_loss
        )
        
        g_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
        optimizer_g.step()
        
        # Accumulate losses
        total_g_loss += g_loss.item()
        total_l1_loss += l1_loss.item()
        total_adv_loss_g += adv_loss_g.item()
        total_fm_loss += fm_loss.item()
        
        pbar.set_postfix({
            'G': f'{g_loss.item():.3f}',
            'D': f'{d_loss.item() if i % 2 == 0 else 0:.3f}',
            'L1': f'{l1_loss.item():.3f}',
        })
    
    # Average over full dataloader length for G, but only updates for D
    num_d_updates = len(dataloader) // 2
    
    return {
        'g_loss': total_g_loss / len(dataloader),
        'd_loss': total_d_loss / max(num_d_updates, 1),
        'l1_loss': total_l1_loss / len(dataloader),
        'adv_loss_g': total_adv_loss_g / len(dataloader),
        'fm_loss': total_fm_loss / len(dataloader),
    }


def validate_gan(generator, dataloader, device):
    """Validate generator."""
    generator.eval()
    total_loss = 0
    total_l1 = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            corrupted = batch['corrupted'].to(device)
            clean = batch['clean'].to(device)
            mask = batch['mask'].to(device)
            
            output = generator(corrupted, mask)
            
            l1 = F.l1_loss(output, clean)
            total_loss += l1.item()
            
            # Also compute masked L1
            if mask.any():
                masked_l1 = F.l1_loss(output[mask.unsqueeze(-1).expand_as(output)], 
                                     clean[mask.unsqueeze(-1).expand_as(clean)])
                total_l1 += masked_l1.item()
    
    return total_loss / len(dataloader), total_l1 / len(dataloader)


def load_model_weights(generator, discriminator, optimizer_g, optimizer_d, device):
    checkpoint_path = CONFIG['input_model_path']
    print(f"Loading checkpoint from: {checkpoint_path}")
    
    # Load to CPU first to avoid GPU OOM if mapping is weird, then map to device. Alternatively restart kernel if issues persist.
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Load model weights
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

    # Load optimizer states
    optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
    optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])

    # Resume epoch and val_loss
    start_epoch = checkpoint.get('epoch', 0) + 1
    
    # Handle case where val_loss might be None or missing in old checkpoints
    val_loss = checkpoint.get('val_loss', float('inf'))

    print(f"Resuming training from epoch {start_epoch}, previous val_loss={val_loss}")

    return start_epoch, val_loss

In [None]:
# Data preparation
class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)


def download_ljspeech(data_dir, url, num_files=1000):
    data_dir = Path(data_dir)
    data_dir.mkdir(parents=True, exist_ok=True)
    
    print("="*60)
    print("Downloading LJSpeech Dataset")
    print("="*60)
    
    tar_path = data_dir.parent / "LJSpeech-1.1.tar.bz2"
    extract_dir = data_dir.parent / "LJSpeech-1.1"
    
    if (extract_dir / "wavs").exists() and len(list((extract_dir / "wavs").glob("*.wav"))) > 0:
        print(f"✓ Dataset already exists")
        wav_files = sorted(list((extract_dir / "wavs").glob("*.wav")))[:num_files]
        
        import shutil
        print(f"Copying {len(wav_files)} files...")
        for i, wav_file in enumerate(wav_files):
            shutil.copy(wav_file, data_dir / wav_file.name)
            if (i + 1) % 100 == 0:
                print(f"  Copied {i + 1}/{len(wav_files)}")
        
        print(f"✓ Ready: {len(wav_files)} files")
        return
    
    if not tar_path.exists():
        print(f"\nDownloading (~2.6 GB)...")
        with DownloadProgressBar(unit='B', unit_scale=True, miniters=1) as t:
            urllib.request.urlretrieve(url, filename=tar_path, reporthook=t.update_to)
    
    print(f"\nExtracting...")
    with tarfile.open(tar_path, 'r:bz2') as tar:
        tar.extractall(data_dir.parent)
    
    wav_files = sorted(list((extract_dir / "wavs").glob("*.wav")))[:num_files]
    
    import shutil
    for i, wav_file in enumerate(wav_files):
        shutil.copy(wav_file, data_dir / wav_file.name)
        if (i + 1) % 100 == 0:
            print(f"  Copied {i + 1}/{len(wav_files)}")
    
    tar_path.unlink()
    print(f"\n✓ Ready: {len(wav_files)} files")

# Inference (Same as before but now uses generator)
def reconstruct_and_visualize(generator, audio_path, loss_rate, device, output_dir):
    generator.eval()
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    waveform, sr = torchaudio.load(audio_path)
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=16000, n_fft=1024, hop_length=256, n_mels=128
    )
    
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(sr, 16000)
        waveform = resampler(waveform)
    
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    mel_spec = mel_transform(waveform).squeeze(0).t()
    mel_spec = torch.log(mel_spec + 1e-9)
    
    mel_mean = mel_spec.mean()
    mel_std = mel_spec.std()
    
    mel_spec = (mel_spec - mel_mean) / (mel_std + 1e-8)
    mel_spec = torch.clamp(mel_spec, -3, 3)
    
    max_len = 500
    if mel_spec.shape[0] > max_len:
        mel_spec = mel_spec[:max_len]
    
    clean = mel_spec.clone()
    
    mask = torch.rand(mel_spec.shape[0]) < loss_rate
    corrupted = clean.clone()
    corrupted[mask] = 0
    
    with torch.no_grad():
        corrupted_input = corrupted.unsqueeze(0).to(device)
        mask_input = mask.unsqueeze(0).to(device)
        reconstructed = generator(corrupted_input, mask_input)
        reconstructed = reconstructed.squeeze(0).cpu()
    
    # Convert to audio using Griffin-Lim
    inverse_mel = torchaudio.transforms.InverseMelScale(n_stft=513, n_mels=80, sample_rate=16000)
    # griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, hop_length=256, n_iter=64)
    griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, hop_length=256, n_iter=128)  # Increase from 64 to 128 for better audio quality, but still need a vocoder
    
    def mel_to_audio(mel, mel_mean, mel_std):
        mel = mel * mel_std + mel_mean
        mel = torch.exp(mel) - 1e-9
        mel = mel.clamp(min=0).t().unsqueeze(0)
        linear = inverse_mel(mel)
        audio = griffin_lim(linear)
        return audio
    
    clean_audio = mel_to_audio(clean, mel_mean, mel_std)
    corrupted_audio = mel_to_audio(corrupted, mel_mean, mel_std)
    reconstructed_audio = mel_to_audio(reconstructed, mel_mean, mel_std)
    
    base_name = Path(audio_path).stem
    loss_pct = int(loss_rate * 100)
    
    torchaudio.save(str(output_dir / f"{base_name}_clean_gan.wav"), clean_audio, 16000)
    torchaudio.save(str(output_dir / f"{base_name}_corrupted_{loss_pct}pct_gan.wav"), corrupted_audio, 16000)
    torchaudio.save(str(output_dir / f"{base_name}_reconstructed_{loss_pct}pct_gan.wav"), reconstructed_audio, 16000)
    
    # Metrics
    if mask.sum() > 0:
        l1_corrupted = F.l1_loss(corrupted[mask], clean[mask]).item()
        l1_reconstructed = F.l1_loss(reconstructed[mask], clean[mask]).item()
        improvement = ((l1_corrupted - l1_reconstructed) / l1_corrupted) * 100
    else:
        l1_corrupted = 0.0
        l1_reconstructed = 0.0
        improvement = 0.0
    
    l1_full = F.l1_loss(reconstructed, clean).item()
    signal_power = (clean ** 2).mean().item()
    noise_power = ((reconstructed - clean) ** 2).mean().item()
    snr = 10 * np.log10(signal_power / (noise_power + 1e-8))
    
    metrics = {
        'l1_corrupted_masked': float(l1_corrupted),
        'l1_reconstructed_masked': float(l1_reconstructed),
        'improvement_pct': float(improvement),
        'l1_full': float(l1_full),
        'snr_db': float(snr),
        'loss_rate': loss_rate,
        'packets_lost': int(mask.sum()),
        'total_packets': len(mask),
        'model': 'GAN'
    }
    
    with open(output_dir / f"{base_name}_metrics_{loss_pct}pct_gan.json", 'w') as f:
        json.dump(metrics, f, indent=2)
    
    # Visualization
    fig, axes = plt.subplots(4, 1, figsize=(14, 12))
    
    axes[0].imshow(clean.numpy().T, aspect='auto', origin='lower', cmap='viridis', vmin=-2, vmax=2)
    axes[0].set_title('Clean Audio', fontsize=14, fontweight='bold')
    
    axes[1].imshow(corrupted.numpy().T, aspect='auto', origin='lower', cmap='viridis', vmin=-2, vmax=2)
    axes[1].set_title(f'Corrupted ({loss_pct}% Packet Loss)', fontsize=14, fontweight='bold')
    for i, is_lost in enumerate(mask):
        if is_lost:
            axes[1].axvline(x=i, color='red', alpha=0.3, linewidth=0.5)
    
    axes[2].imshow(reconstructed.numpy().T, aspect='auto', origin='lower', cmap='viridis', vmin=-2, vmax=2)
    axes[2].set_title(f'Reconstructed (GAN) - SNR: {snr:.1f} dB', fontsize=14, fontweight='bold')
    
    error = torch.abs(reconstructed - clean)
    im = axes[3].imshow(error.numpy().T, aspect='auto', origin='lower', cmap='hot')
    axes[3].set_title(f'Error (Improvement: {improvement:.1f}%)', fontsize=14, fontweight='bold')
    axes[3].set_xlabel('Time Frames')
    plt.colorbar(im, ax=axes[3])
    
    for ax in axes:
        ax.set_ylabel('Mel Bins')
    
    plt.tight_layout()
    plt.savefig(output_dir / f"{base_name}_comparison_{loss_pct}pct_gan.png", dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"\n✓ {base_name} @ {loss_pct}% loss")
    print(f"  L1 (masked): {l1_corrupted:.4f} → {l1_reconstructed:.4f}")
    print(f"  Improvement: {improvement:.1f}%")
    print(f"  SNR: {snr:.1f} dB")
    
    return metrics

In [None]:
def main(resume_training=False):
    print("="*60)
    print("Voice Packet Reconstruction - GAN VERSION")
    print("Sharp, Realistic Reconstruction!")
    print("="*60)
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nDevice: {device}")
    
    # Download data
    print("\n" + "="*60)
    print("STEP 1: Data")
    print("="*60)
    download_ljspeech(CONFIG['data_dir'], CONFIG['ljspeech_url'], CONFIG['num_files_to_use'])
    
    # Dataset
    print("\n" + "="*60)
    print("STEP 2: Dataset")
    print("="*60)
    full_dataset = PacketLossDataset(
        audio_dir=CONFIG['data_dir'],
        loss_min=CONFIG['packet_loss_min'],
        loss_max=CONFIG['packet_loss_max']
    )
    
    train_size = int(0.9 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2)
    
    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")

    start_epoch = 0
    
    # Models
    print("\n" + "="*60)
    print("STEP 3: Models")
    print("="*60)
    
    generator = PacketLossReconstructor(
        n_mels=128, hidden_dim=512, num_layers=6, dropout=0.3, max_len=500
    ).to(device)
    
    discriminator = PatchGANDiscriminator(
        n_mels=128, ndf=64
    ).to(device)
    
    num_params_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
    num_params_d = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
    
    print(f"Generator parameters: {num_params_g:,}")
    print(f"Discriminator parameters: {num_params_d:,}")
    print(f"Loss: L1 + Adversarial + Feature Matching")
    
    # Optimizers
    optimizer_g = torch.optim.Adam(
        generator.parameters(), 
        lr=CONFIG['learning_rate_g'], 
        betas=(0.5, 0.999)
    )
    optimizer_d = torch.optim.Adam(
        discriminator.parameters(), 
        lr=CONFIG['learning_rate_d'], 
        betas=(0.5, 0.999)
    )
    
    # Training
    print("\n" + "="*60)
    print("STEP 4: Training GAN")
    print("="*60)
    
    Path(CONFIG['checkpoint_dir']).mkdir(parents=True, exist_ok=True)
    best_val_loss = float('inf')

    if resume_training:
        print("\nResuming from checkpoint...")
        # Pass the instances created above into the function
        start_epoch, resume_loss = load_model_weights(
            generator, discriminator, optimizer_g, optimizer_d, device
        )
        if resume_loss is not None:
            best_val_loss = resume_loss

    # Schedulers
    scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, mode='min', factor=0.5, patience=3)
    scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, mode='min', factor=0.5, patience=3)

    patience_counter = 0
    history = {'g_loss': [], 'd_loss': [], 'val_loss': []}

    if start_epoch >= CONFIG['num_epochs']:
        print(f"Training already completed ({start_epoch} epochs). Increase num_epochs to continue.")
        return
    
    # for epoch in range(CONFIG['num_epochs']):
    for epoch in range(start_epoch, CONFIG['num_epochs']):
        print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
        
        train_metrics = train_epoch_gan(
            generator, discriminator, train_loader, 
            optimizer_g, optimizer_d, device, CONFIG
        )
        
        val_loss, val_masked = validate_gan(generator, val_loader, device)
        
        scheduler_g.step(val_loss)
        scheduler_d.step(val_loss)
        
        print(f"Generator Loss: {train_metrics['g_loss']:.4f} "
              f"(L1: {train_metrics['l1_loss']:.4f}, "
              f"Adv: {train_metrics['adv_loss_g']:.4f}, "
              f"FM: {train_metrics['fm_loss']:.4f})")
        print(f"Discriminator Loss: {train_metrics['d_loss']:.4f}")
        print(f"Val Loss: {val_loss:.4f} (Masked: {val_masked:.4f})")
        
        history['g_loss'].append(train_metrics['g_loss'])
        history['d_loss'].append(train_metrics['d_loss'])
        history['val_loss'].append(val_loss)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_g_state_dict': optimizer_g.state_dict(),
                'optimizer_d_state_dict': optimizer_d.state_dict(),
                'val_loss': val_loss,
            }, Path(CONFIG['checkpoint_dir']) / 'best_model_gan.pt')
            print("✓ Saved!")
        else:
            patience_counter += 1
            if patience_counter >= CONFIG['early_stopping_patience']:
                print(f"\nEarly stop at epoch {epoch + 1}")
                break
    
    # Save history
    with open(Path(CONFIG['checkpoint_dir']) / 'training_history_gan.json', 'w') as f:
        json.dump(history, f, indent=2)
    
    # Plot training
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.plot(history['g_loss'], label='Generator')
    ax1.plot(history['d_loss'], label='Discriminator')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Losses')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2.plot(history['val_loss'], label='Validation L1')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.set_title('Validation Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(Path(CONFIG['checkpoint_dir']) / 'training_curves_gan.png', dpi=150)
    plt.close()
    
    # Inference
    print("\n" + "="*60)
    print("STEP 5: Inference")
    print("="*60)
    
    test_files = sorted(list(Path(CONFIG['data_dir']).glob("*.wav")))[:5]
    test_loss_rates = [0.2, 0.3, 0.4, 0.5]
    
    for audio_file in test_files:
        for loss_rate in test_loss_rates:
            reconstruct_and_visualize(generator, str(audio_file), loss_rate, device, CONFIG['output_dir'])
    
    print("\n" + "="*60)
    print("✓ DONE!")
    print(f"Best val loss: {best_val_loss:.4f}")
    print(f"Outputs: {CONFIG['output_dir']}")
    print("="*60)

In [43]:
if __name__ == '__main__':
    main(resume_training=False)

Voice Packet Reconstruction - GAN VERSION
Sharp, Realistic Reconstruction!

Device: cuda

STEP 1: Data
Downloading LJSpeech Dataset
✓ Dataset already exists
Copying 1000 files...
  Copied 100/1000
  Copied 200/1000
  Copied 300/1000
  Copied 400/1000
  Copied 500/1000
  Copied 600/1000
  Copied 700/1000
  Copied 800/1000
  Copied 900/1000
  Copied 1000/1000
✓ Ready: 1000 files

STEP 2: Dataset
Train: 13500, Val: 1500

STEP 3: Models
Generator parameters: 2,404,784
Discriminator parameters: 2,765,505
Loss: L1 + Adversarial + Feature Matching

STEP 4: Training GAN

Epoch 1/10


Training GAN: 100%|██████████| 844/844 [02:34<00:00,  5.47it/s, G=41.822, D=0.000, L1=0.380]
Validation: 100%|██████████| 94/94 [00:16<00:00,  5.55it/s]


Generator Loss: 46.3991 (L1: 0.4259, Adv: 4.0032, FM: 0.3408)
Discriminator Loss: 0.2164
Val Loss: 0.3244 (Masked: 0.7062)
✓ Saved!

Epoch 2/10


Training GAN: 100%|██████████| 844/844 [02:30<00:00,  5.62it/s, G=37.854, D=0.000, L1=0.340]
Validation: 100%|██████████| 94/94 [00:17<00:00,  5.50it/s]


Generator Loss: 40.4599 (L1: 0.3640, Adv: 6.0977, FM: 0.3449)
Discriminator Loss: 0.1651
Val Loss: 0.3104 (Masked: 0.7093)
✓ Saved!

Epoch 3/10


Training GAN: 100%|██████████| 844/844 [02:31<00:00,  5.56it/s, G=38.803, D=0.000, L1=0.346]
Validation: 100%|██████████| 94/94 [00:17<00:00,  5.50it/s]


Generator Loss: 39.7091 (L1: 0.3554, Adv: 6.9058, FM: 0.3475)
Discriminator Loss: 0.1637
Val Loss: 0.3032 (Masked: 0.7079)
✓ Saved!

Epoch 4/10


Training GAN: 100%|██████████| 844/844 [02:28<00:00,  5.70it/s, G=40.722, D=0.000, L1=0.363]
Validation: 100%|██████████| 94/94 [00:16<00:00,  5.67it/s]


Generator Loss: 39.3887 (L1: 0.3516, Adv: 7.4033, FM: 0.3490)
Discriminator Loss: 0.1633
Val Loss: 0.3018 (Masked: 0.7084)
✓ Saved!

Epoch 5/10


Training GAN: 100%|██████████| 844/844 [02:31<00:00,  5.57it/s, G=34.839, D=0.000, L1=0.307]
Validation: 100%|██████████| 94/94 [00:16<00:00,  5.56it/s]


Generator Loss: 39.0139 (L1: 0.3473, Adv: 7.8269, FM: 0.3505)
Discriminator Loss: 0.1631
Val Loss: 0.2942 (Masked: 0.7082)
✓ Saved!

Epoch 6/10


Training GAN: 100%|██████████| 844/844 [02:31<00:00,  5.58it/s, G=42.492, D=0.000, L1=0.379]
Validation: 100%|██████████| 94/94 [00:17<00:00,  5.45it/s]


Generator Loss: 38.7053 (L1: 0.3437, Adv: 8.1797, FM: 0.3520)
Discriminator Loss: 0.1629
Val Loss: 0.2918 (Masked: 0.7077)
✓ Saved!

Epoch 7/10


Training GAN: 100%|██████████| 844/844 [02:33<00:00,  5.50it/s, G=34.505, D=0.000, L1=0.303]
Validation: 100%|██████████| 94/94 [00:17<00:00,  5.52it/s]


Generator Loss: 38.4618 (L1: 0.3410, Adv: 8.4388, FM: 0.3515)
Discriminator Loss: 0.1628
Val Loss: 0.2933 (Masked: 0.7078)

Epoch 8/10


Training GAN: 100%|██████████| 844/844 [02:34<00:00,  5.47it/s, G=34.551, D=0.000, L1=0.303]
Validation: 100%|██████████| 94/94 [00:17<00:00,  5.53it/s]


Generator Loss: 38.1072 (L1: 0.3377, Adv: 8.5603, FM: 0.3486)
Discriminator Loss: 0.1628
Val Loss: 0.2792 (Masked: 0.6710)
✓ Saved!

Epoch 9/10


Training GAN: 100%|██████████| 844/844 [02:31<00:00,  5.58it/s, G=33.535, D=0.000, L1=0.303]
Validation: 100%|██████████| 94/94 [00:16<00:00,  5.58it/s]


Generator Loss: 34.9293 (L1: 0.3189, Adv: 5.8613, FM: 0.2458)
Discriminator Loss: 0.2052
Val Loss: 0.2502 (Masked: 0.5811)
✓ Saved!

Epoch 10/10


Training GAN: 100%|██████████| 844/844 [02:32<00:00,  5.53it/s, G=34.069, D=0.000, L1=0.307]
Validation: 100%|██████████| 94/94 [00:17<00:00,  5.52it/s]


Generator Loss: 33.1902 (L1: 0.2993, Adv: 7.3668, FM: 0.2524)
Discriminator Loss: 0.1634
Val Loss: 0.2330 (Masked: 0.5131)
✓ Saved!

STEP 5: Inference

✓ LJ001-0001 @ 20% loss
  L1 (masked): 0.8596 → 0.7512
  Improvement: 12.6%
  SNR: 7.2 dB

✓ LJ001-0001 @ 30% loss
  L1 (masked): 0.8534 → 0.7464
  Improvement: 12.5%
  SNR: 5.4 dB

✓ LJ001-0001 @ 40% loss
  L1 (masked): 0.7999 → 0.6874
  Improvement: 14.1%
  SNR: 5.1 dB

✓ LJ001-0001 @ 50% loss
  L1 (masked): 0.8408 → 0.7297
  Improvement: 13.2%
  SNR: 4.1 dB

✓ LJ001-0002 @ 20% loss
  L1 (masked): 0.7979 → 0.4312
  Improvement: 46.0%
  SNR: 10.5 dB

✓ LJ001-0002 @ 30% loss
  L1 (masked): 0.8033 → 0.5049
  Improvement: 37.2%
  SNR: 8.4 dB

✓ LJ001-0002 @ 40% loss
  L1 (masked): 0.8083 → 0.4614
  Improvement: 42.9%
  SNR: 8.4 dB

✓ LJ001-0002 @ 50% loss
  L1 (masked): 0.8103 → 0.4954
  Improvement: 38.9%
  SNR: 7.2 dB

✓ LJ001-0003 @ 20% loss
  L1 (masked): 0.7745 → 0.6568
  Improvement: 15.2%
  SNR: 8.0 dB

✓ LJ001-0003 @ 30% loss
  L1

Another inference

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

generator = PacketLossReconstructor(
    n_mels=128, hidden_dim=512, num_layers=6, dropout=0.3, max_len=500
).to(device)
    
discriminator = PatchGANDiscriminator(
    n_mels=128, ndf=64
).to(device)
    
num_params_g = sum(p.numel() for p in generator.parameters() if p.requires_grad)
num_params_d = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
    
print(f"Generator parameters: {num_params_g:,}")
print(f"Discriminator parameters: {num_params_d:,}")
print(f"Loss: L1 + Adversarial + Feature Matching")
    
# Optimizers
optimizer_g = torch.optim.Adam(
    generator.parameters(), 
    lr=CONFIG['learning_rate_g'], 
    betas=(0.5, 0.999)
)
optimizer_d = torch.optim.Adam(
    discriminator.parameters(), 
    lr=CONFIG['learning_rate_d'], 
    betas=(0.5, 0.999)
)


checkpoint_path = "/kaggle/working/checkpoints/best_model_gan.pt"
print(f"Loading checkpoint from: {checkpoint_path}")
    
# Load to CPU first to avoid GPU OOM if mapping is weird, then map to device
checkpoint = torch.load(checkpoint_path, map_location=device)

# Load model weights
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

# Load optimizer states
optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
    
print("\nResuming from checkpoint...")

Generator parameters: 2,404,784
Discriminator parameters: 2,765,505
Loss: L1 + Adversarial + Feature Matching
Loading checkpoint from: /kaggle/working/checkpoints/best_model_gan.pt

Resuming from checkpoint...


In [None]:
def load_hifigan(device):
    """Load HiFi-GAN - should work with 80 mels"""
    try:
        print("Loading HiFi-GAN vocoder...")
        loaded_obj = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_hifigan')
        
        if isinstance(loaded_obj, tuple):
            vocoder = loaded_obj[0]
        else:
            vocoder = loaded_obj
            
        vocoder = vocoder.to(device).eval()
        print("✓ HiFi-GAN loaded successfully!")
        return vocoder, 'hifigan'
        
    except Exception as e:
        print(f"⚠ Failed: {e}")
        return None, 'griffin-lim'


def mel_to_audio_hifigan(mel_spec, vocoder, device, original_mean, original_std):
    """
    Convert 80-mel spectrogram to audio using HiFi-GAN
    """
    # Denormalize
    mel = mel_spec * original_std + original_mean
    mel = torch.exp(mel) - 1e-9
    mel = mel.clamp(min=0)
    
    # HiFi-GAN expects (batch, 80, time) - we have (seq_len, 80)
    mel = mel.t().unsqueeze(0).to(device)  # (1, 80, seq_len)
    
    # Generate audio
    with torch.no_grad():
        audio = vocoder(mel).squeeze(1)
    
    return audio.cpu()


def reconstruct_and_visualize_80mel(model, audio_path, loss_rate, device, output_dir, 
                                   vocoder=None, vocoder_type='griffin-lim'):
    """
    Inference function for 80-mel model with HiFi-GAN support
    """
    import torchaudio
    import matplotlib.pyplot as plt
    from pathlib import Path
    import json
    import torch.nn.functional as F
    import numpy as np
    
    model.eval()
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load audio
    waveform, sr = torchaudio.load(audio_path)
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=16000, n_fft=1024, hop_length=256, n_mels=80  # 80 mels!
    )
    
    if sr != 16000:
        resampler = torchaudio.transforms.Resample(sr, 16000)
        waveform = resampler(waveform)
    
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    mel_spec = mel_transform(waveform).squeeze(0).t()
    mel_spec = torch.log(mel_spec + 1e-9)
    
    # Save original stats BEFORE normalization
    original_mean = mel_spec.mean()
    original_std = mel_spec.std()
    
    # Normalize
    mel_spec = (mel_spec - original_mean) / (original_std + 1e-8)
    mel_spec = torch.clamp(mel_spec, -3, 3)
    
    max_len = 500
    if mel_spec.shape[0] > max_len:
        mel_spec = mel_spec[:max_len]
    
    clean = mel_spec.clone()
    
    # Simulate packet loss
    mask = torch.rand(mel_spec.shape[0]) < loss_rate
    corrupted = clean.clone()
    corrupted[mask] = 0
    
    # Reconstruct
    with torch.no_grad():
        corrupted_input = corrupted.unsqueeze(0).to(device)
        mask_input = mask.unsqueeze(0).to(device)
        reconstructed = model(corrupted_input, mask_input)
        reconstructed = reconstructed.squeeze(0).cpu()
    
    # Convert to audio
    if vocoder is not None and vocoder_type == 'hifigan':
        print(f"  Using HiFi-GAN vocoder (80 mels - perfect match!)...")
        
        clean_audio = mel_to_audio_hifigan(clean, vocoder, device, original_mean, original_std)
        corrupted_audio = mel_to_audio_hifigan(corrupted, vocoder, device, original_mean, original_std)
        reconstructed_audio = mel_to_audio_hifigan(reconstructed, vocoder, device, original_mean, original_std)
        
    else:
        # Fallback to Griffin-Lim
        print("  Using Griffin-Lim vocoder...")
        inverse_mel = torchaudio.transforms.InverseMelScale(n_stft=513, n_mels=80, sample_rate=16000)
        griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, hop_length=256, n_iter=128)
        
        def mel_to_audio_gl(mel):
            mel = mel * original_std + original_mean
            mel = torch.exp(mel) - 1e-9
            mel = mel.clamp(min=0).t().unsqueeze(0)
            linear = inverse_mel(mel)
            audio = griffin_lim(linear)
            return audio
        
        clean_audio = mel_to_audio_gl(clean)
        corrupted_audio = mel_to_audio_gl(corrupted)
        reconstructed_audio = mel_to_audio_gl(reconstructed)
    
    # Save audio
    base_name = Path(audio_path).stem
    loss_pct = int(loss_rate * 100)
    
    suffix = f"_{vocoder_type}_80mel"
    
    torchaudio.save(str(output_dir / f"{base_name}_clean{suffix}.wav"), clean_audio, 16000)
    torchaudio.save(str(output_dir / f"{base_name}_corrupted_{loss_pct}pct{suffix}.wav"), corrupted_audio, 16000)
    torchaudio.save(str(output_dir / f"{base_name}_reconstructed_{loss_pct}pct{suffix}.wav"), reconstructed_audio, 16000)
    
    # Metrics
    if mask.sum() > 0:
        l1_corrupted = F.l1_loss(corrupted[mask], clean[mask]).item()
        l1_reconstructed = F.l1_loss(reconstructed[mask], clean[mask]).item()
        improvement = ((l1_corrupted - l1_reconstructed) / l1_corrupted) * 100
    else:
        l1_corrupted = 0.0
        l1_reconstructed = 0.0
        improvement = 0.0
    
    l1_full = F.l1_loss(reconstructed, clean).item()
    signal_power = (clean ** 2).mean().item()
    noise_power = ((reconstructed - clean) ** 2).mean().item()
    snr = 10 * np.log10(signal_power / (noise_power + 1e-8))
    
    metrics = {
        'l1_corrupted_masked': float(l1_corrupted),
        'l1_reconstructed_masked': float(l1_reconstructed),
        'improvement_pct': float(improvement),
        'l1_full': float(l1_full),
        'snr_db': float(snr),
        'loss_rate': loss_rate,
        'packets_lost': int(mask.sum()),
        'total_packets': len(mask),
        'vocoder': vocoder_type,
        'n_mels': 80
    }
    
    with open(output_dir / f"{base_name}_metrics_{loss_pct}pct{suffix}.json", 'w') as f:
        json.dump(metrics, f, indent=2)
    
    # Visualization
    fig, axes = plt.subplots(4, 1, figsize=(14, 12))
    
    axes[0].imshow(clean.numpy().T, aspect='auto', origin='lower', cmap='viridis', vmin=-2, vmax=2)
    axes[0].set_title('Clean Audio (80 Mel Bins)', fontsize=14, fontweight='bold')
    
    axes[1].imshow(corrupted.numpy().T, aspect='auto', origin='lower', cmap='viridis', vmin=-2, vmax=2)
    axes[1].set_title(f'Corrupted ({loss_pct}% Packet Loss)', fontsize=14, fontweight='bold')
    for i, is_lost in enumerate(mask):
        if is_lost:
            axes[1].axvline(x=i, color='red', alpha=0.3, linewidth=0.5)
    
    axes[2].imshow(reconstructed.numpy().T, aspect='auto', origin='lower', cmap='viridis', vmin=-2, vmax=2)
    axes[2].set_title(f'Reconstructed (GAN + {vocoder_type.upper()}) - SNR: {snr:.1f} dB', 
                     fontsize=14, fontweight='bold')
    
    error = torch.abs(reconstructed - clean)
    im = axes[3].imshow(error.numpy().T, aspect='auto', origin='lower', cmap='hot')
    axes[3].set_title(f'Error (Improvement: {improvement:.1f}%)', fontsize=14, fontweight='bold')
    axes[3].set_xlabel('Time Frames')
    plt.colorbar(im, ax=axes[3])
    
    for ax in axes:
        ax.set_ylabel('Mel Bins')
    
    plt.tight_layout()
    plt.savefig(output_dir / f"{base_name}_comparison_{loss_pct}pct{suffix}.png", 
               dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"\n✓ {base_name} @ {loss_pct}% loss")
    print(f"  Vocoder: {vocoder_type}")
    print(f"  L1 (masked): {l1_corrupted:.4f} → {l1_reconstructed:.4f}")
    print(f"  Improvement: {improvement:.1f}%")
    print(f"  SNR: {snr:.1f} dB")
    
    return metrics

In [50]:
vocoder, vocoder_type = load_hifigan(device)

Loading HiFi-GAN vocoder...


Downloading: "https://github.com/NVIDIA/DeepLearningExamples/zipball/torchhub" to /root/.cache/torch/hub/torchhub.zip
Downloading checkpoint from https://api.ngc.nvidia.com/v2/models/nvidia/dle/hifigan__pyt_ckpt_mode-finetune_ds-ljs22khz/versions/21.08.0_amp/files/hifigan_gen_checkpoint_10000_ft.pt


HiFi-GAN: Removing weight norm.
✓ HiFi-GAN loaded successfully!


In [55]:
test_files = sorted(list(Path(CONFIG['data_dir']).glob("*.wav")))[:5]
test_loss_rates = [0.2, 0.3, 0.4, 0.5]
    
for audio_file in test_files:
    for loss_rate in test_loss_rates:    
        reconstruct_and_visualize_80mel(generator, str(audio_file), loss_rate, device, 
                                               CONFIG['output_dir'], vocoder=vocoder, 
                                               vocoder_type=vocoder_type
        )

  Using HiFi-GAN vocoder (80 mels - perfect match!)...

✓ LJ001-0001 @ 20% loss
  Vocoder: hifigan
  L1 (masked): 0.8275 → 0.6954
  Improvement: 16.0%
  SNR: 8.2 dB
  Using HiFi-GAN vocoder (80 mels - perfect match!)...

✓ LJ001-0001 @ 30% loss
  Vocoder: hifigan
  L1 (masked): 0.8396 → 0.7322
  Improvement: 12.8%
  SNR: 6.4 dB
  Using HiFi-GAN vocoder (80 mels - perfect match!)...

✓ LJ001-0001 @ 40% loss
  Vocoder: hifigan
  L1 (masked): 0.8047 → 0.7156
  Improvement: 11.1%
  SNR: 5.0 dB
  Using HiFi-GAN vocoder (80 mels - perfect match!)...

✓ LJ001-0001 @ 50% loss
  Vocoder: hifigan
  L1 (masked): 0.8318 → 0.7296
  Improvement: 12.3%
  SNR: 4.0 dB
  Using HiFi-GAN vocoder (80 mels - perfect match!)...

✓ LJ001-0002 @ 20% loss
  Vocoder: hifigan
  L1 (masked): 0.7668 → 0.4702
  Improvement: 38.7%
  SNR: 10.2 dB
  Using HiFi-GAN vocoder (80 mels - perfect match!)...

✓ LJ001-0002 @ 30% loss
  Vocoder: hifigan
  L1 (masked): 0.7932 → 0.4912
  Improvement: 38.1%
  SNR: 8.4 dB
  Using H

In [68]:
import torch
import torchaudio
import torchaudio.transforms as T
import matplotlib.pyplot as plt
import os

def sanity_check_hifigan(audio_path, device='cuda'):
    print(f"--- Processing {audio_path} ---")
    
    # 1. Load Audio
    waveform, sr = torchaudio.load(audio_path)
    if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True)
    
    # 2. Resample to 22050 (Required for this Vocoder)
    resampler = T.Resample(sr, 22050).to(device)
    waveform_22k = resampler(waveform.to(device))
    
    # 3. Generate 80-band Mels (Required for this Vocoder)
    mel_transform = T.MelSpectrogram(
        sample_rate=22050,
        n_fft=1024,
        win_length=1024,
        hop_length=256,
        n_mels=80,
        f_min=0.0,
        f_max=8000.0,
        power=1.0,
        normalized=False
    ).to(device)
    
    mel_spec = mel_transform(waveform_22k)
    
    # 4. Log Transform
    log_mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
    
    # 5. Load Vocoder
    print("Loading SpeechBrain HiFi-GAN...")
    from speechbrain.pretrained import HIFIGAN
    vocoder = HIFIGAN.from_hparams(
        source="speechbrain/tts-hifigan-ljspeech", 
        savedir="tmpdir_vocoder",
        run_opts={"device": str(device)}
    )
    
    # 6. Vocode
    print("Vocoding...")
    with torch.no_grad():
        if log_mel_spec.dim() == 2:
            log_mel_spec = log_mel_spec.unsqueeze(0)
            
        audio_out = vocoder.decode_batch(log_mel_spec)
        
        # --- FIX FOR RUNTIME ERROR ---
        # Ensure audio is strictly (Channels, Time) -> (1, T)
        audio_out = audio_out.squeeze() # Remove all extra dims (becomes 1D)
        audio_out = audio_out.unsqueeze(0) # Add channel dim (becomes 2D)
        
        print(f"Output Audio Shape: {audio_out.shape}")

    # 7. Save
    out_path = "DEBUG_hifigan_22k_result.wav"
    torchaudio.save(out_path, audio_out.cpu(), 22050)
    print(f"✓ Saved to {out_path}")
    print("Please listen to this file specifically.")

# Run it
if __name__ == "__main__":
    # Update this path to your file
    sanity_check_hifigan("/kaggle/working/audio_data/LJ001-0001.wav")
sanity_check_hifigan("/kaggle/working/audio_data/LJ001-0001.wav")

--- Processing /kaggle/working/audio_data/LJ001-0001.wav ---
Loading SpeechBrain HiFi-GAN...
Vocoding...
Output Audio Shape: torch.Size([1, 215552])
✓ Saved to DEBUG_hifigan_22k_result.wav
Please listen to this file specifically.
--- Processing /kaggle/working/audio_data/LJ001-0001.wav ---
Loading SpeechBrain HiFi-GAN...
Vocoding...
Output Audio Shape: torch.Size([1, 215552])
✓ Saved to DEBUG_hifigan_22k_result.wav
Please listen to this file specifically.


In [56]:
!ls /kaggle/working/outputs

LJ001-0001_clean_gan.wav
LJ001-0001_clean_hifigan_80mel.wav
LJ001-0001_comparison_20pct_gan.png
LJ001-0001_comparison_20pct_hifigan_80mel.png
LJ001-0001_comparison_30pct_gan.png
LJ001-0001_comparison_30pct_hifigan_80mel.png
LJ001-0001_comparison_40pct_gan.png
LJ001-0001_comparison_40pct_hifigan_80mel.png
LJ001-0001_comparison_50pct_gan.png
LJ001-0001_comparison_50pct_hifigan_80mel.png
LJ001-0001_corrupted_20pct_gan.wav
LJ001-0001_corrupted_20pct_hifigan_80mel.wav
LJ001-0001_corrupted_30pct_gan.wav
LJ001-0001_corrupted_30pct_hifigan_80mel.wav
LJ001-0001_corrupted_40pct_gan.wav
LJ001-0001_corrupted_40pct_hifigan_80mel.wav
LJ001-0001_corrupted_50pct_gan.wav
LJ001-0001_corrupted_50pct_hifigan_80mel.wav
LJ001-0001_metrics_20pct_gan.json
LJ001-0001_metrics_20pct_hifigan_80mel.json
LJ001-0001_metrics_30pct_gan.json
LJ001-0001_metrics_30pct_hifigan_80mel.json
LJ001-0001_metrics_40pct_gan.json
LJ001-0001_metrics_40pct_hifigan_80mel.json
LJ001-0001_metrics_50pct_gan.json
LJ001-0001_metrics_50pct

In [None]:
Audio(filename="/kaggle/working/outputs/LJ001-0001_clean_gan.wav")

In [None]:
Audio(filename="/kaggle/working/outputs/LJ001-0001_corrupted_40pct_gan.wav")

In [None]:
Audio(filename="/kaggle/working/outputs/LJ001-0001_reconstructed_40pct_gan.wav")

In [97]:
!ls /kaggle/working/outputs

LJ001-0001_clean_gan.wav
LJ001-0001_clean_hifigan.wav
LJ001-0001_comparison_20pct_gan.png
LJ001-0001_comparison_30pct_gan.png
LJ001-0001_comparison_40pct_gan.png
LJ001-0001_comparison_50pct_gan.png
LJ001-0001_corrupted_20pct_gan.wav
LJ001-0001_corrupted_20pct_hifigan.wav
LJ001-0001_corrupted_30pct_gan.wav
LJ001-0001_corrupted_30pct_hifigan.wav
LJ001-0001_corrupted_40pct_gan.wav
LJ001-0001_corrupted_40pct_hifigan.wav
LJ001-0001_corrupted_50pct_gan.wav
LJ001-0001_corrupted_50pct_hifigan.wav
LJ001-0001_metrics_20pct_gan.json
LJ001-0001_metrics_20pct_hifigan.json
LJ001-0001_metrics_30pct_gan.json
LJ001-0001_metrics_30pct_hifigan.json
LJ001-0001_metrics_40pct_gan.json
LJ001-0001_metrics_40pct_hifigan.json
LJ001-0001_metrics_50pct_gan.json
LJ001-0001_metrics_50pct_hifigan.json
LJ001-0001_reconstructed_20pct_gan.wav
LJ001-0001_reconstructed_20pct_hifigan.wav
LJ001-0001_reconstructed_30pct_gan.wav
LJ001-0001_reconstructed_30pct_hifigan.wav
LJ001-0001_reconstructed_40pct_gan.wav
LJ001-0001_reco

In [11]:
!ls /kaggle/working/checkpoints

best_model_gan.pt  training_curves_gan.png  training_history_gan.json


In [19]:
import tarfile
folder_path = "/kaggle/working/checkpoints"
output_path = "/kaggle/working/checkpoints.tar.gz"

with tarfile.open(output_path, "w:gz") as tar:
    tar.add(folder_path, arcname="my_folder")

print("Created:", output_path)

Created: /kaggle/working/checkpoints.tar.gz


Corrupted 40%

<audio controls src="LJ001-0001_corrupted_40pct_gan (1).wav" title="Title"></audio>

Reconstructed

<audio controls src="LJ001-0001_reconstructed_40pct_gan (1).wav" title="Title"></audio>

Original

<audio controls src="LJ001-0001_clean.wav" title="Title"></audio>

The warbly, bubbly sound is because of Griffin Lim. Need a better vocoder the HiFi GAN does not work, either need to tweak settings or change the vocoder. With griffin lim sounds fine but the warbly-ness caused by phase shifting (that is fixed with a vocoder).
Some choppiness exists in reconstructed but can be fixed with longer training.