# üéµ Vision Transformer (ViT) Training with GAN-Based Data Augmentation

## üìã Overview
This notebook implements **Vision Transformer (ViT)** using the pre-trained `google/vit-base-patch16-224-in21k` model for music emotion recognition with **GAN-based data augmentation** to expand the DEAM dataset.

### Key Features:
- **Pre-trained ViT**: Uses `google/vit-base-patch16-224-in21k` trained on ImageNet-21k
- **Transfer Learning**: Fine-tunes large vision model on audio spectrograms
- **Conditional GAN**: Generates synthetic spectrograms conditioned on valence/arousal
- **Data Expansion**: Increases dataset size from ~1800 to 5000+ samples
- **Emotion Prediction**: Valence-Arousal (VA) continuous values

### Pipeline:
1. Load DEAM dataset and extract real spectrograms
2. Train Conditional GAN to generate synthetic spectrograms
3. Augment dataset with GAN-generated samples
4. Fine-tune pre-trained ViT model on expanded dataset
5. Evaluate on test set

## 1Ô∏è‚É£ Import Libraries

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import warnings
from pathlib import Path
warnings.filterwarnings('ignore')

# Audio processing
import librosa
import librosa.display

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

# Hugging Face Transformers
from transformers import ViTModel, ViTConfig
from PIL import Image

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
root = Path('/kaggle/input').resolve()
print(f"Root exists: {root.exists()}")
print("‚úÖ All libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")

## 2Ô∏è‚É£ Configuration & Hyperparameters

In [None]:
# ========================
# DATASET CONFIGURATION
# ========================
AUDIO_DIR = '/kaggle/input/deam-mediaeval-dataset-emotional-analysis-in-music/DEAM_audio/MEMD_audio/'
ANNOTATIONS_DIR = '/kaggle/input/deam-mediaeval-dataset-emotional-analysis-in-music/DEAM_Annotations/annotations/annotations averaged per song/song_level/'

# ========================
# AUDIO PROCESSING CONFIG
# ========================
SAMPLE_RATE = 22050          # Audio sampling rate (Hz)
DURATION = 30                # Audio clip duration (seconds)
N_MELS = 128                 # Number of mel-frequency bins
HOP_LENGTH = 512             # Hop length for STFT
N_FFT = 2048                 # FFT window size
FMIN = 20                    # Minimum frequency
FMAX = 8000                  # Maximum frequency

# ========================
# VIT PREPROCESSING CONFIG
# ========================
VIT_IMAGE_SIZE = 224         # ViT expects 224x224 images
VIT_CHANNELS = 3             # RGB channels (we'll triplicate grayscale)
IMAGENET_MEAN = [0.485, 0.456, 0.406]  # ImageNet normalization mean
IMAGENET_STD = [0.229, 0.224, 0.225]   # ImageNet normalization std

# ========================
# GAN CONFIGURATION
# ========================
LATENT_DIM = 100             # Dimension of GAN noise vector
CONDITION_DIM = 2            # Valence + Arousal
GAN_LR = 0.0002              # GAN learning rate
GAN_BETA1 = 0.5              # Adam beta1 for GAN
GAN_BETA2 = 0.999            # Adam beta2 for GAN
GAN_EPOCHS = 10              # GAN pre-training epochs
GAN_BATCH_SIZE = 24          # GAN batch size (reduced from 32 to save memory)
NUM_SYNTHETIC = 3200         # Number of synthetic samples to generate

# ========================
# VIT MODEL CONFIGURATION
# ========================
# OPTION 1: Use pre-downloaded model (recommended to avoid download issues)
VIT_MODEL_NAME = '/kaggle/input/vit-model-kaggle/vit-model-for-kaggle'  # Update with your dataset path

# OPTION 2: Fallback to online download (may fail with 500 errors)
# VIT_MODEL_NAME = 'google/vit-base-patch16-224-in21k'

# OPTION 3: Use smaller, more stable model
# VIT_MODEL_NAME = 'google/vit-base-patch16-224'

FREEZE_BACKBONE = False      # Whether to freeze ViT encoder layers
DROPOUT = 0.1                # Dropout rate

# ========================
# TRAINING CONFIGURATION
# ========================
BATCH_SIZE = 12              # Training batch size (reduced from 16 to save memory)
NUM_EPOCHS = 24              # Training epochs
LEARNING_RATE = 1e-4         # Learning rate for fine-tuning
WEIGHT_DECAY = 0.05          # AdamW weight decay
TRAIN_SPLIT = 0.8            # Train/validation split ratio

# ========================
# MEMORY OPTIMIZATION
# ========================
# If you still encounter OOM errors, try these:
# - Reduce GAN_BATCH_SIZE to 16
# - Reduce BATCH_SIZE to 8
# - Reduce NUM_SYNTHETIC to 2000
# - Set PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# ========================
# SYSTEM CONFIGURATION
# ========================
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
OUTPUT_DIR = '/kaggle/working/vit_augmented'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Enable memory efficient settings
if torch.cuda.is_available():
    # Enable TF32 for faster computation on Ampere GPUs
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    # Enable cudnn benchmarking for optimal performance
    torch.backends.cudnn.benchmark = True
    print(f"üöÄ CUDA optimizations enabled")

print("=" * 60)
print("üìä CONFIGURATION SUMMARY")
print("=" * 60)
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
print(f"Audio Duration: {DURATION}s @ {SAMPLE_RATE}Hz")
print(f"Mel-Spectrogram: {N_MELS} bins")
print(f"\nüñºÔ∏è ViT Configuration:")
print(f"  - Model Path: {VIT_MODEL_NAME}")
print(f"  - Input Size: {VIT_IMAGE_SIZE}x{VIT_IMAGE_SIZE}x{VIT_CHANNELS}")
print(f"  - Freeze Backbone: {FREEZE_BACKBONE}")
print(f"\nüé® GAN Configuration:")
print(f"  - Latent Dim: {LATENT_DIM}")
print(f"  - GAN Epochs: {GAN_EPOCHS}")
print(f"  - GAN Batch Size: {GAN_BATCH_SIZE}")
print(f"  - Synthetic Samples: {NUM_SYNTHETIC}")
print(f"\nüèãÔ∏è Training Configuration:")
print(f"  - Epochs: {NUM_EPOCHS}")
print(f"  - Batch Size: {BATCH_SIZE}")
print(f"  - Learning Rate: {LEARNING_RATE}")
print("=" * 60)

## 3Ô∏è‚É£ Load DEAM Dataset & Extract Real Spectrograms

In [None]:
# Load annotations
static_2000 = root / 'static-annotations-1-2000' / 'static_annotations_averaged_songs_1_2000.csv'
static_2058 = root / 'static-annots-2058' / 'static_annots_2058.csv'

try:
    df1 = pd.read_csv(static_2000)
    df2 = pd.read_csv(static_2058)
    df_annotations = pd.concat([df1, df2], axis=0)
    print(f"‚úÖ Loaded annotations: {len(df_annotations)} songs")
except Exception as e:
    print(f"‚ùå Error loading annotations: {e}")
    raise

# Clean column names
df_annotations.columns = df_annotations.columns.str.strip()

print("\\nüìä Annotation Sample:")
print(df_annotations.head())
print(f"\\nColumns: {list(df_annotations.columns)}")

# Check for audio files
audio_files = glob.glob(os.path.join(AUDIO_DIR, '*.mp3'))
print(f"\\nüéµ Found {len(audio_files)} audio files")

# Extract spectrograms with error logging
print("\\nüîä Extracting spectrograms from real audio...")

error_log = []

def extract_melspectrogram(audio_path, sr=SAMPLE_RATE, duration=DURATION):
    """Extract mel-spectrogram from audio file with error handling"""
    try:
        # Load audio
        y, _ = librosa.load(audio_path, sr=sr, duration=duration)
        
        # Compute mel-spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=y, sr=sr, n_mels=N_MELS, n_fft=N_FFT, 
            hop_length=HOP_LENGTH, fmin=FMIN, fmax=FMAX
        )
        
        # Convert to log scale (dB)
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Normalize to [-1, 1]
        mel_spec_norm = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-8)
        
        return mel_spec_norm, None
    except Exception as e:
        error_msg = f"Error processing {os.path.basename(audio_path)}: {str(e)}"
        return None, error_msg

# Extract spectrograms and labels
real_spectrograms = []
real_labels = []

for idx, row in tqdm(df_annotations.iterrows(), total=len(df_annotations), desc="Extracting spectrograms"):
    song_id = str(int(row['song_id']))
    audio_path = os.path.join(AUDIO_DIR, f"{song_id}.mp3")
    
    if not os.path.exists(audio_path):
        error_log.append(f"Missing audio file: {song_id}.mp3")
        continue
    
    # Extract spectrogram
    spec, error = extract_melspectrogram(audio_path)
    
    if error is not None:
        error_log.append(error)
        continue
        
    if spec is not None:
        real_spectrograms.append(spec)
        
        # Get valence and arousal
        valence = row.get('valence_mean', row.get('valence', 0.5))
        arousal = row.get('arousal_mean', row.get('arousal', 0.5))
        
        # Normalize to [-1, 1] range
        valence_norm = (valence - 5.0) / 4.0
        arousal_norm = (arousal - 5.0) / 4.0
        
        real_labels.append([valence_norm, arousal_norm])

# Convert to numpy arrays
real_spectrograms = np.array(real_spectrograms)
real_labels = np.array(real_labels)

print(f"\\n‚úÖ Extracted {len(real_spectrograms)} spectrograms")
print(f"Spectrogram shape: {real_spectrograms.shape}")
print(f"Labels shape: {real_labels.shape}")
print(f"Spectrogram range: [{real_spectrograms.min():.2f}, {real_spectrograms.max():.2f}]")
print(f"Labels range: [{real_labels.min():.2f}, {real_labels.max():.2f}]")

if error_log:
    print(f"\\n‚ö†Ô∏è {len(error_log)} errors occurred during extraction:")
    for i, error in enumerate(error_log[:10]):  # Show first 10 errors
        print(f"  {i+1}. {error}")
    if len(error_log) > 10:
        print(f"  ... and {len(error_log) - 10} more errors")

### üìä Load Annotations and Extract Mel-Spectrograms

Load the DEAM dataset annotations (valence/arousal ratings) and extract mel-spectrograms from audio files.

In [None]:
# Visualize sample spectrogram
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

axes[0].imshow(real_spectrograms[0], aspect='auto', origin='lower', cmap='viridis')
axes[0].set_title(f'Sample Spectrogram\\nValence: {real_labels[0][0]:.2f}, Arousal: {real_labels[0][1]:.2f}')
axes[0].set_xlabel('Time Frames')
axes[0].set_ylabel('Mel Frequency Bins')

axes[1].scatter(real_labels[:, 0], real_labels[:, 1], alpha=0.5)
axes[1].set_xlabel('Valence (normalized)')
axes[1].set_ylabel('Arousal (normalized)')
axes[1].set_title('Valence-Arousal Distribution (Real Data)')
axes[1].grid(True, alpha=0.3)
axes[1].axhline(0, color='k', linewidth=0.5)
axes[1].axvline(0, color='k', linewidth=0.5)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'real_data_visualization.png'), dpi=150, bbox_inches='tight')
plt.show()

### üìà Visualize Real Data Distribution

Display sample spectrograms and the distribution of valence-arousal values in the real DEAM dataset.

## 4Ô∏è‚É£ Conditional GAN Architecture

In [None]:
class ChannelAttention(nn.Module):
    """
    Memory-efficient channel attention instead of spatial self-attention.
    Reduces memory from O(H*W * H*W) to O(C*C).
    """
    def __init__(self, channels, reduction=8):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        batch_size, channels, _, _ = x.size()
        
        # Channel attention via global pooling
        avg_out = self.fc(self.avg_pool(x).view(batch_size, channels))
        max_out = self.fc(self.max_pool(x).view(batch_size, channels))
        
        # Combine and apply attention
        attention = (avg_out + max_out).view(batch_size, channels, 1, 1)
        return x * attention


class ImprovedSpectrogramGenerator(nn.Module):
    """Enhanced Conditional GAN Generator with Channel Attention (memory-efficient)"""
    def __init__(self, latent_dim=LATENT_DIM, condition_dim=CONDITION_DIM, 
                 n_mels=N_MELS, time_steps=1292):
        super(ImprovedSpectrogramGenerator, self).__init__()
        
        self.latent_dim = latent_dim
        self.condition_dim = condition_dim
        self.n_mels = n_mels
        self.time_steps = time_steps
        
        # Improved condition embedding
        self.condition_embed = nn.Sequential(
            nn.Linear(condition_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2)
        )
        
        # Initial projection with condition
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + 256, 256 * 16 * 20),
            nn.BatchNorm1d(256 * 16 * 20),
            nn.LeakyReLU(0.2)
        )
        
        # Convolutional upsampling with channel attention
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2)
        )
        
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2)
        )
        
        # Channel attention module (memory-efficient)
        self.attention = ChannelAttention(64, reduction=8)
        
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2)
        )
        
        self.conv4 = nn.Sequential(
            nn.ConvTranspose2d(32, 1, kernel_size=(1, 8), stride=(1, 8), padding=0),
            nn.Tanh()
        )
        
    def forward(self, z, c):
        # Embed condition
        c_embed = self.condition_embed(c)
        
        # Concatenate noise and embedded condition
        x = torch.cat([z, c_embed], dim=1)
        x = self.fc(x)
        x = x.view(-1, 256, 16, 20)
        
        # Upsampling with attention
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention(x)  # Apply channel attention (much more memory efficient)
        x = self.conv3(x)
        x = self.conv4(x)
        
        # Ensure correct output size
        if x.shape[-1] != self.time_steps or x.shape[-2] != self.n_mels:
            x = F.interpolate(x, size=(self.n_mels, self.time_steps), mode='bilinear', align_corners=False)
        
        return x


class ImprovedSpectrogramDiscriminator(nn.Module):
    """Enhanced Conditional GAN Discriminator with Spectral Normalization"""
    def __init__(self, condition_dim=CONDITION_DIM, n_mels=N_MELS, time_steps=1292):
        super(ImprovedSpectrogramDiscriminator, self).__init__()
        
        self.n_mels = n_mels
        self.time_steps = time_steps
        
        # Simplified condition embedding (reduce memory)
        self.condition_embed = nn.Sequential(
            nn.Linear(condition_dim, 64),
            nn.LeakyReLU(0.2)
        )
        
        # Convolutional layers with spectral normalization for stability
        self.conv_layers = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        conv_output_size = 256 * 8 * 80
        
        # Fully connected layers with dropout and condition
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(conv_output_size + 64, 256),  # Concatenate with condition
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
            # No sigmoid - will use BCEWithLogitsLoss for better stability
        )
        
    def forward(self, spec, c):
        # Embed condition (reduced dimensionality for memory)
        batch_size = spec.size(0)
        c_embed = self.condition_embed(c)
        
        # Apply conv layers
        features = self.conv_layers(spec)
        features = features.view(features.size(0), -1)
        
        # Concatenate with condition and classify
        x = torch.cat([features, c_embed], dim=1)
        output = self.fc(x)
        return output


# Initialize improved GAN models
time_steps = real_spectrograms.shape[2]
generator = ImprovedSpectrogramGenerator(
    latent_dim=LATENT_DIM, 
    condition_dim=CONDITION_DIM, 
    n_mels=N_MELS, 
    time_steps=time_steps
).to(DEVICE)

discriminator = ImprovedSpectrogramDiscriminator(
    condition_dim=CONDITION_DIM, 
    n_mels=N_MELS, 
    time_steps=time_steps
).to(DEVICE)

print("=" * 60)
print("üé® IMPROVED GAN ARCHITECTURE (Memory-Efficient)")
print("=" * 60)
print(f"‚ú® Generator Features:")
print(f"   - Channel attention (memory-efficient)")
print(f"   - Enhanced condition embedding")
print(f"   - Parameters: {sum(p.numel() for p in generator.parameters()):,}")
print(f"\n‚ú® Discriminator Features:")
print(f"   - Spectral normalization for stability")
print(f"   - Compact condition embedding")
print(f"   - Parameters: {sum(p.numel() for p in discriminator.parameters()):,}")
print("=" * 60)

# Clear cache after model initialization
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"üíæ GPU memory after models: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

### üé® Define GAN Generator with Channel Attention

Improved conditional GAN generator that creates synthetic spectrograms based on valence/arousal conditions.

## 5Ô∏è‚É£ Train Conditional GAN

In [None]:
print("\n" + "=" * 60)
print("üéÆ BALANCED GAN TRAINING CONFIGURATION")
print("=" * 60)

# GAN Training Hyperparameters
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))  # Lower LR for discriminator
criterion = nn.BCEWithLogitsLoss()

# Adaptive training parameters
D_STEPS_THRESHOLD = 0.8  # Train discriminator more if accuracy < 80%
G_STEPS_THRESHOLD = 0.2  # Train generator more if discriminator accuracy > 80%

print(f"‚öôÔ∏è  Optimizer Configuration:")
print(f"   Generator LR: {0.0002} (Adam, Œ≤1=0.5, Œ≤2=0.999)")
print(f"   Discriminator LR: {0.0001} (Adam, Œ≤1=0.5, Œ≤2=0.999)")
print(f"   Loss Function: BCEWithLogitsLoss")
print(f"\nüéØ Adaptive Training:")
print(f"   D steps if D_acc < 80%: 1-2 steps")
print(f"   G steps if D_acc > 80%: 2-3 steps")
print(f"   Gradient clipping: max_norm = 1.0")
print(f"\nüíæ Memory Optimization:")
print(f"   Gradient accumulation: 2 steps (effective batch = {GAN_BATCH_SIZE * 2})")
print(f"   Pin memory: enabled")
print(f"   Periodic cache clearing: every 5 batches")
print(f"   Data on CPU, batch transfer to GPU")
print("=" * 60)

# Extract conditions from real labels (valence, arousal)
real_conditions = real_labels.copy()  # Shape: (N, 2) - valence and arousal
print(f"\nüìä Data prepared for GAN training:")
print(f"   Real spectrograms: {real_spectrograms.shape}")
print(f"   Real conditions: {real_conditions.shape}")
print(f"   Condition range: [{real_conditions.min():.2f}, {real_conditions.max():.2f}]")

# Create memory-efficient DataLoader (data on CPU, transfer batches to GPU)
real_specs_tensor = torch.FloatTensor(real_spectrograms).unsqueeze(1)  # Keep on CPU
real_conditions_tensor = torch.FloatTensor(real_conditions)  # Keep on CPU

gan_dataset = torch.utils.data.TensorDataset(real_specs_tensor, real_conditions_tensor)
gan_loader = torch.utils.data.DataLoader(
    gan_dataset, 
    batch_size=GAN_BATCH_SIZE, 
    shuffle=True,
    pin_memory=True,  # Fast CPU->GPU transfer
    num_workers=0  # Avoid multiprocessing overhead
)

# Clear cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"üßπ Cleared GPU cache")
    print(f"üíæ Initial GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

print(f"\nüöÄ Starting Balanced GAN Training...\n")

# Training loop
gan_losses = {'g_loss': [], 'd_loss': [], 'd_real_acc': [], 'd_fake_acc': []}
GRADIENT_ACCUMULATION_STEPS = 2  # Effective batch size = GAN_BATCH_SIZE * 2

for epoch in range(GAN_EPOCHS):
    epoch_g_loss = 0
    epoch_d_loss = 0
    d_real_correct = 0
    d_fake_correct = 0
    total_samples = 0
    
    # Reset gradient accumulation
    g_optimizer.zero_grad()
    d_optimizer.zero_grad()
    
    for i, (real_specs, conditions) in enumerate(tqdm(gan_loader, desc=f"Epoch {epoch+1}/{GAN_EPOCHS}")):
        # Move batch to GPU (lazy loading)
        real_specs = real_specs.to(DEVICE)
        conditions = conditions.to(DEVICE)
        batch_size = real_specs.size(0)
        
        # Discriminator labels (don't confuse with emotion labels)
        d_real_labels = torch.ones(batch_size, 1).to(DEVICE)
        d_fake_labels = torch.zeros(batch_size, 1).to(DEVICE)
        
        # ========== Train Discriminator ==========
        # Calculate discriminator accuracy for adaptive training
        with torch.no_grad():
            z_temp = torch.randn(batch_size, LATENT_DIM).to(DEVICE)
            fake_specs_temp = generator(z_temp, conditions)
            d_real_out = discriminator(real_specs, conditions)
            d_fake_out = discriminator(fake_specs_temp, conditions)
            
            d_real_acc = ((torch.sigmoid(d_real_out) > 0.5).float().mean()).item()
            d_fake_acc = ((torch.sigmoid(d_fake_out) < 0.5).float().mean()).item()
            
        # Adaptive discriminator steps
        d_steps = 1 if d_real_acc > D_STEPS_THRESHOLD and d_fake_acc > D_STEPS_THRESHOLD else 2
        
        for _ in range(d_steps):
            # Real spectrograms
            real_output = discriminator(real_specs, conditions)
            d_real_loss = criterion(real_output, d_real_labels)
            
            # Fake spectrograms
            z = torch.randn(batch_size, LATENT_DIM).to(DEVICE)
            fake_specs = generator(z, conditions).detach()
            fake_output = discriminator(fake_specs, conditions)
            d_fake_loss = criterion(fake_output, d_fake_labels)
            
            # Total discriminator loss (scaled for gradient accumulation)
            d_loss = (d_real_loss + d_fake_loss) / (2 * GRADIENT_ACCUMULATION_STEPS)
            d_loss.backward()
            
            # Update discriminator (every GRADIENT_ACCUMULATION_STEPS batches)
            if (i + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)
                d_optimizer.step()
                d_optimizer.zero_grad()
        
        # ========== Train Generator ==========
        # Adaptive generator steps
        g_steps = 3 if d_real_acc > D_STEPS_THRESHOLD else 1
        
        for _ in range(g_steps):
            z = torch.randn(batch_size, LATENT_DIM).to(DEVICE)
            fake_specs = generator(z, conditions)
            fake_output = discriminator(fake_specs, conditions)
            
            # Generator loss (scaled for gradient accumulation)
            g_loss = criterion(fake_output, d_real_labels) / GRADIENT_ACCUMULATION_STEPS
            g_loss.backward()
            
            # Update generator (every GRADIENT_ACCUMULATION_STEPS batches)
            if (i + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
                g_optimizer.step()
                g_optimizer.zero_grad()
        
        # Track losses and accuracy
        epoch_g_loss += g_loss.item() * GRADIENT_ACCUMULATION_STEPS * g_steps
        epoch_d_loss += d_loss.item() * GRADIENT_ACCUMULATION_STEPS * 2 * d_steps
        d_real_correct += d_real_acc * batch_size
        d_fake_correct += d_fake_acc * batch_size
        total_samples += batch_size
        
        # Periodic cache clearing (every 5 batches)
        if i % 5 == 0 and i > 0:
            torch.cuda.empty_cache()
    
    # Epoch statistics
    avg_g_loss = epoch_g_loss / len(gan_loader)
    avg_d_loss = epoch_d_loss / len(gan_loader)
    avg_d_real_acc = d_real_correct / total_samples
    avg_d_fake_acc = d_fake_correct / total_samples
    
    gan_losses['g_loss'].append(avg_g_loss)
    gan_losses['d_loss'].append(avg_d_loss)
    gan_losses['d_real_acc'].append(avg_d_real_acc)
    gan_losses['d_fake_acc'].append(avg_d_fake_acc)
    
    print(f"Epoch [{epoch+1}/{GAN_EPOCHS}]")
    print(f"  G Loss: {avg_g_loss:.4f} | D Loss: {avg_d_loss:.4f}")
    print(f"  D Real Acc: {avg_d_real_acc:.2%} | D Fake Acc: {avg_d_fake_acc:.2%}")
    
    # GPU memory monitoring
    if torch.cuda.is_available():
        print(f"  üíæ GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
        if (epoch + 1) % 3 == 0:  # Deep clean every 3 epochs
            torch.cuda.empty_cache()
    print()

print("\n‚úÖ GAN Training Complete!")
print(f"Final Generator Loss: {gan_losses['g_loss'][-1]:.4f}")
print(f"Final Discriminator Loss: {gan_losses['d_loss'][-1]:.4f}")
print(f"Final D Real Accuracy: {gan_losses['d_real_acc'][-1]:.2%}")
print(f"Final D Fake Accuracy: {gan_losses['d_fake_acc'][-1]:.2%}")

### üèãÔ∏è GAN Training Loop

Train the conditional GAN to generate realistic spectrograms conditioned on valence/arousal values.

## 5.5Ô∏è‚É£ GAN Quality Metrics (Functions)

In [None]:
from scipy import linalg

def calculate_statistics(spectrograms):
    """Calculate mean and covariance of spectrograms (FID-style)"""
    # Flatten spectrograms
    specs_flat = spectrograms.reshape(spectrograms.shape[0], -1)
    
    # Calculate statistics
    mu = np.mean(specs_flat, axis=0)
    sigma = np.cov(specs_flat, rowvar=False)
    
    return mu, sigma


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """
    Calculate Frechet Distance (similar to FID score).
    Lower is better - indicates generated data is closer to real data.
    """
    # Calculate mean difference
    diff = mu1 - mu2
    
    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    
    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    
    # Calculate FD
    fd = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
    
    return fd


def evaluate_spectrogram_quality(real_specs, fake_specs, n_samples=500):
    """
    Comprehensive evaluation of generated spectrogram quality.
    
    Returns various metrics comparing real vs synthetic spectrograms.
    """
    print("üìä Evaluating GAN Generation Quality...\n")
    
    # Subsample for efficiency
    n_samples = min(n_samples, len(real_specs), len(fake_specs))
    real_sample = real_specs[:n_samples]
    fake_sample = fake_specs[:n_samples]
    
    metrics = {}
    
    # 1. Frechet Distance (FID-style)
    print("  üî¢ Computing Frechet Distance...")
    mu_real, sigma_real = calculate_statistics(real_sample)
    mu_fake, sigma_fake = calculate_statistics(fake_sample)
    fd = calculate_frechet_distance(mu_real, sigma_real, mu_fake, sigma_fake)
    metrics['frechet_distance'] = fd
    print(f"     Frechet Distance: {fd:.4f} (lower is better)")
    
    # 2. Statistical moments comparison
    print("\n  üìà Computing statistical moments...")
    real_mean = np.mean(real_sample)
    fake_mean = np.mean(fake_sample)
    real_std = np.std(real_sample)
    fake_std = np.std(fake_sample)
    
    metrics['mean_diff'] = abs(real_mean - fake_mean)
    metrics['std_diff'] = abs(real_std - fake_std)
    
    print(f"     Mean - Real: {real_mean:.4f}, Fake: {fake_mean:.4f}, Diff: {metrics['mean_diff']:.4f}")
    print(f"     Std  - Real: {real_std:.4f}, Fake: {fake_std:.4f}, Diff: {metrics['std_diff']:.4f}")
    
    # 3. Spectrogram smoothness (measure of noise)
    print("\n  üé® Evaluating smoothness (temporal consistency)...")
    real_smoothness = np.mean([np.mean(np.abs(np.diff(spec, axis=1))) for spec in real_sample])
    fake_smoothness = np.mean([np.mean(np.abs(np.diff(spec, axis=1))) for spec in fake_sample])
    
    metrics['real_smoothness'] = real_smoothness
    metrics['fake_smoothness'] = fake_smoothness
    metrics['smoothness_ratio'] = fake_smoothness / (real_smoothness + 1e-8)
    
    print(f"     Real smoothness: {real_smoothness:.4f}")
    print(f"     Fake smoothness: {fake_smoothness:.4f}")
    print(f"     Ratio: {metrics['smoothness_ratio']:.4f} (closer to 1.0 is better)")
    
    # 4. Frequency distribution analysis
    print("\n  üéµ Analyzing frequency content...")
    real_freq_mean = np.mean(real_sample, axis=(0, 2))  # Average across batch and time
    fake_freq_mean = np.mean(fake_sample, axis=(0, 2))
    
    freq_correlation = np.corrcoef(real_freq_mean, fake_freq_mean)[0, 1]
    metrics['frequency_correlation'] = freq_correlation
    
    print(f"     Frequency correlation: {freq_correlation:.4f} (higher is better)")
    
    # 5. Dynamic range
    print("\n  üìä Comparing dynamic range...")
    real_range = np.max(real_sample) - np.min(real_sample)
    fake_range = np.max(fake_sample) - np.min(fake_sample)
    
    metrics['real_range'] = real_range
    metrics['fake_range'] = fake_range
    metrics['range_diff'] = abs(real_range - fake_range)
    
    print(f"     Real range: {real_range:.4f}")
    print(f"     Fake range: {fake_range:.4f}")
    print(f"     Difference: {metrics['range_diff']:.4f}")
    
    return metrics


def visualize_quality_comparison(real_specs, fake_specs, metrics, n_visual=3):
    """Visualize quality comparison between real and synthetic spectrograms."""
    
    fig = plt.figure(figsize=(18, 12))
    gs = fig.add_gridspec(4, 3, hspace=0.3, wspace=0.3)
    
    # Top row: Sample spectrograms
    for i in range(n_visual):
        # Real spectrograms
        ax = fig.add_subplot(gs[0, i])
        ax.imshow(real_specs[i], aspect='auto', origin='lower', cmap='viridis')
        ax.set_title(f'Real Sample {i+1}')
        ax.set_xlabel('Time')
        if i == 0:
            ax.set_ylabel('Mel Frequency')
        
        # Fake spectrograms
        ax = fig.add_subplot(gs[1, i])
        ax.imshow(fake_specs[i], aspect='auto', origin='lower', cmap='viridis')
        ax.set_title(f'Generated Sample {i+1}')
        ax.set_xlabel('Time')
        if i == 0:
            ax.set_ylabel('Mel Frequency')
    
    # Middle row: Metrics visualization
    ax = fig.add_subplot(gs[2, :])
    metric_names = ['Frechet\nDistance', 'Frequency\nCorrelation', 'Smoothness\nRatio']
    metric_values = [
        metrics['frechet_distance'],
        metrics['frequency_correlation'],
        metrics['smoothness_ratio']
    ]
    colors = ['#e74c3c' if v > 10 else '#3498db' if v > 5 else '#2ecc71' 
              for v in [metrics['frechet_distance'], 
                       1-metrics['frequency_correlation'], 
                       abs(1-metrics['smoothness_ratio'])]]
    
    bars = ax.bar(metric_names, metric_values, color=colors, alpha=0.7, edgecolor='black')
    ax.set_ylabel('Metric Value')
    ax.set_title('GAN Quality Metrics')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, value in zip(bars, metric_values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{value:.3f}',
                ha='center', va='bottom', fontweight='bold')
    
    # Bottom row: Temporal evolution comparison
    ax = fig.add_subplot(gs[3, :])
    real_temporal = np.mean(real_specs[:50], axis=(0, 1))
    fake_temporal = np.mean(fake_specs[:50], axis=(0, 1))
    ax.plot(real_temporal, label='Real', linewidth=2, alpha=0.8)
    ax.plot(fake_temporal, label='Synthetic', linewidth=2, alpha=0.8)
    ax.set_xlabel('Time Frame')
    ax.set_ylabel('Average Amplitude')
    ax.set_title('Temporal Evolution Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.savefig(os.path.join(OUTPUT_DIR, 'gan_quality_evaluation.png'), dpi=150, bbox_inches='tight')
    plt.show()

print("‚úÖ Quality evaluation functions defined")

### üìä Define Quality Evaluation Functions

Functions to evaluate GAN generation quality using Fr√©chet Distance, correlation metrics, and visualization.

In [None]:
print(f"üé® Generating {NUM_SYNTHETIC} synthetic spectrograms...\n")

generator.eval()
synthetic_spectrograms = []
synthetic_labels = []

with torch.no_grad():
    num_batches = NUM_SYNTHETIC // GAN_BATCH_SIZE
    
    for i in tqdm(range(num_batches), desc="Generating"):
        z = torch.randn(GAN_BATCH_SIZE, LATENT_DIM).to(DEVICE)
        random_conditions = torch.FloatTensor(GAN_BATCH_SIZE, 2).uniform_(-1, 1).to(DEVICE)
        
        fake_specs = generator(z, random_conditions)
        
        synthetic_spectrograms.append(fake_specs.cpu().numpy())
        synthetic_labels.append(random_conditions.cpu().numpy())

# Concatenate all batches
synthetic_spectrograms = np.concatenate(synthetic_spectrograms, axis=0)
synthetic_labels = np.concatenate(synthetic_labels, axis=0)

# Remove channel dimension
synthetic_spectrograms = synthetic_spectrograms.squeeze(1)

print(f"‚úÖ Generated {len(synthetic_spectrograms)} synthetic spectrograms")
print(f"Synthetic spectrogram shape: {synthetic_spectrograms.shape}")
print(f"Synthetic labels shape: {synthetic_labels.shape}")

# ========== ROBUST DATA VALIDATION & CONVERSION ==========
print(f"\nüîç Validating and preparing label data...")
print(f"‚úÖ Using 'real_conditions' for original emotion labels (valence/arousal)")
print(f"   (This contains the original DEAM annotations)")

def prepare_labels(labels, name="labels"):
    """
    Robust label preparation function that handles all edge cases.
    Ensures output is numpy array with shape (N, 2).
    """
    try:
        # Step 1: Convert to numpy if tensor
        if torch.is_tensor(labels):
            print(f"  - {name} is a tensor, converting to numpy...")
            if labels.is_cuda:
                labels_np = labels.cpu().numpy()
            else:
                labels_np = labels.numpy()
        else:
            labels_np = np.array(labels)
        
        print(f"  - {name} shape after conversion: {labels_np.shape}")
        
        # Step 2: Handle shape issues
        if len(labels_np.shape) == 1:
            # 1D array - reshape to (N, 2)
            print(f"  - {name} is 1D, reshaping to (-1, 2)...")
            labels_np = labels_np.reshape(-1, 2)
        elif len(labels_np.shape) == 2:
            # 2D array - check if transposed
            if labels_np.shape[0] == 2 and labels_np.shape[1] > 2:
                # Likely transposed (2, N) -> (N, 2)
                print(f"  - {name} appears transposed {labels_np.shape}, fixing...")
                labels_np = labels_np.T
            elif labels_np.shape[1] == 1:
                # Shape is (N, 1) - might need to be (N//2, 2)
                print(f"  - {name} has shape {labels_np.shape}, reshaping...")
                labels_np = labels_np.reshape(-1, 2)
        
        # Step 3: Final validation
        if labels_np.shape[1] != 2:
            print(f"  ‚ö†Ô∏è WARNING: {name} has unexpected shape {labels_np.shape}")
            print(f"  Attempting to force reshape to (-1, 2)...")
            labels_np = labels_np.reshape(-1, 2)
        
        print(f"  ‚úÖ {name} final shape: {labels_np.shape}")
        return labels_np
    
    except Exception as e:
        print(f"  ‚ùå ERROR processing {name}: {e}")
        print(f"  Returning original data as-is")
        return labels if not torch.is_tensor(labels) else labels.cpu().numpy()

# Prepare emotion labels for dataset
# 'real_conditions' contains the original valence/arousal values from DEAM
print(f"\nüìä Original emotion labels info:")
print(f"   real_conditions shape: {real_conditions.shape}")
print(f"   real_conditions type: {type(real_conditions)}")

# Prepare real emotion labels from real_conditions  
emotion_labels_real = prepare_labels(real_conditions, "real_conditions")

# Prepare synthetic labels
emotion_labels_synthetic = prepare_labels(synthetic_labels, "synthetic_labels")

print(f"\n‚úÖ Data preparation complete!")
print(f"   Real emotion labels: {emotion_labels_real.shape}")
print(f"   Synthetic emotion labels: {emotion_labels_synthetic.shape}")

# ========== VISUALIZATION WITH ERROR HANDLING ==========
try:
    print(f"\nüìä Creating visualizations...")
    
    # Visualize synthetic vs real spectrograms
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))
    
    for i in range(3):
        try:
            axes[0, i].imshow(real_spectrograms[i], aspect='auto', origin='lower', cmap='viridis')
            # Robust label access
            v_real = emotion_labels_real[i, 0] if emotion_labels_real.shape[1] >= 1 else 0
            a_real = emotion_labels_real[i, 1] if emotion_labels_real.shape[1] >= 2 else 0
            axes[0, i].set_title(f'Real Spec {i+1}\nV: {v_real:.2f}, A: {a_real:.2f}')
            axes[0, i].set_xlabel('Time')
            axes[0, i].set_ylabel('Mel Bins')
        except Exception as e:
            print(f"  ‚ö†Ô∏è Warning: Could not plot real spectrogram {i}: {e}")
            axes[0, i].text(0.5, 0.5, f'Error\n{str(e)[:30]}', 
                          ha='center', va='center', transform=axes[0, i].transAxes)
    
    for i in range(3):
        try:
            axes[1, i].imshow(synthetic_spectrograms[i], aspect='auto', origin='lower', cmap='viridis')
            # Robust label access
            v_syn = emotion_labels_synthetic[i, 0] if emotion_labels_synthetic.shape[1] >= 1 else 0
            a_syn = emotion_labels_synthetic[i, 1] if emotion_labels_synthetic.shape[1] >= 2 else 0
            axes[1, i].set_title(f'Synthetic Spec {i+1}\nV: {v_syn:.2f}, A: {a_syn:.2f}')
            axes[1, i].set_xlabel('Time')
            axes[1, i].set_ylabel('Mel Bins')
        except Exception as e:
            print(f"  ‚ö†Ô∏è Warning: Could not plot synthetic spectrogram {i}: {e}")
            axes[1, i].text(0.5, 0.5, f'Error\n{str(e)[:30]}', 
                          ha='center', va='center', transform=axes[1, i].transAxes)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'real_vs_synthetic.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print("  ‚úÖ Spectrogram comparison plot saved")
    
except Exception as e:
    print(f"  ‚ùå Error creating spectrogram comparison plot: {e}")
    print("  Continuing execution...")

# ========== DISTRIBUTION COMPARISON WITH ERROR HANDLING ==========
try:
    print(f"\nüìà Creating distribution plots...")
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Scatter plot
    try:
        axes[0].scatter(emotion_labels_real[:, 0], emotion_labels_real[:, 1], 
                       alpha=0.5, label='Real', s=20)
        axes[0].scatter(emotion_labels_synthetic[:, 0], emotion_labels_synthetic[:, 1], 
                       alpha=0.3, label='Synthetic', s=20)
        axes[0].set_xlabel('Valence')
        axes[0].set_ylabel('Arousal')
        axes[0].set_title('Valence-Arousal Distribution')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        axes[0].axhline(0, color='k', linewidth=0.5)
        axes[0].axvline(0, color='k', linewidth=0.5)
    except Exception as e:
        print(f"  ‚ö†Ô∏è Warning: Could not create scatter plot: {e}")
        axes[0].text(0.5, 0.5, f'Scatter plot error\n{str(e)[:30]}', 
                    ha='center', va='center', transform=axes[0].transAxes)
    
    # Bar plot
    try:
        sizes = [len(real_spectrograms), len(synthetic_spectrograms), 
                 len(real_spectrograms) + len(synthetic_spectrograms)]
        labels = ['Real', 'Synthetic', 'Total']
        colors = ['#3498db', '#e74c3c', '#2ecc71']
        axes[1].bar(labels, sizes, color=colors, alpha=0.7, edgecolor='black')
        axes[1].set_ylabel('Number of Samples')
        axes[1].set_title('Dataset Size Comparison')
        axes[1].grid(True, alpha=0.3, axis='y')
        for i, v in enumerate(sizes):
            axes[1].text(i, v + 50, str(v), ha='center', va='bottom', fontweight='bold')
    except Exception as e:
        print(f"  ‚ö†Ô∏è Warning: Could not create bar plot: {e}")
        axes[1].text(0.5, 0.5, f'Bar plot error\n{str(e)[:30]}', 
                    ha='center', va='center', transform=axes[1].transAxes)
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, 'augmented_dataset_comparison.png'), dpi=150, bbox_inches='tight')
    plt.show()
    print("  ‚úÖ Distribution comparison plot saved")
    
except Exception as e:
    print(f"  ‚ùå Error creating distribution plots: {e}")
    print("  Continuing execution...")

# ========== STATISTICS ==========
print(f"\nüìä Augmented Dataset Statistics:")
print(f"=" * 60)
print(f"Real Data:")
print(f"  Samples: {len(real_spectrograms)}")
print(f"  Valence: mean={emotion_labels_real[:, 0].mean():.3f}, std={emotion_labels_real[:, 0].std():.3f}")
print(f"  Arousal: mean={emotion_labels_real[:, 1].mean():.3f}, std={emotion_labels_real[:, 1].std():.3f}")
print(f"\nSynthetic Data:")
print(f"  Samples: {len(synthetic_spectrograms)}")
print(f"  Valence: mean={emotion_labels_synthetic[:, 0].mean():.3f}, std={emotion_labels_synthetic[:, 0].std():.3f}")
print(f"  Arousal: mean={emotion_labels_synthetic[:, 1].mean():.3f}, std={emotion_labels_synthetic[:, 1].std():.3f}")
print(f"\nüìà Data Augmentation Ratio: {len(synthetic_spectrograms) / len(real_spectrograms):.2f}x")
print(f"=" * 60)

### üé® Generate Synthetic Spectrograms

Use the trained GAN generator to create synthetic spectrograms with random valence/arousal conditions.

## 6.2Ô∏è‚É£ Evaluate GAN Quality

In [None]:
# ========== OPTIONAL GAN QUALITY EVALUATION ==========
# This step can consume 4-6 GB of RAM due to covariance matrix computation
# Skip if memory is limited

import gc

# Check if we should run quality evaluation
SKIP_QUALITY_EVAL = True  # Set to False if you have >20GB RAM available

if SKIP_QUALITY_EVAL:
    print("\n" + "="*60)
    print("‚ö†Ô∏è SKIPPING GAN QUALITY EVALUATION (Memory Optimization)")
    print("="*60)
    print("üìä Reason: Quality evaluation requires ~4-6 GB RAM for covariance computation")
    print("üí° To enable: Set SKIP_QUALITY_EVAL = False in this cell")
    print("\n‚úÖ GAN training completed successfully!")
    print("   Moving on to ViT training with augmented dataset...")
    print("="*60)
    
else:
    print("\n" + "="*60)
    print("üî¨ GAN QUALITY EVALUATION")
    print("="*60)
    print("‚ö†Ô∏è Warning: This may consume 4-6 GB of memory")
    
    try:
        # Clear memory before evaluation
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Evaluate quality (with smaller sample size to reduce memory)
        quality_metrics = evaluate_spectrogram_quality(
            real_spectrograms[:200],  # Reduced from 500 to 200
            synthetic_spectrograms[:200]
        )
        
        # Visualize comparison
        visualize_quality_comparison(
            real_spectrograms[:5],  # Reduced from 10 to 5
            synthetic_spectrograms[:5],
            quality_metrics
        )
        
        # Overall quality score
        print("\n" + "="*60)
        print("üéØ OVERALL QUALITY ASSESSMENT")
        print("="*60)
        
        # Compute composite quality score (0-100)
        fd_score = max(0, 100 - quality_metrics['frechet_distance'] * 10)
        freq_score = quality_metrics['frequency_correlation'] * 100
        smooth_score = max(0, 100 - abs(1.0 - quality_metrics['smoothness_ratio']) * 100)
        
        overall_score = (fd_score * 0.4 + freq_score * 0.4 + smooth_score * 0.2)
        
        print(f"  Frechet Distance Score: {fd_score:.1f}/100")
        print(f"  Frequency Correlation Score: {freq_score:.1f}/100")
        print(f"  Smoothness Score: {smooth_score:.1f}/100")
        print(f"\n  üìä Overall GAN Quality Score: {overall_score:.1f}/100")
        
        if overall_score >= 70:
            print("  ‚úÖ Excellent - GAN generates high-quality spectrograms")
        elif overall_score >= 50:
            print("  ‚ö†Ô∏è Good - GAN output is acceptable but could be improved")
        else:
            print("  ‚ùå Poor - GAN needs significant improvement (mostly noise)")
        
        print("="*60)
        
        # Clear memory after evaluation
        del quality_metrics
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
    except Exception as e:
        print(f"\n‚ùå Quality evaluation failed: {e}")
        print("‚ö†Ô∏è Continuing without quality metrics...")
        print("üí° This does not affect ViT training")

# Final cleanup before moving to ViT training
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"\nüíæ GPU memory before ViT prep: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

print("\n‚úÖ Ready to proceed with ViT dataset preparation!")

### üî¨ [Optional] Evaluate GAN Quality Metrics

Optional step to evaluate synthetic spectrogram quality (requires ~4-6 GB RAM). Can be skipped for memory efficiency.

## 6.5Ô∏è‚É£ Audio Reconstruction - Listen to GAN Outputs

In [None]:
from IPython.display import Audio, display
import soundfile as sf

def spectrogram_to_audio(spec_normalized, sample_rate=SAMPLE_RATE, n_fft=N_FFT, 
                         hop_length=HOP_LENGTH, n_iter=32):
    """
    Convert normalized mel spectrogram back to audio using Griffin-Lim algorithm.
    
    Args:
        spec_normalized: Normalized spectrogram in range [-1, 1]
        sample_rate: Audio sample rate
        n_fft: FFT window size
        hop_length: Hop length for STFT
        n_iter: Number of Griffin-Lim iterations
    
    Returns:
        audio: Reconstructed audio signal
    """
    # Denormalize spectrogram
    spec_db = spec_normalized * 40.0  # Approximate dB range
    
    # Convert from dB to power
    spec_power = librosa.db_to_amplitude(spec_db)
    
    # Reconstruct audio using Griffin-Lim
    audio = librosa.feature.inverse.mel_to_audio(
        spec_power,
        sr=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_iter=n_iter,
        fmin=FMIN,
        fmax=FMAX
    )
    
    return audio


def generate_and_listen_to_samples(generator, n_samples=5, emotions=None, device=DEVICE):
    """
    Generate synthetic spectrograms and convert them to audio for listening.
    
    Args:
        generator: Trained GAN generator
        n_samples: Number of samples to generate
        emotions: List of (valence, arousal) tuples, or None for random
        device: Torch device
    """
    generator.eval()
    
    print(f"üéµ Generating {n_samples} audio samples from GAN...")
    
    with torch.no_grad():
        for i in range(n_samples):
            # Generate noise
            z = torch.randn(1, LATENT_DIM).to(device)
            
            # Use provided emotions or generate random
            if emotions and i < len(emotions):
                valence, arousal = emotions[i]
            else:
                valence = np.random.uniform(-1, 1)
                arousal = np.random.uniform(-1, 1)
            
            condition = torch.FloatTensor([[valence, arousal]]).to(device)
            
            # Generate spectrogram
            fake_spec = generator(z, condition)
            fake_spec_np = fake_spec.squeeze().cpu().numpy()
            
            # Convert to audio
            print(f"\\nüéß Sample {i+1}: Valence={valence:.2f}, Arousal={arousal:.2f}")
            audio = spectrogram_to_audio(fake_spec_np)
            
            # Normalize audio
            audio = audio / (np.max(np.abs(audio)) + 1e-8) * 0.9
            
            # Save audio file
            audio_path = os.path.join(OUTPUT_DIR, f'generated_sample_{i+1}_v{valence:.2f}_a{arousal:.2f}.wav')
            sf.write(audio_path, audio, SAMPLE_RATE)
            print(f"   üíæ Saved: {audio_path}")
            
            # Display audio player
            display(Audio(audio, rate=SAMPLE_RATE))
            
            # Visualize spectrogram
            plt.figure(figsize=(10, 4))
            librosa.display.specshow(
                fake_spec_np,
                sr=SAMPLE_RATE,
                hop_length=HOP_LENGTH,
                x_axis='time',
                y_axis='mel',
                fmax=FMAX,
                cmap='viridis'
            )
            plt.colorbar(format='%+2.0f dB')
            plt.title(f'Generated Spectrogram {i+1}\\nValence: {valence:.2f}, Arousal: {arousal:.2f}')
            plt.tight_layout()
            plt.savefig(os.path.join(OUTPUT_DIR, f'generated_spec_{i+1}.png'), dpi=150, bbox_inches='tight')
            plt.show()


# Define emotion targets to test
test_emotions = [
    (-0.8, -0.6),  # Sad, calm
    (0.8, 0.7),    # Happy, energetic
    (-0.3, 0.8),   # Angry, tense
    (0.5, -0.5),   # Content, relaxed
    (0.0, 0.0),    # Neutral
]

print("üé® Generating audio from synthetic spectrograms...")
print("This allows you to qualitatively assess GAN generation quality.\\n")

generate_and_listen_to_samples(
    generator, 
    n_samples=5, 
    emotions=test_emotions,
    device=DEVICE
)

print("\\n" + "="*60)
print("‚úÖ Audio generation complete!")
print("="*60)
print("üí° Tips for evaluation:")
print("  - Listen for musical structure vs pure noise")
print("  - Check if emotion patterns are perceptible")
print("  - Compare across different valence/arousal values")
print("  - Real music should have harmonic and temporal patterns")
print("="*60)

### üéµ [Optional] Convert Spectrograms to Audio

Optional: Convert generated spectrograms back to audio for qualitative listening evaluation.

## 7Ô∏è‚É£ Prepare Augmented Dataset for ViT Training

Now that we have trained the GAN and generated synthetic spectrograms, we combine the real and synthetic data to create an expanded dataset for training the Vision Transformer model.

## 8Ô∏è‚É£ Define ViT Model Architecture for Emotion Regression

We define a custom Vision Transformer model for emotion prediction, using a pre-trained ViT backbone with a custom regression head for valence/arousal prediction.

In [None]:
# ========== MEMORY-EFFICIENT DATASET PREPARATION ==========
import gc

print("üîÑ Preparing augmented dataset for ViT training...")
print(f"üíæ Memory before concatenation: {torch.cuda.memory_allocated()/1024**3:.2f} GB" if torch.cuda.is_available() else "")

# Clear any unused memory before concatenation
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

# Combine real and synthetic spectrograms
print(f"\nüì¶ Combining datasets:")
print(f"   Real spectrograms: {real_spectrograms.shape}")
print(f"   Synthetic spectrograms: {synthetic_spectrograms.shape}")
print(f"   Real emotion labels: {emotion_labels_real.shape}")
print(f"   Synthetic emotion labels: {emotion_labels_synthetic.shape}")

try:
    all_spectrograms = np.concatenate([real_spectrograms, synthetic_spectrograms], axis=0)
    all_emotion_labels = np.concatenate([emotion_labels_real, emotion_labels_synthetic], axis=0)
    
    print(f"\n‚úÖ Total augmented dataset:")
    print(f"   - Total samples: {len(all_spectrograms)}")
    print(f"   - Spectrograms shape: {all_spectrograms.shape}")
    print(f"   - Emotion labels shape: {all_emotion_labels.shape}")
    print(f"   - Memory usage: ~{all_spectrograms.nbytes / 1024**3:.2f} GB")
    
except MemoryError as e:
    print(f"\n‚ùå MemoryError during concatenation: {e}")
    print(f"üîß Reducing synthetic samples to fit in memory...")
    
    # Reduce synthetic samples if OOM
    max_synthetic = 2000  # Reduce from 3192 to 2000
    synthetic_spectrograms = synthetic_spectrograms[:max_synthetic]
    emotion_labels_synthetic = emotion_labels_synthetic[:max_synthetic]
    
    print(f"   Reduced synthetic samples to: {max_synthetic}")
    
    # Try again
    all_spectrograms = np.concatenate([real_spectrograms, synthetic_spectrograms], axis=0)
    all_emotion_labels = np.concatenate([emotion_labels_real, emotion_labels_synthetic], axis=0)
    
    print(f"‚úÖ Reduced dataset created: {len(all_spectrograms)} samples")

# Delete intermediate arrays to free memory
print(f"\nüßπ Freeing intermediate memory...")
del synthetic_spectrograms  # Delete synthetic spectrograms (we have all_spectrograms now)
del emotion_labels_synthetic  # Delete synthetic labels (we have all_emotion_labels now)

# Don't delete real_spectrograms yet - needed for evaluation
gc.collect()

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"üíæ GPU memory after cleanup: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

print(f"‚úÖ Memory cleanup complete\n")


# ========== MEMORY-EFFICIENT DATASET CLASS ==========
class SpectrogramDataset(Dataset):
    """
    Memory-efficient dataset for mel-spectrograms with ViT preprocessing.
    Performs preprocessing on-the-fly instead of storing preprocessed data.
    """
    
    def __init__(self, spectrograms, labels, image_size=VIT_IMAGE_SIZE):
        """
        Args:
            spectrograms: numpy array of shape (N, n_mels, time_steps)
            labels: numpy array of shape (N, 2)
            image_size: target image size for ViT (default 224)
        """
        # Validate input shapes
        assert len(spectrograms) == len(labels), \
            f"Spectrogram count ({len(spectrograms)}) must match label count ({len(labels)})"
        assert labels.shape[1] == 2, \
            f"Labels must have shape (N, 2), got {labels.shape}"
        
        # Store as numpy arrays (more memory efficient than tensors)
        self.spectrograms = spectrograms
        self.labels = labels
        self.image_size = image_size
        
        print(f"  üìä Dataset created: {len(self.spectrograms)} samples")
        print(f"     Spectrograms: {self.spectrograms.shape}")
        print(f"     Labels: {self.labels.shape}")
        
        # Precompute normalization constants
        self.imagenet_mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
        self.imagenet_std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
    
    def __len__(self):
        return len(self.spectrograms)
    
    def __getitem__(self, idx):
        """
        Get spectrogram and label with on-the-fly preprocessing.
        This saves memory by not storing preprocessed tensors.
        """
        # Get spectrogram and label (as numpy arrays)
        spec = self.spectrograms[idx]  # Shape: (n_mels, time_steps)
        label = self.labels[idx]  # Shape: (2,)
        
        # Normalize spectrogram to [0, 1]
        spec_min = spec.min()
        spec_max = spec.max()
        spec_norm = (spec - spec_min) / (spec_max - spec_min + 1e-8)
        
        # Convert to tensor and resize to ViT input size (224x224)
        spec_tensor = torch.FloatTensor(spec_norm).unsqueeze(0)  # Add channel dim: (1, H, W)
        spec_resized = F.interpolate(
            spec_tensor.unsqueeze(0),  # Add batch dim: (1, 1, H, W)
            size=(self.image_size, self.image_size), 
            mode='bilinear', 
            align_corners=False
        ).squeeze(0)  # Remove batch dim: (1, 224, 224)
        
        # Convert to 3 channels (RGB) by triplicating
        spec_rgb = spec_resized.repeat(3, 1, 1)  # (3, 224, 224)
        
        # Apply ImageNet normalization
        spec_normalized = (spec_rgb - self.imagenet_mean) / self.imagenet_std
        
        return spec_normalized, torch.FloatTensor(label)


# ========== CREATE DATASETS AND DATALOADERS ==========
print("üîÑ Creating dataset and dataloaders...")

try:
    # Create full dataset
    full_dataset = SpectrogramDataset(all_spectrograms, all_emotion_labels)
    print(f"‚úÖ Created dataset with {len(full_dataset)} samples")
    
    # Split into train and validation
    train_size = int(TRAIN_SPLIT * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = random_split(
        full_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    print(f"‚úÖ Dataset split:")
    print(f"   - Train: {len(train_dataset)} samples")
    print(f"   - Validation: {len(val_dataset)} samples")
    
    # Create dataloaders with memory-efficient settings
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=0,  # Set to 0 to avoid multiprocessing memory overhead
        pin_memory=True,
        persistent_workers=False  # Don't keep workers alive
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False,
        num_workers=0,  # Set to 0 to avoid multiprocessing memory overhead
        pin_memory=True,
        persistent_workers=False
    )
    
    print(f"‚úÖ Dataloaders created:")
    print(f"   - Train batches: {len(train_loader)}")
    print(f"   - Validation batches: {len(val_loader)}")
    print(f"   - Batch size: {BATCH_SIZE}")
    
    # Final memory cleanup
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"\nüíæ Final GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
    
    print(f"\n‚úÖ Dataset preparation complete! Ready for ViT training.")
    
except Exception as e:
    print(f"\n‚ùå Error during dataset creation: {e}")
    print(f"üí° Suggestion: Reduce NUM_SYNTHETIC or BATCH_SIZE in configuration")
    raise

### üì¶ Combine Real and Synthetic Data

Concatenate real and synthetic spectrograms to create the augmented training dataset.

In [None]:
class ViTForEmotionRegression(nn.Module):
    """Vision Transformer for emotion regression with valence/arousal prediction."""
    
    def __init__(self, model_name=VIT_MODEL_NAME, num_emotions=2, freeze_backbone=False, dropout=0.1):
        super().__init__()
        self.model_name = model_name
        self.num_emotions = num_emotions
        
        print(f"\nü§ñ Initializing ViT Model: {model_name}")
        
        # Load ViT model with 3-tier fallback strategy
        self.vit_model = self._load_vit_model_with_fallback()
        
        # Get the hidden size from the model configuration
        self.hidden_size = self.vit_model.config.hidden_size
        print(f"  Hidden Size: {self.hidden_size}")
        
        # Freeze backbone if requested
        if freeze_backbone:
            self._freeze_backbone()
            print("  üßä Backbone frozen")
        else:
            print("  üî• Backbone trainable")
        
        # Add custom regression head
        self.emotion_head = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Dropout(dropout),
            nn.Linear(self.hidden_size, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, 64),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(64, self.num_emotions),
            nn.Tanh()  # Output range [-1, 1] for valence/arousal
        )
        
    def _load_vit_model_with_fallback(self):
        """
        3-Tier Model Loading Strategy:
        1. Try loading from Kaggle dataset input folder
        2. Try downloading from Hugging Face
        3. Fall back to base model if all else fails
        """
        print(f"\n? Starting 3-Tier Model Loading Strategy...")
        
        # ========== TIER 1: Kaggle Dataset Input ==========
        kaggle_model_paths = [
            '/kaggle/input/vit-model-for-kaggle/vit-model-for-kaggle',
            '/kaggle/input/vit-model-for-kaggle',
        ]
        
        for kaggle_path in kaggle_model_paths:
            try:
                print(f"\nüì¶ TIER 1: Trying Kaggle dataset...")
                print(f"  Path: {kaggle_path}")
                
                if os.path.exists(kaggle_path):
                    print(f"  ‚úÖ Path exists!")
                    
                    # List contents for debugging
                    if os.path.isdir(kaggle_path):
                        contents = os.listdir(kaggle_path)
                        print(f"  üìÇ Contents: {contents[:5]}..." if len(contents) > 5 else f"  üìÇ Contents: {contents}")
                    
                    # Check for required files
                    config_path = os.path.join(kaggle_path, 'config.json')
                    if os.path.exists(config_path):
                        print(f"  ‚úÖ Found config.json")
                        
                        # Try to load the model
                        print(f"  ‚ö° Loading ViT from Kaggle dataset...")
                        model = ViTModel.from_pretrained(kaggle_path, local_files_only=True)
                        print(f"  ‚úÖ SUCCESS! Loaded model from Kaggle dataset")
                        return model
                    else:
                        print(f"  ‚ö†Ô∏è Missing config.json at {config_path}")
                else:
                    print(f"  ‚ö†Ô∏è Path does not exist")
                    
            except Exception as e:
                print(f"  ‚ùå Kaggle dataset loading failed: {str(e)[:100]}")
        
        # ========== TIER 2: Download from Hugging Face ==========
        try:
            print(f"\nüåê TIER 2: Trying Hugging Face download...")
            model = self._download_from_huggingface()
            print(f"  ‚úÖ SUCCESS! Downloaded model from Hugging Face")
            return model
            
        except Exception as e:
            print(f"  ‚ùå Hugging Face download failed: {str(e)[:100]}")
        
        # ========== TIER 3: Base Model Fallback ==========
        try:
            print(f"\nüîß TIER 3: Falling back to base model...")
            base_model_name = 'google/vit-base-patch16-224-in21k'
            print(f"  Loading: {base_model_name}")
            model = ViTModel.from_pretrained(base_model_name)
            print(f"  ‚úÖ SUCCESS! Loaded base model")
            print(f"  ‚ö†Ô∏è WARNING: Using base model without fine-tuning")
            return model
            
        except Exception as e:
            print(f"  ‚ùå Base model loading failed: {str(e)}")
            raise RuntimeError(
                "All 3 tiers of model loading failed!\n"
                "SOLUTION:\n"
                "1. Download the model using: python download_vit_model.py\n"
                "2. Upload to Kaggle as dataset: vit-model-for-kaggle\n"
                "3. Add dataset to notebook inputs\n"
                "4. Verify path: /kaggle/input/vit-model-for-kaggle/vit-model-for-kaggle"
            )
    
    def _download_from_huggingface(self):
        """Download model from Hugging Face with retry logic."""
        model_name = 'google/vit-base-patch16-224-in21k'
        max_retries = 2
        retry_delays = [5, 10]  # seconds
        
        for attempt in range(max_retries):
            try:
                print(f"  üåê Download attempt {attempt + 1}/{max_retries}...")
                
                # Try to load from cache first
                model = ViTModel.from_pretrained(
                    model_name,
                    resume_download=True,
                    force_download=False,
                    cache_dir='/kaggle/working/model_cache'
                )
                
                return model
                
            except Exception as e:
                print(f"  ‚ùå Attempt {attempt + 1} failed: {str(e)[:80]}")
                
                if attempt < max_retries - 1:
                    delay = retry_delays[attempt]
                    print(f"  ‚è≥ Retrying in {delay} seconds...")
                    time.sleep(delay)
                else:
                    raise e
    
    def _freeze_backbone(self):
        """Freeze the ViT backbone parameters."""
        for param in self.vit_model.parameters():
            param.requires_grad = False
    
    def forward(self, pixel_values):
        """Forward pass through ViT + emotion head."""
        # Get ViT outputs
        outputs = self.vit_model(pixel_values=pixel_values)
        
        # Use [CLS] token representation (first token)
        cls_output = outputs.last_hidden_state[:, 0, :]  # Shape: (batch_size, hidden_size)
        
        # Pass through emotion prediction head
        emotion_predictions = self.emotion_head(cls_output)  # Shape: (batch_size, 2)
        
        return emotion_predictions


print("‚úÖ ViT Model class defined with 3-tier loading strategy")

### ü§ñ Define ViT Regression Model Class

Create a custom ViT model with a regression head for predicting continuous valence and arousal values.

## 9Ô∏è‚É£ Load Pre-trained Vision Transformer Model

Load the pre-trained ViT model with a 3-tier fallback strategy: Kaggle dataset ‚Üí Hugging Face API ‚Üí Local fallback.

In [None]:
# Pre-download and verify model availability
print("üîç Checking model availability...")

from huggingface_hub import hf_hub_download, model_info
import time

def verify_model_download(model_name, max_retries=3):
    """
    Verify model can be downloaded or is available locally
    """
    print(f"Model: {model_name}")
    
    # Check if model exists in cache
    try:
        from transformers import ViTModel
        
        # Try loading from cache first
        try:
            print("  ‚è≥ Checking local cache...")
            ViTModel.from_pretrained(model_name, local_files_only=True)
            print("  ‚úÖ Model found in local cache!")
            return True
        except:
            print("  ‚ÑπÔ∏è Model not in cache, will download...")
        
        # Check model info
        print("  ‚è≥ Verifying model on Hugging Face Hub...")
        info = model_info(model_name)
        print(f"  ‚úì Model exists: {info.modelId}")
        print(f"  ‚úì Last modified: {info.lastModified}")
        
        return True
        
    except Exception as e:
        print(f"  ‚ö†Ô∏è Warning: {str(e)[:150]}")
        print("  üí° Will attempt to download during model initialization...")
        return False

# Verify ViT model
model_available = verify_model_download(VIT_MODEL_NAME)

if not model_available:
    print("\n‚ö†Ô∏è WARNING: Model verification failed!")
    print("The notebook will still attempt to download the model.")
    print("If download fails, the model will be initialized with random weights.")
    print("\nAlternatives:")
    print("  1. Wait and retry (Hugging Face servers may be temporarily busy)")
    print("  2. Use a smaller model: 'google/vit-base-patch16-224'")
    print("  3. Continue without pre-training (train from scratch)")
    
print("\n" + "=" * 60)

### üöÄ Instantiate ViT Model

Create an instance of the ViT model with the custom emotion regression head.

## ? Train ViT Model on Augmented Dataset

Train the Vision Transformer on the combined real + synthetic dataset using the AdamW optimizer and CosineAnnealing learning rate schedule.

In [None]:
# ALTERNATIVE DOWNLOAD METHOD (Run this cell if automatic download fails)
# This cell attempts to download the model using direct URLs

import requests
from tqdm import tqdm

def download_file(url, filename):
    """Download file with progress bar"""
    try:
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(filename, 'wb') as file, tqdm(
            desc=filename,
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as progress_bar:
            for data in response.iter_content(chunk_size=1024):
                size = file.write(data)
                progress_bar.update(size)
        
        print(f"‚úÖ Downloaded {filename}")
        return True
    except Exception as e:
        print(f"‚ùå Failed to download {filename}: {e}")
        return False

# Alternative: Try downloading with huggingface_hub's snapshot_download
print("üîÑ Attempting alternative download method...")

try:
    from huggingface_hub import snapshot_download
    
    # Download entire model repository
    cache_dir = "/kaggle/working/model_cache"
    
    print(f"Downloading {VIT_MODEL_NAME} to {cache_dir}...")
    model_path = snapshot_download(
        repo_id=VIT_MODEL_NAME,
        cache_dir=cache_dir,
        resume_download=True,
        max_workers=1  # Use single worker to avoid 500 errors
    )
    
    print(f"‚úÖ Model downloaded successfully to: {model_path}")
    print("üí° Now update VIT_MODEL_NAME to use local path:")
    print(f"    VIT_MODEL_NAME = '{model_path}'")
    
except Exception as e:
    print(f"‚ö†Ô∏è Alternative download also failed: {str(e)[:200]}")
    print("\nüí° FALLBACK OPTIONS:")
    print("  1. Use smaller model: VIT_MODEL_NAME = 'google/vit-base-patch16-224'")
    print("  2. Train without pre-training (random initialization)")
    print("  3. Wait 15 minutes and retry")
    
print("\n" + "=" * 60)

### ‚öôÔ∏è Setup Training Configuration

Define loss function (MSE), optimizer (AdamW), and learning rate scheduler for ViT training.

In [None]:
import time

class ViTForEmotionRegression(nn.Module):
    """
    Vision Transformer for Emotion Regression
    Uses pre-trained ViT from Hugging Face and adds regression head
    """
    def __init__(self, model_name=VIT_MODEL_NAME, freeze_backbone=FREEZE_BACKBONE, dropout=DROPOUT):
        super(ViTForEmotionRegression, self).__init__()
        
        # Load pre-trained ViT model with retry logic
        print(f"Loading pre-trained ViT model: {model_name}...")
        
        max_retries = 3
        retry_delay = 5  # seconds
        
        for attempt in range(max_retries):
            try:
                # Try loading with resume_download=True to handle interrupted downloads
                self.vit = ViTModel.from_pretrained(
                    model_name,
                    resume_download=True,
                    force_download=False,
                    local_files_only=False
                )
                print(f"  ‚úì Model loaded successfully!")
                break
                
            except Exception as e:
                if attempt < max_retries - 1:
                    print(f"  ‚ö†Ô∏è Download attempt {attempt + 1} failed: {str(e)[:100]}")
                    print(f"  ‚è≥ Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                else:
                    print(f"  ‚ùå All download attempts failed!")
                    print(f"  üí° Trying alternative approach...")
                    
                    # Fallback: Try to load from cache only
                    try:
                        self.vit = ViTModel.from_pretrained(
                            model_name,
                            local_files_only=True
                        )
                        print(f"  ‚úì Loaded from local cache!")
                    except:
                        # Last resort: Create model from config
                        print(f"  üîß Creating model from config (no pre-trained weights)...")
                        config = ViTConfig(
                            hidden_size=768,
                            num_hidden_layers=12,
                            num_attention_heads=12,
                            intermediate_size=3072,
                            image_size=224,
                            patch_size=16,
                            num_channels=3
                        )
                        self.vit = ViTModel(config)
                        print(f"  ‚ö†Ô∏è WARNING: Using randomly initialized ViT (no transfer learning)")
        
        # Freeze backbone if specified
        if freeze_backbone:
            for param in self.vit.parameters():
                param.requires_grad = False
            print("  ‚úì ViT backbone frozen")
        else:
            print("  ‚úì ViT backbone trainable")
        
        # Get hidden size from ViT config
        self.hidden_size = self.vit.config.hidden_size  # 768 for base model
        
        # Regression head for valence and arousal
        self.regression_head = nn.Sequential(
            nn.LayerNorm(self.hidden_size),
            nn.Linear(self.hidden_size, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, 2)  # Valence and Arousal
        )
        
    def forward(self, pixel_values):
        # Get ViT outputs
        outputs = self.vit(pixel_values=pixel_values)
        
        # Use [CLS] token representation
        cls_output = outputs.last_hidden_state[:, 0]  # (batch_size, hidden_size)
        
        # Regression head
        emotion_output = self.regression_head(cls_output)  # (batch_size, 2)
        
        return emotion_output


# Initialize model with error handling
print("=" * 60)
print("ü§ñ INITIALIZING VISION TRANSFORMER")
print("=" * 60)

try:
    model = ViTForEmotionRegression(
        model_name=VIT_MODEL_NAME,
        freeze_backbone=FREEZE_BACKBONE,
        dropout=DROPOUT
    ).to(DEVICE)
    
    # Print model summary
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print("\n" + "=" * 60)
    print("‚úÖ MODEL INITIALIZED SUCCESSFULLY")
    print("=" * 60)
    print(f"Model: {VIT_MODEL_NAME}")
    print(f"Hidden size: {model.hidden_size}")
    print(f"Freeze backbone: {FREEZE_BACKBONE}")
    print(f"\nTotal parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Frozen parameters: {total_params - trainable_params:,}")
    print("=" * 60)
    
except Exception as e:
    print(f"\n‚ùå ERROR: Failed to initialize model")
    print(f"Error details: {str(e)}")
    print("\nüí° TROUBLESHOOTING TIPS:")
    print("  1. Check your internet connection")
    print("  2. Try restarting the kernel and running again")
    print("  3. Manually download the model from: https://huggingface.co/google/vit-base-patch16-224-in21k")
    print("  4. Set local_files_only=True if model is already downloaded")
    raise

### üìä Define Evaluation Metrics

Define Concordance Correlation Coefficient (CCC) for measuring prediction quality on valence and arousal.

## 1Ô∏è‚É£1Ô∏è‚É£ Visualize Training Results & Analysis

Visualize training curves, create scatter plots of predictions vs actual values, and analyze model performance.

In [None]:
# ========== FINAL RESULTS SUMMARY ==========
print("=" * 60)
print("üìä FINAL RESULTS SUMMARY")
print("=" * 60)

# Model Architecture Summary
print(f"\nüñºÔ∏è Model Architecture:")
if 'best_vit_model' in locals():
    total_params = sum(p.numel() for p in best_vit_model.parameters())
    trainable_params = sum(p.numel() for p in best_vit_model.parameters() if p.requires_grad)
    print(f"  - Base Model: {VIT_MODEL_NAME}")
    print(f"  - Total Parameters: {total_params:,}")
    print(f"  - Trainable Parameters: {trainable_params:,}")
    if FREEZE_BACKBONE:
        print(f"  - Backbone: Frozen")
    else:
        print(f"  - Backbone: Fine-tuned")
else:
    print(f"  - Model not initialized")

# GAN Augmentation Summary (using cached counts)
print(f"\nüé® GAN Augmentation:")
if 'DATASET_COUNTS' in locals():
    print(f"  - Real samples: {DATASET_COUNTS['real_count']}")
    print(f"  - Synthetic samples: {DATASET_COUNTS['synthetic_count']}")
    print(f"  - Total samples: {DATASET_COUNTS['total_count']}")
    print(f"  - Augmentation factor: {DATASET_COUNTS['augmentation_factor']:.2f}x")
else:
    print(f"  - Dataset counts not available")

# Dataset Split Summary
print(f"\nüîÄ Dataset Split:")
if 'DATASET_COUNTS' in locals() and 'train_count' in DATASET_COUNTS:
    train_pct = (DATASET_COUNTS['train_count'] / DATASET_COUNTS['total_count']) * 100
    val_pct = (DATASET_COUNTS['val_count'] / DATASET_COUNTS['total_count']) * 100
    test_pct = (DATASET_COUNTS['test_count'] / DATASET_COUNTS['total_count']) * 100
    
    print(f"  - Train:      {DATASET_COUNTS['train_count']:5d} samples ({train_pct:.1f}%)")
    print(f"  - Validation: {DATASET_COUNTS['val_count']:5d} samples ({val_pct:.1f}%)")
    print(f"  - Test:       {DATASET_COUNTS['test_count']:5d} samples ({test_pct:.1f}%)")
else:
    print(f"  - Split information not available")

# Training Configuration
print(f"\n‚öôÔ∏è  Training Configuration:")
print(f"  - Batch Size: {BATCH_SIZE}")
print(f"  - Epochs: {NUM_EPOCHS}")
print(f"  - Learning Rate: {LEARNING_RATE}")
print(f"  - Weight Decay: {WEIGHT_DECAY}")
print(f"  - Image Size: {VIT_IMAGE_SIZE}x{VIT_IMAGE_SIZE}")

# Training Results Summary
print(f"\nüìà Training Results:")
if 'training_history' in locals() and len(training_history['train_loss']) > 0:
    best_val_ccc = max([max(v, a) for v, a in zip(training_history['val_ccc_valence'], training_history['val_ccc_arousal'])])
    best_val_mae = min(training_history['val_mae'])
    final_train_loss = training_history['train_loss'][-1]
    final_val_loss = training_history['val_loss'][-1]
    
    print(f"  - Best Validation CCC: {best_val_ccc:.4f}")
    print(f"  - Best Validation MAE: {best_val_mae:.4f}")
    print(f"  - Final Train Loss: {final_train_loss:.4f}")
    print(f"  - Final Val Loss: {final_val_loss:.4f}")
    print(f"  - Total Epochs Trained: {len(training_history['train_loss'])}")
else:
    print(f"  - Training history not available")

# Test Results (if available)
if 'test_results' in locals():
    print(f"\nüß™ Test Set Performance:")
    print(f"  - Test Loss: {test_results.get('test_loss', 'N/A')}")
    print(f"  - Test MAE: {test_results.get('test_mae', 'N/A')}")
    print(f"  - Test CCC (Valence): {test_results.get('test_ccc_valence', 'N/A')}")
    print(f"  - Test CCC (Arousal): {test_results.get('test_ccc_arousal', 'N/A')}")

# Memory Usage
print(f"\nüíæ Memory Usage:")
if torch.cuda.is_available():
    current_memory = torch.cuda.memory_allocated() / 1024**3
    max_memory = torch.cuda.max_memory_allocated() / 1024**3
    print(f"  - Current GPU Memory: {current_memory:.2f} GB")
    print(f"  - Peak GPU Memory: {max_memory:.2f} GB")
    print(f"  - Device: {torch.cuda.get_device_name(0)}")
else:
    print(f"  - GPU: Not available")
    print(f"  - Using CPU")

# Output Files
print(f"\nüìÅ Output Files:")
print(f"  - Output Directory: {OUTPUT_DIR}")
if os.path.exists(OUTPUT_DIR):
    output_files = os.listdir(OUTPUT_DIR)
    print(f"  - Files Created: {len(output_files)}")
    
    # Key files
    key_files = [
        'best_vit_model.pth',
        'training_history.npy',
        'training_curves.png',
        'confusion_matrix.png'
    ]
    
    print(f"\n  Key Output Files:")
    for file in key_files:
        file_path = os.path.join(OUTPUT_DIR, file)
        if os.path.exists(file_path):
            file_size = os.path.getsize(file_path) / 1024**2  # MB
            print(f"    ‚úÖ {file} ({file_size:.2f} MB)")
        else:
            print(f"    ‚ùå {file} (not found)")
else:
    print(f"  - Output directory not found")

# Completion Status
print(f"\n" + "=" * 60)
print(f"‚úÖ Notebook Execution Complete!")
print(f"=" * 60)

# Next Steps
print(f"\nüí° Next Steps:")
print(f"  1. Review training curves and metrics")
print(f"  2. Analyze test set performance")
print(f"  3. Examine edge case predictions")
print(f"  4. Export model for deployment")
print(f"  5. Document findings and insights")

print(f"\nüéâ Thank you for using the ViT + GAN Emotion Prediction Pipeline!")
print(f"=" * 60)

### üèãÔ∏è Define Training and Validation Functions

Implement the training loop and validation function with CCC metric tracking.

## 9Ô∏è‚É£ Train ViT Model on Augmented Dataset

In [None]:
# ========== MEMORY-EFFICIENT DATASET PREPARATION ==========
import gc

print("üîÑ Preparing augmented dataset for ViT training...")
print(f"üíæ Memory before concatenation: {torch.cuda.memory_allocated()/1024**3:.2f} GB" if torch.cuda.is_available() else "")

# Clear any unused memory before concatenation
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

# Combine real and synthetic spectrograms
print(f"\nüì¶ Combining datasets:")
print(f"   Real spectrograms: {real_spectrograms.shape}")
print(f"   Synthetic spectrograms: {synthetic_spectrograms.shape}")
print(f"   Real labels: {real_labels_np.shape}")  # Use prepared numpy version
print(f"   Synthetic labels: {synthetic_labels.shape}")

# ========== STORE COUNTS BEFORE DELETION ==========
# These will be used in final summary since we'll delete the arrays
DATASET_COUNTS = {
    'real_count': len(real_spectrograms),
    'synthetic_count': len(synthetic_spectrograms),
    'real_label_count': len(real_labels_np),
    'synthetic_label_count': len(synthetic_labels)
}

try:
    all_spectrograms = np.concatenate([real_spectrograms, synthetic_spectrograms], axis=0)
    all_labels = np.concatenate([real_labels_np, synthetic_labels], axis=0)  # FIX: Use real_labels_np
    
    print(f"\n‚úÖ Total augmented dataset:")
    print(f"   - Total samples: {len(all_spectrograms)}")
    print(f"   - Spectrograms shape: {all_spectrograms.shape}")
    print(f"   - Labels shape: {all_labels.shape}")
    print(f"   - Memory usage: ~{all_spectrograms.nbytes / 1024**3:.2f} GB")
    
    # Update counts with final total
    DATASET_COUNTS['total_count'] = len(all_spectrograms)
    DATASET_COUNTS['augmentation_factor'] = len(all_spectrograms) / DATASET_COUNTS['real_count']
    
except MemoryError as e:
    print(f"\n‚ùå MemoryError during concatenation: {e}")
    print(f"üîß Reducing synthetic samples to fit in memory...")
    
    # Reduce synthetic samples if OOM
    max_synthetic = 2000  # Reduce from 3192 to 2000
    synthetic_spectrograms = synthetic_spectrograms[:max_synthetic]
    synthetic_labels = synthetic_labels[:max_synthetic]
    
    print(f"   Reduced synthetic samples to: {max_synthetic}")
    
    # Update counts
    DATASET_COUNTS['synthetic_count'] = max_synthetic
    DATASET_COUNTS['synthetic_label_count'] = max_synthetic
    
    # Try again
    all_spectrograms = np.concatenate([real_spectrograms, synthetic_spectrograms], axis=0)
    all_emotion_labels = np.concatenate([emotion_labels_real, emotion_labels_synthetic], axis=0)
    
    DATASET_COUNTS['total_count'] = len(all_spectrograms)
    DATASET_COUNTS['augmentation_factor'] = len(all_spectrograms) / DATASET_COUNTS['real_count']
    
    print(f"‚úÖ Reduced dataset created: {len(all_spectrograms)} samples")

# Delete intermediate arrays to free memory
print(f"\nüßπ Freeing intermediate memory...")
del synthetic_spectrograms  # Delete synthetic spectrograms (we have all_spectrograms now)
del emotion_labels_synthetic  # Delete synthetic labels (we have all_emotion_labels now)

# Don't delete real_spectrograms yet - needed for evaluation
gc.collect()

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"üíæ GPU memory after cleanup: {torch.cuda.memory_allocated()/1024**3:.2f} GB")

print(f"‚úÖ Memory cleanup complete\n")


# ========== TRAIN/TEST/VALIDATION SPLIT ==========
print("üîÄ Splitting dataset into train/test/validation sets...")

# First split: 80% train+val, 20% test
from sklearn.model_selection import train_test_split

# Split into train+val (80%) and test (20%)
train_val_specs, test_specs, train_val_labels, test_labels = train_test_split(
    all_spectrograms, 
    all_emotion_labels, 
    test_size=0.2,  # 20% for test
    random_state=42,
    shuffle=True
)

# Second split: Split train+val into 80% train, 20% val (of the 80%)
# This gives us: 64% train, 16% val, 20% test
train_specs, val_specs, train_labels, val_labels = train_test_split(
    train_val_specs,
    train_val_labels,
    test_size=0.2,  # 20% of train+val = 16% of total
    random_state=42,
    shuffle=True
)

# Store split counts
DATASET_COUNTS['train_count'] = len(train_specs)
DATASET_COUNTS['val_count'] = len(val_specs)
DATASET_COUNTS['test_count'] = len(test_specs)

print(f"\n‚úÖ Dataset split complete:")
print(f"   üéì Train:      {len(train_specs):5d} samples ({len(train_specs)/len(all_spectrograms)*100:.1f}%)")
print(f"   üìä Validation: {len(val_specs):5d} samples ({len(val_specs)/len(all_spectrograms)*100:.1f}%)")
print(f"   üß™ Test:       {len(test_specs):5d} samples ({len(test_specs)/len(all_spectrograms)*100:.1f}%)")
print(f"   üì¶ Total:      {len(all_spectrograms):5d} samples")

# Free the full dataset now that we have splits
del all_spectrograms
del all_emotion_labels
gc.collect()


# ========== MEMORY-EFFICIENT DATASET CLASS ==========
class SpectrogramDataset(Dataset):
    """
    Memory-efficient dataset for mel-spectrograms with ViT preprocessing.
    Performs preprocessing on-the-fly instead of storing preprocessed data.
    """
    
    def __init__(self, spectrograms, labels, image_size=VIT_IMAGE_SIZE):
        """
        Args:
            spectrograms: numpy array of shape (N, n_mels, time_steps)
            labels: numpy array of shape (N, 2)
            image_size: target image size for ViT (default 224)
        """
        # Validate input shapes
        assert len(spectrograms) == len(labels), \
            f"Spectrogram count ({len(spectrograms)}) must match label count ({len(labels)})"
        assert labels.shape[1] == 2, \
            f"Labels must have shape (N, 2), got {labels.shape}"
        
        # Store as numpy arrays (more memory efficient than tensors)
        self.spectrograms = spectrograms
        self.labels = labels
        self.image_size = image_size
        
        print(f"  üìä Dataset created: {len(self.spectrograms)} samples")
        print(f"     Spectrograms: {self.spectrograms.shape}")
        print(f"     Labels: {self.labels.shape}")
        
        # Precompute normalization constants
        self.imagenet_mean = torch.tensor(IMAGENET_MEAN).view(3, 1, 1)
        self.imagenet_std = torch.tensor(IMAGENET_STD).view(3, 1, 1)
    
    def __len__(self):
        return len(self.spectrograms)
    
    def __getitem__(self, idx):
        """
        Get spectrogram and label with on-the-fly preprocessing.
        This saves memory by not storing preprocessed tensors.
        """
        # Get spectrogram and label (as numpy arrays)
        spec = self.spectrograms[idx]  # Shape: (n_mels, time_steps)
        label = self.labels[idx]  # Shape: (2,)
        
        # Normalize spectrogram to [0, 1]
        spec_min = spec.min()
        spec_max = spec.max()
        spec_norm = (spec - spec_min) / (spec_max - spec_min + 1e-8)
        
        # Convert to tensor and resize to ViT input size (224x224)
        spec_tensor = torch.FloatTensor(spec_norm).unsqueeze(0)  # Add channel dim: (1, H, W)
        spec_resized = F.interpolate(
            spec_tensor.unsqueeze(0),  # Add batch dim: (1, 1, H, W)
            size=(self.image_size, self.image_size), 
            mode='bilinear', 
            align_corners=False
        ).squeeze(0)  # Remove batch dim: (1, 224, 224)
        
        # Convert to 3 channels (RGB) by triplicating
        spec_rgb = spec_resized.repeat(3, 1, 1)  # (3, 224, 224)
        
        # Apply ImageNet normalization
        spec_normalized = (spec_rgb - self.imagenet_mean) / self.imagenet_std
        
        return spec_normalized, torch.FloatTensor(label)


# ========== CREATE DATASETS AND DATALOADERS ==========
print("\nüîÑ Creating datasets and dataloaders...")

try:
    # Create datasets for train, validation, and test
    print("\nüì¶ Creating datasets:")
    train_dataset = SpectrogramDataset(train_specs, train_labels)
    val_dataset = SpectrogramDataset(val_specs, val_labels)
    test_dataset = SpectrogramDataset(test_specs, test_labels)
    
    print(f"\n‚úÖ All datasets created successfully")
    
    # Create dataloaders with memory-efficient settings
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=0,  # Set to 0 to avoid multiprocessing memory overhead
        pin_memory=True,
        persistent_workers=False  # Don't keep workers alive
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        persistent_workers=False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
        persistent_workers=False
    )
    
    print(f"\n‚úÖ Dataloaders created:")
    print(f"   üéì Train:      {len(train_loader):4d} batches")
    print(f"   üìä Validation: {len(val_loader):4d} batches")
    print(f"   üß™ Test:       {len(test_loader):4d} batches")
    print(f"   üì¶ Batch size: {BATCH_SIZE}")
    
    # Final memory cleanup
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"\nüíæ Final GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f} GB")
    
    print(f"\n‚úÖ Dataset preparation complete! Ready for ViT training.")
    
except Exception as e:
    print(f"\n‚ùå Error during dataset creation: {e}")
    print(f"üí° Suggestion: Reduce NUM_SYNTHETIC or BATCH_SIZE in configuration")
    raise

### üöÄ Execute Training Loop

Run the complete training process for the specified number of epochs with validation after each epoch.

## üîü Visualize Results & Analysis

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('MSE Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# MAE
axes[0, 1].plot(history['train_mae'], label='Train MAE', linewidth=2)
axes[0, 1].plot(history['val_mae'], label='Val MAE', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Mean Absolute Error')
axes[0, 1].set_title('Training & Validation MAE')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# CCC Valence
axes[1, 0].plot(history['train_ccc_v'], label='Train CCC', linewidth=2)
axes[1, 0].plot(history['val_ccc_v'], label='Val CCC', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('CCC')
axes[1, 0].set_title('Valence CCC')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)
axes[1, 0].axhline(y=0, color='r', linestyle='--', alpha=0.3)

# CCC Arousal
axes[1, 1].plot(history['train_ccc_a'], label='Train CCC', linewidth=2)
axes[1, 1].plot(history['val_ccc_a'], label='Val CCC', linewidth=2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('CCC')
axes[1, 1].set_title('Arousal CCC')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].axhline(y=0, color='r', linestyle='--', alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'vit_training_curves.png'), dpi=150, bbox_inches='tight')
plt.show()

# Scatter plots: Predicted vs Actual
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Valence
axes[0].scatter(val_labels[:, 0], val_preds[:, 0], alpha=0.5, s=20)
axes[0].plot([-1, 1], [-1, 1], 'r--', linewidth=2, label='Perfect Prediction')
axes[0].set_xlabel('Actual Valence')
axes[0].set_ylabel('Predicted Valence')
axes[0].set_title(f'Valence Prediction (CCC: {val_ccc_v:.4f})')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].set_xlim(-1.2, 1.2)
axes[0].set_ylim(-1.2, 1.2)

# Arousal
axes[1].scatter(val_labels[:, 1], val_preds[:, 1], alpha=0.5, s=20)
axes[1].plot([-1, 1], [-1, 1], 'r--', linewidth=2, label='Perfect Prediction')
axes[1].set_xlabel('Actual Arousal')
axes[1].set_ylabel('Predicted Arousal')
axes[1].set_title(f'Arousal Prediction (CCC: {val_ccc_a:.4f})')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim(-1.2, 1.2)
axes[1].set_ylim(-1.2, 1.2)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'prediction_scatter.png'), dpi=150, bbox_inches='tight')
plt.show()

# 2D Valence-Arousal Space
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Ground Truth
axes[0].scatter(val_labels[:, 0], val_labels[:, 1], alpha=0.6, s=50, c='blue', edgecolors='black')
axes[0].set_xlabel('Valence')
axes[0].set_ylabel('Arousal')
axes[0].set_title('Ground Truth VA Space')
axes[0].grid(True, alpha=0.3)
axes[0].axhline(0, color='k', linewidth=0.5)
axes[0].axvline(0, color='k', linewidth=0.5)
axes[0].set_xlim(-1.2, 1.2)
axes[0].set_ylim(-1.2, 1.2)

# Predictions
axes[1].scatter(val_preds[:, 0], val_preds[:, 1], alpha=0.6, s=50, c='red', edgecolors='black')
axes[1].set_xlabel('Valence')
axes[1].set_ylabel('Arousal')
axes[1].set_title('Predicted VA Space')
axes[1].grid(True, alpha=0.3)
axes[1].axhline(0, color='k', linewidth=0.5)
axes[1].axvline(0, color='k', linewidth=0.5)
axes[1].set_xlim(-1.2, 1.2)
axes[1].set_ylim(-1.2, 1.2)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'va_space_comparison.png'), dpi=150, bbox_inches='tight')
plt.show()

# Final summary
print("\\n" + "="*60)
print("üìä FINAL RESULTS SUMMARY")
print("="*60)
print(f"\\nüñºÔ∏è Model Architecture:")
print(f"  - Base Model: {VIT_MODEL_NAME}")
print(f"  - Total Parameters: {total_params:,}")
print(f"  - Trainable Parameters: {trainable_params:,}")

print(f"\\nüé® GAN Augmentation:")
print(f"  - Real samples: {len(real_spectrograms)}")
print(f"  - Synthetic samples: {len(synthetic_spectrograms)}")
print(f"  - Total samples: {len(all_spectrograms)}")
print(f"  - Augmentation factor: {len(all_spectrograms)/len(real_spectrograms):.2f}x")

print(f"\\nü§ñ ViT Model Performance:")
print(f"  - Best Val Loss: {best_val_loss:.4f}")
print(f"  - Final Val MAE: {val_mae:.4f}")
print(f"  - Final Val CCC Valence: {val_ccc_v:.4f}")
print(f"  - Final Val CCC Arousal: {val_ccc_a:.4f}")

print(f"\\nüíæ Saved Outputs:")
print(f"  - Generator model: generator.pth")
print(f"  - Discriminator model: discriminator.pth")
print(f"  - Best ViT model: best_vit_model.pth")
print(f"  - Training curves: vit_training_curves.png")
print(f"  - Prediction scatter: prediction_scatter.png")
print(f"  - VA space comparison: va_space_comparison.png")
print("="*60)

### üìà Plot Training and Validation Curves

Visualize loss, MAE, and CCC metrics over training epochs to assess model convergence.

In [None]:
# Training configuration
NUM_EPOCHS = 24
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01
PATIENCE = 5  # Early stopping patience

# Initialize model, criterion, optimizer
print("üöÄ Initializing ViT training...")
vit_model = EmotionViT().to(device)
criterion = WeightedEmotionLoss(valence_weight=0.6, arousal_weight=0.4)
optimizer = torch.optim.AdamW(vit_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2, verbose=True
)

# Training history
history = {
    'train_loss': [], 'train_valence_loss': [], 'train_arousal_loss': [],
    'val_loss': [], 'val_valence_loss': [], 'val_arousal_loss': [],
    'val_valence_ccc': [], 'val_arousal_ccc': []
}

# Early stopping
best_val_loss = float('inf')
patience_counter = 0
best_model_path = 'best_vit_emotion_model.pth'

print(f"üìä Training for {NUM_EPOCHS} epochs...")
print(f"üìà Dataset sizes: Train={DATASET_COUNTS['train']}, Val={DATASET_COUNTS['val']}, Test={DATASET_COUNTS['test']}")
print("-" * 80)

for epoch in range(NUM_EPOCHS):
    # Train
    train_loss, train_val_loss, train_aro_loss = train_epoch(
        vit_model, train_loader, criterion, optimizer, device, epoch
    )
    
    # Validate
    val_loss, val_val_loss, val_aro_loss, val_valence_ccc, val_arousal_ccc = evaluate_epoch(
        vit_model, val_loader, criterion, device, epoch, phase='Val'
    )
    
    # Update scheduler
    scheduler.step(val_loss)
    
    # Store history
    history['train_loss'].append(train_loss)
    history['train_valence_loss'].append(train_val_loss)
    history['train_arousal_loss'].append(train_aro_loss)
    history['val_loss'].append(val_loss)
    history['val_valence_loss'].append(val_val_loss)
    history['val_arousal_loss'].append(val_aro_loss)
    history['val_valence_ccc'].append(val_valence_ccc)
    history['val_arousal_ccc'].append(val_arousal_ccc)
    
    # Print epoch summary
    print(f"\n{'='*80}")
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} Summary:")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"  Valence CCC: {val_valence_ccc:.4f} | Arousal CCC: {val_arousal_ccc:.4f}")
    print(f"{'='*80}\n")
    
    # Early stopping and model saving
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': vit_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'history': history
        }, best_model_path)
        print(f"‚úÖ Saved best model (val_loss={val_loss:.4f})")
    else:
        patience_counter += 1
        print(f"‚è≥ Patience: {patience_counter}/{PATIENCE}")
        
        if patience_counter >= PATIENCE:
            print(f"üõë Early stopping triggered after {epoch+1} epochs")
            break

print("\nüéâ Training completed!")
print(f"‚úÖ Best validation loss: {best_val_loss:.4f}")

In [None]:
# Load best model and evaluate on test set
print("üß™ Evaluating on test set...")
checkpoint = torch.load(best_model_path)
vit_model.load_state_dict(checkpoint['model_state_dict'])

test_loss, test_val_loss, test_aro_loss, test_valence_ccc, test_arousal_ccc = evaluate_epoch(
    vit_model, test_loader, criterion, device, epoch=0, phase='Test'
)

print(f"\n{'='*80}")
print("üìä FINAL TEST SET RESULTS")
print(f"{'='*80}")
print(f"  Test Loss: {test_loss:.4f}")
print(f"  Valence Loss: {test_val_loss:.4f} | CCC: {test_valence_ccc:.4f}")
print(f"  Arousal Loss: {test_aro_loss:.4f} | CCC: {test_arousal_ccc:.4f}")
print(f"{'='*80}\n")

In [None]:
class DistillationLoss(nn.Module):
    """Combined distillation loss: MSE + Diffusion denoising"""
    def __init__(self, alpha=0.5, temperature=3.0):
        super().__init__()
        self.alpha = alpha  # Balance between MSE and denoising
        self.temperature = temperature
        self.mse = nn.MSELoss()
        
    def forward(self, student_pred, teacher_pred, noisy_teacher, student_noise_pred, noise_target):
        """
        student_pred: Student's direct prediction
        teacher_pred: Teacher's clean prediction
        noisy_teacher: Teacher prediction with added noise
        student_noise_pred: Student's attempt to denoise
        noise_target: Actual noise that was added
        """
        # Direct prediction loss (MSE between student and teacher)
        pred_loss = self.mse(student_pred, teacher_pred)
        
        # Denoising loss (student learns to predict the noise)
        denoise_loss = self.mse(student_noise_pred, noise_target)
        
        # Combined loss
        total_loss = self.alpha * pred_loss + (1 - self.alpha) * denoise_loss
        
        return total_loss, pred_loss, denoise_loss

distillation_criterion = DistillationLoss(alpha=0.6)  # 60% direct, 40% denoising
print("‚úÖ Distillation loss initialized (60% MSE, 40% denoising)")

In [None]:
# Plot distillation curves
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Total loss
axes[0].plot(distill_history['train_loss'], marker='o', color='purple')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Distillation Loss')
axes[0].grid(True, alpha=0.3)

# Component losses
axes[1].plot(distill_history['train_pred_loss'], label='Prediction Loss', marker='o')
axes[1].plot(distill_history['train_denoise_loss'], label='Denoising Loss', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].set_title('Loss Components')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('distillation_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Distillation curves saved to 'distillation_curves.png'")

In [None]:
# Visualize predictions comparison
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Teacher - Valence
axes[0, 0].scatter(true_labels[:, 0], teacher_predictions[:, 0], alpha=0.5, s=20)
axes[0, 0].plot([true_labels[:, 0].min(), true_labels[:, 0].max()], 
                [true_labels[:, 0].min(), true_labels[:, 0].max()], 
                'r--', lw=2, label='Perfect prediction')
axes[0, 0].set_xlabel('True Valence')
axes[0, 0].set_ylabel('Predicted Valence')
axes[0, 0].set_title(f'Teacher Model - Valence (CCC={teacher_val_ccc:.4f})')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Teacher - Arousal
axes[0, 1].scatter(true_labels[:, 1], teacher_predictions[:, 1], alpha=0.5, s=20, color='orange')
axes[0, 1].plot([true_labels[:, 1].min(), true_labels[:, 1].max()], 
                [true_labels[:, 1].min(), true_labels[:, 1].max()], 
                'r--', lw=2, label='Perfect prediction')
axes[0, 1].set_xlabel('True Arousal')
axes[0, 1].set_ylabel('Predicted Arousal')
axes[0, 1].set_title(f'Teacher Model - Arousal (CCC={teacher_aro_ccc:.4f})')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Student - Valence
axes[1, 0].scatter(true_labels[:, 0], student_predictions[:, 0], alpha=0.5, s=20, color='green')
axes[1, 0].plot([true_labels[:, 0].min(), true_labels[:, 0].max()], 
                [true_labels[:, 0].min(), true_labels[:, 0].max()], 
                'r--', lw=2, label='Perfect prediction')
axes[1, 0].set_xlabel('True Valence')
axes[1, 0].set_ylabel('Predicted Valence')
axes[1, 0].set_title(f'Student Model - Valence (CCC={student_val_ccc:.4f})')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Student - Arousal
axes[1, 1].scatter(true_labels[:, 1], student_predictions[:, 1], alpha=0.5, s=20, color='purple')
axes[1, 1].plot([true_labels[:, 1].min(), true_labels[:, 1].max()], 
                [true_labels[:, 1].min(), true_labels[:, 1].max()], 
                'r--', lw=2, label='Perfect prediction')
axes[1, 1].set_xlabel('True Arousal')
axes[1, 1].set_ylabel('Predicted Arousal')
axes[1, 1].set_title(f'Student Model - Arousal (CCC={student_aro_ccc:.4f})')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('teacher_vs_student_predictions.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Comparison plots saved to 'teacher_vs_student_predictions.png'")

In [None]:
# Visualize predictions on original songs
if os.path.exists(deam_audio_dir) and len(ground_truth) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Valence comparison
    x = np.arange(len(ground_truth))
    width = 0.25
    
    axes[0].bar(x - width, ground_truth[:, 0], width, label='Ground Truth', alpha=0.8)
    axes[0].bar(x, teacher_results[:, 0], width, label='Teacher', alpha=0.8)
    axes[0].bar(x + width, student_results[:, 0], width, label='Student', alpha=0.8)
    axes[0].set_xlabel('Song Sample')
    axes[0].set_ylabel('Valence')
    axes[0].set_title('Valence Predictions on Original Songs')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3, axis='y')
    
    # Arousal comparison
    axes[1].bar(x - width, ground_truth[:, 1], width, label='Ground Truth', alpha=0.8)
    axes[1].bar(x, teacher_results[:, 1], width, label='Teacher', alpha=0.8)
    axes[1].bar(x + width, student_results[:, 1], width, label='Student', alpha=0.8)
    axes[1].set_xlabel('Song Sample')
    axes[1].set_ylabel('Arousal')
    axes[1].set_title('Arousal Predictions on Original Songs')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig('original_songs_predictions.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print("üìä Original songs visualization saved to 'original_songs_predictions.png'")

In [None]:
print("="*100)
print(" " * 35 + "üéâ COMPLETE PIPELINE SUMMARY")
print("="*100)

print("\nüìä DATASET STATISTICS:")
print(f"  ‚Ä¢ Real DEAM samples: {DATASET_COUNTS['real']}")
print(f"  ‚Ä¢ GAN-generated samples: {DATASET_COUNTS['synthetic']}")
print(f"  ‚Ä¢ Total samples: {DATASET_COUNTS['total']}")
print(f"  ‚Ä¢ Train set: {DATASET_COUNTS['train']} (64%)")
print(f"  ‚Ä¢ Validation set: {DATASET_COUNTS['val']} (16%)")
print(f"  ‚Ä¢ Test set: {DATASET_COUNTS['test']} (20%)")

print("\nüéì TEACHER MODEL (Full Vision Transformer):")
print(f"  ‚Ä¢ Architecture: google/vit-base-patch16-224-in21k")
print(f"  ‚Ä¢ Parameters: {count_parameters(vit_model):,}")
print(f"  ‚Ä¢ Test Valence CCC: {test_valence_ccc:.4f}")
print(f"  ‚Ä¢ Test Arousal CCC: {test_arousal_ccc:.4f}")
print(f"  ‚Ä¢ Model saved: {best_model_path}")

print("\nüî¨ STUDENT MODEL (Diffusion-Compressed):")
print(f"  ‚Ä¢ Architecture: Lightweight ViT (6 layers, 384 hidden)")
print(f"  ‚Ä¢ Parameters: {count_parameters(student_model):,}")
print(f"  ‚Ä¢ Compression: {count_parameters(vit_model) / count_parameters(student_model):.1f}x smaller")
print(f"  ‚Ä¢ Test Valence CCC: {student_val_ccc:.4f} ({val_ccc_drop:.2f}% drop)")
print(f"  ‚Ä¢ Test Arousal CCC: {student_aro_ccc:.4f} ({aro_ccc_drop:.2f}% drop)")
print(f"  ‚Ä¢ Model saved: {student_model_path}")

print("\nüí° KEY INSIGHTS:")
print(f"  ‚Ä¢ GAN augmentation provided {DATASET_COUNTS['synthetic']/DATASET_COUNTS['real']:.1f}x more training data")
print(f"  ‚Ä¢ Diffusion-based distillation achieved {count_parameters(vit_model) / count_parameters(student_model):.1f}x compression")
print(f"  ‚Ä¢ Student model maintains {100 - max(val_ccc_drop, aro_ccc_drop):.1f}% of teacher performance")
print(f"  ‚Ä¢ Memory savings: ~{(count_parameters(vit_model) - count_parameters(student_model)) * 4 / 1e6:.1f} MB")

print("\nüìà TRAINING DETAILS:")
print(f"  ‚Ä¢ ViT Epochs: {len(history['train_loss'])} / {NUM_EPOCHS}")
print(f"  ‚Ä¢ Distillation Epochs: {DISTILL_EPOCHS}")
print(f"  ‚Ä¢ Best Validation Loss: {best_val_loss:.4f}")
print(f"  ‚Ä¢ Loss Function: Weighted (60% valence, 40% arousal)")

print("\n‚úÖ DELIVERABLES:")
print(f"  ‚Ä¢ Trained Teacher Model: {best_model_path}")
print(f"  ‚Ä¢ Compressed Student Model: {student_model_path}")
print(f"  ‚Ä¢ Training Curves: vit_training_curves.png")
print(f"  ‚Ä¢ Distillation Curves: distillation_curves.png")
print(f"  ‚Ä¢ Prediction Comparisons: teacher_vs_student_predictions.png")
if os.path.exists(deam_audio_dir):
    print(f"  ‚Ä¢ Original Songs Test: original_songs_predictions.png")

print("\n" + "="*100)
print(" " * 30 + "üöÄ Pipeline Execution Complete!")
print("="*100)

## üéØ Final Summary

Complete pipeline summary with all metrics and model information.

In [None]:
# Find original DEAM audio files
deam_audio_dir = '/kaggle/input/deam-dataset/DEAM_audio/MEMD_audio'
import os

if os.path.exists(deam_audio_dir):
    audio_files = [f for f in os.listdir(deam_audio_dir) if f.endswith('.mp3')]
    print(f"üìÅ Found {len(audio_files)} DEAM audio files")
    
    # Test on random sample of 20 songs
    import random
    random.seed(42)
    sample_files = random.sample(audio_files, min(20, len(audio_files)))
    
    print(f"üéµ Testing on {len(sample_files)} random songs...")
    
    teacher_results = []
    student_results = []
    ground_truth = []
    
    vit_model.eval()
    student_model.eval()
    
    for audio_file in tqdm(sample_files, desc='Processing songs'):
        audio_path = os.path.join(deam_audio_dir, audio_file)
        
        # Extract song ID from filename
        song_id = int(audio_file.split('.')[0])
        
        # Get ground truth labels
        if song_id in deam_annotations_df['song_id'].values:
            true_valence = deam_annotations_df[deam_annotations_df['song_id'] == song_id]['valence_mean'].values[0]
            true_arousal = deam_annotations_df[deam_annotations_df['song_id'] == song_id]['arousal_mean'].values[0]
            
            # Process audio to spectrogram
            spec = audio_to_spectrogram(audio_path)
            if spec is not None:
                spec = spec.unsqueeze(0).to(device)
                
                # Get predictions from both models
                with torch.no_grad():
                    teacher_pred = vit_model(spec).cpu().numpy()[0]
                    student_pred = student_model(spec).cpu().numpy()[0]
                
                teacher_results.append(teacher_pred)
                student_results.append(student_pred)
                ground_truth.append([true_valence, true_arousal])
    
    # Convert to arrays
    teacher_results = np.array(teacher_results)
    student_results = np.array(student_results)
    ground_truth = np.array(ground_truth)
    
    # Calculate metrics on original songs
    orig_teacher_val_ccc = concordance_correlation_coefficient(ground_truth[:, 0], teacher_results[:, 0])
    orig_teacher_aro_ccc = concordance_correlation_coefficient(ground_truth[:, 1], teacher_results[:, 1])
    orig_student_val_ccc = concordance_correlation_coefficient(ground_truth[:, 0], student_results[:, 0])
    orig_student_aro_ccc = concordance_correlation_coefficient(ground_truth[:, 1], student_results[:, 1])
    
    print(f"\n{'='*80}")
    print("üéµ RESULTS ON ORIGINAL DEAM SONGS")
    print(f"{'='*80}")
    print(f"Teacher Model:")
    print(f"  Valence CCC: {orig_teacher_val_ccc:.4f} | Arousal CCC: {orig_teacher_aro_ccc:.4f}")
    print(f"Student Model:")
    print(f"  Valence CCC: {orig_student_val_ccc:.4f} | Arousal CCC: {orig_student_aro_ccc:.4f}")
    print(f"{'='*80}\n")
    
else:
    print(f"‚ùå DEAM audio directory not found at {deam_audio_dir}")
    print("üí° This section requires the original DEAM audio files")
    print("üí° Skip if only spectrograms are available")

In [None]:
def audio_to_spectrogram(audio_path, target_length=431):
    """Convert audio file to mel-spectrogram matching DEAM format"""
    try:
        # Load audio
        y, sr = librosa.load(audio_path, sr=22050, duration=45.0)
        
        # Extract mel-spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=y, sr=sr, n_mels=128, n_fft=2048, hop_length=512
        )
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Pad or truncate to target length
        if mel_spec_db.shape[1] < target_length:
            pad_width = target_length - mel_spec_db.shape[1]
            mel_spec_db = np.pad(mel_spec_db, ((0, 0), (0, pad_width)), mode='constant')
        else:
            mel_spec_db = mel_spec_db[:, :target_length]
        
        # Normalize
        mel_spec_db = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-6)
        
        # Convert to 3-channel (RGB-like) format for ViT
        mel_spec_rgb = np.stack([mel_spec_db, mel_spec_db, mel_spec_db], axis=0)
        
        # Resize to 224x224 for ViT
        mel_spec_rgb = torch.from_numpy(mel_spec_rgb).float()
        mel_spec_rgb = torch.nn.functional.interpolate(
            mel_spec_rgb.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False
        )
        
        return mel_spec_rgb.squeeze(0)
    
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None

print("‚úÖ Audio processing function ready")

## üéµ Testing on Original DEAM Songs

Let's test both models on original DEAM audio files (not pre-computed spectrograms) to validate real-world performance.

In [None]:
# Compare teacher vs student on test set
print("üî¨ Comparing Teacher vs Student models on test set...")

vit_model.eval()
student_model.eval()

teacher_predictions = []
student_predictions = []
true_labels = []

with torch.no_grad():
    for spectrograms, labels in tqdm(test_loader, desc='Evaluating'):
        spectrograms = spectrograms.to(device)
        
        # Teacher predictions
        teacher_pred = vit_model(spectrograms)
        teacher_predictions.append(teacher_pred.cpu().numpy())
        
        # Student predictions
        student_pred = student_model(spectrograms)
        student_predictions.append(student_pred.cpu().numpy())
        
        true_labels.append(labels.numpy())

# Concatenate all predictions
teacher_predictions = np.concatenate(teacher_predictions, axis=0)
student_predictions = np.concatenate(student_predictions, axis=0)
true_labels = np.concatenate(true_labels, axis=0)

# Calculate metrics for both models
teacher_val_ccc = concordance_correlation_coefficient(true_labels[:, 0], teacher_predictions[:, 0])
teacher_aro_ccc = concordance_correlation_coefficient(true_labels[:, 1], teacher_predictions[:, 1])
teacher_val_mae = np.mean(np.abs(true_labels[:, 0] - teacher_predictions[:, 0]))
teacher_aro_mae = np.mean(np.abs(true_labels[:, 1] - teacher_predictions[:, 1]))

student_val_ccc = concordance_correlation_coefficient(true_labels[:, 0], student_predictions[:, 0])
student_aro_ccc = concordance_correlation_coefficient(true_labels[:, 1], student_predictions[:, 1])
student_val_mae = np.mean(np.abs(true_labels[:, 0] - student_predictions[:, 0]))
student_aro_mae = np.mean(np.abs(true_labels[:, 1] - student_predictions[:, 1]))

print(f"\n{'='*80}")
print("üìä TEACHER MODEL (Full ViT - 86M params)")
print(f"{'='*80}")
print(f"  Valence - CCC: {teacher_val_ccc:.4f} | MAE: {teacher_val_mae:.4f}")
print(f"  Arousal - CCC: {teacher_aro_ccc:.4f} | MAE: {teacher_aro_mae:.4f}")

print(f"\n{'='*80}")
print("üìä STUDENT MODEL (Lightweight - 20M params)")
print(f"{'='*80}")
print(f"  Valence - CCC: {student_val_ccc:.4f} | MAE: {student_val_mae:.4f}")
print(f"  Arousal - CCC: {student_aro_ccc:.4f} | MAE: {student_aro_mae:.4f}")

print(f"\n{'='*80}")
print("üìâ PERFORMANCE DROP")
print(f"{'='*80}")
val_ccc_drop = ((teacher_val_ccc - student_val_ccc) / teacher_val_ccc) * 100
aro_ccc_drop = ((teacher_aro_ccc - student_aro_ccc) / teacher_aro_ccc) * 100
print(f"  Valence CCC: {val_ccc_drop:.2f}% drop")
print(f"  Arousal CCC: {aro_ccc_drop:.2f}% drop")
print(f"  Size Reduction: {count_parameters(vit_model) / count_parameters(student_model):.1f}x smaller")
print(f"{'='*80}\n")

In [None]:
# Distillation training configuration
DISTILL_EPOCHS = 15
DISTILL_LR = 1e-4

# Freeze teacher model
for param in vit_model.parameters():
    param.requires_grad = False
vit_model.eval()

# Initialize student optimizer
student_optimizer = torch.optim.AdamW(student_model.parameters(), lr=DISTILL_LR, weight_decay=0.01)
student_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(student_optimizer, T_max=DISTILL_EPOCHS)

# Distillation history
distill_history = {
    'train_loss': [],
    'train_pred_loss': [],
    'train_denoise_loss': []
}

print("üéì Starting diffusion-based distillation...")
print(f"üìä Training for {DISTILL_EPOCHS} epochs")
print(f"üéØ Teacher: {count_parameters(vit_model):,} params (frozen)")
print(f"üéØ Student: {count_parameters(student_model):,} params (training)")
print("-" * 80)

for epoch in range(DISTILL_EPOCHS):
    # Train student
    loss, pred_loss, denoise_loss = train_distillation_epoch(
        student_model, vit_model, train_loader, 
        distillation_criterion, student_optimizer, diffusion, device, epoch
    )
    
    # Update scheduler
    student_scheduler.step()
    
    # Store history
    distill_history['train_loss'].append(loss)
    distill_history['train_pred_loss'].append(pred_loss)
    distill_history['train_denoise_loss'].append(denoise_loss)
    
    print(f"\nEpoch {epoch+1}/{DISTILL_EPOCHS}:")
    print(f"  Loss: {loss:.4f} | Pred: {pred_loss:.4f} | Denoise: {denoise_loss:.4f}")

# Save student model
student_model_path = 'lightweight_vit_student.pth'
torch.save({
    'model_state_dict': student_model.state_dict(),
    'distill_history': distill_history
}, student_model_path)

print(f"\n‚úÖ Distillation completed! Student model saved to '{student_model_path}'")

In [None]:
def train_distillation_epoch(student, teacher, dataloader, criterion, optimizer, diffusion, device, epoch):
    """Train student model with diffusion-based distillation"""
    student.train()
    teacher.eval()  # Teacher is frozen
    
    running_loss = 0.0
    running_pred_loss = 0.0
    running_denoise_loss = 0.0
    
    progress_bar = tqdm(dataloader, desc=f'Distillation Epoch {epoch+1}')
    
    with torch.no_grad():
        teacher_frozen = True  # Ensure teacher doesn't update
    
    for batch_idx, (spectrograms, _) in enumerate(progress_bar):
        spectrograms = spectrograms.to(device)
        
        # Get teacher predictions (frozen, no gradients)
        with torch.no_grad():
            teacher_pred = teacher(spectrograms)
        
        # Sample timesteps for diffusion
        t = diffusion.get_timestep(spectrograms.shape[0], device)
        
        # Add noise to teacher predictions
        noisy_teacher, noise_target = diffusion.add_noise(teacher_pred, t, device)
        
        # Student makes two predictions:
        # 1. Direct prediction from input
        student_pred = student(spectrograms)
        
        # 2. Denoising prediction (predicting the noise in noisy teacher output)
        # Here we use a simple approach: predict noise as difference from noisy input
        student_noise_pred = student_pred - noisy_teacher.detach()
        
        # Calculate distillation loss
        optimizer.zero_grad()
        loss, pred_loss, denoise_loss = criterion(
            student_pred, teacher_pred.detach(), 
            noisy_teacher, student_noise_pred, noise_target
        )
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Update metrics
        running_loss += loss.item()
        running_pred_loss += pred_loss.item()
        running_denoise_loss += denoise_loss.item()
        
        # Update progress bar
        if (batch_idx + 1) % 10 == 0:
            progress_bar.set_postfix({
                'loss': f'{running_loss/(batch_idx+1):.4f}',
                'pred': f'{running_pred_loss/(batch_idx+1):.4f}',
                'denoise': f'{running_denoise_loss/(batch_idx+1):.4f}'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_pred_loss = running_pred_loss / len(dataloader)
    epoch_denoise_loss = running_denoise_loss / len(dataloader)
    
    return epoch_loss, epoch_pred_loss, epoch_denoise_loss

In [None]:
class DiffusionDistillation:
    """Diffusion-based knowledge distillation"""
    def __init__(self, num_timesteps=100, beta_start=0.0001, beta_end=0.02):
        self.num_timesteps = num_timesteps
        
        # Linear beta schedule
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        
    def add_noise(self, predictions, t, device):
        """Add noise to teacher predictions based on timestep t"""
        batch_size = predictions.shape[0]
        
        # Get alpha for this timestep
        alpha_t = self.alphas_cumprod[t].to(device)
        alpha_t = alpha_t.view(-1, 1)  # Shape: (batch, 1)
        
        # Sample noise
        noise = torch.randn_like(predictions)
        
        # Add noise: x_t = sqrt(alpha_t) * x_0 + sqrt(1 - alpha_t) * noise
        noisy_predictions = torch.sqrt(alpha_t) * predictions + torch.sqrt(1 - alpha_t) * noise
        
        return noisy_predictions, noise
    
    def get_timestep(self, batch_size, device):
        """Sample random timesteps for batch"""
        return torch.randint(0, self.num_timesteps, (batch_size,), device=device)

diffusion = DiffusionDistillation(num_timesteps=100)
print("‚úÖ Diffusion scheduler initialized with 100 timesteps")

In [None]:
class LightweightViT(nn.Module):
    """Lightweight ViT for student model (~20M parameters)"""
    def __init__(self):
        super(LightweightViT, self).__init__()
        
        # Use smaller ViT variant
        self.vit = ViTModel.from_pretrained(
            "google/vit-base-patch16-224-in21k",
            num_hidden_layers=6,  # Reduce from 12 to 6 layers
            hidden_size=384,      # Reduce from 768 to 384
            num_attention_heads=6, # Reduce from 12 to 6
            intermediate_size=1536 # Reduce from 3072 to 1536
        )
        
        # Emotion prediction head
        self.emotion_head = nn.Sequential(
            nn.LayerNorm(384),
            nn.Dropout(0.1),
            nn.Linear(384, 192),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(192, 2)  # valence, arousal
        )
        
    def forward(self, x):
        outputs = self.vit(pixel_values=x)
        pooled = outputs.last_hidden_state[:, 0]  # CLS token
        emotions = self.emotion_head(pooled)
        return emotions

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"üìè Teacher Model (Full ViT): {count_parameters(vit_model):,} parameters")
student_model = LightweightViT().to(device)
print(f"üìè Student Model (Lightweight): {count_parameters(student_model):,} parameters")
print(f"üéØ Compression Ratio: {count_parameters(vit_model) / count_parameters(student_model):.2f}x")

## üî¨ Diffusion-Based Model Compression

We'll use a diffusion-inspired knowledge distillation approach to compress the ViT model:
- **Teacher Model**: Full ViT (86M parameters)
- **Student Model**: Lightweight ViT (20M parameters)
- **Approach**: Student learns to denoise teacher predictions + match outputs
- **Goal**: 4-5x compression with <5% performance drop

## üéØ Knowledge Distillation for Mobile Deployment

### Goal: Create a Lightweight Model for Android Phones

We'll use **attention transfer distillation** to compress the full ViT model:

**Teacher Model (Full ViT)**
- Architecture: `google/vit-base-patch16-224-in21k`
- Parameters: ~86M
- Performance: High accuracy but requires ~350MB memory

**Student Model (MobileViT)**
- Architecture: Lightweight ViT with reduced layers
- Parameters: ~5-8M (target)
- Performance: >90% of teacher with <40MB memory
- Target: Runnable on Android phones with 4GB RAM

### Distillation Method: Response-Based + Feature-Based

1. **Response Distillation**: Student mimics teacher's emotion predictions
2. **Feature Distillation**: Student learns intermediate representations
3. **Attention Transfer**: Student learns where teacher "looks"

### Expected Results
- Model size: 10-15x smaller (~25-40MB)
- Inference speed: 3-5x faster
- Performance retention: >90% of teacher CCC
- Memory usage: <200MB during inference

In [None]:
class MobileViTBlock(nn.Module):
    """Efficient ViT block for mobile deployment"""
    def __init__(self, dim, num_heads=4, mlp_ratio=2.0, drop=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=drop, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        
        # Smaller MLP expansion ratio for mobile
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(drop)
        )
    
    def forward(self, x):
        # Pre-norm architecture for better training
        attn_out, attn_weights = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x, attn_weights


class MobileViTStudent(nn.Module):
    """
    Lightweight Vision Transformer optimized for mobile deployment
    
    Architecture choices for Android phones:
    - Smaller patch size (16x16) for better feature extraction
    - Fewer transformer layers (4 instead of 12)
    - Smaller hidden dimension (192 instead of 768)
    - Efficient attention mechanism
    - ~5-8M parameters (vs 86M in full ViT)
    """
    def __init__(self, 
                 image_size=224,
                 patch_size=16,
                 num_classes=2,  # valence, arousal
                 hidden_dim=192,  # Reduced from 768
                 num_layers=4,    # Reduced from 12
                 num_heads=4,     # Reduced from 12
                 mlp_ratio=2.0,   # Reduced from 4.0
                 dropout=0.1):
        super().__init__()
        
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        num_patches = (image_size // patch_size) ** 2
        
        # Patch embedding with depthwise separable convolution (mobile-friendly)
        self.patch_embed = nn.Sequential(
            # Depthwise conv
            nn.Conv2d(3, 3, kernel_size=patch_size, stride=patch_size, groups=3, bias=False),
            nn.BatchNorm2d(3),
            # Pointwise conv
            nn.Conv2d(3, hidden_dim, kernel_size=1, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.GELU()
        )
        
        # Learnable position embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, hidden_dim))
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer encoder blocks
        self.blocks = nn.ModuleList([
            MobileViTBlock(hidden_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(hidden_dim)
        
        # Emotion prediction head
        self.head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes),
            nn.Tanh()  # Output in [-1, 1] range
        )
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for better convergence"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x, return_attention=False):
        """
        Args:
            x: Input tensor (B, 3, 224, 224)
            return_attention: If True, return attention weights
        Returns:
            emotions: Predicted valence and arousal (B, 2)
            attentions: List of attention weights (if return_attention=True)
        """
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, hidden_dim, H/patch, W/patch)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, hidden_dim)
        
        # Add CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, num_patches+1, hidden_dim)
        
        # Add position embeddings
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer blocks
        attentions = []
        for block in self.blocks:
            x, attn = block(x)
            if return_attention:
                attentions.append(attn)
        
        # Final norm
        x = self.norm(x)
        
        # Use CLS token for prediction
        cls_output = x[:, 0]
        emotions = self.head(cls_output)
        
        if return_attention:
            return emotions, attentions
        return emotions
    
    def get_num_params(self):
        """Return number of parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# Initialize student model
print("=" * 80)
print("üéì INITIALIZING MOBILE-OPTIMIZED STUDENT MODEL")
print("=" * 80)

mobile_student = MobileViTStudent(
    image_size=VIT_IMAGE_SIZE,
    hidden_dim=192,
    num_layers=4,
    num_heads=4,
    mlp_ratio=2.0,
    dropout=0.1
).to(DEVICE)

# Count parameters
student_params = mobile_student.get_num_params()
teacher_params = sum(p.numel() for p in model.parameters())

print(f"\nüìä Model Comparison:")
print(f"  Teacher (Full ViT):")
print(f"    - Parameters: {teacher_params:,}")
print(f"    - Memory: ~{teacher_params * 4 / 1024**2:.1f} MB (fp32)")
print(f"    - Layers: 12")
print(f"    - Hidden dim: 768")
print(f"\n  Student (MobileViT):")
print(f"    - Parameters: {student_params:,}")
print(f"    - Memory: ~{student_params * 4 / 1024**2:.1f} MB (fp32)")
print(f"    - Layers: 4")
print(f"    - Hidden dim: 192")
print(f"\n  üìâ Compression:")
print(f"    - Size reduction: {teacher_params / student_params:.1f}x smaller")
print(f"    - Memory savings: {(teacher_params - student_params) * 4 / 1024**2:.1f} MB")
print(f"    - Target platforms: Android phones with 4GB+ RAM")
print("=" * 80)

# Test forward pass
with torch.no_grad():
    test_input = torch.randn(2, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(DEVICE)
    test_output, test_attentions = mobile_student(test_input, return_attention=True)
    print(f"\n‚úÖ Forward pass successful!")
    print(f"   Input shape: {test_input.shape}")
    print(f"   Output shape: {test_output.shape}")
    print(f"   Attention maps: {len(test_attentions)} layers")
    print(f"   Output range: [{test_output.min().item():.3f}, {test_output.max().item():.3f}]")

print(f"\nüí° Model ready for distillation training!")

In [None]:
class KnowledgeDistillationLoss(nn.Module):
    """
    Comprehensive distillation loss combining:
    1. Response-based distillation (output matching)
    2. Feature-based distillation (intermediate layer matching)
    3. Attention transfer (where the model "looks")
    """
    def __init__(self, 
                 alpha=0.5,           # Weight for hard target loss
                 beta=0.3,            # Weight for feature distillation
                 gamma=0.2,           # Weight for attention transfer
                 temperature=4.0,     # Softmax temperature for soft targets
                 response_loss='mse'):
        super().__init__()
        
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.temperature = temperature
        
        # Loss functions
        self.hard_loss = nn.MSELoss()
        self.feature_loss = nn.MSELoss()
        self.attention_loss = nn.MSELoss()
        
        print(f"üéØ Distillation Loss Configuration:")
        print(f"   Œ± (hard targets):    {alpha:.2f}")
        print(f"   Œ≤ (features):        {beta:.2f}")
        print(f"   Œ≥ (attention):       {gamma:.2f}")
        print(f"   Temperature:         {temperature}")
    
    def forward(self, 
                student_outputs, 
                teacher_outputs, 
                true_labels,
                student_features=None,
                teacher_features=None,
                student_attentions=None,
                teacher_attentions=None):
        """
        Args:
            student_outputs: Student predictions (B, 2)
            teacher_outputs: Teacher predictions (B, 2)
            true_labels: Ground truth labels (B, 2)
            student_features: Student intermediate features (optional)
            teacher_features: Teacher intermediate features (optional)
            student_attentions: Student attention maps (optional)
            teacher_attentions: Teacher attention maps (optional)
        
        Returns:
            total_loss: Combined distillation loss
            loss_dict: Dictionary with individual loss components
        """
        # 1. Hard target loss (student vs ground truth)
        loss_hard = self.hard_loss(student_outputs, true_labels)
        
        # 2. Soft target loss (student vs teacher with temperature)
        # Apply temperature to make distributions softer
        soft_student = student_outputs / self.temperature
        soft_teacher = teacher_outputs / self.temperature
        loss_soft = self.hard_loss(soft_student, soft_teacher.detach()) * (self.temperature ** 2)
        
        # Combined response loss
        loss_response = self.alpha * loss_hard + (1 - self.alpha) * loss_soft
        
        # 3. Feature-based distillation (if features provided)
        loss_feature = torch.tensor(0.0).to(student_outputs.device)
        if student_features is not None and teacher_features is not None:
            # Match intermediate representations
            # Features might have different dimensions, so we project them
            for s_feat, t_feat in zip(student_features, teacher_features):
                # Normalize features
                s_feat_norm = F.normalize(s_feat, dim=-1)
                t_feat_norm = F.normalize(t_feat.detach(), dim=-1)
                loss_feature += self.feature_loss(s_feat_norm, t_feat_norm)
            loss_feature /= len(student_features)
        
        # 4. Attention transfer (if attention maps provided)
        loss_attention = torch.tensor(0.0).to(student_outputs.device)
        if student_attentions is not None and teacher_attentions is not None:
            # Match attention distributions
            for s_attn, t_attn in zip(student_attentions, teacher_attentions):
                # Normalize attention maps
                s_attn_norm = F.softmax(s_attn.mean(dim=1), dim=-1)  # Average over heads
                t_attn_norm = F.softmax(t_attn.mean(dim=1).detach(), dim=-1)
                loss_attention += self.attention_loss(s_attn_norm, t_attn_norm)
            loss_attention /= len(student_attentions)
        
        # Total loss
        total_loss = loss_response + self.beta * loss_feature + self.gamma * loss_attention
        
        # Return loss dictionary for monitoring
        loss_dict = {
            'total': total_loss.item(),
            'hard': loss_hard.item(),
            'soft': loss_soft.item(),
            'response': loss_response.item(),
            'feature': loss_feature.item(),
            'attention': loss_attention.item()
        }
        
        return total_loss, loss_dict


def extract_teacher_features(teacher_model, inputs):
    """
    Extract intermediate features from teacher model for distillation.
    This requires accessing internal layers of the ViT model.
    """
    features = []
    
    # Hook function to capture intermediate outputs
    def hook_fn(module, input, output):
        # For ViT, we want the output of each transformer block
        features.append(output[0][:, 0, :])  # CLS token representation
    
    # Register hooks on transformer blocks
    hooks = []
    if hasattr(teacher_model, 'vit'):
        # Access ViT encoder blocks
        if hasattr(teacher_model.vit, 'encoder'):
            for block in teacher_model.vit.encoder.layer:
                hooks.append(block.register_forward_hook(hook_fn))
    
    # Forward pass
    with torch.no_grad():
        outputs = teacher_model(inputs)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    return outputs, features


def extract_student_features(student_model, inputs):
    """Extract intermediate features from student model."""
    features = []
    
    # Hook function
    def hook_fn(module, input, output):
        features.append(output[:, 0, :])  # CLS token
    
    # Register hooks on student blocks
    hooks = []
    for block in student_model.blocks:
        hooks.append(block.register_forward_hook(hook_fn))
    
    # Forward pass
    outputs = student_model(inputs, return_attention=False)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    return outputs, features


# Initialize distillation loss
print("\n" + "=" * 80)
print("üîß DISTILLATION LOSS INITIALIZATION")
print("=" * 80)

distillation_criterion = KnowledgeDistillationLoss(
    alpha=0.5,       # 50% ground truth, 50% teacher
    beta=0.3,        # 30% weight for feature matching
    gamma=0.2,       # 20% weight for attention transfer
    temperature=4.0
)

print(f"\n‚úÖ Distillation loss initialized!")
print(f"   Loss components: Response + Feature + Attention")
print(f"   Total weight: Œ± + Œ≤ + Œ≥ = 1.0")

# Test distillation loss
print(f"\nüß™ Testing distillation loss...")
with torch.no_grad():
    # Create dummy data
    dummy_input = torch.randn(4, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(DEVICE)
    dummy_labels = torch.randn(4, 2).to(DEVICE)
    
    # Get teacher outputs
    teacher_out, teacher_feats = extract_teacher_features(model, dummy_input)
    
    # Get student outputs  
    student_out, student_attns = mobile_student(dummy_input, return_attention=True)
    student_out_feat, student_feats = extract_student_features(mobile_student, dummy_input)
    
    # Calculate loss
    loss, loss_dict = distillation_criterion(
        student_out, teacher_out, dummy_labels,
        student_features=student_feats[:len(teacher_feats)],  # Match teacher depth
        teacher_features=teacher_feats,
        student_attentions=student_attns,
        teacher_attentions=None  # Teacher attention extraction is complex
    )
    
    print(f"‚úÖ Distillation loss test successful!")
    print(f"   Total loss: {loss_dict['total']:.4f}")
    print(f"   - Hard target:  {loss_dict['hard']:.4f}")
    print(f"   - Soft target:  {loss_dict['soft']:.4f}")
    print(f"   - Feature:      {loss_dict['feature']:.4f}")
    print(f"   - Attention:    {loss_dict['attention']:.4f}")

print("=" * 80)

In [None]:
def train_distillation_epoch(student, teacher, train_loader, criterion, optimizer, device, epoch):
    """Train student model for one epoch using knowledge distillation"""
    student.train()
    teacher.eval()  # Teacher is always in eval mode
    
    running_losses = {
        'total': 0.0, 'hard': 0.0, 'soft': 0.0, 
        'response': 0.0, 'feature': 0.0, 'attention': 0.0
    }
    
    progress_bar = tqdm(train_loader, desc=f'Distillation Epoch {epoch+1}')
    
    for batch_idx, (spectrograms, labels) in enumerate(progress_bar):
        spectrograms = spectrograms.to(device)
        labels = labels.to(device)
        
        # Get teacher predictions (no gradients)
        with torch.no_grad():
            teacher_outputs, teacher_features = extract_teacher_features(teacher, spectrograms)
        
        # Get student predictions (with gradients)
        student_outputs, student_attentions = student(spectrograms, return_attention=True)
        _, student_features = extract_student_features(student, spectrograms)
        
        # Calculate distillation loss
        optimizer.zero_grad()
        loss, loss_dict = criterion(
            student_outputs, 
            teacher_outputs, 
            labels,
            student_features=student_features[:len(teacher_features)],
            teacher_features=teacher_features,
            student_attentions=student_attentions,
            teacher_attentions=None
        )
        
        # Backward and optimize
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Update running losses
        for key in running_losses:
            running_losses[key] += loss_dict[key]
        
        # Update progress bar
        if (batch_idx + 1) % 10 == 0:
            avg_loss = running_losses['total'] / (batch_idx + 1)
            progress_bar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'hard': f'{loss_dict["hard"]:.4f}',
                'soft': f'{loss_dict["soft"]:.4f}'
            })
    
    # Calculate epoch averages
    epoch_losses = {key: val / len(train_loader) for key, val in running_losses.items()}
    return epoch_losses


def evaluate_distillation(student, teacher, test_loader, device):
    """Evaluate student and teacher models side by side"""
    student.eval()
    teacher.eval()
    
    student_preds = []
    teacher_preds = []
    true_labels = []
    
    with torch.no_grad():
        for spectrograms, labels in tqdm(test_loader, desc='Evaluating'):
            spectrograms = spectrograms.to(device)
            
            # Get predictions
            student_out = student(spectrograms)
            teacher_out = teacher(spectrograms)
            
            student_preds.append(student_out.cpu().numpy())
            teacher_preds.append(teacher_out.cpu().numpy())
            true_labels.append(labels.numpy())
    
    # Concatenate all batches
    student_preds = np.concatenate(student_preds, axis=0)
    teacher_preds = np.concatenate(teacher_preds, axis=0)
    true_labels = np.concatenate(true_labels, axis=0)
    
    # Calculate metrics
    def calc_metrics(predictions, targets):
        mae = np.mean(np.abs(predictions - targets), axis=0)
        
        # CCC for valence and arousal
        def ccc(y_true, y_pred):
            mean_true = np.mean(y_true)
            mean_pred = np.mean(y_pred)
            var_true = np.var(y_true)
            var_pred = np.var(y_pred)
            covariance = np.mean((y_true - mean_true) * (y_pred - mean_pred))
            ccc_val = (2 * covariance) / (var_true + var_pred + (mean_true - mean_pred)**2)
            return ccc_val
        
        ccc_valence = ccc(targets[:, 0], predictions[:, 0])
        ccc_arousal = ccc(targets[:, 1], predictions[:, 1])
        
        return {
            'mae_valence': mae[0],
            'mae_arousal': mae[1],
            'mae_avg': np.mean(mae),
            'ccc_valence': ccc_valence,
            'ccc_arousal': ccc_arousal,
            'ccc_avg': (ccc_valence + ccc_arousal) / 2
        }
    
    teacher_metrics = calc_metrics(teacher_preds, true_labels)
    student_metrics = calc_metrics(student_preds, true_labels)
    
    return {
        'student': student_metrics,
        'teacher': teacher_metrics,
        'student_preds': student_preds,
        'teacher_preds': teacher_preds,
        'true_labels': true_labels
    }


# Distillation training configuration
print("=" * 80)
print("üéì KNOWLEDGE DISTILLATION TRAINING")
print("=" * 80)

DISTILL_EPOCHS = 20
DISTILL_LR = 2e-4
DISTILL_WEIGHT_DECAY = 0.01
DISTILL_PATIENCE = 5

# Initialize optimizer
distill_optimizer = optim.AdamW(
    mobile_student.parameters(),
    lr=DISTILL_LR,
    weight_decay=DISTILL_WEIGHT_DECAY
)

# Learning rate scheduler
distill_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    distill_optimizer,
    T_0=5,  # Restart every 5 epochs
    T_mult=2,
    eta_min=1e-6
)

print(f"\nüìä Training Configuration:")
print(f"   Epochs: {DISTILL_EPOCHS}")
print(f"   Learning rate: {DISTILL_LR}")
print(f"   Weight decay: {DISTILL_WEIGHT_DECAY}")
print(f"   Patience: {DISTILL_PATIENCE}")
print(f"   Scheduler: Cosine Annealing with Warm Restarts")

# Training history
distill_history = {
    'train_loss': [], 'train_hard': [], 'train_soft': [],
    'train_feature': [], 'train_attention': [],
    'val_student_ccc': [], 'val_teacher_ccc': [],
    'val_student_mae': [], 'val_teacher_mae': []
}

best_student_ccc = 0.0
best_model_state = None
patience_counter = 0

print(f"\n{'=' * 80}")
print(f"üöÄ Starting distillation training...")
print(f"{'=' * 80}\n")

# Freeze teacher
for param in model.parameters():
    param.requires_grad = False

for epoch in range(DISTILL_EPOCHS):
    # Train student
    epoch_losses = train_distillation_epoch(
        mobile_student, model, train_loader, 
        distillation_criterion, distill_optimizer, DEVICE, epoch
    )
    
    # Validate
    print(f"\nüìä Epoch {epoch + 1}/{DISTILL_EPOCHS} - Evaluating...")
    eval_results = evaluate_distillation(mobile_student, model, val_loader, DEVICE)
    
    # Update scheduler
    distill_scheduler.step()
    
    # Store history
    distill_history['train_loss'].append(epoch_losses['total'])
    distill_history['train_hard'].append(epoch_losses['hard'])
    distill_history['train_soft'].append(epoch_losses['soft'])
    distill_history['train_feature'].append(epoch_losses['feature'])
    distill_history['train_attention'].append(epoch_losses['attention'])
    distill_history['val_student_ccc'].append(eval_results['student']['ccc_avg'])
    distill_history['val_teacher_ccc'].append(eval_results['teacher']['ccc_avg'])
    distill_history['val_student_mae'].append(eval_results['student']['mae_avg'])
    distill_history['val_teacher_mae'].append(eval_results['teacher']['mae_avg'])
    
    # Print results
    print(f"\n{'=' * 80}")
    print(f"Epoch {epoch + 1}/{DISTILL_EPOCHS} Summary:")
    print(f"{'=' * 80}")
    print(f"Training Loss: {epoch_losses['total']:.4f}")
    print(f"  ‚îú‚îÄ Hard:      {epoch_losses['hard']:.4f}")
    print(f"  ‚îú‚îÄ Soft:      {epoch_losses['soft']:.4f}")
    print(f"  ‚îú‚îÄ Feature:   {epoch_losses['feature']:.4f}")
    print(f"  ‚îî‚îÄ Attention: {epoch_losses['attention']:.4f}")
    
    print(f"\nTeacher Performance:")
    print(f"  ‚îú‚îÄ CCC Avg:     {eval_results['teacher']['ccc_avg']:.4f}")
    print(f"  ‚îú‚îÄ CCC Valence: {eval_results['teacher']['ccc_valence']:.4f}")
    print(f"  ‚îú‚îÄ CCC Arousal: {eval_results['teacher']['ccc_arousal']:.4f}")
    print(f"  ‚îî‚îÄ MAE Avg:     {eval_results['teacher']['mae_avg']:.4f}")
    
    print(f"\nStudent Performance:")
    print(f"  ‚îú‚îÄ CCC Avg:     {eval_results['student']['ccc_avg']:.4f}")
    print(f"  ‚îú‚îÄ CCC Valence: {eval_results['student']['ccc_valence']:.4f}")
    print(f"  ‚îú‚îÄ CCC Arousal: {eval_results['student']['ccc_arousal']:.4f}")
    print(f"  ‚îî‚îÄ MAE Avg:     {eval_results['student']['mae_avg']:.4f}")
    
    # Calculate retention percentage
    ccc_retention = (eval_results['student']['ccc_avg'] / eval_results['teacher']['ccc_avg']) * 100
    print(f"\nüìà Knowledge Retention: {ccc_retention:.1f}% of teacher performance")
    print(f"{'=' * 80}\n")
    
    # Early stopping and model saving
    current_ccc = eval_results['student']['ccc_avg']
    if current_ccc > best_student_ccc:
        best_student_ccc = current_ccc
        best_model_state = mobile_student.state_dict().copy()
        patience_counter = 0
        
        # Save best model
        torch.save({
            'epoch': epoch,
            'model_state_dict': mobile_student.state_dict(),
            'optimizer_state_dict': distill_optimizer.state_dict(),
            'best_ccc': best_student_ccc,
            'eval_results': eval_results,
            'history': distill_history
        }, os.path.join(OUTPUT_DIR, 'mobile_vit_student_best.pth'))
        
        print(f"‚úÖ Saved best model (CCC: {best_student_ccc:.4f})")
    else:
        patience_counter += 1
        print(f"‚è≥ No improvement. Patience: {patience_counter}/{DISTILL_PATIENCE}")
        
        if patience_counter >= DISTILL_PATIENCE:
            print(f"\nüõë Early stopping triggered after {epoch + 1} epochs")
            break
    
    # Clear GPU cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# Load best model
if best_model_state is not None:
    mobile_student.load_state_dict(best_model_state)
    print(f"\n‚úÖ Loaded best student model (CCC: {best_student_ccc:.4f})")

print(f"\n{'=' * 80}")
print(f"üéâ Distillation Training Complete!")
print(f"{'=' * 80}")
print(f"Best Student CCC: {best_student_ccc:.4f}")
print(f"Training epochs: {len(distill_history['train_loss'])}")
print(f"{'=' * 80}")

In [None]:
# Final evaluation on test set
print("=" * 80)
print("üß™ FINAL TEST SET EVALUATION")
print("=" * 80)

test_results = evaluate_distillation(mobile_student, model, test_loader, DEVICE)

print(f"\nüìä Teacher Model (Full ViT - {teacher_params:,} params):")
print(f"   Valence - CCC: {test_results['teacher']['ccc_valence']:.4f}, MAE: {test_results['teacher']['mae_valence']:.4f}")
print(f"   Arousal - CCC: {test_results['teacher']['ccc_arousal']:.4f}, MAE: {test_results['teacher']['mae_arousal']:.4f}")
print(f"   Average - CCC: {test_results['teacher']['ccc_avg']:.4f}, MAE: {test_results['teacher']['mae_avg']:.4f}")

print(f"\nüì± Student Model (MobileViT - {student_params:,} params):")
print(f"   Valence - CCC: {test_results['student']['ccc_valence']:.4f}, MAE: {test_results['student']['mae_valence']:.4f}")
print(f"   Arousal - CCC: {test_results['student']['ccc_arousal']:.4f}, MAE: {test_results['student']['mae_arousal']:.4f}")
print(f"   Average - CCC: {test_results['student']['ccc_avg']:.4f}, MAE: {test_results['student']['mae_avg']:.4f}")

# Calculate performance retention
ccc_retention = (test_results['student']['ccc_avg'] / test_results['teacher']['ccc_avg']) * 100
mae_increase = ((test_results['student']['mae_avg'] - test_results['teacher']['mae_avg']) 
                / test_results['teacher']['mae_avg']) * 100

print(f"\nüìà Compression Results:")
print(f"   Size Reduction:    {teacher_params / student_params:.1f}x smaller")
print(f"   Memory Savings:    {(teacher_params - student_params) * 4 / 1024**2:.1f} MB")
print(f"   CCC Retention:     {ccc_retention:.1f}%")
print(f"   MAE Increase:      {mae_increase:+.1f}%")

if ccc_retention >= 90:
    print(f"\n‚úÖ Excellent! Student retains >90% of teacher performance")
elif ccc_retention >= 85:
    print(f"\n‚úÖ Good! Student retains >85% of teacher performance")
else:
    print(f"\n‚ö†Ô∏è Student performance could be improved with more training")

print("=" * 80)

# Visualize test results
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# Valence predictions
axes[0, 0].scatter(test_results['true_labels'][:, 0], 
                   test_results['teacher_preds'][:, 0], 
                   alpha=0.5, s=20, label='Teacher', color='blue')
axes[0, 0].scatter(test_results['true_labels'][:, 0], 
                   test_results['student_preds'][:, 0], 
                   alpha=0.5, s=20, label='Student', color='red')
axes[0, 0].plot([-1, 1], [-1, 1], 'k--', lw=2, label='Perfect')
axes[0, 0].set_xlabel('True Valence')
axes[0, 0].set_ylabel('Predicted Valence')
axes[0, 0].set_title(f'Valence Predictions\nTeacher CCC: {test_results["teacher"]["ccc_valence"]:.3f} | Student CCC: {test_results["student"]["ccc_valence"]:.3f}')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Arousal predictions
axes[0, 1].scatter(test_results['true_labels'][:, 1], 
                   test_results['teacher_preds'][:, 1], 
                   alpha=0.5, s=20, label='Teacher', color='blue')
axes[0, 1].scatter(test_results['true_labels'][:, 1], 
                   test_results['student_preds'][:, 1], 
                   alpha=0.5, s=20, label='Student', color='red')
axes[0, 1].plot([-1, 1], [-1, 1], 'k--', lw=2, label='Perfect')
axes[0, 1].set_xlabel('True Arousal')
axes[0, 1].set_ylabel('Predicted Arousal')
axes[0, 1].set_title(f'Arousal Predictions\nTeacher CCC: {test_results["teacher"]["ccc_arousal"]:.3f} | Student CCC: {test_results["student"]["ccc_arousal"]:.3f}')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Training curves
axes[1, 0].plot(distill_history['train_loss'], label='Total Loss', linewidth=2)
axes[1, 0].plot(distill_history['train_hard'], label='Hard Target', linewidth=2, alpha=0.7)
axes[1, 0].plot(distill_history['train_soft'], label='Soft Target', linewidth=2, alpha=0.7)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].set_title('Distillation Training Losses')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# CCC comparison
epochs = range(1, len(distill_history['val_student_ccc']) + 1)
axes[1, 1].plot(epochs, distill_history['val_teacher_ccc'], 
                label='Teacher CCC', linewidth=2, color='blue', marker='o')
axes[1, 1].plot(epochs, distill_history['val_student_ccc'], 
                label='Student CCC', linewidth=2, color='red', marker='s')
axes[1, 1].fill_between(epochs, distill_history['val_student_ccc'], 
                         distill_history['val_teacher_ccc'], alpha=0.2)
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Concordance Correlation Coefficient')
axes[1, 1].set_title('Student vs Teacher CCC During Training')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'distillation_results.png'), dpi=150, bbox_inches='tight')
plt.show()

print(f"\n‚úÖ Visualization saved to {OUTPUT_DIR}/distillation_results.png")

In [None]:
# Export model for Android deployment
print("=" * 80)
print("üì¶ EXPORTING MODEL FOR ANDROID DEPLOYMENT")
print("=" * 80)

# 1. Save PyTorch model
mobile_model_path = os.path.join(OUTPUT_DIR, 'mobile_vit_emotion_model.pth')
torch.save({
    'model_state_dict': mobile_student.state_dict(),
    'model_config': {
        'image_size': VIT_IMAGE_SIZE,
        'hidden_dim': 192,
        'num_layers': 4,
        'num_heads': 4,
        'num_classes': 2
    },
    'test_metrics': test_results['student'],
    'compression_ratio': teacher_params / student_params,
    'imagenet_mean': IMAGENET_MEAN,
    'imagenet_std': IMAGENET_STD
}, mobile_model_path)

print(f"‚úÖ Saved PyTorch model: {mobile_model_path}")
print(f"   Size: {os.path.getsize(mobile_model_path) / 1024**2:.2f} MB")

# 2. Export to TorchScript for mobile
print(f"\nüîß Converting to TorchScript...")
mobile_student.eval()

# Create example input
example_input = torch.randn(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(DEVICE)

# Trace the model
try:
    traced_model = torch.jit.trace(mobile_student, example_input)
    
    # Optimize for mobile
    traced_model_optimized = torch.jit.optimize_for_inference(traced_model)
    
    # Save TorchScript model
    torchscript_path = os.path.join(OUTPUT_DIR, 'mobile_vit_emotion_model.pt')
    traced_model_optimized.save(torchscript_path)
    
    print(f"‚úÖ Saved TorchScript model: {torchscript_path}")
    print(f"   Size: {os.path.getsize(torchscript_path) / 1024**2:.2f} MB")
    
    # Test TorchScript model
    with torch.no_grad():
        original_output = mobile_student(example_input)
        scripted_output = traced_model_optimized(example_input)
        max_diff = torch.max(torch.abs(original_output - scripted_output)).item()
    
    print(f"‚úÖ TorchScript verification: max diff = {max_diff:.6f}")
    
except Exception as e:
    print(f"‚ö†Ô∏è TorchScript export failed: {e}")
    print(f"   Continuing with PyTorch model only...")

# 3. Dynamic Quantization for even smaller size
print(f"\n‚ö° Applying dynamic quantization...")
try:
    quantized_model = torch.quantization.quantize_dynamic(
        mobile_student.cpu(),
        {nn.Linear, nn.MultiheadAttention},
        dtype=torch.qint8
    )
    
    # Save quantized model
    quantized_path = os.path.join(OUTPUT_DIR, 'mobile_vit_emotion_model_quantized.pth')
    torch.save({
        'model': quantized_model,
        'model_config': {
            'image_size': VIT_IMAGE_SIZE,
            'hidden_dim': 192,
            'num_layers': 4,
            'num_heads': 4,
            'num_classes': 2
        },
        'imagenet_mean': IMAGENET_MEAN,
        'imagenet_std': IMAGENET_STD
    }, quantized_path)
    
    print(f"‚úÖ Saved quantized model: {quantized_path}")
    print(f"   Size: {os.path.getsize(quantized_path) / 1024**2:.2f} MB")
    
    # Test quantized model
    quantized_model.eval()
    with torch.no_grad():
        quantized_output = quantized_model(example_input.cpu())
        max_diff_quantized = torch.max(torch.abs(original_output.cpu() - quantized_output)).item()
    
    print(f"‚úÖ Quantization verification: max diff = {max_diff_quantized:.6f}")
    
    # Move back to GPU if needed
    mobile_student.to(DEVICE)
    
except Exception as e:
    print(f"‚ö†Ô∏è Quantization failed: {e}")
    print(f"   Continuing without quantization...")

# 4. Create deployment package
print(f"\nüì¶ Creating deployment package...")

deployment_info = {
    'model_name': 'MobileViT Music Emotion Recognition',
    'version': '1.0.0',
    'description': 'Lightweight ViT for predicting valence and arousal from music spectrograms',
    'input_format': {
        'shape': [1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE],
        'type': 'float32',
        'range': 'ImageNet normalized',
        'mean': IMAGENET_MEAN,
        'std': IMAGENET_STD
    },
    'output_format': {
        'shape': [1, 2],
        'type': 'float32',
        'range': '[-1, 1]',
        'labels': ['valence', 'arousal']
    },
    'model_specs': {
        'parameters': student_params,
        'size_mb': os.path.getsize(mobile_model_path) / 1024**2,
        'layers': 4,
        'hidden_dim': 192,
        'compression_ratio': f'{teacher_params / student_params:.1f}x'
    },
    'performance': {
        'test_ccc_valence': test_results['student']['ccc_valence'],
        'test_ccc_arousal': test_results['student']['ccc_arousal'],
        'test_ccc_avg': test_results['student']['ccc_avg'],
        'test_mae_valence': test_results['student']['mae_valence'],
        'test_mae_arousal': test_results['student']['mae_arousal'],
        'retention_vs_teacher': f"{ccc_retention:.1f}%"
    },
    'android_requirements': {
        'min_ram': '2GB',
        'recommended_ram': '4GB',
        'min_android_version': '8.0 (API 26)',
        'pytorch_mobile_version': '1.13+',
        'estimated_inference_time': '50-100ms on Snapdragon 870'
    }
}

import json
deployment_info_path = os.path.join(OUTPUT_DIR, 'deployment_info.json')
with open(deployment_info_path, 'w') as f:
    json.dump(deployment_info, f, indent=2)

print(f"‚úÖ Saved deployment info: {deployment_info_path}")

# 5. Create inference example
inference_example = '''
# Example: Using the model on Android with PyTorch Mobile

import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.IValue;

// Load model
Module model = Module.load(assetFilePath("mobile_vit_emotion_model.pt"));

// Prepare input (melspectrogram as 224x224 RGB image)
float[][][][] input = preprocessMelspectrogram(melspec);  // Your preprocessing

// Create tensor
Tensor inputTensor = Tensor.fromBlob(
    input,
    new long[]{1, 3, 224, 224}
);

// Run inference
Tensor outputTensor = model.forward(IValue.from(inputTensor)).toTensor();
float[] emotions = outputTensor.getDataAsFloatArray();

// Get results
float valence = emotions[0];  // Range: [-1, 1]
float arousal = emotions[1];  // Range: [-1, 1]

// Convert to 1-9 scale if needed
float valence_scaled = (valence + 1) * 4 + 1;  // Maps [-1,1] to [1,9]
float arousal_scaled = (arousal + 1) * 4 + 1;
'''

example_path = os.path.join(OUTPUT_DIR, 'android_inference_example.java')
with open(example_path, 'w') as f:
    f.write(inference_example)

print(f"‚úÖ Saved inference example: {example_path}")

print(f"\n{'=' * 80}")
print(f"üì± ANDROID DEPLOYMENT SUMMARY")
print(f"{'=' * 80}")
print(f"Model Files:")
print(f"  ‚îú‚îÄ PyTorch:      mobile_vit_emotion_model.pth ({os.path.getsize(mobile_model_path) / 1024**2:.2f} MB)")
if os.path.exists(torchscript_path):
    print(f"  ‚îú‚îÄ TorchScript:  mobile_vit_emotion_model.pt ({os.path.getsize(torchscript_path) / 1024**2:.2f} MB)")
if os.path.exists(quantized_path):
    print(f"  ‚îú‚îÄ Quantized:    mobile_vit_emotion_model_quantized.pth ({os.path.getsize(quantized_path) / 1024**2:.2f} MB)")
print(f"  ‚îú‚îÄ Deployment:   deployment_info.json")
print(f"  ‚îî‚îÄ Example:      android_inference_example.java")

print(f"\nModel Specifications:")
print(f"  ‚îú‚îÄ Parameters:   {student_params:,}")
print(f"  ‚îú‚îÄ Compression:  {teacher_params / student_params:.1f}x smaller than teacher")
print(f"  ‚îú‚îÄ Performance:  {ccc_retention:.1f}% of teacher CCC")
print(f"  ‚îî‚îÄ Memory:       ~{student_params * 4 / 1024**2:.0f} MB (fp32)")

print(f"\nAndroid Requirements:")
print(f"  ‚îú‚îÄ Min RAM:      2GB")
print(f"  ‚îú‚îÄ Recommended:  4GB+")
print(f"  ‚îú‚îÄ Android:      8.0+ (API 26+)")
print(f"  ‚îî‚îÄ PyTorch:      1.13+ Mobile")

print(f"\nüéØ Recommended Usage:")
print(f"  1. Use TorchScript model (.pt) for production")
print(f"  2. Use quantized model for even lower memory devices")
print(f"  3. Implement mel-spectrogram preprocessing on device")
print(f"  4. Cache model to avoid repeated loading")
print(f"  5. Run inference on background thread")

print(f"{'=' * 80}")
print(f"‚úÖ Model export complete! Ready for Android deployment.")
print(f"{'=' * 80}")

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0, 0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Valence loss
axes[0, 1].plot(history['train_valence_loss'], label='Train Valence Loss', marker='o')
axes[0, 1].plot(history['val_valence_loss'], label='Val Valence Loss', marker='s')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Valence Loss')
axes[0, 1].set_title('Valence Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Arousal loss
axes[1, 0].plot(history['train_arousal_loss'], label='Train Arousal Loss', marker='o')
axes[1, 0].plot(history['val_arousal_loss'], label='Val Arousal Loss', marker='s')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Arousal Loss')
axes[1, 0].set_title('Arousal Loss')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# CCC scores
axes[1, 1].plot(history['val_valence_ccc'], label='Valence CCC', marker='o')
axes[1, 1].plot(history['val_arousal_ccc'], label='Arousal CCC', marker='s')
axes[1, 1].axhline(y=0.7, color='r', linestyle='--', alpha=0.5, label='Good threshold')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('CCC Score')
axes[1, 1].set_title('Concordance Correlation Coefficient')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('vit_training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("üìä Training curves saved to 'vit_training_curves.png'")

In [None]:
def evaluate_epoch(model, dataloader, criterion, device, epoch, phase='Val'):
    """Evaluate for one epoch"""
    model.eval()
    running_loss = 0.0
    running_valence_loss = 0.0
    running_arousal_loss = 0.0
    
    all_predictions = []
    all_labels = []
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1} [{phase}]')
    
    with torch.no_grad():
        for spectrograms, labels in progress_bar:
            spectrograms = spectrograms.to(device)
            labels = labels.to(device)
            
            # Forward pass
            predictions = model(spectrograms)
            loss, valence_loss, arousal_loss = criterion(predictions, labels)
            
            # Update metrics
            running_loss += loss.item()
            running_valence_loss += valence_loss.item()
            running_arousal_loss += arousal_loss.item()
            
            # Store predictions and labels
            all_predictions.append(predictions.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(dataloader)
    epoch_valence_loss = running_valence_loss / len(dataloader)
    epoch_arousal_loss = running_arousal_loss / len(dataloader)
    
    # Calculate CCC for valence and arousal
    all_predictions = np.concatenate(all_predictions, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    
    valence_ccc = concordance_correlation_coefficient(all_labels[:, 0], all_predictions[:, 0])
    arousal_ccc = concordance_correlation_coefficient(all_labels[:, 1], all_predictions[:, 1])
    
    return epoch_loss, epoch_valence_loss, epoch_arousal_loss, valence_ccc, arousal_ccc

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device, epoch):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    running_valence_loss = 0.0
    running_arousal_loss = 0.0
    
    progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1} [Train]')
    
    for batch_idx, (spectrograms, labels) in enumerate(progress_bar):
        spectrograms = spectrograms.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        predictions = model(spectrograms)
        loss, valence_loss, arousal_loss = criterion(predictions, labels)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Update metrics
        running_loss += loss.item()
        running_valence_loss += valence_loss.item()
        running_arousal_loss += arousal_loss.item()
        
        # Update progress bar
        if (batch_idx + 1) % 10 == 0:
            avg_loss = running_loss / (batch_idx + 1)
            progress_bar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'val_loss': f'{running_valence_loss/(batch_idx+1):.4f}',
                'aro_loss': f'{running_arousal_loss/(batch_idx+1):.4f}'
            })
    
    epoch_loss = running_loss / len(dataloader)
    epoch_valence_loss = running_valence_loss / len(dataloader)
    epoch_arousal_loss = running_arousal_loss / len(dataloader)
    
    return epoch_loss, epoch_valence_loss, epoch_arousal_loss

## üéì ViT Training Loop

Now we'll train the Vision Transformer on our combined dataset (real + GAN-augmented).

In [None]:
# Archive outputs
!zip -r /kaggle/working/vit_output.zip /kaggle/working/vit_augmented
print("‚úÖ Outputs archived to vit_output.zip")

## üß™ Comprehensive Model Testing

Perform thorough testing of the trained ViT model including edge cases and robustness evaluation.

In [None]:
def test_model_robustness(model, test_loader, device=DEVICE):
    """Test model robustness with various edge cases and perturbations."""
    print("üß™ Testing model robustness...")
    
    model.eval()
    
    # Test results storage
    test_results = {
        'normal_predictions': [],
        'noisy_predictions': [],
        'augmented_predictions': [],
        'targets': [],
        'confidence_scores': []
    }
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader, desc='Robustness Testing')):
            if i >= 10:  # Limit to first 10 batches for testing
                break
                
            inputs = batch['pixel_values'].to(device)
            targets = batch['emotions'].to(device)
            
            # 1. Normal prediction
            normal_output = model(inputs)
            
            # 2. Add noise and test
            noise = torch.randn_like(inputs) * 0.1
            noisy_inputs = torch.clamp(inputs + noise, 0, 1)
            noisy_output = model(noisy_inputs)
            
            # 3. Test with augmentation (random rotation)
            augmented_inputs = torch.roll(inputs, shifts=10, dims=-1)
            augmented_output = model(augmented_inputs)
            
            # Calculate confidence (inverse of prediction variance)
            confidence = 1.0 / (torch.var(normal_output, dim=1) + 1e-6)
            
            # Store results
            test_results['normal_predictions'].append(normal_output.cpu())
            test_results['noisy_predictions'].append(noisy_output.cpu())
            test_results['augmented_predictions'].append(augmented_output.cpu())
            test_results['targets'].append(targets.cpu())
            test_results['confidence_scores'].append(confidence.cpu())
    
    # Concatenate all results
    for key in test_results:
        if test_results[key]:
            test_results[key] = torch.cat(test_results[key], dim=0).numpy()
    
    return test_results

def analyze_prediction_patterns(test_results):
    """Analyze prediction patterns and model behavior."""
    print("\\nüìä Analyzing prediction patterns...")
    
    normal_pred = test_results['normal_predictions']
    noisy_pred = test_results['noisy_predictions']
    aug_pred = test_results['augmented_predictions']
    targets = test_results['targets']
    
    # Calculate robustness metrics
    noise_robustness = np.mean(np.abs(normal_pred - noisy_pred))
    aug_robustness = np.mean(np.abs(normal_pred - aug_pred))
    
    print(f"üîä Noise Robustness (MAE): {noise_robustness:.4f}")
    print(f"üîÑ Augmentation Robustness (MAE): {aug_robustness:.4f}")
    
    # Plot robustness analysis
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Prediction consistency
    axes[0, 0].scatter(normal_pred[:, 0], noisy_pred[:, 0], alpha=0.6, color='blue')
    axes[0, 0].plot([0, 1], [0, 1], 'r--', lw=2)
    axes[0, 0].set_xlabel('Normal Prediction (Valence)')
    axes[0, 0].set_ylabel('Noisy Prediction (Valence)')
    axes[0, 0].set_title('Noise Robustness - Valence')
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[0, 1].scatter(normal_pred[:, 1], noisy_pred[:, 1], alpha=0.6, color='red')
    axes[0, 1].plot([0, 1], [0, 1], 'r--', lw=2)
    axes[0, 1].set_xlabel('Normal Prediction (Arousal)')
    axes[0, 1].set_ylabel('Noisy Prediction (Arousal)')
    axes[0, 1].set_title('Noise Robustness - Arousal')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Prediction distribution
    axes[0, 2].hist(normal_pred.flatten(), bins=30, alpha=0.7, label='Normal', color='blue')
    axes[0, 2].hist(noisy_pred.flatten(), bins=30, alpha=0.7, label='Noisy', color='orange')
    axes[0, 2].set_xlabel('Prediction Value')
    axes[0, 2].set_ylabel('Frequency')
    axes[0, 2].set_title('Prediction Distribution')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Error analysis
    normal_error = np.abs(normal_pred - targets)
    noisy_error = np.abs(noisy_pred - targets)
    
    axes[1, 0].hist(normal_error[:, 0], bins=20, alpha=0.7, label='Normal', color='blue')
    axes[1, 0].hist(noisy_error[:, 0], bins=20, alpha=0.7, label='Noisy', color='orange')
    axes[1, 0].set_xlabel('Absolute Error')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Valence Error Distribution')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    axes[1, 1].hist(normal_error[:, 1], bins=20, alpha=0.7, label='Normal', color='blue')
    axes[1, 1].hist(noisy_error[:, 1], bins=20, alpha=0.7, label='Noisy', color='orange')
    axes[1, 1].set_xlabel('Absolute Error')
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].set_title('Arousal Error Distribution')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Confidence analysis
    confidence = test_results['confidence_scores']
    axes[1, 2].scatter(confidence, normal_error.mean(axis=1), alpha=0.6, color='green')
    axes[1, 2].set_xlabel('Confidence Score')
    axes[1, 2].set_ylabel('Prediction Error')
    axes[1, 2].set_title('Confidence vs Error')
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    return {
        'noise_robustness': noise_robustness,
        'augmentation_robustness': aug_robustness,
        'mean_confidence': np.mean(confidence)
    }

def test_edge_cases(model, device=DEVICE):
    """Test model behavior on edge cases."""
    print("\\nüö® Testing edge cases...")
    
    model.eval()
    edge_cases = {}
    
    with torch.no_grad():
        # Test with all zeros (silence)
        zeros_input = torch.zeros(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(device)
        zeros_pred = model(zeros_input)
        edge_cases['silence'] = zeros_pred.cpu().numpy()
        
        # Test with all ones (maximum intensity)
        ones_input = torch.ones(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(device)
        ones_pred = model(ones_input)
        edge_cases['maximum'] = ones_pred.cpu().numpy()
        
        # Test with random noise
        noise_input = torch.randn(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(device)
        noise_input = torch.clamp(noise_input, 0, 1)
        noise_pred = model(noise_input)
        edge_cases['noise'] = noise_pred.cpu().numpy()
        
        # Test with checkerboard pattern
        checker_input = torch.zeros(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE)
        checker_input[:, :, ::2, ::2] = 1
        checker_input[:, :, 1::2, 1::2] = 1
        checker_input = checker_input.to(device)
        checker_pred = model(checker_input)
        edge_cases['checkerboard'] = checker_pred.cpu().numpy()
    
    print("Edge case predictions:")
    for case, pred in edge_cases.items():
        valence, arousal = pred[0]
        print(f"  {case:12}: Valence={valence:.3f}, Arousal={arousal:.3f}")
    
    return edge_cases

def performance_benchmark(model, test_loader, device=DEVICE):
    """Benchmark model performance and timing."""
    print("\\n‚ö° Performance benchmarking...")
    
    model.eval()
    
    # Warmup
    dummy_input = torch.randn(1, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(device)
    for _ in range(5):
        _ = model(dummy_input)
    
    # Timing test
    import time
    times = []
    batch_sizes = [1, 4, 8, 16]
    
    for batch_size in batch_sizes:
        test_input = torch.randn(batch_size, 3, VIT_IMAGE_SIZE, VIT_IMAGE_SIZE).to(device)
        
        # Measure inference time
        torch.cuda.synchronize() if device.type == 'cuda' else None
        start_time = time.time()
        
        with torch.no_grad():
            for _ in range(10):  # Average over 10 runs
                _ = model(test_input)
        
        torch.cuda.synchronize() if device.type == 'cuda' else None
        end_time = time.time()
        
        avg_time = (end_time - start_time) / 10
        times.append(avg_time)
        
        print(f"  Batch size {batch_size:2d}: {avg_time:.4f}s ({batch_size/avg_time:.1f} samples/s)")
    
    # Memory usage
    if device.type == 'cuda':
        memory_usage = torch.cuda.max_memory_allocated() / 1024**2  # MB
        print(f"  Max GPU memory: {memory_usage:.1f} MB")
    
    return {'batch_sizes': batch_sizes, 'inference_times': times}

# Run comprehensive testing
if 'best_vit_model' in locals() and 'test_loader' in locals():
    print("üöÄ Starting comprehensive ViT model testing...")
    
    # Load best model
    try:
        best_model_path = '/kaggle/working/vit_augmented/best_vit_model.pth'
        if os.path.exists(best_model_path):
            best_vit_model.load_state_dict(torch.load(best_model_path))
            print("‚úÖ Best model loaded for testing")
    except:
        print("‚ö†Ô∏è Using current model state for testing")
    
    # 1. Robustness testing
    test_results = test_model_robustness(best_vit_model, test_loader)
    robustness_metrics = analyze_prediction_patterns(test_results)
    
    # 2. Edge case testing
    edge_results = test_edge_cases(best_vit_model)
    
    # 3. Performance benchmarking
    perf_results = performance_benchmark(best_vit_model, test_loader)
    
    # Summary report
    print("\\n" + "="*60)
    print("üìã COMPREHENSIVE TESTING SUMMARY")
    print("="*60)
    print(f"‚úÖ Robustness Testing:")
    print(f"   - Noise Robustness: {robustness_metrics['noise_robustness']:.4f}")
    print(f"   - Augmentation Robustness: {robustness_metrics['augmentation_robustness']:.4f}")
    print(f"   - Mean Confidence: {robustness_metrics['mean_confidence']:.4f}")
    print(f"\\n‚úÖ Edge Cases: All {len(edge_results)} test cases completed")
    print(f"\\n‚úÖ Performance: Benchmarked across {len(perf_results['batch_sizes'])} batch sizes")
    print("\\nüéâ All tests completed successfully!")
    print("="*60)
    
else:
    print("‚ö†Ô∏è Skipping comprehensive testing - model or test data not available")
    print("Please ensure the model is trained and test data is prepared.")