# Audio Spectrogram Transformer (AST) for Music Emotion Recognition

## Overview
This notebook trains an Audio Spectrogram Transformer (AST) model to predict **valence** and **arousal** from music using the **DEAM (Database for Emotional Analysis of Music)** dataset.

### What is AST?
Audio Spectrogram Transformer is a Vision Transformer (ViT) architecture adapted for audio classification. It:
- Converts audio to mel-spectrograms
- Splits spectrograms into patches
- Uses self-attention to learn temporal and spectral patterns
- Predicts emotional dimensions (valence and arousal)

### Dataset Requirements
This notebook expects the DEAM dataset to be available as a Kaggle dataset. 
**Dataset structure:**
```
/kaggle/input/deam-mediaeval-dataset-emotional-analysis-in-music/
├── DEAM_audio/
│   └── MEMD_audio/  (contains .mp3 files)
└── DEAM_Annotations/
    └── annotations/
        └── annotations averaged per song/
            └── song_level/
                ├── static_annotations_averaged_songs_1_2000.csv
                └── static_annotations_averaged_songs_2000_2058.csv
```

### Training Configuration
- **Model**: Vision Transformer adapted for audio
- **Input**: Mel-spectrograms (128 mel bins × 432 time steps)
- **Output**: Valence and Arousal predictions (1-9 scale)
- **Loss**: Mean Squared Error (MSE)
- **Metrics**: MSE, MAE, Concordance Correlation Coefficient (CCC)

# 🎓 Complete Educational Guide: Audio Spectrogram Transformer (AST)

---

## 📚 **Table of Contents**
1. [Fundamental Concepts](#fundamentals)
2. [Audio Processing Basics](#audio)
3. [Machine Learning Foundations](#ml)
4. [Transformer Architecture](#transformers)
5. [Training Process](#training)
6. [Evaluation Metrics](#metrics)

---

## 🌟 **What You'll Learn**

By the end of this notebook, you'll understand:
- How computers "see" and understand music
- What spectrograms are and why they matter
- How attention mechanisms work (the key to modern AI)
- How transformers process sequential data
- The mathematics behind emotion prediction
- How to train and evaluate deep learning models

**No Prerequisites Required!** We explain everything from first principles.

---

<a name="fundamentals"></a>
## 🎯 **Part 1: Fundamental Concepts**

### **What is Sound?**

**Simple Analogy**: Imagine throwing a stone into a pond. You see ripples spreading outward—these are **waves**. Sound works the same way, but instead of water, it's air molecules vibrating.

**Technical Definition**:
Sound is a **pressure wave** that travels through air (or any medium). When you pluck a guitar string:
1. The string vibrates back and forth
2. It pushes air molecules, creating regions of high pressure (compression) and low pressure (rarefaction)
3. These pressure changes travel through the air as a wave
4. Your ear detects these pressure changes as sound

---

### **Key Sound Properties**

#### 1. **Frequency (Pitch)** 🎵

**What it is**: How many times per second the wave completes one full cycle.
- **Unit**: Hertz (Hz) = cycles per second
- **Example**: 
  - Middle A on a piano = 440 Hz (vibrates 440 times/second)
  - Human hearing range: ~20 Hz to 20,000 Hz

**Analogy**: Think of a swing in a playground:
- **Fast swinging** (high frequency) = high pitch (like a whistle)
- **Slow swinging** (low frequency) = low pitch (like a bass drum)

**Mathematical Representation**:
A pure tone (sine wave) at frequency *f*:
```
y(t) = A × sin(2πft)
```
where:
- **A** = amplitude (volume/loudness)
- **f** = frequency (pitch)
- **t** = time
- **2π** = one complete circle (360 degrees in radians)

---

#### 2. **Amplitude (Loudness)** 🔊

**What it is**: The "height" of the wave—how much the air pressure changes.
- **High amplitude** = loud sound (strong pressure changes)
- **Low amplitude** = quiet sound (weak pressure changes)

**Analogy**: Ocean waves:
- **Tsunami** (huge wave) = very loud sound
- **Ripples** (tiny waves) = quiet sound

**Measurement**: Often measured in **decibels (dB)**:
- 0 dB = threshold of hearing (barely audible)
- 60 dB = normal conversation
- 120 dB = rock concert (painful!)

---

#### 3. **Timbre (Tone Quality)** 🎸

**What it is**: Why a piano sounds different from a guitar even when playing the same note.

**The Secret**: Real-world sounds are NOT pure sine waves. They're **complex combinations** of multiple frequencies:
- **Fundamental frequency**: The main pitch you hear
- **Overtones/Harmonics**: Additional frequencies that give character

**Example**: When you play middle C on a piano:
- **Fundamental**: 261.6 Hz (the note you identify)
- **2nd harmonic**: 523.2 Hz (2× fundamental)
- **3rd harmonic**: 784.8 Hz (3× fundamental)
- ...and many more!

The **mix** of these harmonics creates the piano's unique sound.

**Mathematical Representation** (Fourier Series):
Any complex sound can be decomposed into a sum of simple sine waves:
```
y(t) = A₁sin(2πf₁t) + A₂sin(2πf₂t) + A₃sin(2πf₃t) + ...
```

This is why we need **spectrograms**—to see all these frequencies at once!

---

### **Why Computers Need Special Representation**

**The Problem**: 
- Sound waves are continuous (infinite points in time)
- Computers work with discrete numbers (finite data)

**The Solution**: **Digital Audio**

1. **Sampling**: Measure the wave's amplitude at regular intervals
   - **Sample Rate**: How many measurements per second
   - CD quality: 44,100 Hz (44,100 samples/second)
   - Why 44.1k? Nyquist theorem says you need 2× the highest frequency you want to capture
   - Humans hear up to ~20kHz → need 40kHz sample rate (44.1 gives buffer)

2. **Quantization**: Convert each measurement to a number
   - **Bit Depth**: How precise each measurement is
   - 16-bit = 65,536 possible values (CD quality)
   - 24-bit = 16.7 million possible values (professional audio)

**Analogy**: It's like taking photographs of a moving car:
- **Frame rate** (sample rate) = how many photos per second
- **Camera resolution** (bit depth) = detail in each photo
- Fast frame rate + high resolution = smooth, detailed "recording" of motion

---

### **From Sound Waves to Spectrograms** 📊

**The Journey**:
```
Raw Audio Wave → Digital Samples → Frequency Analysis → Spectrogram → AI Model → Emotion Prediction
```

**Why Not Use Raw Audio Directly?**

1. **Too Much Information**: 5 seconds of audio at 44.1kHz = 220,500 numbers!
2. **Wrong Representation**: AI needs to "see" patterns. A wave is hard to interpret.
3. **Frequency is Key**: Musical emotion comes from harmony, melody, rhythm—all frequency-based!

**The Solution: Spectrograms**

Think of a spectrogram as "sheet music for computers":
- **X-axis**: Time (when things happen)
- **Y-axis**: Frequency (what notes are playing)
- **Color/Brightness**: Amplitude (how loud each frequency is)

**Analogy**: Imagine watching a piano performance from above:
- You see which keys are pressed (frequency)
- You see when they're pressed (time)
- You see how hard they're pressed (amplitude/color intensity)

This 2D image contains ALL the musical information!

<a name="audio"></a>
## 🎵 **Part 2: Audio Processing Deep Dive**

### **The Fourier Transform: From Time to Frequency** 🔬

**The Big Idea**: Any complex sound can be broken down into simple sine waves of different frequencies.

**Analogy**: Imagine a smoothie:
- **Time domain** (raw audio) = the mixed smoothie (you see the final blend)
- **Frequency domain** (after Fourier transform) = individual ingredients (strawberries, bananas, yogurt)

**The Mathematics**:

The **Fourier Transform** converts time-domain signal x(t) to frequency-domain X(f):

```
X(f) = ∫_{-∞}^{∞} x(t) × e^(-i2πft) dt
```

**What this means**:
- **x(t)**: Your audio signal over time
- **X(f)**: How much of each frequency *f* is present
- **e^(-i2πft)**: A complex exponential (represents rotation—relates to sine/cosine)

**For computers (Discrete Fourier Transform - DFT)**:
```
X[k] = Σ_{n=0}^{N-1} x[n] × e^(-i2πkn/N)
```
where:
- **n**: Sample index (time)
- **k**: Frequency bin
- **N**: Total number of samples

**In practice**: We use **Fast Fourier Transform (FFT)**—a clever algorithm that computes this in O(N log N) time instead of O(N²)!

---

### **The Short-Time Fourier Transform (STFT)** ⏱️

**The Problem**: Music changes over time! A single Fourier Transform tells you what frequencies exist, but not **when** they occur.

**The Solution**: STFT = "Take many small Fourier Transforms"

**How it Works**:
1. **Window**: Take a small chunk of audio (e.g., 25 milliseconds)
2. **Transform**: Apply FFT to that chunk
3. **Slide**: Move forward slightly (e.g., 10ms) and repeat
4. **Stack**: Arrange all results side-by-side → creates a spectrogram!

**Mathematical Definition**:
```
STFT{x[n]}(m, ω) = Σ_{n=-∞}^{∞} x[n] × w[n - m] × e^(-iωn)
```
where:
- **m**: Time frame index
- **ω**: Angular frequency
- **w[n]**: Window function (usually Hann or Hamming window)

**Parameters in This Notebook**:
- **n_fft = 2048**: FFT window size (46ms at 44.1kHz)
  - Determines frequency resolution
  - Larger = better frequency detail, worse time detail
- **hop_length = 512**: How much we slide the window (11.6ms)
  - Smaller = more time detail (but more computation)
  - Relation: `time_steps = total_samples / hop_length`

---

### **Mel-Spectrograms: Matching Human Hearing** 👂

**The Problem**: Humans don't hear frequencies linearly!
- We're very sensitive to differences between 100 Hz and 200 Hz
- We barely notice differences between 10,000 Hz and 10,100 Hz

**The Solution**: **Mel Scale** (from "melody")

**The Mel Scale Formula**:
```
mel(f) = 2595 × log₁₀(1 + f/700)
```

**Inverse (Mel to Hz)**:
```
f(mel) = 700 × (10^(mel/2595) - 1)
```

**What This Does**:
- Low frequencies: Spread out more (more detail)
- High frequencies: Compressed together (less detail)
- Matches how the human cochlea (inner ear) processes sound!

**Example**:
- 100 Hz → 150 mel
- 200 Hz → 283 mel (133 mel difference)
- 10,000 Hz → 3908 mel
- 10,100 Hz → 3914 mel (only 6 mel difference!)

**Creating Mel-Spectrograms**:

1. **Compute STFT**: Get frequency spectrum over time
2. **Power Spectrum**: Square the magnitude
   ```
   P[k, t] = |X[k, t]|²
   ```
3. **Mel Filter Banks**: Apply triangular filters spaced on the Mel scale
   ```
   M[m, t] = Σ_{k} H_m[k] × P[k, t]
   ```
   where H_m[k] are mel-spaced triangular filters

4. **Log Compression**: Convert to decibels
   ```
   M_dB[m, t] = 10 × log₁₀(M[m, t])
   ```
   Why log? Human perception of loudness is logarithmic!

**In This Notebook**:
- **n_mels = 128**: We create 128 mel-frequency bins
- Covers frequency range from **fmin = 20 Hz** to **fmax = 8000 Hz**
- Each bin captures energy in a specific frequency range

---

### **Normalization: Making Data ML-Friendly** 📐

**Why Normalize?**

1. **Different scales**: Some frequencies naturally louder than others
2. **Training stability**: Neural networks work best with data centered around 0
3. **Faster convergence**: Gradients flow better through normalized data

**Decibel (dB) Conversion**:
```
mel_spec_dB = 10 × log₁₀(mel_spec)
```
or with librosa:
```python
mel_spec_dB = librosa.power_to_db(mel_spec, ref=np.max)
```

**What `ref=np.max` means**:
- Scale relative to the maximum value in the spectrogram
- 0 dB = loudest point
- Negative values = quieter than maximum

**Z-Score Normalization** (used in training):
```
x_normalized = (x - mean) / std
```
where:
- **mean**: Average value across all spectrograms
- **std**: Standard deviation (measures spread)

**Result**: Data centered at 0 with standard deviation of 1

**Analogy**: Converting temperatures to a standard scale:
- Raw: Some in Celsius, some in Fahrenheit (hard to compare)
- Normalized: All in the same scale with 0 as average (easy to compare)

---

### **SpecAugment: Teaching with Incomplete Information** 🎭

**The Concept**: Data augmentation = artificial data generation to prevent overfitting.

**SpecAugment** applies two types of masking:

#### **1. Frequency Masking** 🎚️
```python
# Hide random frequency bands
f = random(0, freq_mask_param)  # e.g., 0-10 bins
f0 = random(0, n_mels - f)      # starting position
spectrogram[f0:f0+f, :] = 0     # set to zero
```

**Analogy**: Covering a random row of keys on a piano—the model learns to recognize music even without certain notes.

#### **2. Time Masking** ⏰
```python
# Hide random time segments
t = random(0, time_mask_param)  # e.g., 0-20 frames
t0 = random(0, time_steps - t)  # starting position
spectrogram[:, t0:t0+t] = 0     # set to zero
```

**Analogy**: Muting random moments in a song—the model learns to predict emotions even with gaps.

**Why This Works**:
- Forces model to use **all** parts of the spectrogram
- Prevents **overfitting** (memorizing training data)
- Improves **generalization** (works on new, unseen music)

**Research Origin**: Proposed by Google for speech recognition (Park et al., 2019)

<a name="ml"></a>
## 🤖 **Part 3: Machine Learning Foundations**

### **What is Machine Learning?** 🧠

**Traditional Programming**:
```
Rules + Data → Answers
```
Example: "If temperature > 30°C, say 'hot'"

**Machine Learning**:
```
Data + Answers → Rules
```
Example: Show 1000 examples of temperatures labeled "hot"/"cold" → computer learns the pattern!

---

### **Neural Networks: Inspired by the Brain** 🧬

**The Biological Analogy**:
- **Neurons**: Brain cells that process information
- **Synapses**: Connections between neurons (can be strong or weak)
- **Learning**: Strengthening/weakening connections based on experience

**Artificial Neural Networks**:

#### **The Artificial Neuron** (Perceptron)

```
     x₁ ──→ [w₁] ──┐
     x₂ ──→ [w₂] ──┤
     x₃ ──→ [w₃] ──→ Σ ──→ [activation] ──→ output
     ...            │
     xₙ ──→ [wₙ] ──┘
            +[bias]
```

**Mathematical Formula**:
```
output = activation(Σᵢ(wᵢ × xᵢ) + b)
```
where:
- **xᵢ**: Input features
- **wᵢ**: Weights (learned parameters)
- **b**: Bias (learned offset)
- **activation**: Non-linear function

**Step-by-Step Example**:

Suppose we're deciding if a sound is "happy":
```
Inputs:
  x₁ = tempo (fast/slow)           = 0.8
  x₂ = pitch (high/low)            = 0.6
  x₃ = energy (loud/quiet)         = 0.7

Weights (learned by training):
  w₁ = 2.0  (fast tempo → happy)
  w₂ = 1.5  (high pitch → happy)
  w₃ = 1.2  (high energy → happy)
  b  = -1.0 (bias)

Computation:
  z = (2.0 × 0.8) + (1.5 × 0.6) + (1.2 × 0.7) + (-1.0)
    = 1.6 + 0.9 + 0.84 - 1.0
    = 2.34

Activation (Sigmoid):
  output = 1 / (1 + e^(-2.34))
         ≈ 0.91  → 91% confident it's happy!
```

---

### **Activation Functions: Adding Non-Linearity** 📈

**Why Needed?**: Without activation functions, multiple layers = one layer! We need non-linearity to learn complex patterns.

#### **Common Activation Functions**:

**1. ReLU (Rectified Linear Unit)** - Most Popular
```
f(x) = max(0, x)
```
- **Graph**: Flat at 0, then linear diagonal
- **Behavior**: Blocks negative values, passes positive values
- **Used**: Hidden layers in most modern networks

**2. GELU (Gaussian Error Linear Unit)** - Used in Transformers
```
f(x) = x × Φ(x)
where Φ(x) = cumulative distribution function of standard normal
```
- **Smooth version** of ReLU
- **Why better**: Gradients flow better, fewer "dead neurons"
- **Used**: Our AST model uses GELU!

**3. Sigmoid** - Classic
```
f(x) = 1 / (1 + e^(-x))
```
- **Range**: (0, 1)
- **Used**: Binary classification, old networks

**4. Tanh** (Hyperbolic Tangent)
```
f(x) = (e^x - e^(-x)) / (e^x + e^(-x))
```
- **Range**: (-1, 1)
- **Better** than sigmoid (zero-centered)

**Comparison**:
```
Input:  -2    -1     0     1     2
ReLU:    0     0     0     1     2
GELU:  -0.05 -0.16   0   0.84  1.95
Sigmoid: 0.12  0.27  0.5  0.73  0.88
Tanh:   -0.96 -0.76  0   0.76  0.96
```

---

### **Deep Learning: Stacking Layers** 🏗️

**A Deep Neural Network**:
```
Input → [Layer 1] → [Layer 2] → [Layer 3] → ... → [Output]
         Hidden      Hidden       Hidden
```

Each layer learns **progressively abstract features**:

**Example: Image Recognition**:
- **Layer 1**: Edges and corners
- **Layer 2**: Simple shapes (circles, rectangles)
- **Layer 3**: Object parts (eyes, wheels, leaves)
- **Output**: Complete objects (cat, car, tree)

**For Audio/Music**:
- **Layer 1**: Basic frequency patterns, onsets
- **Layer 2**: Musical motifs, chord progressions
- **Layer 3**: Melodic patterns, rhythm structures
- **Output**: Emotional valence and arousal

---

### **Loss Functions: Measuring Mistakes** 📉

**The Concept**: How do we tell the model it's wrong? We need a **loss function** (also called cost or error function).

**Goal**: **Minimize** the loss = make better predictions

#### **Mean Squared Error (MSE)** - Used in This Notebook

**Formula**:
```
MSE = (1/N) × Σᵢ₌₁ᴺ (yᵢ - ŷᵢ)²
```
where:
- **N**: Number of predictions
- **yᵢ**: True value (actual emotion rating)
- **ŷᵢ**: Predicted value (model's guess)

**Example**:
```
True valence:      [7, 3, 8, 5, 6]
Predicted valence: [6, 4, 7, 5, 5]
Errors:            [1,-1, 1, 0, 1]
Squared errors:    [1, 1, 1, 0, 1]
MSE = (1+1+1+0+1) / 5 = 0.8
```

**Why Square?**:
1. **Positive errors**: -2 error same as +2 error (both bad)
2. **Penalizes large errors**: Error of 2 = 4 penalty, error of 4 = 16 penalty
3. **Smooth gradients**: Easy to differentiate mathematically

**Alternative: Mean Absolute Error (MAE)**:
```
MAE = (1/N) × Σᵢ₌₁ᴺ |yᵢ - ŷᵢ|
```
- Simpler: Just average of absolute errors
- **Less sensitive** to outliers than MSE

---

### **Backpropagation: Learning from Mistakes** 🔄

**The Learning Process**:

1. **Forward Pass**: Input → compute output → calculate loss
2. **Backward Pass**: Loss → compute gradients → update weights
3. **Repeat**: Until model performs well

**The Mathematics** (Chain Rule):

For a simple network: `Input → Layer1 → Layer2 → Output → Loss`

**Chain Rule**:
```
∂Loss/∂w₁ = ∂Loss/∂Output × ∂Output/∂Layer2 × ∂Layer2/∂Layer1 × ∂Layer1/∂w₁
```

**Intuition**: "If I change this weight slightly, how much does the loss change?"

**Gradient Descent Update**:
```
w_new = w_old - learning_rate × ∂Loss/∂w
```

**Analogy**: Hiking down a mountain in fog:
- **Loss**: Your altitude (want to minimize = reach valley)
- **Gradient**: The slope direction (steepest descent)
- **Learning rate**: How big your steps are
  - Too large: You might overshoot the valley
  - Too small: Takes forever to descend

---

### **Optimizers: Smart Ways to Update Weights** ⚡

#### **1. SGD (Stochastic Gradient Descent)** - Basic
```
w = w - lr × gradient
```
- **Stochastic**: Use small batches instead of all data (faster!)
- **Problem**: Can get stuck in local minima

#### **2. Momentum** - Add Speed
```
v = β × v + gradient        # velocity
w = w - lr × v
```
- **Idea**: Build momentum like a rolling ball
- **β**: Momentum coefficient (usually 0.9)
- **Advantage**: Accelerates in consistent directions, dampens oscillations

#### **3. Adam (Adaptive Moment Estimation)** - Most Popular
```
m = β₁ × m + (1 - β₁) × gradient           # First moment (mean)
v = β₂ × v + (1 - β₂) × gradient²          # Second moment (variance)
w = w - lr × m / (√v + ε)
```
- **Combines**: Momentum + adaptive learning rates
- **β₁**: Usually 0.9 (momentum)
- **β₂**: Usually 0.999 (RMSprop component)
- **ε**: Small constant for numerical stability (10⁻⁸)

#### **4. AdamW** - Used in This Notebook! 🌟
```
Same as Adam, but with weight decay:
w = w - lr × (m / (√v + ε) + λ × w)
```
- **λ**: Weight decay coefficient (0.01 in our code)
- **Purpose**: **Regularization** (prevents overfitting)
- **Effect**: Keeps weights small, forces model to learn simpler patterns

**Why AdamW is Better Than Adam**:
- Decouples weight decay from gradient updates
- Improves generalization
- Standard choice for transformer models

---

### **Learning Rate Scheduling** 📊

**The Problem**: Fixed learning rate isn't optimal:
- **Start**: Need large steps to explore
- **End**: Need small steps to fine-tune

**Cosine Annealing** (Used in This Notebook):
```
lr_t = lr_min + (lr_max - lr_min) × (1 + cos(πt/T)) / 2
```
where:
- **t**: Current epoch
- **T**: Total epochs
- **Result**: Smooth cosine curve from lr_max → lr_min

**Visualization**:
```
Epochs:    0    1    2    3    4    5
LR:      1e-4  9e-5  7e-5  5e-5  3e-5  1e-5
         ──────────────╲_____________
                        Smooth decay
```

**Advantages**:
- Smooth, predictable decay
- No hyperparameters to tune (unlike step decay)
- Works well for transformers

---

### **Overfitting vs. Underfitting** ⚖️

**Underfitting**: Model too simple (high error on training AND test)
```
🎯 Reality: Complex curve
📉 Model:   Straight line (doesn't fit)
```

**Good Fit**: Just right (low error on training AND test)
```
🎯 Reality: Complex curve
📈 Model:   Matching curve (fits well!)
```

**Overfitting**: Model too complex (low error on training, high on test)
```
🎯 Reality: Smooth curve with noise
📈 Model:   Wiggly line through every point (memorized noise!)
```

**How to Prevent Overfitting**:
1. **More data**: Harder to memorize 100,000 examples than 100
2. **Regularization**: Penalty for complex models (weight decay, dropout)
3. **Data augmentation**: Create variations (SpecAugment!)
4. **Early stopping**: Stop training when validation error increases
5. **Dropout**: Randomly "turn off" neurons during training

<a name="transformers"></a>
## 🔮 **Part 4: Transformer Architecture - The Revolution**

### **The Problem with Traditional Neural Networks** ⚠️

**Recurrent Neural Networks (RNNs)** were the standard for sequences (text, audio, time-series):

```
Input₁ → [RNN] → Output₁
          ↓
Input₂ → [RNN] → Output₂
          ↓
Input₃ → [RNN] → Output₃
```

**Problems**:
1. **Sequential Processing**: Must process one element at a time (slow!)
2. **Vanishing Gradients**: Hard to learn long-range dependencies
3. **No Parallelization**: Can't use GPUs efficiently

---

### **The Transformer Revolution** 🚀 (Vaswani et al., 2017)

**"Attention Is All You Need"** - Groundbreaking paper that changed AI forever!

**Key Idea**: Process **all** elements simultaneously using **self-attention**

**Why It Matters**:
- ✅ **Parallel**: Process entire sequence at once (100× faster)
- ✅ **Long-range**: Directly connects distant elements
- ✅ **Scalable**: Works with massive datasets
- ✅ **Versatile**: Powers GPT, BERT, ChatGPT, DALL-E, and our AST model!

---

### **Attention Mechanism: The Core Innovation** 💡

**The Question**: "Which parts of the input should I focus on?"

**Real-World Analogy**:

Imagine reading a book:
- Your **eyes scan** all words on the page (**keys**)
- Your **mind searches** for relevant information (**query**)
- You **pay attention** to important words (**attention weights**)
- You **extract meaning** from those words (**values**)

---

### **Self-Attention Mathematics** 🔢

**The Setup**:

For each position in the input, we create three vectors:
- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What information do I contain?"
- **Value (V)**: "What information do I provide?"

**Step-by-Step Process**:

#### **Step 1: Linear Projections**
```
Q = Input × W_Q    (transform input to query space)
K = Input × W_K    (transform input to key space)
V = Input × W_V    (transform input to value space)
```

where W_Q, W_K, W_V are learned weight matrices.

#### **Step 2: Compute Attention Scores**
```
Scores = (Q × K^T) / √d_k
```

**What's happening**:
- **Q × K^T**: Dot product measures similarity between query and keys
- **√d_k**: Scaling factor (d_k = dimension of key vectors)
- **Why scale?**: Prevents dot products from becoming too large

**Example** (with 3 input positions):
```
        K₁    K₂    K₃
Q₁  [  0.9   0.3   0.1  ]    ← Q₁ most similar to K₁
Q₂  [  0.2   0.8   0.4  ]    ← Q₂ most similar to K₂
Q₃  [  0.1   0.3   0.9  ]    ← Q₃ most similar to K₃
```

#### **Step 3: Softmax Normalization**
```
Attention_Weights = softmax(Scores)
```

**Softmax** converts scores to probabilities (sum = 1):
```
softmax([x₁, x₂, x₃]) = [e^x₁, e^x₂, e^x₃] / (e^x₁ + e^x₂ + e^x₃)
```

**After softmax**:
```
        V₁    V₂    V₃
Q₁  [  0.7   0.2   0.1  ]    ← 70% attention to V₁
Q₂  [  0.1   0.6   0.3  ]    ← 60% attention to V₂
Q₃  [  0.1   0.2   0.7  ]    ← 70% attention to V₃
```

#### **Step 4: Weighted Sum of Values**
```
Output = Attention_Weights × V
```

**Complete Formula**:
```
Attention(Q, K, V) = softmax((Q × K^T) / √d_k) × V
```

---

### **Multi-Head Attention: Multiple Perspectives** 👥

**The Idea**: Instead of one attention mechanism, use **multiple** in parallel!

**Analogy**: Reading a book with different goals:
- **Head 1**: Focus on plot (who did what?)
- **Head 2**: Focus on emotions (how do characters feel?)
- **Head 3**: Focus on setting (where/when does it happen?)
- **Head 4**: Focus on themes (what's the deeper meaning?)

**Mathematics**:
```
MultiHead(Q, K, V) = Concat(head₁, head₂, ..., headₕ) × W_O

where each head_i = Attention(Q×W_Qⁱ, K×W_Kⁱ, V×W_Vⁱ)
```

**In Our Model**:
- **num_heads = 4**: We use 4 attention heads
- **embed_dim = 256**: Total embedding dimension
- **head_dim = 256 / 4 = 64**: Each head works with 64 dimensions

**Why Multiple Heads?**:
1. Learn **different types** of patterns
2. **Redundancy**: If one head fails, others compensate
3. **Richer representations**: Capture multiple aspects simultaneously

---

### **The Complete Transformer Block** 🏗️

```
Input Embedding
      ↓
[Layer Normalization]
      ↓
[Multi-Head Self-Attention]  ← Learn relationships
      ↓
[Residual Connection] ─────┐  ← Skip connection
      ↓                     │
[Layer Normalization]       │
      ↓                     │
[Feed-Forward Network]      │  ← Process features
      ↓                     │
[Residual Connection] ←─────┘  ← Skip connection
      ↓
Output
```

---

### **Key Components Explained**

#### **1. Layer Normalization**

**Formula**:
```
LN(x) = γ × (x - μ) / σ + β
```
where:
- **μ**: Mean of features
- **σ**: Standard deviation
- **γ, β**: Learned parameters

**Purpose**: Stabilize training, prevent internal covariate shift

**Difference from Batch Normalization**:
- **Batch Norm**: Normalizes across batch dimension
- **Layer Norm**: Normalizes across feature dimension (better for sequences!)

---

#### **2. Residual Connections** (Skip Connections)

```
output = Layer(input) + input
```

**Why Critical**:
- **Gradient Flow**: Allows gradients to flow directly backward
- **Deep Networks**: Enables training networks with 100+ layers
- **Identity Mapping**: Model can learn to "do nothing" if needed

**Analogy**: Like a highway bypass:
- **Main road**: Goes through the layer (can transform)
- **Bypass**: Skips the layer (preserves original)
- **Result**: Best of both worlds!

---

#### **3. Feed-Forward Network (FFN)**

```
FFN(x) = GELU(x × W₁ + b₁) × W₂ + b₂
```

**Structure**:
```
Input (256 dim)
    ↓
[Linear: 256 → 1024]    ← Expansion (mlp_ratio = 4)
    ↓
[GELU Activation]
    ↓
[Dropout]
    ↓
[Linear: 1024 → 256]    ← Compression
    ↓
Output (256 dim)
```

**Purpose**:
- **Non-linear transformation**: Process attended features
- **Position-wise**: Applied independently to each position
- **Expansion-compression**: Increase capacity, then project back

---

### **Positional Encoding: Adding Order Information** 📍

**The Problem**: Self-attention has **no notion of order**!
- "dog bites man" vs. "man bites dog" look the same to attention!

**The Solution**: Add **positional information** to embeddings

**Two Approaches**:

#### **1. Sinusoidal (Original Transformer)**:
```
PE(pos, 2i)   = sin(pos / 10000^(2i/d))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d))
```
- **pos**: Position in sequence
- **i**: Dimension index
- **d**: Embedding dimension

**Properties**:
- Deterministic (no learning required)
- Can generalize to unseen sequence lengths
- Encodes relative positions

#### **2. Learned Positional Embeddings** (Used in Our AST!):
```
pos_embed = Parameter(torch.zeros(1, num_patches + 1, embed_dim))
```
- **Learned**: Trained along with model
- **Fixed length**: Works for specific input size
- **Often performs better** for fixed-size inputs like images/spectrograms

---

### **Vision Transformer (ViT) → Audio Spectrogram Transformer (AST)** 🎨→🎵

**ViT Innovation** (Dosovitskiy et al., 2020):
Apply transformers to **images** by treating them as sequences of patches!

**Our Adaptation** (AST for Audio):
Treat **spectrograms** as images → apply ViT!

```
Spectrogram (128 × 432)
        ↓
Split into patches (16 × 16)
        ↓
216 patches (8 height × 27 width)
        ↓
Flatten each patch → 256-dim vector
        ↓
Add [CLS] token (for classification)
        ↓
Add positional embeddings
        ↓
217 tokens → Transformer Encoder
        ↓
Extract [CLS] token
        ↓
Regression head → Valence & Arousal
```

---

### **Patch Embedding: From Pixels to Tokens** 🧩

**Goal**: Convert 2D spectrogram patches into 1D vectors

**Method**: Convolutional layer with stride = patch_size

```python
self.patch_embed = nn.Conv2d(
    in_channels=1,        # Grayscale spectrogram
    out_channels=256,     # Embedding dimension
    kernel_size=16,       # Patch size (16×16)
    stride=16             # No overlap between patches
)
```

**What This Does**:

**Input**: (Batch, 1, 128, 432)
**Operation**: 16×16 convolution with 256 filters, stride 16
**Output**: (Batch, 256, 8, 27)  ← 8×27 = 216 patches, each 256-dim

**Then flatten**: (Batch, 256, 8, 27) → (Batch, 216, 256)

**Analogy**: Imagine a photo booth taking 216 separate photos of different parts of the spectrogram, each photo captured as a 256-number description.

---

### **The [CLS] Token: A Special Token for Classification** 👑

**Origin**: Borrowed from BERT (Bidirectional Encoder Representations from Transformers)

**Concept**:
```
[CLS] + Patch₁ + Patch₂ + ... + Patch₂₁₆
  ↓
Transformer (all tokens interact)
  ↓
Extract [CLS] → it now contains global information!
```

**Why It Works**:
- [CLS] has **no prior information** (starts as zeros)
- Through self-attention, it **aggregates** information from all patches
- Acts as a **summary** of the entire spectrogram
- Perfect for making a global prediction (valence/arousal)!

**Initialization**:
```python
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.normal_(self.cls_token, std=0.02)  # Small random noise
```

---

### **Parameter Count: Understanding Model Size** 📊

**Our AST Model Parameters**:

```
Component                   Parameters
────────────────────────────────────────
Patch Embedding            65,536
  (1×256, kernel 16×16)
  
CLS Token                  256

Positional Embeddings      55,552
  (217 positions × 256 dim)
  
Transformer Blocks (4×):   ~2.6M each
  - Multi-Head Attention:  788,480
    • Q, K, V projections: 196,608 each
    • Output projection:   65,536
  - Layer Norms:           1,024
  - FFN:                   1,835,008
    • Linear 256→1024:     262,144
    • Linear 1024→256:     262,144
    
Regression Head            33,282
  - Linear 256→128:        32,768
  - Linear 128→2:          258
────────────────────────────────────────
TOTAL:                     ~10.5 Million parameters
```

**What This Means**:
- **10.5M parameters**: 10.5 million numbers to learn!
- **Memory**: ~42 MB (float32) or ~21 MB (float16)
- **Computation**: ~2.5 GFLOPs per forward pass

**Is 10.5M parameters a lot?**
- **Tiny** compared to GPT-3 (175 billion parameters)
- **Medium** for audio/vision tasks
- **Manageable** on consumer GPUs

<a name="training"></a>
## 🏋️ **Part 5: The Training Process - Teaching the Model**

### **The Training Loop: A Complete Cycle** 🔄

```
FOR each epoch (1 to NUM_EPOCHS):
    FOR each batch in training data:
        1. Forward Pass    → Get predictions
        2. Compute Loss    → Measure error
        3. Backward Pass   → Compute gradients
        4. Update Weights  → Improve model
    
    Validate on validation set
    Save best model
```

---

### **Step 1: Forward Pass** ➡️

**The Journey of Data Through Our Model**:

```
Input: Audio file (45 seconds MP3)
    ↓
Load & Resample → 44,100 Hz
    ↓
Extract 5-second segments
    ↓
Compute Mel-Spectrogram → (128 × 432)
    ↓
Normalize → Mean=0, Std=1
    ↓
Apply SpecAugment (training only)
    ↓
Patch Embedding → 216 patches of 256-dim
    ↓
Add [CLS] token → 217 tokens
    ↓
Add Positional Embeddings
    ↓
Transformer Layer 1 → Self-Attention + FFN
    ↓
Transformer Layer 2 → Self-Attention + FFN
    ↓
Transformer Layer 3 → Self-Attention + FFN
    ↓
Transformer Layer 4 → Self-Attention + FFN
    ↓
Extract [CLS] token
    ↓
Regression Head → (valence, arousal)
    ↓
Output: [7.2, 5.8]  (predicted emotions!)
```

**Matrix Dimensions at Each Step**:

```
Operation                        Shape
────────────────────────────────────────────────
Input spectrogram               (batch, 1, 128, 432)
After patch embedding           (batch, 216, 256)
Add CLS token                   (batch, 217, 256)
Add positional embeddings       (batch, 217, 256)  [no shape change]
After each transformer block    (batch, 217, 256)  [no shape change]
Extract CLS token               (batch, 256)
After FC layer 256→128          (batch, 128)
Final output                    (batch, 2)         [valence, arousal]
```

---

### **Step 2: Loss Computation** 📉

**Ground Truth vs. Prediction**:

```
Sample 1:
  True:      [Valence: 7.5, Arousal: 6.2]
  Predicted: [Valence: 7.2, Arousal: 5.8]
  Errors:    [        -0.3,          -0.4]

Sample 2:
  True:      [Valence: 3.1, Arousal: 4.5]
  Predicted: [Valence: 3.8, Arousal: 4.9]
  Errors:    [        +0.7,          +0.4]
```

**Mean Squared Error (MSE)**:
```
MSE = (1/N) × Σ(y_true - y_pred)²

For Sample 1:
  MSE_1 = ((-0.3)² + (-0.4)²) / 2
        = (0.09 + 0.16) / 2
        = 0.125

For Sample 2:
  MSE_2 = ((0.7)² + (0.4)²) / 2
        = (0.49 + 0.16) / 2
        = 0.325

Batch MSE = (0.125 + 0.325) / 2 = 0.225
```

**In PyTorch**:
```python
criterion = nn.MSELoss()
loss = criterion(predictions, targets)
```

---

### **Step 3: Backpropagation** ⬅️

**Goal**: Compute ∂Loss/∂w for every parameter w

**The Chain Rule in Action**:

```
∂Loss/∂w₁ = ∂Loss/∂output × ∂output/∂layer₄ × ∂layer₄/∂layer₃ × ∂layer₃/∂layer₂ × ∂layer₂/∂layer₁ × ∂layer₁/∂w₁
```

**Visualizing Gradient Flow**:

```
Forward Pass →
Input ──→ Layer1 ──→ Layer2 ──→ Layer3 ──→ Output ──→ Loss
     ↑         ↑         ↑         ↑          ↑
     └─────────┴─────────┴─────────┴──────────┘
     ← Backward Pass (Gradients)
```

**PyTorch Magic**:
```python
loss.backward()  # Automatically computes all gradients!
```

**What Happens Internally**:
1. PyTorch builds a **computational graph** during forward pass
2. Records every operation (multiply, add, activation, etc.)
3. During `.backward()`, traverses graph in reverse
4. Applies chain rule automatically
5. Stores gradients in `.grad` attribute of each parameter

---

### **Step 4: Gradient Descent Update** 📈

**AdamW Update (Simplified)**:

```python
# For each parameter w:
# 1. Get gradient
g = w.grad

# 2. Update first moment (momentum)
m = β₁ * m + (1 - β₁) * g

# 3. Update second moment (variance)
v = β₂ * v + (1 - β₂) * g²

# 4. Bias correction
m_hat = m / (1 - β₁^t)
v_hat = v / (1 - β₂^t)

# 5. Apply update with weight decay
w = w - lr * (m_hat / (√v_hat + ε) + λ * w)
```

**Numerical Example**:

```
Parameter w = 0.5
Gradient g = 0.1
Learning rate lr = 0.001
Weight decay λ = 0.01

Adam component:
  m_hat = 0.09 (smoothed gradient)
  v_hat = 0.01 (smoothed variance)
  adam_update = 0.09 / √0.01 = 0.9

Weight decay component:
  decay = 0.01 * 0.5 = 0.005

Total update:
  w_new = 0.5 - 0.001 * (0.9 + 0.005)
        = 0.5 - 0.000905
        = 0.499095
```

**Why This Works**:
- **Momentum (m)**: Accelerates in consistent directions
- **Adaptive rates (v)**: Larger updates for rarely-changing params
- **Weight decay**: Prevents weights from growing too large

---

### **Gradient Clipping: Preventing Explosions** 💣

**The Problem**: Gradients can become extremely large → unstable training

**The Solution**:
```python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
```

**What It Does**:
```
If ||gradients|| > max_norm:
    gradients = gradients * (max_norm / ||gradients||)
```

**Analogy**: Speed limit on a highway:
- If gradients "drive" too fast (large norm)
- Scale them down to maximum allowed speed
- Prevents "crashes" (training divergence)

---

### **Batch Processing: Efficiency Through Parallelism** 🚄

**Why Batches?**

**Bad: Processing one sample at a time**:
```
FOR each sample in 10,000 samples:
    Forward pass (1 sample)     → GPU mostly idle
    Compute loss (1 sample)
    Backward pass (1 sample)
    Update weights
Total time: ~10 hours
```

**Good: Processing batches**:
```
FOR each batch of 16 samples:
    Forward pass (16 samples)   → GPU fully utilized!
    Compute loss (16 samples)
    Backward pass (16 samples)
    Update weights
Total time: ~40 minutes (15× faster!)
```

**Batch Size Trade-offs**:

**Small Batches** (e.g., 8):
- ✅ More frequent updates (faster learning)
- ✅ Better generalization (more noise)
- ✅ Less memory required
- ❌ Less efficient GPU utilization
- ❌ Noisier gradients

**Large Batches** (e.g., 64):
- ✅ Better GPU utilization (faster training)
- ✅ More stable gradients
- ❌ Fewer updates per epoch
- ❌ May overfit more easily
- ❌ Requires more GPU memory

**Our Choice: batch_size = 16**
- Balanced for most GPUs
- Good trade-off between speed and stability

---

### **Train/Validation/Test Split: Honest Evaluation** 📊

**Why We Need Three Sets**:

```
FULL DATASET (100%)
    ↓
    ├─→ Training Set (80%)      → Model learns from this
    ├─→ Validation Set (10%)    → Tune hyperparameters, select best model
    └─→ Test Set (10%)          → Final evaluation (never seen during training!)
```

**Analogy**: Learning for an exam:
- **Training**: Practice problems you study from
- **Validation**: Practice exam to check progress
- **Test**: Actual exam (truly measures what you learned!)

**Critical Rule**: 🚨 **NEVER** use test set during training or model selection!

**Why It Matters**:

**Bad Practice**:
```
Try model A → Test accuracy: 85%
Try model B → Test accuracy: 87%  ← Pick this!
Try model C → Test accuracy: 84%

Report: "Our model achieves 87% on test set"
Problem: You optimized FOR the test set! (data leakage)
```

**Good Practice**:
```
Try model A → Val accuracy: 85%
Try model B → Val accuracy: 87%  ← Pick this!
Try model C → Val accuracy: 84%

Finally evaluate model B on test set → 86%
Report: "Our model achieves 86% on test set" ✓
```

---

### **Epochs: Multiple Passes Through Data** 🔁

**One Epoch** = One complete pass through all training data

**Why Multiple Epochs?**

**After 1 Epoch**:
```
Model: "I've seen each song once"
Performance: 60% accuracy (still learning basic patterns)
```

**After 5 Epochs**:
```
Model: "I've seen each song 5 times"
Performance: 80% accuracy (learned most patterns)
```

**After 15 Epochs**:
```
Model: "I've seen each song 15 times"
Performance: 85% accuracy (refined understanding)
```

**After 50 Epochs** (too many!):
```
Model: "I've memorized the training songs!"
Train Performance: 98% (great!)
Test Performance: 75% (worse than epoch 15!)
Problem: OVERFITTING
```

**Typical Training Curve**:

```
Accuracy
    ↑
100%┤                     ╱───── Training (overfitting)
    │                   ╱
 80%┤              ╭───╯─────── Validation (best)
    │          ╭──╯      ╲
 60%┤      ╭──╯            ╲──── Validation (declining)
    │  ╭──╯
 40%┤─╯
    └────────────────────────→ Epochs
     0  5  10  15  20  25  30

Optimal stopping point: ~15 epochs
```

---

### **Early Stopping: Knowing When to Stop** ⏸️

**The Strategy**:
```python
best_val_loss = infinity

for epoch in epochs:
    train_loss = train_one_epoch()
    val_loss = validate()
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_model()  # ✓ New best!
        patience_counter = 0
    else:
        patience_counter += 1
    
    if patience_counter >= patience_limit:
        print("Early stopping!")
        break  # Stop training
```

**Example**:
```
Epoch  Val Loss  Action
────────────────────────────
  1     1.250    Save (best so far)
  2     1.100    Save (improved!)
  3     1.050    Save (still improving)
  4     1.080    Don't save (worse)
  5     1.120    Don't save (worse)
  6     1.150    Don't save (worse) ← patience = 3, STOP!

Final model: Epoch 3 (val_loss = 1.050)
```

**Benefits**:
- Prevents overfitting
- Saves computational resources
- Automatically finds optimal training duration

<a name="metrics"></a>
## 📏 **Part 6: Evaluation Metrics - Measuring Success**

### **Why Multiple Metrics?** 🎯

Different metrics reveal different aspects of model performance:
- **MSE**: Penalizes large errors heavily
- **MAE**: Treats all errors equally
- **CCC**: Measures agreement (considers both accuracy and correlation)

**Analogy**: Grading a student:
- **MSE**: Big mistakes hurt grade a lot (failed exam → F)
- **MAE**: Average score across all assignments
- **CCC**: Consistency between effort and results

---

### **1. Mean Squared Error (MSE)** 📉

**Formula**:
```
MSE = (1/N) × Σᵢ₌₁ᴺ (yᵢ - ŷᵢ)²
```

**Step-by-Step Calculation**:

```
Predictions:  [7.2, 5.8, 3.1, 8.5, 4.9]
Ground Truth: [7.5, 6.2, 3.0, 8.0, 5.0]

Step 1: Compute errors
Errors:       [-0.3, -0.4, 0.1, 0.5, -0.1]

Step 2: Square each error
Squared:      [0.09, 0.16, 0.01, 0.25, 0.01]

Step 3: Average
MSE = (0.09 + 0.16 + 0.01 + 0.25 + 0.01) / 5
    = 0.52 / 5
    = 0.104
```

**Interpretation**:
- **Lower is better** (0 = perfect predictions)
- **Units**: Squared units of measurement (if predicting 1-9 scale, MSE in range [0, 64])
- **Sensitivity**: Very sensitive to outliers

**Why Square?**:
1. **Positive**: Errors in both directions penalized equally
2. **Differentiable**: Smooth gradient for optimization
3. **Outlier penalty**: Large errors (3.0) penalized much more than small errors (0.3)

---

### **2. Root Mean Squared Error (RMSE)** 📐

**Formula**:
```
RMSE = √MSE = √[(1/N) × Σᵢ₌₁ᴺ (yᵢ - ŷᵢ)²]
```

**Continuing Example**:
```
RMSE = √0.104 = 0.322
```

**Interpretation**:
- **Same units** as original measurements (if 1-9 scale, RMSE also 1-9)
- **Easier to interpret** than MSE
- **Example**: RMSE = 0.5 means "on average, predictions off by ±0.5 points"

---

### **3. Mean Absolute Error (MAE)** 📊

**Formula**:
```
MAE = (1/N) × Σᵢ₌₁ᴺ |yᵢ - ŷᵢ|
```

**Step-by-Step Calculation**:

```
Predictions:  [7.2, 5.8, 3.1, 8.5, 4.9]
Ground Truth: [7.5, 6.2, 3.0, 8.0, 5.0]

Step 1: Compute errors
Errors:       [-0.3, -0.4, 0.1, 0.5, -0.1]

Step 2: Absolute value
Absolute:     [0.3, 0.4, 0.1, 0.5, 0.1]

Step 3: Average
MAE = (0.3 + 0.4 + 0.1 + 0.5 + 0.1) / 5
    = 1.4 / 5
    = 0.28
```

**Interpretation**:
- **Lower is better** (0 = perfect)
- **Same units** as measurements
- **Direct meaning**: "Average absolute error is 0.28 points"

**MSE vs. MAE Comparison**:

```
Errors:      [0.1, 0.1, 0.1, 5.0]

MAE = (0.1 + 0.1 + 0.1 + 5.0) / 4 = 1.325
MSE = (0.01 + 0.01 + 0.01 + 25) / 4 = 6.26

The outlier (5.0) has huge impact on MSE!
```

**When to Use**:
- **MAE**: When all errors matter equally (robust to outliers)
- **MSE**: When large errors are particularly bad (penalize outliers)

---

### **4. Concordance Correlation Coefficient (CCC)** 🎯

**The Gold Standard for Continuous Prediction Evaluation**

**What It Measures**: Agreement between predictions and truth

**Formula**:
```
CCC = (2 × ρ × σ_y × σ_ŷ) / (σ_y² + σ_ŷ² + (μ_y - μ_ŷ)²)
```

where:
- **ρ** = Pearson correlation coefficient
- **σ_y** = Standard deviation of ground truth
- **σ_ŷ** = Standard deviation of predictions
- **μ_y** = Mean of ground truth
- **μ_ŷ** = Mean of predictions

---

### **CCC: Component Breakdown** 🔍

#### **Component 1: Pearson Correlation (ρ)**

**Formula**:
```
ρ = Cov(y, ŷ) / (σ_y × σ_ŷ)
```

**What it measures**: Linear relationship strength
- **ρ = +1**: Perfect positive correlation
- **ρ = 0**: No correlation
- **ρ = -1**: Perfect negative correlation

**Example Calculation**:
```
Ground Truth: [3, 5, 7, 4, 6]    Mean: 5.0, Std: 1.41
Predictions:  [3, 6, 8, 4, 7]    Mean: 5.6, Std: 1.85

Covariance:
  Cov = Σ[(y - μ_y)(ŷ - μ_ŷ)] / N
      = [(3-5)(3-5.6) + (5-5)(6-5.6) + (7-5)(8-5.6) + (4-5)(4-5.6) + (6-5)(7-5.6)] / 5
      = [5.2 + 0 + 4.8 + 1.6 + 1.4] / 5
      = 2.6

Correlation:
  ρ = 2.6 / (1.41 × 1.85) = 2.6 / 2.61 = 0.996
```

#### **Component 2: Location Shift (μ_y - μ_ŷ)**

**What it measures**: Systematic bias
```
If μ_y = 5.0 and μ_ŷ = 5.6:
  Location shift = 5.0 - 5.6 = -0.6
  Interpretation: Predictions systematically 0.6 points higher
```

#### **Component 3: Scale Difference (σ_y vs. σ_ŷ)**

**What it measures**: Spread difference
```
If σ_y = 1.41 and σ_ŷ = 1.85:
  Interpretation: Predictions have wider range than truth
```

---

### **CCC: Complete Calculation Example** 📝

```
Ground Truth: [3, 5, 7, 4, 6]
Predictions:  [3, 6, 8, 4, 7]

Step 1: Compute statistics
  μ_y = 5.0,    μ_ŷ = 5.6
  σ_y = 1.41,   σ_ŷ = 1.85
  ρ = 0.996 (from above)

Step 2: Apply CCC formula
  Numerator = 2 × ρ × σ_y × σ_ŷ
            = 2 × 0.996 × 1.41 × 1.85
            = 5.19

  Denominator = σ_y² + σ_ŷ² + (μ_y - μ_ŷ)²
              = 1.41² + 1.85² + (5.0 - 5.6)²
              = 1.99 + 3.42 + 0.36
              = 5.77

  CCC = 5.19 / 5.77 = 0.90
```

**Interpretation**:
- **CCC = 1.00**: Perfect agreement
- **CCC = 0.90**: Excellent agreement (our example)
- **CCC = 0.70**: Good agreement
- **CCC = 0.50**: Moderate agreement
- **CCC = 0**: No agreement
- **CCC < 0**: Worse than no agreement

---

### **CCC vs. Pearson Correlation: Key Difference** ⚡

**Pearson Correlation**: Measures **relationship**, ignores location and scale

**Example 1**:
```
Ground Truth: [1, 2, 3, 4, 5]
Predictions:  [2, 3, 4, 5, 6]    ← Shifted by +1

Pearson ρ = 1.00  (perfect correlation!)
CCC = 0.98        (penalized for shift)
```

**Example 2**:
```
Ground Truth: [1, 2, 3, 4, 5]
Predictions:  [2, 4, 6, 8, 10]   ← Scaled by 2×

Pearson ρ = 1.00  (perfect correlation!)
CCC = 0.76        (penalized for scale difference)
```

**Why CCC is Better for Prediction**:
- Pearson: "Are they related?"
- CCC: "Do they match?"

For prediction tasks, we want **matching**, not just **relationship**!

---

### **Valence and Arousal: Two-Dimensional Emotion** 🎭

**The Circumplex Model of Emotion** (Russell, 1980):

```
        Arousal
           ↑
    Angry  │  Excited
     😠    │    😃
           │
───────────┼─────────→ Valence
    Sad    │    Calm
     😢    │    😌
           ↓
```

**Valence (X-axis)**:
- **Low (1-3)**: Negative, unpleasant, sad
- **Mid (4-6)**: Neutral
- **High (7-9)**: Positive, pleasant, happy

**Arousal (Y-axis)**:
- **Low (1-3)**: Calm, relaxed, sleepy
- **Mid (4-6)**: Moderate energy
- **High (7-9)**: Excited, energetic, alert

**Example Emotions**:
```
Valence  Arousal  Emotion
───────────────────────────
  2        8      Angry, Tense
  8        8      Excited, Joyful
  2        2      Sad, Depressed
  8        2      Calm, Content
  5        5      Neutral, Ambivalent
```

**Why Two Dimensions?**:
- **Richer** than single "happiness" score
- **Distinguishes**: Calm-happy vs. Excited-happy
- **Psychological basis**: Matches how humans perceive emotions

---

### **Interpreting Our Metrics in Context** 🎼

**Example Results**:
```
Valence Prediction:
  MSE: 0.85    ← Average squared error
  MAE: 0.72    ← Average error is ±0.72 points on 1-9 scale
  CCC: 0.81    ← 81% agreement (good!)

Arousal Prediction:
  MSE: 1.20    ← Slightly worse than valence
  MAE: 0.89    ← Average error is ±0.89 points
  CCC: 0.73    ← 73% agreement (decent)
```

**What This Means**:
- On average, valence predictions off by **less than 1 point**
- Arousal is **harder to predict** (more subjective)
- Model has **good agreement** with human ratings

**Real-World Comparison**:
```
Human-Human Agreement:  CCC ≈ 0.85-0.90  (inter-rater reliability)
Our Model:              CCC ≈ 0.73-0.81
Simple Baseline:        CCC ≈ 0.20-0.30  (predicting mean)
```

Our model performs **reasonably close** to human-level agreement!

---

### **Visualizing Predictions** 📊

**1. Scatter Plot** (Predicted vs. Actual):

```
Perfect predictions lie on diagonal line:
  y = x  (predicted = actual)

Points above line: Over-predictions
Points below line: Under-predictions

Tighter clustering around line = better model
```

**2. Residual Plot** (Error vs. Actual):

```
Residual = Predicted - Actual

Ideal: Random scatter around y=0
Bad: Systematic pattern (indicates bias)
```

**3. Valence-Arousal Space**:

```
Compare distribution of predictions to ground truth:
  Should cover same emotional space
  Should not cluster in one region
```

---

## ✅ **Summary: You Now Understand!**

**From Sound to Emotion**:
1. 🎵 **Audio waves** → digitized samples
2. 🔬 **Fourier Transform** → frequency spectrum
3. 📊 **Mel-spectrogram** → human-perceptual representation
4. 🧩 **Patches** → divide into manageable chunks
5. 🤖 **Transformer** → learn patterns via self-attention
6. 🎭 **Regression** → predict valence and arousal
7. 📏 **Evaluation** → measure with MSE, MAE, CCC

**The Magic of Transformers**:
- **Self-attention**: "Look at all parts, focus on important ones"
- **Parallel processing**: Fast training on GPUs
- **Scalability**: Works from small to massive datasets

**Training is Optimization**:
- **Forward pass**: Make predictions
- **Loss function**: Measure errors
- **Backpropagation**: Compute how to improve
- **Gradient descent**: Actually improve

**Evaluation is Multi-Faceted**:
- **MSE**: Penalize large errors
- **MAE**: Average error magnitude
- **CCC**: True agreement measure

---

Now let's train the model and see these concepts in action! 🚀

In [None]:
# ============================================================================
# IMPORTS AND ENVIRONMENT SETUP
# ============================================================================

import os
import warnings
warnings.filterwarnings('ignore')

# Check Kaggle environment
print("Kaggle Environment Check:")
print(f"Is Kaggle: {os.path.exists('/kaggle')}")
print(f"GPU Available: {os.path.exists('/dev/nvidia0')}")
print()

# Core libraries
import numpy as np
import pandas as pd
import librosa
import librosa.display
import matplotlib.pyplot as plt

# PyTorch imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, SequentialSampler

# Evaluation metrics
from sklearn.metrics import mean_squared_error, mean_absolute_error

print("All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"Librosa version: {librosa.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 1. Configuration and Paths

**Important for Kaggle Users:**
- Make sure you've added the DEAM dataset to your notebook
- The dataset should be named: `deam-mediaeval-dataset-emotional-analysis-in-music`
- Go to "Add Data" → Search for "DEAM" → Add to notebook

In [None]:
# ============================================================================
# CONFIGURATION AND PATHS FOR KAGGLE
# ============================================================================

# Kaggle paths - using the main DEAM dataset
DATASET_PATH = '/kaggle/input/deam-mediaeval-dataset-emotional-analysis-in-music'
AUDIO_DIR = os.path.join(DATASET_PATH, 'DEAM_audio', 'MEMD_audio')
ANNOTATIONS_DIR = os.path.join(DATASET_PATH, 'DEAM_Annotations', 'annotations', 
                               'annotations averaged per song', 'song_level')

# Output paths (Kaggle working directory)
OUTPUT_DIR = '/kaggle/working'
MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, 'best_ast_model.pth')

# Annotation file paths (inside the main DEAM dataset)
STATIC_CSV_1_2000 = os.path.join(ANNOTATIONS_DIR, 'static_annotations_averaged_songs_1_2000.csv')
STATIC_CSV_2000_2058 = os.path.join(ANNOTATIONS_DIR, 'static_annotations_averaged_songs_2000_2058.csv')

# Hyperparameters
SEGMENT_LENGTH = 5      # seconds per audio segment
SAMPLE_RATE = 44100     # Hz
N_MELS = 128            # Mel frequency bins
TARGET_TIME_STEPS = 432 # Fixed width for padding (divisible by patch_size)
PATCH_SIZE = 16         # Patch size for transformer
EMBED_DIM = 256         # Embedding dimension
NUM_HEADS = 4           # Number of attention heads
NUM_LAYERS = 4          # Number of transformer layers
DROPOUT = 0.1
BATCH_SIZE = 16         # Increased for GPU
NUM_EPOCHS = 5          # Reduced for faster training (increase to 15-20 for better results)
LEARNING_RATE = 1e-4

# Split ratios
TRAIN_RATIO = 0.8
VAL_RATIO = 0.1
TEST_RATIO = 0.1

print("="*80)
print("AST Model Training - Kaggle Environment")
print("="*80)
print(f"Dataset Path: {DATASET_PATH}")
print(f"Audio Directory: {AUDIO_DIR}")
print(f"Annotations Directory: {ANNOTATIONS_DIR}")
print(f"Output Directory: {OUTPUT_DIR}")
print(f"Model Save Path: {MODEL_SAVE_PATH}")
print(f"\n⚙️  Training Configuration:")
print(f"  Epochs: {NUM_EPOCHS} (increase to 15-20 for better convergence)")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Learning Rate: {LEARNING_RATE}")
print("\nExpected Kaggle dataset structure:")
print("  /kaggle/input/deam-mediaeval-dataset-emotional-analysis-in-music/")
print("    ├── DEAM_audio/")
print("    │   └── MEMD_audio/  (audio files)")
print("    ├── DEAM_Annotations/")
print("    │   └── annotations/")
print("    │       └── annotations averaged per song/")
print("    │           └── song_level/  (CSV files)")
print("    └── features/")
print("="*80)

## 2. Data Loading and Exploration

In [None]:
# Verify paths exist
assert os.path.exists(AUDIO_DIR), f"Audio directory not found: {AUDIO_DIR}"
assert os.path.exists(STATIC_CSV_1_2000), f"Annotations file not found: {STATIC_CSV_1_2000}"
assert os.path.exists(STATIC_CSV_2000_2058), f"Annotations file not found: {STATIC_CSV_2000_2058}"

print("✓ All paths verified!")

# Load and combine annotation files
print("\nLoading annotation files...")
df1 = pd.read_csv(STATIC_CSV_1_2000)
df2 = pd.read_csv(STATIC_CSV_2000_2058)
df_annotations = pd.concat([df1, df2], axis=0, ignore_index=True)

# Strip whitespace from column names (important!)
df_annotations.columns = df_annotations.columns.str.strip()

print(f"Loaded {len(df_annotations)} song annotations")
print(f"Columns: {df_annotations.columns.tolist()}")
print(f"\nSample data:")
df_annotations.head()

In [None]:
# Verify audio files exist
print("Verifying audio files...")
available_songs = []
missing_songs = []

for _, row in df_annotations.iterrows():
    song_id = int(row['song_id'])
    audio_path = os.path.join(AUDIO_DIR, f"{song_id}.mp3")
    if os.path.exists(audio_path):
        available_songs.append(song_id)
    else:
        missing_songs.append(song_id)

print(f"Available audio files: {len(available_songs)}")
print(f"Missing audio files: {len(missing_songs)}")
if len(missing_songs) > 0 and len(missing_songs) < 20:
    print(f"Missing song IDs: {missing_songs}")

# Filter to only available songs
df_annotations = df_annotations[df_annotations['song_id'].isin(available_songs)].reset_index(drop=True)
print(f"Final dataset size: {len(df_annotations)} songs")

## 3. PyTorch Dataset Class

This custom dataset class:
1. Loads audio files and splits them into 5-second segments
2. Computes mel-spectrograms for each segment
3. Pads/truncates to fixed width (432 time steps)
4. Optionally applies SpecAugment (data augmentation)
5. Returns (spectrogram, [valence, arousal]) pairs

In [None]:
class DEAMSpectrogramDataset(Dataset):
    """
    PyTorch Dataset for DEAM audio with mel-spectrogram features.
    """
    def __init__(self, df_annotations, audio_dir, segment_length=5, sr=44100, 
                 target_time_steps=432, use_aug=False):
        self.df = df_annotations.reset_index(drop=True)
        self.audio_dir = audio_dir
        self.segment_length = segment_length
        self.sr = sr
        self.segment_samples = segment_length * sr
        self.target_time_steps = target_time_steps
        self.use_aug = use_aug
        
        # Pre-compute all valid segment indices
        self.segment_indices = []
        
        print(f"Initializing dataset (aug={'ON' if use_aug else 'OFF'})...")
        for song_idx in range(len(self.df)):
            song_id = int(self.df.iloc[song_idx]['song_id'])
            audio_path = os.path.join(self.audio_dir, f"{song_id}.mp3")
            
            try:
                y, _ = librosa.load(audio_path, sr=self.sr, mono=True)
                num_segments = len(y) // self.segment_samples
                
                for seg_idx in range(num_segments):
                    self.segment_indices.append((song_idx, seg_idx))
                    
            except Exception as e:
                print(f"Warning: Error loading song {song_id}: {e}")
                continue
        
        print(f"Dataset ready: {len(self.segment_indices)} segments")
    
    def __len__(self):
        return len(self.segment_indices)
    
    def __getitem__(self, idx):
        song_idx, seg_idx = self.segment_indices[idx]
        song_id = int(self.df.iloc[song_idx]['song_id'])
        
        # Load audio
        audio_path = os.path.join(self.audio_dir, f"{song_id}.mp3")
        y, _ = librosa.load(audio_path, sr=self.sr, mono=True)
        
        # Extract segment
        start = seg_idx * self.segment_samples
        segment = y[start:start + self.segment_samples]
        
        # Compute mel-spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=segment, sr=self.sr, n_mels=N_MELS, n_fft=2048, hop_length=512
        )
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Pad or truncate to target width
        if mel_spec_db.shape[1] < self.target_time_steps:
            pad_width = self.target_time_steps - mel_spec_db.shape[1]
            mel_spec_db = np.pad(mel_spec_db, ((0, 0), (0, pad_width)), mode='constant')
        elif mel_spec_db.shape[1] > self.target_time_steps:
            mel_spec_db = mel_spec_db[:, :self.target_time_steps]
        
        # Convert to tensor
        mel_spec_db = torch.tensor(mel_spec_db, dtype=torch.float32).unsqueeze(0)
        
        # SpecAugment (optional)
        if self.use_aug:
            freq_mask_param = 10
            if N_MELS > freq_mask_param:
                freq_mask_start = torch.randint(0, N_MELS - freq_mask_param, (1,)).item()
                mel_spec_db[:, freq_mask_start:freq_mask_start + freq_mask_param, :] = 0
            
            time_mask_param = 20
            if self.target_time_steps > time_mask_param:
                time_mask_start = torch.randint(0, self.target_time_steps - time_mask_param, (1,)).item()
                mel_spec_db[:, :, time_mask_start:time_mask_start + time_mask_param] = 0
        
        # Get labels
        label = torch.tensor([
            self.df.iloc[song_idx]['valence_mean'], 
            self.df.iloc[song_idx]['arousal_mean']
        ], dtype=torch.float32)
        
        return mel_spec_db, label

print("✓ Dataset class defined")

## 4. Data Preparation and Splitting

In [None]:
# Create initial dataset (no augmentation for computing statistics)
full_dataset = DEAMSpectrogramDataset(
    df_annotations, AUDIO_DIR, 
    target_time_steps=TARGET_TIME_STEPS, 
    use_aug=False
)

# Split into train/val/test
train_size = int(TRAIN_RATIO * len(full_dataset))
val_size = int(VAL_RATIO * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

print(f"\nDataset splits:")
print(f"  Train: {train_size} ({100*TRAIN_RATIO:.0f}%)")
print(f"  Val:   {val_size} ({100*VAL_RATIO:.0f}%)")
print(f"  Test:  {test_size} ({100*TEST_RATIO:.0f}%)")

# Random split
train_dataset_temp, val_dataset_temp, test_dataset_temp = random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

In [None]:
# Compute normalization statistics from training set
print("Computing normalization statistics...")
train_specs_list = []
max_samples = min(1000, len(train_dataset_temp))

for idx in list(train_dataset_temp.indices)[:max_samples]:
    spec, _ = full_dataset[idx]
    train_specs_list.append(spec)

train_specs = torch.cat(train_specs_list, dim=0)
global_mean = train_specs.mean()
global_std = train_specs.std()

print(f"Global mean: {global_mean:.4f}")
print(f"Global std:  {global_std:.4f}")

In [None]:
# Create final datasets with augmentation
train_indices = train_dataset_temp.indices
val_indices = val_dataset_temp.indices
test_indices = test_dataset_temp.indices

train_dataset = DEAMSpectrogramDataset(
    df_annotations, AUDIO_DIR, 
    target_time_steps=TARGET_TIME_STEPS, 
    use_aug=True  # Enable augmentation for training
)

val_dataset = DEAMSpectrogramDataset(
    df_annotations, AUDIO_DIR, 
    target_time_steps=TARGET_TIME_STEPS, 
    use_aug=False
)

test_dataset = DEAMSpectrogramDataset(
    df_annotations, AUDIO_DIR, 
    target_time_steps=TARGET_TIME_STEPS, 
    use_aug=False
)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, 
    sampler=SubsetRandomSampler(train_indices),
    num_workers=2, pin_memory=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, 
    sampler=SequentialSampler(val_indices),
    num_workers=2, pin_memory=True
)

test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, 
    sampler=SequentialSampler(test_indices),
    num_workers=2, pin_memory=True
)

print("✓ DataLoaders created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches:   {len(val_loader)}")
print(f"  Test batches:  {len(test_loader)}")

## 5. Model Architecture (Audio Spectrogram Transformer)

The AST model architecture:
- **Patch Embedding**: Splits spectrogram into 16×16 patches using Conv2D
- **Positional Encoding**: Learnable position embeddings
- **CLS Token**: Special token for classification (like BERT/ViT)
- **Transformer Encoder**: Multi-head self-attention layers
- **Regression Head**: MLP to predict valence and arousal

In [None]:
class SpectrogramTransformer(nn.Module):
    """Audio Spectrogram Transformer for emotion recognition."""
    
    def __init__(self, input_height=128, input_width=432, patch_size=16, 
                 embed_dim=256, num_heads=4, num_layers=4, dropout=0.1):
        super().__init__()
        
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        
        # Calculate number of patches
        self.num_patches_h = input_height // patch_size  # 8
        self.num_patches_w = input_width // patch_size   # 27
        num_patches = self.num_patches_h * self.num_patches_w  # 216
        
        # Patch embedding
        self.patch_embed = nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size)
        
        # Learnable CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Positional embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, embed_dim))
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=dropout, activation='gelu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Regression head
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, 2)  # valence, arousal
        )
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.kaiming_normal_(self.patch_embed.weight, mode='fan_out', nonlinearity='relu')
        if self.patch_embed.bias is not None:
            nn.init.constant_(self.patch_embed.bias, 0)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        
        # Prepend CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Add positional embeddings
        x = x + self.pos_embed
        
        # Transformer encoding
        x = self.transformer_encoder(x)
        
        # Extract CLS token
        cls_output = x[:, 0]
        
        # Regression head
        predictions = self.head(cls_output)
        
        return predictions

print("✓ Model class defined")

In [None]:
# Initialize model
model = SpectrogramTransformer(
    input_height=N_MELS,
    input_width=TARGET_TIME_STEPS,
    patch_size=PATCH_SIZE,
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT
)

# Move to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print(f"Device: {device}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 6. Training Setup and Metrics

In [None]:
# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

def concordance_correlation_coefficient(y_true, y_pred):
    """Concordance Correlation Coefficient for evaluating agreement."""
    mean_true = np.mean(y_true)
    mean_pred = np.mean(y_pred)
    var_true = np.var(y_true)
    var_pred = np.var(y_pred)
    
    covariance = np.cov(y_true, y_pred)[0, 1] if len(y_true) > 1 else 0
    denominator = np.sqrt(var_true * var_pred)
    rho = covariance / denominator if denominator > 0 else 0
    
    numerator = 2 * rho * np.sqrt(var_true) * np.sqrt(var_pred)
    denominator = var_true + var_pred + (mean_true - mean_pred) ** 2
    ccc = numerator / denominator if denominator > 0 else 0
    
    return ccc

print("✓ Training setup complete")
print(f"Loss: MSELoss")
print(f"Optimizer: AdamW (lr={LEARNING_RATE})")
print(f"Scheduler: CosineAnnealingLR")

## 7. Training Loop

This will take some time depending on your GPU. On Kaggle with GPU enabled, expect ~30-60 minutes for 15 epochs.

In [None]:
best_val_loss = float('inf')
train_losses = []
val_losses = []

print("="*80)
print("STARTING TRAINING")
print("="*80)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    print("-" * 40)
    
    # Training
    model.train()
    train_loss = 0
    train_batches = 0
    
    for batch_idx, (specs, labels) in enumerate(train_loader):
        # Normalize
        specs = (specs - global_mean) / global_std
        specs, labels = specs.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(specs)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        train_loss += loss.item()
        train_batches += 1
        
        if (batch_idx + 1) % 50 == 0:
            print(f"  Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")
    
    train_loss /= train_batches
    train_losses.append(train_loss)
    
    scheduler.step()
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Train Loss: {train_loss:.4f}, LR: {current_lr:.6f}")
    
    # Validation
    model.eval()
    val_loss = 0
    val_batches = 0
    val_preds = []
    val_true = []
    
    with torch.no_grad():
        for specs, labels in val_loader:
            specs = (specs - global_mean) / global_std
            specs, labels = specs.to(device), labels.to(device)
            
            outputs = model(specs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            val_batches += 1
            
            val_preds.append(outputs.cpu().numpy())
            val_true.append(labels.cpu().numpy())
    
    val_loss /= val_batches
    val_losses.append(val_loss)
    
    # Metrics
    val_preds = np.vstack(val_preds)
    val_true = np.vstack(val_true)
    
    mse_valence = mean_squared_error(val_true[:, 0], val_preds[:, 0])
    mse_arousal = mean_squared_error(val_true[:, 1], val_preds[:, 1])
    ccc_valence = concordance_correlation_coefficient(val_true[:, 0], val_preds[:, 0])
    ccc_arousal = concordance_correlation_coefficient(val_true[:, 1], val_preds[:, 1])
    
    print(f"Val Loss: {val_loss:.4f}")
    print(f"  Valence - MSE: {mse_valence:.4f}, CCC: {ccc_valence:.4f}")
    print(f"  Arousal - MSE: {mse_arousal:.4f}, CCC: {ccc_arousal:.4f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'global_mean': global_mean,
            'global_std': global_std,
        }, MODEL_SAVE_PATH)
        print(f"  ✓ New best model saved! (Val Loss: {val_loss:.4f})")

print("\n" + "="*80)
print("TRAINING COMPLETE")
print("="*80)
print(f"Best validation loss: {best_val_loss:.4f}")

## 8. Test Set Evaluation

In [None]:
# Load best model
checkpoint = torch.load(MODEL_SAVE_PATH)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

test_preds = []
test_true = []

print("Evaluating on test set...")
with torch.no_grad():
    for specs, labels in test_loader:
        specs = (specs - global_mean) / global_std
        specs, labels = specs.to(device), labels.to(device)
        
        outputs = model(specs)
        test_preds.append(outputs.cpu().numpy())
        test_true.append(labels.cpu().numpy())

test_preds = np.vstack(test_preds)
test_true = np.vstack(test_true)

# Compute metrics
test_mse_valence = mean_squared_error(test_true[:, 0], test_preds[:, 0])
test_mse_arousal = mean_squared_error(test_true[:, 1], test_preds[:, 1])
test_mae_valence = mean_absolute_error(test_true[:, 0], test_preds[:, 0])
test_mae_arousal = mean_absolute_error(test_true[:, 1], test_preds[:, 1])
test_ccc_valence = concordance_correlation_coefficient(test_true[:, 0], test_preds[:, 0])
test_ccc_arousal = concordance_correlation_coefficient(test_true[:, 1], test_preds[:, 1])

print("\n" + "="*80)
print("FINAL TEST RESULTS")
print("="*80)
print(f"Valence - MSE: {test_mse_valence:.4f}, MAE: {test_mae_valence:.4f}, CCC: {test_ccc_valence:.4f}")
print(f"Arousal - MSE: {test_mse_arousal:.4f}, MAE: {test_mae_arousal:.4f}, CCC: {test_ccc_arousal:.4f}")
print("="*80)

## 9. Visualizations

In [None]:
# Plot 1: Training curves
plt.figure(figsize=(14, 5))

plt.subplot(1, 2, 1)
plt.plot(range(1, NUM_EPOCHS + 1), train_losses, 'b-', label='Train Loss', linewidth=2)
plt.plot(range(1, NUM_EPOCHS + 1), val_losses, 'r-', label='Val Loss', linewidth=2)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('MSE Loss', fontsize=12)
plt.title('Training and Validation Loss', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)

# Plot 2: Predictions vs Ground Truth
plt.subplot(1, 2, 2)
plt.scatter(test_true[:, 0], test_preds[:, 0], alpha=0.5, label='Valence', s=20, c='blue')
plt.scatter(test_true[:, 1], test_preds[:, 1], alpha=0.5, label='Arousal', s=20, c='red')
plt.plot([1, 9], [1, 9], 'k--', label='Perfect Prediction', linewidth=2)
plt.xlabel('Ground Truth', fontsize=12)
plt.ylabel('Prediction', fontsize=12)
plt.title('Test Set Predictions', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.xlim(1, 9)
plt.ylim(1, 9)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'training_results.png'), dpi=150, bbox_inches='tight')
plt.show()

print("✓ Training curves saved")

In [None]:
# Plot: Valence-Arousal Distribution
plt.figure(figsize=(10, 10))
plt.scatter(test_true[:, 0], test_true[:, 1], alpha=0.6, c='blue', 
            label='Ground Truth', s=40, edgecolors='black', linewidth=0.5)
plt.scatter(test_preds[:, 0], test_preds[:, 1], alpha=0.6, c='red', 
            label='Predictions', s=40, edgecolors='black', linewidth=0.5, marker='^')
plt.xlabel('Valence (1-9)', fontsize=14)
plt.ylabel('Arousal (1-9)', fontsize=14)
plt.title('Valence-Arousal Distribution (Test Set)', fontsize=16, fontweight='bold')
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.xlim(1, 9)
plt.ylim(1, 9)

# Add quadrant labels
plt.axhline(y=5, color='gray', linestyle='--', alpha=0.5)
plt.axvline(x=5, color='gray', linestyle='--', alpha=0.5)
plt.text(2.5, 7.5, 'Distressed\n(Low V, High A)', ha='center', fontsize=10, alpha=0.7)
plt.text(6.5, 7.5, 'Excited\n(High V, High A)', ha='center', fontsize=10, alpha=0.7)
plt.text(2.5, 2.5, 'Sad\n(Low V, Low A)', ha='center', fontsize=10, alpha=0.7)
plt.text(6.5, 2.5, 'Calm\n(High V, Low A)', ha='center', fontsize=10, alpha=0.7)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'valence_arousal_distribution.png'), 
            dpi=150, bbox_inches='tight')
plt.show()

print("✓ Valence-Arousal plot saved")

## 10. Summary and Next Steps

### What We Accomplished
✅ Loaded and preprocessed the DEAM dataset  
✅ Created mel-spectrogram features from audio  
✅ Built an Audio Spectrogram Transformer model  
✅ Trained the model with data augmentation  
✅ Evaluated on test set with multiple metrics  
✅ Visualized predictions and training progress  

### Model Performance
The model predicts **valence** (happiness) and **arousal** (energy) on a 1-9 scale.

**Key Metrics:**
- **MSE (Mean Squared Error)**: Lower is better
- **MAE (Mean Absolute Error)**: Average prediction error
- **CCC (Concordance Correlation Coefficient)**: Agreement measure (-1 to 1, higher is better)

### Saved Outputs
All outputs are saved to `/kaggle/working/`:
- `best_ast_model.pth` - Trained model checkpoint
- `training_results.png` - Training curves
- `valence_arousal_distribution.png` - Prediction visualization

### Next Steps
1. **Fine-tune hyperparameters**: Try different learning rates, batch sizes, or model architectures
2. **Increase training epochs**: Train longer for potentially better performance
3. **Try transfer learning**: Use pre-trained audio models
4. **Ensemble models**: Combine multiple models for better predictions
5. **Deploy the model**: Use the saved checkpoint for inference on new audio files

### How to Use the Saved Model

```python
# Load the model
checkpoint = torch.load('/kaggle/working/best_ast_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Inference on new audio
# ... (compute mel-spectrogram, normalize, predict)
```