# üéµ Vision Transformer (ViT) Training with GAN-Based Data Augmentation

## üìã Overview
This notebook implements **Vision Transformer (ViT)** using the pre-trained `google/vit-base-patch16-224-in21k` model for music emotion recognition with **GAN-based data augmentation** to expand the DEAM dataset.

### Key Features:
- **Pre-trained ViT**: Uses `google/vit-base-patch16-224-in21k` trained on ImageNet-21k
- **Transfer Learning**: Fine-tunes large vision model on audio spectrograms
- **Conditional GAN**: Generates synthetic spectrograms conditioned on valence/arousal
- **Data Expansion**: Increases dataset size from ~1800 to 5000+ samples
- **Emotion Prediction**: Valence-Arousal (VA) continuous values

### Pipeline:
1. Load DEAM dataset and extract real spectrograms
2. Train Conditional GAN to generate synthetic spectrograms
3. Augment dataset with GAN-generated samples
4. Fine-tune pre-trained ViT model on expanded dataset
5. Evaluate on test set

## 1Ô∏è‚É£ Import Libraries

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')

# Audio processing
import librosa
import librosa.display

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

# Hugging Face Transformers
from transformers import ViTModel, ViTConfig
from PIL import Image

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
root = Path('/kaggle/input').resolve()
print(f"Root exists: {root.exists()}")
print("‚úÖ All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 2Ô∏è‚É£ Configuration & Hyperparameters

In [None]:
# ========================
# DATASET CONFIGURATION
# ========================
AUDIO_DIR = '/kaggle/input/deam-mediaeval-dataset-emotional-analysis-in-music/DEAM_audio/MEMD_audio/'
ANNOTATIONS_DIR = '/kaggle/input/deam-mediaeval-dataset-emotional-analysis-in-music/DEAM_Annotations/annotations/annotations averaged per song/song_level/'

# ========================
# AUDIO PROCESSING CONFIG
# ========================
SAMPLE_RATE = 22050          # Audio sampling rate (Hz)
DURATION = 30                # Audio clip duration (seconds)
N_MELS = 128                 # Number of mel-frequency bins
HOP_LENGTH = 512             # Hop length for STFT
N_FFT = 2048                 # FFT window size
FMIN = 20                    # Minimum frequency
FMAX = 8000                  # Maximum frequency

# ========================
# VIT PREPROCESSING CONFIG
# ========================
VIT_IMAGE_SIZE = 224         # ViT expects 224x224 images
VIT_CHANNELS = 3             # RGB channels (we'll triplicate grayscale)
IMAGENET_MEAN = [0.485, 0.456, 0.406]  # ImageNet normalization mean
IMAGENET_STD = [0.229, 0.224, 0.225]   # ImageNet normalization std

# ========================
# GAN CONFIGURATION
# ========================
LATENT_DIM = 100             # Dimension of GAN noise vector
CONDITION_DIM = 2            # Valence + Arousal
GAN_LR = 0.0002              # GAN learning rate
GAN_BETA1 = 0.5              # Adam beta1 for GAN
GAN_BETA2 = 0.999            # Adam beta2 for GAN
GAN_EPOCHS = 10              # GAN pre-training epochs
GAN_BATCH_SIZE = 24          # GAN batch size (reduced from 32 to save memory)
NUM_SYNTHETIC = 3200         # Number of synthetic samples to generate

# ========================
# VIT MODEL CONFIGURATION
# ========================
# OPTION 1: Use pre-downloaded model (recommended to avoid download issues)
VIT_MODEL_NAME = '/kaggle/input/vit-model-kaggle/vit-model-for-kaggle'  # Update with your dataset path

# OPTION 2: Fallback to online download (may fail with 500 errors)
# VIT_MODEL_NAME = 'google/vit-base-patch16-224-in21k'

# OPTION 3: Use smaller, more stable model
# VIT_MODEL_NAME = 'google/vit-base-patch16-224'

FREEZE_BACKBONE = False      # Whether to freeze ViT encoder layers
DROPOUT = 0.1                # Dropout rate

# ========================
# TRAINING CONFIGURATION
# ========================
BATCH_SIZE = 12              # Training batch size (reduced from 16 to save memory)
NUM_EPOCHS = 24              # Training epochs
LEARNING_RATE = 1e-4         # Learning rate for fine-tuning
WEIGHT_DECAY = 0.05          # AdamW weight decay
TRAIN_SPLIT = 0.8            # Train/validation split ratio

# ========================
# MEMORY OPTIMIZATION
# ========================
# If you still encounter OOM errors, try these:
# - Reduce GAN_BATCH_SIZE to 16
# - Reduce BATCH_SIZE to 8
# - Reduce NUM_SYNTHETIC to 2000
# - Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# ========================
# SYSTEM CONFIGURATION
# ========================
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
OUTPUT_DIR = '/kaggle/working/vit_augmented'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Enable memory efficient settings
if torch.cuda.is_available():
    # Enable TF32 for faster computation on Ampere GPUs
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    # Enable cudnn benchmarking for optimal performance
    torch.backends.cudnn.benchmark = True
    print(f"üöÄ CUDA optimizations enabled")

print("=" * 60)
print("üìä CONFIGURATION SUMMARY")
print("=" * 60)
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
print(f"Audio Duration: {DURATION}s @ {SAMPLE_RATE}Hz")
print(f"Mel-Spectrogram: {N_MELS} bins")
print(f"\nüñºÔ∏è ViT Configuration:")
print(f"  - Model Path: {VIT_MODEL_NAME}")
print(f"  - Input Size: {VIT_IMAGE_SIZE}x{VIT_IMAGE_SIZE}x{VIT_CHANNELS}")
print(f"  - Freeze Backbone: {FREEZE_BACKBONE}")
print(f"\nüé® GAN Configuration:")
print(f"  - Latent Dim: {LATENT_DIM}")
print(f"  - GAN Epochs: {GAN_EPOCHS}")
print(f"  - GAN Batch Size: {GAN_BATCH_SIZE}")
print(f"  - Synthetic Samples: {NUM_SYNTHETIC}")
print(f"\nüèãÔ∏è Training Configuration:")
print(f"  - Epochs: {NUM_EPOCHS}")
print(f"  - Batch Size: {BATCH_SIZE}")
print(f"  - Learning Rate: {LEARNING_RATE}")
print("=" * 60)

## 3Ô∏è‚É£ Load DEAM Dataset & Extract Real Spectrograms

In [None]:
# Load annotations
static_2000 = root / 'static-annotations-1-2000' / 'static_annotations_averaged_songs_1_2000.csv'
static_2058 = root / 'static-annots-2058' / 'static_annots_2058.csv'

try:
    df1 = pd.read_csv(static_2000)
    df2 = pd.read_csv(static_2058)
    df_annotations = pd.concat([df1, df2], axis=0)
    print(f"‚úÖ Loaded annotations: {len(df_annotations)} songs")
except Exception as e:
    print(f"‚ùå Error loading annotations: {e}")
    raise

# Clean column names
df_annotations.columns = df_annotations.columns.str.strip()

print("\\nüìä Annotation Sample:")
print(df_annotations.head())
print(f"\\nColumns: {list(df_annotations.columns)}")

# Check for audio files
audio_files = glob.glob(os.path.join(AUDIO_DIR, '*.mp3'))
print(f"\\nüéµ Found {len(audio_files)} audio files")

# Extract spectrograms with error logging
print("\\nüîä Extracting spectrograms from real audio...")

error_log = []

def extract_melspectrogram(audio_path, sr=SAMPLE_RATE, duration=DURATION):
    """Extract mel-spectrogram from audio file with error handling"""
    try:
        # Load audio
        y, _ = librosa.load(audio_path, sr=sr, duration=duration)
        
        # Compute mel-spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=y, sr=sr, n_mels=N_MELS, n_fft=N_FFT, 
            hop_length=HOP_LENGTH, fmin=FMIN, fmax=FMAX
        )
        
        # Convert to log scale (dB)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Normalize to [-1, 1]
        mel_spec_norm = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-8)
        
        return mel_spec_norm, None
    except Exception as e:
        error_msg = f"Error processing {os.path.basename(audio_path)}: {str(e)}"
        return None, error_msg

# Extract spectrograms and labels
real_spectrograms = []
real_labels = []

for idx, row in tqdm(df_annotations.iterrows(), total=len(df_annotations), desc="Extracting spectrograms"):
    song_id = str(int(row['song_id']))
    audio_path = os.path.join(AUDIO_DIR, f"{song_id}.mp3")
    
    if not os.path.exists(audio_path):
        error_log.append(f"Missing audio file: {song_id}.mp3")
        continue
    
    # Extract spectrogram
    spec, error = extract_melspectrogram(audio_path)
    
    if error is not None:
        error_log.append(error)
        continue
        
    if spec is not None:
        real_spectrograms.append(spec)
        
        # Get valence and arousal
        valence = row.get('valence_mean', row.get('valence', 0.5))
        arousal = row.get('arousal_mean', row.get('arousal', 0.5))
        
        # Normalize to [-1, 1] range
        valence_norm = (valence - 5.0) / 4.0
        arousal_norm = (arousal - 5.0) / 4.0
        
        real_labels.append([valence_norm, arousal_norm])

# Convert to numpy arrays
real_spectrograms = np.array(real_spectrograms)
real_labels = np.array(real_labels)

print(f"\\n‚úÖ Extracted {len(real_spectrograms)} spectrograms")
print(f"Spectrogram shape: {real_spectrograms.shape}")
print(f"Labels shape: {real_labels.shape}")
print(f"Spectrogram range: [{real_spectrograms.min():.2f}, {real_spectrograms.max():.2f}]")
print(f"Labels range: [{real_labels.min():.2f}, {real_labels.max():.2f}]")

if error_log:
    print(f"\\n‚ö†Ô∏è {len(error_log)} errors occurred during extraction:")
    for i, error in enumerate(error_log[:10]):  # Show first 10 errors
        print(f"  {i+1}. {error}")
    if len(error_log) > 10:
        print(f"  ... and {len(error_log) - 10} more errors")

In [None]:
# Visualize sample spectrogram
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

axes[0].imshow(real_spectrograms[0], aspect='auto', origin='lower', cmap='viridis')
axes[0].set_title(f'Sample Spectrogram\\nValence: {real_labels[0][0]:.2f}, Arousal: {real_labels[0][1]:.2f}')
axes[0].set_xlabel('Time Frames')
axes[0].set_ylabel('Mel Frequency Bins')

axes[1].scatter(real_labels[:, 0], real_labels[:, 1], alpha=0.5)
axes[1].set_xlabel('Valence (normalized)')
axes[1].set_ylabel('Arousal (normalized)')
axes[1].set_title('Valence-Arousal Distribution (Real Data)')
axes[1].grid(True, alpha=0.3)
axes[1].axhline(0, color='k', linewidth=0.5)
axes[1].axvline(0, color='k', linewidth=0.5)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'real_data_visualization.png'), dpi=150, bbox_inches='tight')
plt.show()

## 4Ô∏è‚É£ Conditional GAN Architecture

In [None]:
class ChannelAttention(nn.Module):
    """
    Memory-efficient channel attention instead of spatial self-attention.
    Reduces memory from O(H*W * H*W) to O(C*C).
    """
    def __init__(self, channels, reduction=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        batch_size, channels, _, _ = x.size()
        
        # Channel attention via global pooling
        avg_out = self.fc(self.avg_pool(x).view(batch_size, channels))
        max_out = self.fc(self.max_pool(x).view(batch_size, channels))
        
        # Combine and apply attention
        attention = (avg_out + max_out).view(batch_size, channels, 1, 1)
        return x * attention


class ImprovedSpectrogramGenerator(nn.Module):
    """Enhanced Conditional GAN Generator with Channel Attention (memory-efficient)"""
    def __init__(self, latent_dim=LATENT_DIM, condition_dim=CONDITION_DIM, 
                 n_mels=N_MELS, time_steps=1292):
        super(ImprovedSpectrogramGenerator, self).__init__()
        
        self.latent_dim = latent_dim
        self.condition_dim = condition_dim
        self.n_mels = n_mels
        self.time_steps = time_steps
        
        # Improved condition embedding
        self.condition_embed = nn.Sequential(
            nn.Linear(condition_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2)
        )
        
        # Initial projection with condition
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + 256, 256 * 16 * 20),
            nn.BatchNorm1d(256 * 16 * 20),
            nn.LeakyReLU(0.2)
        )
        
        # Convolutional upsampling with channel attention
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2)
        )
        
        # Channel attention module (memory-efficient)
        self.attention = ChannelAttention(64, reduction=8)
        
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2)
        )
        
        self.conv4 = nn.Sequential(
            nn.ConvTranspose2d(32, 1, kernel_size=(1, 8), stride=(1, 8), padding=0),
            nn.Tanh()
        )
        
    def forward(self, z, c):
        # Embed condition
        c_embed = self.condition_embed(c)
        
        # Concatenate noise and embedded condition
        x = torch.cat([z, c_embed], dim=1)
        x = self.fc(x)
        x = x.view(-1, 256, 16, 20)
        
        # Upsampling with attention
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention(x)  # Apply channel attention (much more memory efficient)
        x = self.conv3(x)
        x = self.conv4(x)
        
        # Ensure correct output size
        if x.shape[-1] != self.time_steps or x.shape[-2] != self.n_mels:
            x = F.interpolate(x, size=(self.n_mels, self.time_steps), mode='bilinear', align_corners=False)
        
        return x


class ImprovedSpectrogramDiscriminator(nn.Module):
    """Enhanced Conditional GAN Discriminator with Spectral Normalization"""
    def __init__(self, condition_dim=CONDITION_DIM, n_mels=N_MELS, time_steps=1292):
        super(ImprovedSpectrogramDiscriminator, self).__init__()
        
        self.n_mels = n_mels
        self.time_steps = time_steps
        
        # Simplified condition embedding (reduce memory)
        self.condition_embed = nn.Sequential(
            nn.Linear(condition_dim, 64),
            nn.LeakyReLU(0.2)
        )
        
        # Convolutional layers with spectral normalization for stability
        self.conv_layers = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        conv_output_size = 256 * 8 * 80
        
        # Fully connected layers with dropout and condition
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(conv_output_size + 64, 256),  # Concatenate with condition
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
            # No sigmoid - will use BCEWithLogitsLoss for better stability
        )
        
    def forward(self, spec, c):
        # Embed condition (reduced dimensionality for memory)
        batch_size = spec.size(0)
        c_embed = self.condition_embed(c)
        
        # Apply conv layers
        features = self.conv_layers(spec)
        features = features.view(features.size(0), -1)
        
        # Concatenate with condition and classify
        x = torch.cat([features, c_embed], dim=1)
        output = self.fc(x)
        return output


# Initialize improved GAN models
time_steps = real_spectrograms.shape[2]
generator = ImprovedSpectrogramGenerator(
    latent_dim=LATENT_DIM, 
    condition_dim=CONDITION_DIM, 
    n_mels=N_MELS, 
    time_steps=time_steps
).to(DEVICE)

discriminator = ImprovedSpectrogramDiscriminator(
    condition_dim=CONDITION_DIM, 
    n_mels=N_MELS, 
    time_steps=time_steps
).to(DEVICE)

print("=" * 60)
print("üé® IMPROVED GAN ARCHITECTURE (Memory-Efficient)")
print("=" * 60)
print(f"‚ú® Generator Features:")
print(f"   - Channel attention (memory-efficient)")
print(f"   - Enhanced condition embedding")
print(f"   - Parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"\n‚ú® Discriminator Features:")
print(f"   - Spectral normalization for stability")
print(f"   - Compact condition embedding")
print(f"   - Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
print("=" * 60)

# Clear cache after model initialization
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"üíæ GPU memory after models: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

## 5Ô∏è‚É£ Train Conditional GAN

In [None]:
print("\n" + "=" * 60)
print("üéÆ BALANCED GAN TRAINING CONFIGURATION")
print("=" * 60)

# GAN Training Hyperparameters
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))  # Lower LR for discriminator
criterion = nn.BCEWithLogitsLoss()

# Adaptive training parameters
D_STEPS_THRESHOLD = 0.8  # Train discriminator more if accuracy < 80%
G_STEPS_THRESHOLD = 0.2  # Train generator more if discriminator accuracy > 80%

print(f"‚öôÔ∏è  Optimizer Configuration:")
print(f"   Generator LR: {0.0002} (Adam, Œ≤1=0.5, Œ≤2=0.999)")
print(f"   Discriminator LR: {0.0001} (Adam, Œ≤1=0.5, Œ≤2=0.999)")
print(f"   Loss Function: BCEWithLogitsLoss")
print(f"\nüéØ Adaptive Training:")
print(f"   D steps if D_acc < 80%: 1-2 steps")
print(f"   G steps if D_acc > 80%: 2-3 steps")
print(f"   Gradient clipping: max_norm = 1.0")
print(f"\nüíæ Memory Optimization:")
print(f"   Gradient accumulation: 2 steps (effective batch = {GAN_BATCH_SIZE * 2})")
print(f"   Pin memory: enabled")
print(f"   Periodic cache clearing: every 5 batches")
print(f"   Data on CPU, batch transfer to GPU")
print("=" * 60)

# Extract conditions from real labels (valence, arousal)
real_conditions = real_labels.copy()  # Shape: (N, 2) - valence and arousal
print(f"\nüìä Data prepared for GAN training:")
print(f"   Real spectrograms: {real_spectrograms.shape}")
print(f"   Real conditions: {real_conditions.shape}")
print(f"   Condition range: [{real_conditions.min():.2f}, {real_conditions.max():.2f}]")

# Create memory-efficient DataLoader (data on CPU, transfer batches to GPU)
real_specs_tensor = torch.FloatTensor(real_spectrograms).unsqueeze(1)  # Keep on CPU
real_conditions_tensor = torch.FloatTensor(real_conditions)  # Keep on CPU

gan_dataset = torch.utils.data.TensorDataset(real_specs_tensor, real_conditions_tensor)
gan_loader = torch.utils.data.DataLoader(
    gan_dataset, 
    batch_size=GAN_BATCH_SIZE, 
    shuffle=True,
    pin_memory=True,  # Fast CPU->GPU transfer
    num_workers=0  # Avoid multiprocessing overhead
)

# Clear cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"üßπ Cleared GPU cache")
    print(f"üíæ Initial GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

print(f"\nüöÄ Starting Balanced GAN Training...\n")

# Training loop
gan_losses = {'g_loss': [], 'd_loss': [], 'd_real_acc': [], 'd_fake_acc': []}
GRADIENT_ACCUMULATION_STEPS = 2  # Effective batch size = GAN_BATCH_SIZE * 2

for epoch in range(GAN_EPOCHS):
    epoch_g_loss = 0
    epoch_d_loss = 0
    d_real_correct = 0
    d_fake_correct = 0
    total_samples = 0
    
    # Reset gradient accumulation
    g_optimizer.zero_grad()
    d_optimizer.zero_grad()
    
    for i, (real_specs, conditions) in enumerate(tqdm(gan_loader, desc=f"Epoch {epoch+1}/{GAN_EPOCHS}")):
        # Move batch to GPU (lazy loading)
        real_specs = real_specs.to(DEVICE)
        conditions = conditions.to(DEVICE)
        batch_size = real_specs.size(0)
        
        # Labels
        real_labels = torch.ones(batch_size, 1).to(DEVICE)
        fake_labels = torch.zeros(batch_size, 1).to(DEVICE)
        
        # ========== Train Discriminator ==========
        # Calculate discriminator accuracy for adaptive training
        with torch.no_grad():
            z_temp = torch.randn(batch_size, LATENT_DIM).to(DEVICE)
            fake_specs_temp = generator(z_temp, conditions)
            d_real_out = discriminator(real_specs, conditions)
            d_fake_out = discriminator(fake_specs_temp, conditions)
            
            d_real_acc = ((torch.sigmoid(d_real_out) > 0.5).float().mean()).item()
            d_fake_acc = ((torch.sigmoid(d_fake_out) < 0.5).float().mean()).item()
            
        # Adaptive discriminator steps
        d_steps = 1 if d_real_acc > D_STEPS_THRESHOLD and d_fake_acc > D_STEPS_THRESHOLD else 2
        
        for _ in range(d_steps):
            # Real spectrograms
            real_output = discriminator(real_specs, conditions)
            d_real_loss = criterion(real_output, real_labels)
            
            # Fake spectrograms
            z = torch.randn(batch_size, LATENT_DIM).to(DEVICE)
            fake_specs = generator(z, conditions).detach()
            fake_output = discriminator(fake_specs, conditions)
            d_fake_loss = criterion(fake_output, fake_labels)
            
            # Total discriminator loss (scaled for gradient accumulation)
            d_loss = (d_real_loss + d_fake_loss) / (2 * GRADIENT_ACCUMULATION_STEPS)
            d_loss.backward()
            
            # Update discriminator (every GRADIENT_ACCUMULATION_STEPS batches)
            if (i + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
                d_optimizer.step()
                d_optimizer.zero_grad()
        
        # ========== Train Generator ==========
        # Adaptive generator steps
        g_steps = 3 if d_real_acc > D_STEPS_THRESHOLD else 1
        
        for _ in range(g_steps):
            z = torch.randn(batch_size, LATENT_DIM).to(DEVICE)
            fake_specs = generator(z, conditions)
            fake_output = discriminator(fake_specs, conditions)
            
            # Generator loss (scaled for gradient accumulation)
            g_loss = criterion(fake_output, real_labels) / GRADIENT_ACCUMULATION_STEPS
            g_loss.backward()
            
            # Update generator (every GRADIENT_ACCUMULATION_STEPS batches)
            if (i + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
                g_optimizer.step()
                g_optimizer.zero_grad()
        
        # Track losses and accuracy
        epoch_g_loss += g_loss.item() * GRADIENT_ACCUMULATION_STEPS * g_steps
        epoch_d_loss += d_loss.item() * GRADIENT_ACCUMULATION_STEPS * 2 * d_steps
        d_real_correct += d_real_acc * batch_size
        d_fake_correct += d_fake_acc * batch_size
        total_samples += batch_size
        
        # Periodic cache clearing (every 5 batches)
        if i % 5 == 0 and i > 0:
            torch.cuda.empty_cache()
    
    # Epoch statistics
    avg_g_loss = epoch_g_loss / len(gan_loader)
    avg_d_loss = epoch_d_loss / len(gan_loader)
    avg_d_real_acc = d_real_correct / total_samples
    avg_d_fake_acc = d_fake_correct / total_samples
    
    gan_losses['g_loss'].append(avg_g_loss)
    gan_losses['d_loss'].append(avg_d_loss)
    gan_losses['d_real_acc'].append(avg_d_real_acc)
    gan_losses['d_fake_acc'].append(avg_d_fake_acc)
    
    print(f"Epoch [{epoch+1}/{GAN_EPOCHS}]")
    print(f"  G Loss: {avg_g_loss:.4f} | D Loss: {avg_d_loss:.4f}")
    print(f"  D Real Acc: {avg_d_real_acc:.2%} | D Fake Acc: {avg_d_fake_acc:.2%}")
    
    # GPU memory monitoring
    if torch.cuda.is_available():
        print(f"  üíæ GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
        if (epoch + 1) % 3 == 0:  # Deep clean every 3 epochs
            torch.cuda.empty_cache()
    print()

print("\n‚úÖ GAN Training Complete!")
print(f"Final Generator Loss: {gan_losses['g_loss'][-1]:.4f}")
print(f"Final Discriminator Loss: {gan_losses['d_loss'][-1]:.4f}")
print(f"Final D Real Accuracy: {gan_losses['d_real_acc'][-1]:.2%}")
print(f"Final D Fake Accuracy: {gan_losses['d_fake_acc'][-1]:.2%}")

## 5.5Ô∏è‚É£ GAN Quality Metrics (Functions)

In [None]:
from scipy import linalg

def calculate_statistics(spectrograms):
    """Calculate mean and covariance of spectrograms (FID-style)"""
    # Flatten spectrograms
    specs_flat = spectrograms.reshape(spectrograms.shape[0], -1)
    
    # Calculate statistics
    mu = np.mean(specs_flat, axis=0)
    sigma = np.cov(specs_flat, rowvar=False)
    
    return mu, sigma


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """
    Calculate Frechet Distance (similar to FID score).
    Lower is better - indicates generated data is closer to real data.
    """
    # Calculate mean difference
    diff = mu1 - mu2
    
    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    
    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    # Calculate FD
    fd = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    
    return fd


def evaluate_spectrogram_quality(real_specs, fake_specs, n_samples=500):
    """
    Comprehensive evaluation of generated spectrogram quality.
    
    Returns various metrics comparing real vs synthetic spectrograms.
    """
    print("üìä Evaluating GAN Generation Quality...\n")
    
    # Subsample for efficiency
    n_samples = min(n_samples, len(real_specs), len(fake_specs))
    real_sample = real_specs[:n_samples]
    fake_sample = fake_specs[:n_samples]
    
    metrics = {}
    
    # 1. Frechet Distance (FID-style)
    print("  üî¢ Computing Frechet Distance...")
    mu_real, sigma_real = calculate_statistics(real_sample)
    mu_fake, sigma_fake = calculate_statistics(fake_sample)
    fd = calculate_frechet_distance(mu_real, sigma_real, mu_fake, sigma_fake)
    metrics['frechet_distance'] = fd
    print(f"     Frechet Distance: {fd:.4f} (lower is better)")
    
    # 2. Statistical moments comparison
    print("\n  üìà Computing statistical moments...")
    real_mean = np.mean(real_sample)
    fake_mean = np.mean(fake_sample)
    real_std = np.std(real_sample)
    fake_std = np.std(fake_sample)
    
    metrics['mean_diff'] = abs(real_mean - fake_mean)
    metrics['std_diff'] = abs(real_std - fake_std)
    
    print(f"     Mean - Real: {real_mean:.4f}, Fake: {fake_mean:.4f}, Diff: {metrics['mean_diff']:.4f}")
    print(f"     Std  - Real: {real_std:.4f}, Fake: {fake_std:.4f}, Diff: {metrics['std_diff']:.4f}")
    
    # 3. Spectrogram smoothness (measure of noise)
    print("\n  üé® Evaluating smoothness (temporal consistency)...")
    real_smoothness = np.mean([np.mean(np.abs(np.diff(spec, axis=1))) for spec in real_sample])
    fake_smoothness = np.mean([np.mean(np.abs(np.diff(spec, axis=1))) for spec in fake_sample])
    
    metrics['real_smoothness'] = real_smoothness
    metrics['fake_smoothness'] = fake_smoothness
    metrics['smoothness_ratio'] = fake_smoothness / (real_smoothness + 1e-8)
    
    print(f"     Real smoothness: {real_smoothness:.4f}")
    print(f"     Fake smoothness: {fake_smoothness:.4f}")
    print(f"     Ratio: {metrics['smoothness_ratio']:.4f} (closer to 1.0 is better)")
    
    # 4. Frequency distribution analysis
    print("\n  üéµ Analyzing frequency content...")
    real_freq_mean = np.mean(real_sample, axis=(0, 2))  # Average across batch and time
    fake_freq_mean = np.mean(fake_sample, axis=(0, 2))
    
    freq_correlation = np.corrcoef(real_freq_mean, fake_freq_mean)[0, 1]
    metrics['frequency_correlation'] = freq_correlation
    
    print(f"     Frequency correlation: {freq_correlation:.4f} (higher is better)")
    
    # 5. Dynamic range
    print("\n  üìä Comparing dynamic range...")
    real_range = np.max(real_sample) - np.min(real_sample)
    fake_range = np.max(fake_sample) - np.min(fake_sample)
    
    metrics['real_range'] = real_range
    metrics['fake_range'] = fake_range
    metrics['range_diff'] = abs(real_range - fake_range)
    
    print(f"     Real range: {real_range:.4f}")
    print(f"     Fake range: {fake_range:.4f}")
    print(f"     Difference: {metrics['range_diff']:.4f}")
    
    return metrics


def visualize_quality_comparison(real_specs, fake_specs, metrics, n_visual=3):
    """Visualize quality comparison between real and synthetic spectrograms."""
    
    fig = plt.figure(figsize=(18, 12))
    gs = fig.add_gridspec(4, 3, hspace=0.3, wspace=0.3)
    
    # Top row: Sample spectrograms
    for i in range(n_visual):
        # Real spectrograms
        ax = fig.add_subplot(gs[0, i])
        ax.imshow(real_specs[i], aspect='auto', origin='lower', cmap='viridis')
        ax.set_title(f'Real Sample {i+1}')
        ax.set_xlabel('Time')
        if i == 0:
            ax.set_ylabel('Mel Frequency')
        
        # Fake spectrograms
        ax = fig.add_subplot(gs[1, i])
        ax.imshow(fake_specs[i], aspect='auto', origin='lower', cmap='viridis')
        ax.set_title(f'Generated Sample {i+1}')
        ax.set_xlabel('Time')
        if i == 0:
            ax.set_ylabel('Mel Frequency')
    
    # Middle row: Metrics visualization
    ax = fig.add_subplot(gs[2, :])
    metric_names = ['Frechet\nDistance', 'Frequency\nCorrelation', 'Smoothness\nRatio']
    metric_values = [
        metrics['frechet_distance'],
        metrics['frequency_correlation'],
        metrics['smoothness_ratio']
    ]
    colors = ['#e74c3c' if v > 10 else '#3498db' if v > 5 else '#2ecc71' 
              for v in [metrics['frechet_distance'], 
                       1-metrics['frequency_correlation'], 
                       abs(1-metrics['smoothness_ratio'])]]
    
    bars = ax.bar(metric_names, metric_values, color=colors, alpha=0.7, edgecolor='black')
    ax.set_ylabel('Metric Value')
    ax.set_title('GAN Quality Metrics')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, value in zip(bars, metric_values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{value:.3f}',
                ha='center', va='bottom', fontweight='bold')
    
    # Bottom row: Temporal evolution comparison
    ax = fig.add_subplot(gs[3, :])
    real_temporal = np.mean(real_specs[:50], axis=(0, 1))
    fake_temporal = np.mean(fake_specs[:50], axis=(0, 1))
    ax.plot(real_temporal, label='Real', linewidth=2, alpha=0.8)
    ax.plot(fake_temporal, label='Synthetic', linewidth=2, alpha=0.8)
    ax.set_xlabel('Time Frame')
    ax.set_ylabel('Average Amplitude')
    ax.set_title('Temporal Evolution Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.savefig(os.path.join(OUTPUT_DIR, 'gan_quality_evaluation.png'), dpi=150, bbox_inches='tight')
    plt.show()

print("‚úÖ Quality evaluation functions defined")

In [None]:
print(f"üé® Generating {NUM_SYNTHETIC} synthetic spectrograms...\n")

generator.eval()
synthetic_spectrograms = []
synthetic_labels = []

with torch.no_grad():
    num_batches = NUM_SYNTHETIC // GAN_BATCH_SIZE
    
    for i in tqdm(range(num_batches), desc="Generating"):
        z = torch.randn(GAN_BATCH_SIZE, LATENT_DIM).to(DEVICE)
        random_conditions = torch.FloatTensor(GAN_BATCH_SIZE, 2).uniform_(-1, 1).to(DEVICE)
        
        fake_specs = generator(z, random_conditions)
        
        synthetic_spectrograms.append(fake_specs.cpu().numpy())
        synthetic_labels.append(random_conditions.cpu().numpy())

# Concatenate all batches
synthetic_spectrograms = np.concatenate(synthetic_spectrograms, axis=0)
synthetic_labels = np.concatenate(synthetic_labels, axis=0)

# Remove channel dimension
synthetic_spectrograms = synthetic_spectrograms.squeeze(1)

print(f"‚úÖ Generated {len(synthetic_spectrograms)} synthetic spectrograms")
print(f"Synthetic spectrogram shape: {synthetic_spectrograms.shape}")
print(f"Synthetic labels shape: {synthetic_labels.shape}")

# ========== ROBUST DATA VALIDATION & CONVERSION ==========
print(f"\nüîç Validating and preparing label data...")

def prepare_labels(labels, name="labels"):
    """
    Robust label preparation function that handles all edge cases.
    Ensures output is numpy array with shape (N, 2).
    """
    try:
        # Step 1: Convert to numpy if tensor
        if torch.is_tensor(labels):
            print(f"  - {name} is a tensor, converting to numpy...")
            if labels.is_cuda:
                labels_np = labels.cpu().numpy()
            else:
                labels_np = labels.numpy()
        else:
            labels_np = np.array(labels)
        
        print(f"  - {name} shape after conversion: {labels_np.shape}")
        
        # Step 2: Handle shape issues
        if len(labels_np.shape) == 1:
            # 1D array - reshape to (N, 2)
            print(f"  - {name} is 1D, reshaping to (-1, 2)...")
            labels_np = labels_np.reshape(-1, 2)
        elif len(labels_np.shape) == 2:
            # 2D array - check if transposed
            if labels_np.shape[0] == 2 and labels_np.shape[1] > 2:
                # Likely transposed (2, N) -> (N, 2)
                print(f"  - {name} appears transposed {labels_np.shape}, fixing...")
                labels_np = labels_np.T
            elif labels_np.shape[1] == 1:
                # Shape is (N, 1) - might need to be (N//2, 2)
                print(f"  - {name} has shape {labels_np.shape}, reshaping...")
                labels_np = labels_np.reshape(-1, 2)
        
        # Step 3: Final validation
        if labels_np.shape[1] != 2:
            print(f"  ‚ö†Ô∏è WARNING: {name} has unexpected shape {labels_np.shape}")
            print(f"  Attempting to force reshape to (-1, 2)...")
            labels_np = labels_np.reshape(-1, 2)
        
        print(f"  ‚úÖ {name} final shape: {labels_np.shape}")
        return labels_np
    
    except Exception as e:
        print(f"  ‚ùå ERROR processing {name}: {e}")
        print(f"  Returning original data as-is")
        return labels if not torch.is_tensor(labels) else labels.cpu().numpy()

# Prepare real labels
real_labels_np = prepare_labels(real_labels, "real_labels")

# Prepare synthetic labels
synthetic_labels = prepare_labels(synthetic_labels, "synthetic_labels")

print(f"\n‚úÖ Data preparation complete!")
print(f"   Real labels: {real_labels_np.shape}")
print(f"   Synthetic labels: {synthetic_labels.shape}")

# ========== VISUALIZATION WITH ERROR HANDLING ==========
try:
    print(f"\nüìä Creating visualizations...")
    
    # Visualize synthetic vs real spectrograms
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    for i in range(3):
        try:
            axes[0, i].imshow(real_spectrograms[i], aspect='auto', origin='lower', cmap='viridis')
            # Robust label access
            v_real = real_labels_np[i, 0] if real_labels_np.shape[1] >= 1 else 0
            a_real = real_labels_np[i, 1] if real_labels_np.shape[1] >= 2 else 0
            axes[0, i].set_title(f'Real Spec {i+1}\nV: {v_real:.2f}, A: {a_real:.2f}')
            axes[0, i].set_xlabel('Time')
            axes[0, i].set_ylabel('Mel Bins')
        except Exception as e:
            print(f"  ‚ö†Ô∏è Warning: Could not plot real spectrogram {i}: {e}")
            axes[0, i].text(0.5, 0.5, f'Error\n{str(e)[:30]}', 
                          ha='center', va='center', transform=axes[0, i].transAxes)
    
    for i in range(3):
        try:
            axes[1, i].imshow(synthetic_spectrograms[i], aspect='auto', origin='lower', cmap='viridis')
            # Robust label access
            v_syn = synthetic_labels[i, 0] if synthetic_labels.shape[1] >= 1 else 0
            a_syn = synthetic_labels[i, 1] if synthetic_labels.shape[1] >= 2 else 0
            axes[1, i].set_title(f'Synthetic Spec {i+1}\nV: {v_syn:.2f}, A: {a_syn:.2f}')
            axes[1, i].set_xlabel('Time')
            axes[1, i].set_ylabel('Mel Bins')
        except Exception as e:
            print(f"  ‚ö†Ô∏è Warning: Could not plot synthetic spectrogram {i}: {e}")
            axes[1, i].text(0.5, 0.5, f'Error\n{str(e)[:30]}', 
                          ha='center', va='center', transform=axes[1, i].transAxes)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'real_vs_synthetic.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print("  ‚úÖ Spectrogram comparison plot saved")
    
except Exception as e:
    print(f"  ‚ùå Error creating spectrogram comparison plot: {e}")
    print("  Continuing execution...")

# ========== DISTRIBUTION COMPARISON WITH ERROR HANDLING ==========
try:
    print(f"\nüìà Creating distribution plots...")
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Scatter plot
    try:
        axes[0].scatter(real_labels_np[:, 0], real_labels_np[:, 1], 
                       alpha=0.5, label='Real', s=20)
        axes[0].scatter(synthetic_labels[:, 0], synthetic_labels[:, 1], 
                       alpha=0.3, label='Synthetic', s=20)
        axes[0].set_xlabel('Valence')
        axes[0].set_ylabel('Arousal')
        axes[0].set_title('Valence-Arousal Distribution')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        axes[0].axhline(0, color='k', linewidth=0.5)
        axes[0].axvline(0, color='k', linewidth=0.5)
    except Exception as e:
        print(f"  ‚ö†Ô∏è Warning: Could not create scatter plot: {e}")
        axes[0].text(0.5, 0.5, f'Scatter plot error\n{str(e)[:30]}', 
                    ha='center', va='center', transform=axes[0].transAxes)
    
    # Bar plot
    try:
        sizes = [len(real_spectrograms), len(synthetic_spectrograms), 
                 len(real_spectrograms) + len(synthetic_spectrograms)]
        labels = ['Real', 'Synthetic', 'Total']
        colors = ['#3498db', '#e74c3c', '#2ecc71']
        axes[1].bar(labels, sizes, color=colors, alpha=0.7, edgecolor='black')
        axes[1].set_ylabel('Number of Samples')
        axes[1].set_title('Dataset Size Comparison')
        axes[1].grid(True, alpha=0.3, axis='y')
        for i, v in enumerate(sizes):
            axes[1].text(i, v + 50, str(v), ha='center', va='bottom', fontweight='bold')
    except Exception as e:
        print(f"  ‚ö†Ô∏è Warning: Could not create bar plot: {e}")
        axes[1].text(0.5, 0.5, f'Bar plot error\n{str(e)[:30]}', 
                    ha='center', va='center', transform=axes[1].transAxes)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'augmented_dataset_comparison.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print("  ‚úÖ Distribution comparison plot saved")
    
except Exception as e:
    print(f"  ‚ùå Error creating distribution plots: {e}")
    print("  Continuing execution...")

# ========== STATISTICS ==========
try:
    print(f"\nüìä Dataset Statistics:")
    print(f"  - Real samples: {len(real_spectrograms)}")
    print(f"  - Synthetic samples: {len(synthetic_spectrograms)}")
    print(f"  - Total samples: {len(real_spectrograms) + len(synthetic_spectrograms)}")
    aug_factor = (len(real_spectrograms) + len(synthetic_spectrograms)) / len(real_spectrograms)
    print(f"  - Augmentation factor: {aug_factor:.2f}x")
except Exception as e:
    print(f"  ‚ö†Ô∏è Warning: Could not compute statistics: {e}")

# ========== UPDATE GLOBAL VARIABLES ==========
# Ensure real_labels is numpy array for downstream use
real_labels = real_labels_np

print(f"\n‚úÖ Synthetic data generation and visualization complete!")
print(f"   Proceeding to next step...")

## 6.2Ô∏è‚É£ Evaluate GAN Quality

In [None]:
# Perform GAN quality evaluation
print("\n" + "="*60)
print("üî¨ GAN QUALITY EVALUATION")
print("="*60)

# Evaluate quality
quality_metrics = evaluate_spectrogram_quality(
    real_spectrograms[:500], 
    synthetic_spectrograms[:500]
)

# Visualize comparison
visualize_quality_comparison(
    real_spectrograms[:10], 
    synthetic_spectrograms[:10],
    quality_metrics
)

# Overall quality score
print("\n" + "="*60)
print("üéØ OVERALL QUALITY ASSESSMENT")
print("="*60)

# Compute composite quality score (0-100)
fd_score = max(0, 100 - quality_metrics['frechet_distance'] * 10)  # Lower FD is better
freq_score = quality_metrics['frequency_correlation'] * 100
smooth_score = max(0, 100 - abs(1.0 - quality_metrics['smoothness_ratio']) * 100)

overall_score = (fd_score * 0.4 + freq_score * 0.4 + smooth_score * 0.2)

print(f"  Frechet Distance Score: {fd_score:.1f}/100")
print(f"  Frequency Correlation Score: {freq_score:.1f}/100")
print(f"  Smoothness Score: {smooth_score:.1f}/100")
print(f"\n  üìä Overall GAN Quality Score: {overall_score:.1f}/100")

if overall_score >= 70:
    print("  ‚úÖ Excellent - GAN generates high-quality spectrograms")
elif overall_score >= 50:
    print("  ‚ö†Ô∏è Good - GAN output is acceptable but could be improved")
else:
    print("  ‚ùå Poor - GAN needs significant improvement (mostly noise)")

print("="*60)

## 6.5Ô∏è‚É£ Audio Reconstruction - Listen to GAN Outputs

In [None]:
from IPython.display import Audio, display
import soundfile as sf

def spectrogram_to_audio(spec_normalized, sample_rate=SAMPLE_RATE, n_fft=N_FFT, 
                         hop_length=HOP_LENGTH, n_iter=32):
    """
    Convert normalized mel spectrogram back to audio using Griffin-Lim algorithm.
    
    Args:
        spec_normalized: Normalized spectrogram in range [-1, 1]
        sample_rate: Audio sample rate
        n_fft: FFT window size
        hop_length: Hop length for STFT
        n_iter: Number of Griffin-Lim iterations
    
    Returns:
        audio: Reconstructed audio signal
    """
    # Denormalize spectrogram
    spec_db = spec_normalized * 40.0  # Approximate dB range
    
    # Convert from dB to power
    spec_power = librosa.db_to_amplitude(spec_db)
    
    # Reconstruct audio using Griffin-Lim
    audio = librosa.feature.inverse.mel_to_audio(
        spec_power,
        sr=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_iter=n_iter,
        fmin=FMIN,
        fmax=FMAX
    )
    
    return audio


def generate_and_listen_to_samples(generator, n_samples=5, emotions=None, device=DEVICE):
    """
    Generate synthetic spectrograms and convert them to audio for listening.
    
    Args:
        generator: Trained GAN generator
        n_samples: Number of samples to generate
        emotions: List of (valence, arousal) tuples, or None for random
        device: Torch device
    """
    generator.eval()
    
    print(f"üéµ Generating {n_samples} audio samples from GAN...")
    
    with torch.no_grad():
        for i in range(n_samples):
            # Generate noise
            z = torch.randn(1, LATENT_DIM).to(device)
            
            # Use provided emotions or generate random
            if emotions and i < len(emotions):
                valence, arousal = emotions[i]
            else:
                valence = np.random.uniform(-1, 1)
                arousal = np.random.uniform(-1, 1)
            
            condition = torch.FloatTensor([[valence, arousal]]).to(device)
            
            # Generate spectrogram
            fake_spec = generator(z, condition)
            fake_spec_np = fake_spec.squeeze().cpu().numpy()
            
            # Convert to audio
            print(f"\\nüéß Sample {i+1}: Valence={valence:.2f}, Arousal={arousal:.2f}")
            audio = spectrogram_to_audio(fake_spec_np)
            
            # Normalize audio
            audio = audio / (np.max(np.abs(audio)) + 1e-8) * 0.9
            
            # Save audio file
            audio_path = os.path.join(OUTPUT_DIR, f'generated_sample_{i+1}_v{valence:.2f}_a{arousal:.2f}.wav')
            sf.write(audio_path, audio, SAMPLE_RATE)
            print(f"   üíæ Saved: {audio_path}")
            
            # Display audio player
            display(Audio(audio, rate=SAMPLE_RATE))
            
            # Visualize spectrogram
            plt.figure(figsize=(10, 4))
            librosa.display.specshow(
                fake_spec_np,
                sr=SAMPLE_RATE,
                hop_length=HOP_LENGTH,
                x_axis='time',
                y_axis='mel',
                fmax=FMAX,
                cmap='viridis'
            )
            plt.colorbar(format='%+2.0f dB')
            plt.title(f'Generated Spectrogram {i+1}\\nValence: {valence:.2f}, Arousal: {arousal:.2f}')
            plt.tight_layout()
            plt.savefig(os.path.join(OUTPUT_DIR, f'generated_spec_{i+1}.png'), dpi=150, bbox_inches='tight')
            plt.show()


# Define emotion targets to test
test_emotions = [
    (-0.8, -0.6),  # Sad, calm
    (0.8, 0.7),    # Happy, energetic
    (-0.3, 0.8),   # Angry, tense
    (0.5, -0.5),   # Content, relaxed
    (0.0, 0.0),    # Neutral
]

print("üé® Generating audio from synthetic spectrograms...")
print("This allows you to qualitatively assess GAN generation quality.\\n")

generate_and_listen_to_samples(
    generator, 
    n_samples=5, 
    emotions=test_emotions,
    device=DEVICE
)

print("\\n" + "="*60)
print("‚úÖ Audio generation complete!")
print("="*60)
print("üí° Tips for evaluation:")
print("  - Listen for musical structure vs pure noise")
print("  - Check if emotion patterns are perceptible")
print("  - Compare across different valence/arousal values")
print("  - Real music should have harmonic and temporal patterns")
print("="*60)

## 8Ô∏è‚É£ ViT Model Definition with Advanced Features

## 7Ô∏è‚É£ Prepare Augmented Dataset for ViT Training

In [None]:
# Combine real and synthetic spectrograms
print("üîÑ Preparing augmented dataset for ViT training...")

all_spectrograms = np.concatenate([real_spectrograms, synthetic_spectrograms], axis=0)
all_labels = np.concatenate([real_labels, synthetic_labels], axis=0)

print(f"‚úÖ Total augmented dataset:")
print(f"   - Total samples: {len(all_spectrograms)}")
print(f"   - Shape: {all_spectrograms.shape}")
print(f"   - Labels shape: {all_labels.shape}")


class SpectrogramDataset(Dataset):
    """Dataset for mel-spectrograms with ViT preprocessing."""
    
    def __init__(self, spectrograms, labels, image_size=VIT_IMAGE_SIZE):
        self.spectrograms = spectrograms
        self.labels = labels
        self.image_size = image_size
    
    def __len__(self):
        return len(self.spectrograms)
    
    def __getitem__(self, idx):
        # Get spectrogram and label
        spec = self.spectrograms[idx]  # Shape: (n_mels, time_steps)
        label = self.labels[idx]  # Shape: (2,)
        
        # Normalize spectrogram to [0, 1]
        spec_min = spec.min()
        spec_max = spec.max()
        spec_norm = (spec - spec_min) / (spec_max - spec_min + 1e-8)
        
        # Resize to ViT input size (224x224)
        spec_resized = torch.FloatTensor(spec_norm).unsqueeze(0)  # Add channel dim
        spec_resized = F.interpolate(
            spec_resized.unsqueeze(0), 
            size=(self.image_size, self.image_size), 
            mode='bilinear', 
            align_corners=False
        ).squeeze(0)
        
        # Convert to 3 channels (RGB) by triplicating
        spec_rgb = spec_resized.repeat(3, 1, 1)
        
        # Apply ImageNet normalization
        mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
        std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
        spec_normalized = (spec_rgb - mean) / std
        
        return spec_normalized, torch.FloatTensor(label)


# Create dataset
print("\nüîÑ Creating dataset...")
full_dataset = SpectrogramDataset(all_spectrograms, all_labels)

# Split into train and validation
train_size = int(TRAIN_SPLIT * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(
    full_dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"‚úÖ Datasets and dataloaders created:")
print(f"   - Train samples: {len(train_dataset)}")
print(f"   - Validation samples: {len(val_dataset)}")
print(f"   - Train batches: {len(train_loader)}")
print(f"   - Validation batches: {len(val_loader)}")

In [None]:
class ViTForEmotionRegression(nn.Module):
    """Vision Transformer for emotion regression with valence/arousal prediction."""
    
    def __init__(self, model_name=VIT_MODEL_NAME, num_emotions=2, freeze_backbone=False, dropout=0.1):
        super().__init__()
        self.model_name = model_name
        self.num_emotions = num_emotions
        
        print(f"\\nü§ñ Initializing ViT Model: {model_name}")
        
        # Load ViT model with robust error handling
        self.vit_model = self._load_vit_model(model_name)
        
        # Get the hidden size from the model configuration
        self.hidden_size = self.vit_model.config.hidden_size
        print(f"  Hidden Size: {self.hidden_size}")
        
        # Freeze backbone if requested
        if freeze_backbone:
            self._freeze_backbone()
            print("  üßä Backbone frozen")
        else:
            print("  üî• Backbone trainable")
        
        # Add custom regression head
        self.emotion_head = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, 64),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(64, self.num_emotions),
            nn.Tanh()  # Output range [-1, 1] for valence/arousal
        )
        
    def _load_vit_model(self, model_name):
        """Load ViT model with comprehensive error handling."""
        print(f"  üì• Loading model from: {model_name}")
        
        # Check if this is a local path or online model
        if os.path.exists(model_name):
            print(f"  üóÇÔ∏è Loading from local path...")
            return self._load_local_model(model_name)
        else:
            print(f"  üåê Loading from Hugging Face Hub...")
            return self._load_online_model(model_name)
    
    def _load_local_model(self, model_path):
        """Load model from local filesystem."""
        try:
            print(f"  üìÇ Checking local model at: {model_path}")
            
            # Verify the path exists and contains model files
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Model path does not exist: {model_path}")
            
            # Check for required model files
            required_files = ['config.json']
            missing_files = [f for f in required_files if not os.path.exists(os.path.join(model_path, f))]
            
            if missing_files:
                raise FileNotFoundError(f"Missing model files: {missing_files}")
            
            # Load the model
            print(f"  ‚ö° Loading ViT from local path...")
            model = ViTModel.from_pretrained(model_path, local_files_only=True)
            print(f"  ‚úÖ Successfully loaded local model!")
            return model
            
        except Exception as e:
            print(f"  ‚ùå Local model loading failed: {str(e)}")
            print(f"  üîÑ Falling back to online download...")
            return self._load_online_model('google/vit-base-patch16-224-in21k')
    
    def _load_online_model(self, model_name):
        """Load model from Hugging Face Hub with retry logic."""
        max_retries = 3
        retry_delays = [5, 10, 20]  # seconds
        
        for attempt in range(max_retries):
            try:
                print(f"  üåê Download attempt {attempt + 1}/{max_retries}...")
                
                # Try to load from cache first
                model = ViTModel.from_pretrained(
                    model_name,
                    resume_download=True,
                    force_download=False,
                    cache_dir='/kaggle/working/model_cache'
                )
                
                print(f"  ‚úÖ Successfully loaded {model_name}")
                return model
                
            except Exception as e:
                print(f"  ‚ùå Attempt {attempt + 1} failed: {str(e)}")
                
                if attempt < max_retries - 1:
                    delay = retry_delays[attempt]
                    print(f"  ‚è≥ Retrying in {delay} seconds...")
                    time.sleep(delay)
                else:
                    print(f"  üíÄ All download attempts failed!")
                    print(f"  ? SOLUTION: Download the model locally using the provided scripts:")
                    print(f"     1. Run download_vit_model.py on your local machine")
                    print(f"     2. Upload the model as a Kaggle dataset")
                    print(f"     3. Update VIT_MODEL_NAME to your dataset path")
                    raise RuntimeError(f"Failed to load ViT model after {max_retries} attempts: {str(e)}")
    
    def _freeze_backbone(self):
        """Freeze the ViT backbone parameters."""
        for param in self.vit_model.parameters():
            param.requires_grad = False
    
    def forward(self, pixel_values):
        """Forward pass through ViT + emotion head."""
        # Get ViT outputs
        outputs = self.vit_model(pixel_values=pixel_values)
        
        # Use the pooled output (CLS token representation)
        pooled_output = outputs.pooler_output
        
        # Pass through emotion regression head
        emotions = self.emotion_head(pooled_output)
        
        return emotions

## 8Ô∏è‚É£ Load Pre-trained Vision Transformer (ViT) Model

In [None]:
# Pre-download and verify model availability
print("üîç Checking model availability...")

from huggingface_hub import hf_hub_download, model_info
import time

def verify_model_download(model_name, max_retries=3):
    """
    Verify model can be downloaded or is available locally
    """
    print(f"Model: {model_name}")
    
    # Check if model exists in cache
    try:
        from transformers import ViTModel
        
        # Try loading from cache first
        try:
            print("  ‚è≥ Checking local cache...")
            ViTModel.from_pretrained(model_name, local_files_only=True)
            print("  ‚úÖ Model found in local cache!")
            return True
        except:
            print("  ‚ÑπÔ∏è Model not in cache, will download...")
        
        # Check model info
        print("  ‚è≥ Verifying model on Hugging Face Hub...")
        info = model_info(model_name)
        print(f"  ‚úì Model exists: {info.modelId}")
        print(f"  ‚úì Last modified: {info.lastModified}")
        
        return True
        
    except Exception as e:
        print(f"  ‚ö†Ô∏è Warning: {str(e)[:150]}")
        print("  üí° Will attempt to download during model initialization...")
        return False

# Verify ViT model
model_available = verify_model_download(VIT_MODEL_NAME)

if not model_available:
    print("\n‚ö†Ô∏è WARNING: Model verification failed!")
    print("The notebook will still attempt to download the model.")
    print("If download fails, the model will be initialized with random weights.")
    print("\nAlternatives:")
    print("  1. Wait and retry (Hugging Face servers may be temporarily busy)")
    print("  2. Use a smaller model: 'google/vit-base-patch16-224'")
    print("  3. Continue without pre-training (train from scratch)")
    
print("\n" + "=" * 60)

### üîß Troubleshooting: Manual Model Download (Run if needed)

If automatic download fails due to Hugging Face server issues, you can:

**Option 1: Use Alternative Model**
```python
# Use the smaller, more stable version
VIT_MODEL_NAME = 'google/vit-base-patch16-224'  # Instead of -in21k
```

**Option 2: Manual Download via Git**
```bash
# In a terminal or code cell with !
!git lfs install
!git clone https://huggingface.co/google/vit-base-patch16-224-in21k
```

**Option 3: Continue without Pre-training**
- The notebook will automatically fall back to random initialization
- Results will be similar to custom AST (no transfer learning benefits)

**Option 4: Wait and Retry**
- Hugging Face servers may be temporarily overloaded
- Try again in 10-15 minutes

In [None]:
# ALTERNATIVE DOWNLOAD METHOD (Run this cell if automatic download fails)
# This cell attempts to download the model using direct URLs

import requests
from tqdm import tqdm

def download_file(url, filename):
    """Download file with progress bar"""
    try:
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(filename, 'wb') as file, tqdm(
            desc=filename,
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as progress_bar:
            for data in response.iter_content(chunk_size=1024):
                size = file.write(data)
                progress_bar.update(size)
        
        print(f"‚úÖ Downloaded {filename}")
        return True
    except Exception as e:
        print(f"‚ùå Failed to download {filename}: {e}")
        return False

# Alternative: Try downloading with huggingface_hub's snapshot_download
print("üîÑ Attempting alternative download method...")

try:
    from huggingface_hub import snapshot_download
    
    # Download entire model repository
    cache_dir = "/kaggle/working/model_cache"
    
    print(f"Downloading {VIT_MODEL_NAME} to {cache_dir}...")
    model_path = snapshot_download(
        repo_id=VIT_MODEL_NAME,
        cache_dir=cache_dir,
        resume_download=True,
        max_workers=1  # Use single worker to avoid 500 errors
    )
    
    print(f"‚úÖ Model downloaded successfully to: {model_path}")
    print("üí° Now update VIT_MODEL_NAME to use local path:")
    print(f"    VIT_MODEL_NAME = '{model_path}'")
    
except Exception as e:
    print(f"‚ö†Ô∏è Alternative download also failed: {str(e)[:200]}")
    print("\nüí° FALLBACK OPTIONS:")
    print("  1. Use smaller model: VIT_MODEL_NAME = 'google/vit-base-patch16-224'")
    print("  2. Train without pre-training (random initialization)")
    print("  3. Wait 15 minutes and retry")
    
print("\n" + "=" * 60)

In [None]:
import time

class ViTForEmotionRegression(nn.Module):
    """
    Vision Transformer for Emotion Regression
    Uses pre-trained ViT from Hugging Face and adds regression head
    """
    def __init__(self, model_name=VIT_MODEL_NAME, freeze_backbone=FREEZE_BACKBONE, dropout=DROPOUT):
        super(ViTForEmotionRegression, self).__init__()
        
        # Load pre-trained ViT model with retry logic
        print(f"Loading pre-trained ViT model: {model_name}...")
        
        max_retries = 3
        retry_delay = 5  # seconds
        
        for attempt in range(max_retries):
            try:
                # Try loading with resume_download=True to handle interrupted downloads
                self.vit = ViTModel.from_pretrained(
                    model_name,
                    resume_download=True,
                    force_download=False,
                    local_files_only=False
                )
                print(f"  ‚úì Model loaded successfully!")
                break
                
            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"  ‚ö†Ô∏è Download attempt {attempt + 1} failed: {str(e)[:100]}")
                    print(f"  ‚è≥ Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                else:
                    print(f"  ‚ùå All download attempts failed!")
                    print(f"  üí° Trying alternative approach...")
                    
                    # Fallback: Try to load from cache only
                    try:
                        self.vit = ViTModel.from_pretrained(
                            model_name,
                            local_files_only=True
                        )
                        print(f"  ‚úì Loaded from local cache!")
                    except:
                        # Last resort: Create model from config
                        print(f"  üîß Creating model from config (no pre-trained weights)...")
                        config = ViTConfig(
                            hidden_size=768,
                            num_hidden_layers=12,
                            num_attention_heads=12,
                            intermediate_size=3072,
                            image_size=224,
                            patch_size=16,
                            num_channels=3
                        )
                        self.vit = ViTModel(config)
                        print(f"  ‚ö†Ô∏è WARNING: Using randomly initialized ViT (no transfer learning)")
        
        # Freeze backbone if specified
        if freeze_backbone:
            for param in self.vit.parameters():
                param.requires_grad = False
            print("  ‚úì ViT backbone frozen")
        else:
            print("  ‚úì ViT backbone trainable")
        
        # Get hidden size from ViT config
        self.hidden_size = self.vit.config.hidden_size  # 768 for base model
        
        # Regression head for valence and arousal
        self.regression_head = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Linear(self.hidden_size, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, 2)  # Valence and Arousal
        )
        
    def forward(self, pixel_values):
        # Get ViT outputs
        outputs = self.vit(pixel_values=pixel_values)
        
        # Use [CLS] token representation
        cls_output = outputs.last_hidden_state[:, 0]  # (batch_size, hidden_size)
        
        # Regression head
        emotion_output = self.regression_head(cls_output)  # (batch_size, 2)
        
        return emotion_output


# Initialize model with error handling
print("=" * 60)
print("ü§ñ INITIALIZING VISION TRANSFORMER")
print("=" * 60)

try:
    model = ViTForEmotionRegression(
        model_name=VIT_MODEL_NAME,
        freeze_backbone=FREEZE_BACKBONE,
        dropout=DROPOUT
    ).to(DEVICE)
    
    # Print model summary
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print("\n" + "=" * 60)
    print("‚úÖ MODEL INITIALIZED SUCCESSFULLY")
    print("=" * 60)
    print(f"Model: {VIT_MODEL_NAME}")
    print(f"Hidden size: {model.hidden_size}")
    print(f"Freeze backbone: {FREEZE_BACKBONE}")
    print(f"\nTotal parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Frozen parameters: {total_params - trainable_params:,}")
    print("=" * 60)
    
except Exception as e:
    print(f"\n‚ùå ERROR: Failed to initialize model")
    print(f"Error details: {str(e)}")
    print("\nüí° TROUBLESHOOTING TIPS:")
    print("  1. Check your internet connection")
    print("  2. Try restarting the kernel and running again")
    print("  3. Manually download the model from: https://huggingface.co/google/vit-base-patch16-224-in21k")
    print("  4. Set local_files_only=True if model is already downloaded")
    raise

## 9Ô∏è‚É£ Train ViT Model on Augmented Dataset

In [None]:
# ===============================
# Weighted Emotion Loss Function
# ===============================
class WeightedEmotionLoss(nn.Module):
    """Weighted MSE loss for valence and arousal with different importance."""
    def __init__(self, valence_weight=0.6, arousal_weight=0.4):
        super(WeightedEmotionLoss, self).__init__()
        self.valence_weight = valence_weight
        self.arousal_weight = arousal_weight
        
    def forward(self, pred, target):
        valence_loss = F.mse_loss(pred[:, 0], target[:, 0])
        arousal_loss = F.mse_loss(pred[:, 1], target[:, 1])
        return self.valence_weight * valence_loss + self.arousal_weight * arousal_loss


# ===============================
# Progressive Layer Unfreezing
# ===============================
def unfreeze_vit_layers(model, layers_to_unfreeze):
    """
    Progressively unfreeze ViT encoder layers.
    
    Args:
        model: ViTForEmotionRegression model
        layers_to_unfreeze: List of layer indices to unfreeze (e.g., [10, 11] for last 2 layers)
    """
    # First ensure everything in encoder is frozen
    for param in model.vit.parameters():
        param.requires_grad = False
    
    # Unfreeze specified layers
    if hasattr(model.vit, 'encoder') and hasattr(model.vit.encoder, 'layer'):
        for layer_idx in layers_to_unfreeze:
            if layer_idx < len(model.vit.encoder.layer):
                for param in model.vit.encoder.layer[layer_idx].parameters():
                    param.requires_grad = True
                print(f"   ‚úì Unfroze encoder layer {layer_idx}")
    
    # Head is always trainable
    for param in model.regression_head.parameters():
        param.requires_grad = True
    
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"   Trainable params: {trainable:,} / {total:,} ({100*trainable/total:.1f}%)")


# ===============================
# Enhanced Loss and Optimizer
# ===============================
criterion = WeightedEmotionLoss(valence_weight=0.6, arousal_weight=0.4)

# Separate learning rates for backbone and head
backbone_params = []
head_params = []

for name, param in model.named_parameters():
    if param.requires_grad:
        if 'regression_head' in name:
            head_params.append(param)
        else:
            backbone_params.append(param)

optimizer = optim.AdamW([
    {'params': head_params, 'lr': LEARNING_RATE},
    {'params': backbone_params, 'lr': LEARNING_RATE * 0.1}  # Lower LR for backbone
], weight_decay=WEIGHT_DECAY)

scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

print("\\nüéØ Using Weighted Emotion Loss:")
print(f"   Valence weight: 60%")
print(f"   Arousal weight: 40%")


# ===============================
# Evaluation Metrics
# ===============================
def concordance_correlation_coefficient(y_true, y_pred):
    """Calculate CCC for emotion prediction evaluation"""
    mean_true = torch.mean(y_true)
    mean_pred = torch.mean(y_pred)
    var_true = torch.var(y_true)
    var_pred = torch.var(y_pred)
    covariance = torch.mean((y_true - mean_true) * (y_pred - mean_pred))
    
    ccc = (2 * covariance) / (var_true + var_pred + (mean_true - mean_pred) ** 2 + 1e-8)
    return ccc.item()


# ===============================
# Training Function with Mixed Precision
# ===============================
scaler = torch.cuda.amp.GradScaler() if DEVICE.type == 'cuda' else None

def train_epoch(model, loader, criterion, optimizer, device, use_amp=True):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for specs, labels in tqdm(loader, desc="Training", leave=False):
        specs, labels = specs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Mixed precision training
        if use_amp and scaler:
            with torch.cuda.amp.autocast():
                outputs = model(specs)
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(specs)
            loss = criterion(outputs, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        total_loss += loss.item()
        all_preds.append(outputs.detach().cpu())
        all_labels.append(labels.detach().cpu())
    
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    avg_loss = total_loss / len(loader)
    mae = F.l1_loss(all_preds, all_labels).item()
    ccc_valence = concordance_correlation_coefficient(all_labels[:, 0], all_preds[:, 0])
    ccc_arousal = concordance_correlation_coefficient(all_labels[:, 1], all_preds[:, 1])
    
    return avg_loss, mae, ccc_valence, ccc_arousal


# ===============================
# Validation Function
# ===============================
def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for specs, labels in tqdm(loader, desc="Validating", leave=False):
            specs, labels = specs.to(device), labels.to(device)
            
            outputs = model(specs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            all_preds.append(outputs.cpu())
            all_labels.append(labels.cpu())
    
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    avg_loss = total_loss / len(loader)
    mae = F.l1_loss(all_preds, all_labels).item()
    ccc_valence = concordance_correlation_coefficient(all_labels[:, 0], all_preds[:, 0])
    ccc_arousal = concordance_correlation_coefficient(all_labels[:, 1], all_preds[:, 1])
    
    return avg_loss, mae, ccc_valence, ccc_arousal, all_preds, all_labels


# ===============================
# Training Loop with Progressive Unfreezing
# ===============================
print("\\nüöÄ Starting ViT Training on Augmented Dataset...\\n")

history = {
    'train_loss': [], 'train_mae': [], 'train_ccc_v': [], 'train_ccc_a': [],
    'val_loss': [], 'val_mae': [], 'val_ccc_v': [], 'val_ccc_a': []
}

best_val_loss = float('inf')
patience = 5
patience_counter = 0

# Progressive unfreezing schedule
unfreezing_schedule = {
    8: [11],      # Unfreeze last layer at epoch 8
    12: [10, 11],  # Unfreeze last 2 layers at epoch 12
    16: [9, 10, 11],  # Unfreeze last 3 layers at epoch 16
}

for epoch in range(NUM_EPOCHS):
    print(f"\\n{'='*60}")
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
    print(f"{'='*60}")
    
    # Progressive unfreezing
    if epoch in unfreezing_schedule:
        print(f"\\nüîì Progressive Unfreezing at Epoch {epoch + 1}:")
        unfreeze_vit_layers(model, unfreezing_schedule[epoch])
        
        # Update optimizer with newly unfrozen parameters
        backbone_params = []
        head_params = []
        for name, param in model.named_parameters():
            if param.requires_grad:
                if 'regression_head' in name:
                    head_params.append(param)
                else:
                    backbone_params.append(param)
        
        optimizer = optim.AdamW([
            {'params': head_params, 'lr': LEARNING_RATE},
            {'params': backbone_params, 'lr': LEARNING_RATE * 0.1}
        ], weight_decay=WEIGHT_DECAY)
        
        scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS - epoch, eta_min=1e-6)
    
    # Train
    train_loss, train_mae, train_ccc_v, train_ccc_a = train_epoch(
        model, train_loader, criterion, optimizer, DEVICE
    )
    
    # Validate
    val_loss, val_mae, val_ccc_v, val_ccc_a, val_preds, val_labels = validate(
        model, val_loader, criterion, DEVICE
    )
    
    # Update learning rate
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_mae'].append(train_mae)
    history['train_ccc_v'].append(train_ccc_v)
    history['train_ccc_a'].append(train_ccc_a)
    history['val_loss'].append(val_loss)
    history['val_mae'].append(val_mae)
    history['val_ccc_v'].append(val_ccc_v)
    history['val_ccc_a'].append(val_ccc_a)
    
    # Print metrics
    print(f"\\nüìä Training Metrics:")
    print(f"  Loss: {train_loss:.4f} | MAE: {train_mae:.4f}")
    print(f"  CCC Valence: {train_ccc_v:.4f} | CCC Arousal: {train_ccc_a:.4f}")
    
    print(f"\\nüìä Validation Metrics:")
    print(f"  Loss: {val_loss:.4f} | MAE: {val_mae:.4f}")
    print(f"  CCC Valence: {val_ccc_v:.4f} | CCC Arousal: {val_ccc_a:.4f}")
    print(f"  Learning Rate: {current_lr:.2e}")
    
    # Early stopping and model saving
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'best_vit_model.pth'))
        print(f"\\n‚úÖ Best model saved! (Val Loss: {val_loss:.4f})")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\\n‚ö†Ô∏è Early stopping triggered after {patience} epochs without improvement")
            break

print("\\n" + "="*60)
print("‚úÖ Training Complete!")
print("="*60)
print(f"Best Validation Loss: {best_val_loss:.4f}")

## üîü Visualize Results & Analysis

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('MSE Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# MAE
axes[0, 1].plot(history['train_mae'], label='Train MAE', linewidth=2)
axes[0, 1].plot(history['val_mae'], label='Val MAE', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Mean Absolute Error')
axes[0, 1].set_title('Training & Validation MAE')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# CCC Valence
axes[1, 0].plot(history['train_ccc_v'], label='Train CCC', linewidth=2)
axes[1, 0].plot(history['val_ccc_v'], label='Val CCC', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('CCC')
axes[1, 0].set_title('Valence CCC')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].axhline(y=0, color='r', linestyle='--', alpha=0.3)

# CCC Arousal
axes[1, 1].plot(history['train_ccc_a'], label='Train CCC', linewidth=2)
axes[1, 1].plot(history['val_ccc_a'], label='Val CCC', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('CCC')
axes[1, 1].set_title('Arousal CCC')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].axhline(y=0, color='r', linestyle='--', alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'vit_training_curves.png'), dpi=150, bbox_inches='tight')
plt.show()

# Scatter plots: Predicted vs Actual
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Valence
axes[0].scatter(val_labels[:, 0], val_preds[:, 0], alpha=0.5, s=20)
axes[0].plot([-1, 1], [-1, 1], 'r--', linewidth=2, label='Perfect Prediction')
axes[0].set_xlabel('Actual Valence')
axes[0].set_ylabel('Predicted Valence')
axes[0].set_title(f'Valence Prediction (CCC: {val_ccc_v:.4f})')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_xlim(-1.2, 1.2)
axes[0].set_ylim(-1.2, 1.2)

# Arousal
axes[1].scatter(val_labels[:, 1], val_preds[:, 1], alpha=0.5, s=20)
axes[1].plot([-1, 1], [-1, 1], 'r--', linewidth=2, label='Perfect Prediction')
axes[1].set_xlabel('Actual Arousal')
axes[1].set_ylabel('Predicted Arousal')
axes[1].set_title(f'Arousal Prediction (CCC: {val_ccc_a:.4f})')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim(-1.2, 1.2)
axes[1].set_ylim(-1.2, 1.2)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'prediction_scatter.png'), dpi=150, bbox_inches='tight')
plt.show()

# 2D Valence-Arousal Space
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Ground Truth
axes[0].scatter(val_labels[:, 0], val_labels[:, 1], alpha=0.6, s=50, c='blue', edgecolors='black')
axes[0].set_xlabel('Valence')
axes[0].set_ylabel('Arousal')
axes[0].set_title('Ground Truth VA Space')
axes[0].grid(True, alpha=0.3)
axes[0].axhline(0, color='k', linewidth=0.5)
axes[0].axvline(0, color='k', linewidth=0.5)
axes[0].set_xlim(-1.2, 1.2)
axes[0].set_ylim(-1.2, 1.2)

# Predictions
axes[1].scatter(val_preds[:, 0], val_preds[:, 1], alpha=0.6, s=50, c='red', edgecolors='black')
axes[1].set_xlabel('Valence')
axes[1].set_ylabel('Arousal')
axes[1].set_title('Predicted VA Space')
axes[1].grid(True, alpha=0.3)
axes[1].axhline(0, color='k', linewidth=0.5)
axes[1].axvline(0, color='k', linewidth=0.5)
axes[1].set_xlim(-1.2, 1.2)
axes[1].set_ylim(-1.2, 1.2)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'va_space_comparison.png'), dpi=150, bbox_inches='tight')
plt.show()

# Final summary
print("\\n" + "="*60)
print("üìä FINAL RESULTS SUMMARY")
print("="*60)
print(f"\\nüñºÔ∏è Model Architecture:")
print(f"  - Base Model: {VIT_MODEL_NAME}")
print(f"  - Total Parameters: {total_params:,}")
print(f"  - Trainable Parameters: {trainable_params:,}")

print(f"\\nüé® GAN Augmentation:")
print(f"  - Real samples: {len(real_spectrograms)}")
print(f"  - Synthetic samples: {len(synthetic_spectrograms)}")
print(f"  - Total samples: {len(all_spectrograms)}")
print(f"  - Augmentation factor: {len(all_spectrograms)/len(real_spectrograms):.2f}x")

print(f"\\nü§ñ ViT Model Performance:")
print(f"  - Best Val Loss: {best_val_loss:.4f}")
print(f"  - Final Val MAE: {val_mae:.4f}")
print(f"  - Final Val CCC Valence: {val_ccc_v:.4f}")
print(f"  - Final Val CCC Arousal: {val_ccc_a:.4f}")

print(f"\\nüíæ Saved Outputs:")
print(f"  - Generator model: generator.pth")
print(f"  - Discriminator model: discriminator.pth")
print(f"  - Best ViT model: best_vit_model.pth")
print(f"  - Training curves: vit_training_curves.png")
print(f"  - Prediction scatter: prediction_scatter.png")
print(f"  - VA space comparison: va_space_comparison.png")
print("="*60)

In [None]:
# Archive outputs
!zip -r /kaggle/working/vit_output.zip /kaggle/working/vit_augmented
print("‚úÖ Outputs archived to vit_output.zip")

## üß™ Comprehensive Model Testing

Perform thorough testing of the trained ViT model including edge cases and robustness evaluation.

In [None]:
def test_model_robustness(model, test_loader, device=DEVICE):
    """Test model robustness with various edge cases and perturbations."""
    print("üß™ Testing model robustness...")
    
    model.eval()
    
    # Test results storage
    test_results = {
        'normal_predictions': [],
        'noisy_predictions': [],
        'augmented_predictions': [],
        'targets': [],
        'confidence_scores': []
    }
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader, desc='Robustness Testing')):
            if i >= 10:  # Limit to first 10 batches for testing
                break
                
            inputs = batch['pixel_values'].to(device)
            targets = batch['emotions'].to(device)
            
            # 1. Normal prediction
            normal_output = model(inputs)
            
            # 2. Add noise and test
            noise = torch.randn_like(inputs) * 0.1
            noisy_inputs = torch.clamp(inputs + noise, 0, 1)
            noisy_output = model(noisy_inputs)
            
            # 3. Test with augmentation (random rotation)
            augmented_inputs = torch.roll(inputs, shifts=10, dims=-1)
            augmented_output = model(augmented_inputs)
            
            # Calculate confidence (inverse of prediction variance)
            confidence = 1.0 / (torch.var(normal_output, dim=1) + 1e-6)
            
            # Store results
            test_results['normal_predictions'].append(normal_output.cpu())
            test_results['noisy_predictions'].append(noisy_output.cpu())
            test_results['augmented_predictions'].append(augmented_output.cpu())
            test_results['targets'].append(targets.cpu())
            test_results['confidence_scores'].append(confidence.cpu())
    
    # Concatenate all results
    for key in test_results:
        if test_results[key]:
            test_results[key] = torch.cat(test_results[key], dim=0).numpy()
    
    return test_results

def analyze_prediction_patterns(test_results):
    """Analyze prediction patterns and model behavior."""
    print("\\nüìä Analyzing prediction patterns...")
    
    normal_pred = test_results['normal_predictions']
    noisy_pred = test_results['noisy_predictions']
    aug_pred = test_results['augmented_predictions']
    targets = test_results['targets']
    
    # Calculate robustness metrics
    noise_robustness = np.mean(np.abs(normal_pred - noisy_pred))
    aug_robustness = np.mean(np.abs(normal_pred - aug_pred))
    
    print(f"üîä Noise Robustness (MAE): {noise_robustness:.4f}")
    print(f"üîÑ Augmentation Robustness (MAE): {aug_robustness:.4f}")
    
    # Plot robustness analysis
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Prediction consistency
    axes[0, 0].scatter(normal_pred[:, 0], noisy_pred[:, 0], alpha=0.6, color='blue')
    axes[0, 0].plot([0, 1], [0, 1], 'r--', lw=2)
    axes[0, 0].set_xlabel('Normal Prediction (Valence)')
    axes[0, 0].set_ylabel('Noisy Prediction (Valence)')
    axes[0, 0].set_title('Noise Robustness - Valence')
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[0, 1].scatter(normal_pred[:, 1], noisy_pred[:, 1], alpha=0.6, color='red')
    axes[0, 1].plot([0, 1], [0, 1], 'r--', lw=2)
    axes[0, 1].set_xlabel('Normal Prediction (Arousal)')
    axes[0, 1].set_ylabel('Noisy Prediction (Arousal)')
    axes[0, 1].set_title('Noise Robustness - Arousal')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Prediction distribution
    axes[0, 2].hist(normal_pred.flatten(), bins=30, alpha=0.7, label='Normal', color='blue')
    axes[0, 2].hist(noisy_pred.flatten(), bins=30, alpha=0.7, label='Noisy', color='orange')
    axes[0, 2].set_xlabel('Prediction Value')
    axes[0, 2].set_ylabel('Frequency')
    axes[0, 2].set_title('Prediction Distribution')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Error analysis
    normal_error = np.abs(normal_pred - targets)
    noisy_error = np.abs(noisy_pred - targets)
    
    axes[1, 0].hist(normal_error[:, 0], bins=20, alpha=0.7, label='Normal', color='blue')
    axes[1, 0].hist(noisy_error[:, 0], bins=20, alpha=0.7, label='Noisy', color='orange')
    axes[1, 0].set_xlabel('Absolute Error')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Valence Error Distribution')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    axes[1, 1].hist(normal_error[:, 1], bins=20, alpha=0.7, label='Normal', color='blue')
    axes[1, 1].hist(noisy_error[:, 1], bins=20, alpha=0.7, label='Noisy', color='orange')
    axes[1, 1].set_xlabel('Absolute Error')
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].set_title('Arousal Error Distribution')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Confidence analysis
    confidence = test_results['confidence_scores']
    axes[1, 2].scatter(confidence, normal_error.mean(axis=1), alpha=0.6, color='green')
    axes[1, 2].set_xlabel('Confidence Score')
    axes[1, 2].set_ylabel('Prediction Error')
    axes[1, 2].set_title('Confidence vs Error')
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return {
        'noise_robustness': noise_robustness,
        'augmentation_robustness': aug_robustness,
        'mean_confidence': np.mean(confidence)
    }

def test_edge_cases(model, device=DEVICE):
    """Test model behavior on edge cases."""
    print("\\nüö® Testing edge cases...")
    
    model.eval()
    edge_cases = {}
    
    with torch.no_grad():
        # Test with all zeros (silence)
        zeros_input = torch.zeros(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(device)
        zeros_pred = model(zeros_input)
        edge_cases['silence'] = zeros_pred.cpu().numpy()
        
        # Test with all ones (maximum intensity)
        ones_input = torch.ones(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(device)
        ones_pred = model(ones_input)
        edge_cases['maximum'] = ones_pred.cpu().numpy()
        
        # Test with random noise
        noise_input = torch.randn(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(device)
        noise_input = torch.clamp(noise_input, 0, 1)
        noise_pred = model(noise_input)
        edge_cases['noise'] = noise_pred.cpu().numpy()
        
        # Test with checkerboard pattern
        checker_input = torch.zeros(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE)
        checker_input[:, :, ::2, ::2] = 1
        checker_input[:, :, 1::2, 1::2] = 1
        checker_input = checker_input.to(device)
        checker_pred = model(checker_input)
        edge_cases['checkerboard'] = checker_pred.cpu().numpy()
    
    print("Edge case predictions:")
    for case, pred in edge_cases.items():
        valence, arousal = pred[0]
        print(f"  {case:12}: Valence={valence:.3f}, Arousal={arousal:.3f}")
    
    return edge_cases

def performance_benchmark(model, test_loader, device=DEVICE):
    """Benchmark model performance and timing."""
    print("\\n‚ö° Performance benchmarking...")
    
    model.eval()
    
    # Warmup
    dummy_input = torch.randn(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(device)
    for _ in range(5):
        _ = model(dummy_input)
    
    # Timing test
    import time
    times = []
    batch_sizes = [1, 4, 8, 16]
    
    for batch_size in batch_sizes:
        test_input = torch.randn(batch_size, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(device)
        
        # Measure inference time
        torch.cuda.synchronize() if device.type == 'cuda' else None
        start_time = time.time()
        
        with torch.no_grad():
            for _ in range(10):  # Average over 10 runs
                _ = model(test_input)
        
        torch.cuda.synchronize() if device.type == 'cuda' else None
        end_time = time.time()
        
        avg_time = (end_time - start_time) / 10
        times.append(avg_time)
        
        print(f"  Batch size {batch_size:2d}: {avg_time:.4f}s ({batch_size/avg_time:.1f} samples/s)")
    
    # Memory usage
    if device.type == 'cuda':
        memory_usage = torch.cuda.max_memory_allocated() / 1024**2  # MB
        print(f"  Max GPU memory: {memory_usage:.1f} MB")
    
    return {'batch_sizes': batch_sizes, 'inference_times': times}

# Run comprehensive testing
if 'best_vit_model' in locals() and 'test_loader' in locals():
    print("üöÄ Starting comprehensive ViT model testing...")
    
    # Load best model
    try:
        best_model_path = '/kaggle/working/vit_augmented/best_vit_model.pth'
        if os.path.exists(best_model_path):
            best_vit_model.load_state_dict(torch.load(best_model_path))
            print("‚úÖ Best model loaded for testing")
    except:
        print("‚ö†Ô∏è Using current model state for testing")
    
    # 1. Robustness testing
    test_results = test_model_robustness(best_vit_model, test_loader)
    robustness_metrics = analyze_prediction_patterns(test_results)
    
    # 2. Edge case testing
    edge_results = test_edge_cases(best_vit_model)
    
    # 3. Performance benchmarking
    perf_results = performance_benchmark(best_vit_model, test_loader)
    
    # Summary report
    print("\\n" + "="*60)
    print("üìã COMPREHENSIVE TESTING SUMMARY")
    print("="*60)
    print(f"‚úÖ Robustness Testing:")
    print(f"   - Noise Robustness: {robustness_metrics['noise_robustness']:.4f}")
    print(f"   - Augmentation Robustness: {robustness_metrics['augmentation_robustness']:.4f}")
    print(f"   - Mean Confidence: {robustness_metrics['mean_confidence']:.4f}")
    print(f"\\n‚úÖ Edge Cases: All {len(edge_results)} test cases completed")
    print(f"\\n‚úÖ Performance: Benchmarked across {len(perf_results['batch_sizes'])} batch sizes")
    print("\\nüéâ All tests completed successfully!")
    print("="*60)
    
else:
    print("‚ö†Ô∏è Skipping comprehensive testing - model or test data not available")
    print("Please ensure the model is trained and test data is prepared.")