In [6]:
# ====================================================================
# 1. IMPORTS AND INITIAL CONFIGURATION (UPDATED)
# ====================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt
import time 
from IPython.display import Audio, display
from ipywidgets import interactive
from pathlib import Path
import warnings
import timm # Added for StudentModel

# Suppress warnings
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

# Configuration of Paths and Variables
# ----------------------------------
RESULTS_DIR = Path("results") # Phase 1 Folder
# If you want to use the models from results_v2, change the path to: RESULTS_DIR = Path("results_v2")

MODELS = {
    # We will use the Phase 1 winner
    "ü•á Winning Student (RepVGG KD)": "best_distilled_RepVGG.pth",
    "üöÄ Fastest Student (VanillaCNN KD)": "best_distilled_VanillaCNN.pth",
    "üìâ Baseline Without KD (RepVGG)": "best_baseline_RepVGG.pth",
}

# Test CSV File (IEMOCAP Session 5)
TEST_CSV_PATH = Path("processed_combined_teacher/combined_test.csv") 

# Preprocessing Parameters (Must match your training script)
SAMPLE_RATE = 16000
N_MELS = 64
MAX_DURATION_S = 8.0 # NOTE: If you used 3.0s in training, ADJUST HERE
N_FFT = 1024
HOP_LENGTH = 512

# Emotion Mapping
EMOTION_LABELS = {0: 'Anger', 1: 'Happy', 2: 'Sadness', 3: 'Neutral'}

COLOR_MAP = {
    'anger': '#e74c3c', # Red
    'happy': '#2ecc71', # Green
    'sad': '#3498db', # Blue
    'neutral': '#7f8c8d'  # Gray
}
# Update keys to match the full English labels (Anger, Happy, Sadness, Neutral)
# Assuming your CSV uses 'ang', 'hap', 'sad', 'neu'
COLOR_MAP_KEYS = {'ang': 'anger', 'hap': 'happy', 'sad': 'sad', 'neu': 'neutral'}


# ------------------------------------------------
# 2. CLASS DEFINITIONS (COPIED FROM YOUR TRAINING SCRIPT)
# ------------------------------------------------

class VanillaCNN(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.classifier = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = x.flatten(1)
        return self.classifier(x)

class StudentModel(nn.Module):
    def __init__(self, model_name, num_classes=4):
        super().__init__()
        if model_name == 'VanillaCNN':
            self.backbone = VanillaCNN(num_classes)
        else:
            try:
                # timm will load RepVGG, MobileOne, GhostNetV2
                # pretrained=False because we only load the weights later
                self.backbone = timm.create_model(model_name, pretrained=False, num_classes=num_classes, in_chans=1)
            except:
                self.backbone = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
            
    def forward(self, x): 
        return self.backbone(x)

# ------------------------------------------------
# 3. DATA LOADING AND PREPROCESSING
# ------------------------------------------------
try:
    df_test = pd.read_csv(TEST_CSV_PATH)
    print(f"‚úÖ Test Dataset Loaded: {len(df_test)} samples.")
except FileNotFoundError:
    print(f"‚ùå ERROR: CSV file not found at {TEST_CSV_PATH}. Adjust the path.")

def preprocess_audio(file_path):
    """Loads audio, generates Mel-Spectrogram, and converts it to a tensor."""
    y, sr = librosa.load(file_path, sr=SAMPLE_RATE)
    
    # Padding/Truncating
    target_length = int(MAX_DURATION_S * SAMPLE_RATE)
    if len(y) > target_length: y = y[:target_length]
    else: y = np.pad(y, (0, target_length - len(y)), mode='constant')
        
    mel_spectrogram = librosa.feature.melspectrogram(
        y=y, sr=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS
    )
    log_mel_spectrogram = librosa.power_to_db(mel_spectrogram, ref=np.max)
    
    # Normalization (IMPORTANT: Must match your training normalization)
    if log_mel_spectrogram.std() > 1e-6:
        log_mel_spectrogram = (log_mel_spectrogram - log_mel_spectrogram.mean()) / log_mel_spectrogram.std()
    
    # PyTorch expects (B, C, H, W) -> (1, 1, N_MELS, T)
    tensor = torch.tensor(log_mel_spectrogram).float().unsqueeze(0).unsqueeze(0)
    return tensor

def load_and_evaluate_model(model_name, input_tensor):
    """Loads a .pth model and performs inference."""
    model_path = RESULTS_DIR / MODELS[model_name]
    
    if not model_path.exists():
        return f"‚ùå Model not found: {model_path}", None
    
    # Select the correct base architecture name for the StudentModel Factory
    if "VanillaCNN" in model_name:
        arch_name = 'VanillaCNN'
    elif "RepVGG" in model_name:
        arch_name = 'repvgg_a0'
    else:
        return "‚ùå Unknown architecture.", None
    
    # Initialize the model using your StudentModel Factory
    model = StudentModel(arch_name, num_classes=4) 

    # Load weights
    try:
        # Load state_dict to CPU (map_location='cpu')
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        model.eval()
    except Exception as e:
        return f"‚ùå Error loading weights: {e}", None

    # Inference and Latency Measurement (Crucial for the Trade-off)
    start_time = time.time()
    with torch.no_grad():
        output = model(input_tensor)
    end_time = time.time()
    
    latency_ms = (end_time - start_time) * 1000
    
    # Get probabilities and prediction
    probabilities = F.softmax(output, dim=1).squeeze().tolist()
    pred_index = torch.argmax(output, dim=1).item()
    
    # Format the result
    results = {
        "prediction": EMOTION_LABELS.get(pred_index, "N/A"),
        "latency": f"{latency_ms:.2f} ms",
        "probabilities": {EMOTION_LABELS[i]: p * 100 for i, p in enumerate(probabilities)}
    }
    return None, results

def visualize_and_predict(row_index):
    """Main function connected to the interactive widget."""
    if not df_test.empty:
        sample = df_test.iloc[row_index]
        file_path = sample['wav_path']
        true_emotion_code = sample['emotion'].lower()
        true_emotion_label = EMOTION_LABELS.get(EMOTION_MAP_REVERSE.get(true_emotion_code, 3), 'N/A')
        
        # 1. Preprocessing and Audio
        try:
            input_tensor = preprocess_audio(file_path)
        except Exception as e:
            print(f"‚ùå Preprocessing/Load Error: {e}")
            return
            
        # 2. Spectrogram Visualization
        fig, ax = plt.subplots(figsize=(10, 5))
        log_mel_spectrogram = input_tensor.squeeze().numpy()
        
        # üñºÔ∏è CORRECTION: Capture the image object ('img')
        img = librosa.display.specshow(
            log_mel_spectrogram, 
            sr=SAMPLE_RATE, 
            x_axis='time', 
            y_axis='mel', 
            ax=ax, 
            hop_length=HOP_LENGTH
        )
        
        # Get color based on true emotion code
        title_color = COLOR_MAP.get(COLOR_MAP_KEYS.get(true_emotion_code, 'black'), 'black')
        
        ax.set_title(f"Spectrogram - TRUE Emotion: {true_emotion_label.upper()}", 
                     color=title_color, 
                     fontsize=16, 
                     fontweight='bold')
        
        # Pass the 'img' object to the colorbar
        fig.colorbar(img, ax=ax, format='%+2.0f dB') 
        
        plt.tight_layout()
        plt.show()

        # 3. Play Audio
        print("\nüîä Audio (Truncated to 8 seconds):")
        display(Audio(file_path, rate=SAMPLE_RATE))
        
        # 4. Inferences and Comparison
        print("\n--- üß† Classification Results ---")
        
        model_results = {}
        for model_name in MODELS:
            error, result = load_and_evaluate_model(model_name, input_tensor)
            model_results[model_name] = {"error": error, "result": result}
            
        # 5. Tabulation of Results
        comparison_data = []
        
        for model_name, data in model_results.items():
            if data['result']:
                pred_label_full = data['result']['prediction']
                latency = data['result']['latency']
                # Check if the predicted label starts with the true emotion (e.g., 'Anger' vs 'Anger (Enojo)')
                status = "‚úÖ Correct" if pred_label_full.lower().startswith(true_emotion_label.lower()) else "‚ùå Incorrect"
                
                comparison_data.append({
                    "Model": model_name,
                    "Prediction": pred_label_full,
                    "Latency (GPU)": latency,
                    "Status": status
                })
            elif data['error']:
                 comparison_data.append({
                    "Model": model_name,
                    "Prediction": data['error'],
                    "Latency (GPU)": "N/A",
                    "Status": "ERROR"
                })

        df_comparison = pd.DataFrame(comparison_data)
        display(df_comparison)
        
        # 6. Probability Plot of the Winning Model
        print(f"\nüìà Probability Distribution: {list(MODELS.keys())[0]}")
        
        winner_probs_data = model_results[list(MODELS.keys())[0]]
        
        if winner_probs_data['result']:
            winner_probs = winner_probs_data['result']['probabilities']
            emotions = list(winner_probs.keys())
            probabilities = list(winner_probs.values())
            
            plt.figure(figsize=(8, 4))
            # Get colors for the bar chart
            colors = [COLOR_MAP.get(e.lower().split()[0], 'gray') for e in emotions]
            
            bars = plt.bar(emotions, probabilities, color=colors)
            
            # Add value labels
            for bar in bars:
                plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 1, f'{bar.get_height():.1f}%', ha='center', fontsize=10)
                
            plt.title(f"Probabilities (RepVGG Distilled)")
            plt.ylabel("Probability (%)")
            plt.ylim(0, 100)
            plt.axhline(y=25, color='gray', linestyle='--') # Random chance line (1/4)
            plt.tight_layout()
            plt.show()
        else:
             print("Could not generate probability plot.")
    else:
        print("The test DataFrame is empty.")

# ------------------------------------------------
# 4. INTERACTIVE INTERFACE
# ------------------------------------------------

# Create the widget (selector)
sample_slider = interactive(
    visualize_and_predict,
    row_index=(0, len(df_test) - 1, 1), 
    continuous_update=False
)

print("\n\n=============== üöÄ INTERACTIVE TRADE-OFF ANALYSIS (UAR vs. Latency) ===============\n")
print(f"Select an index (from 0 to {len(df_test)-1}) from the test set (IEMOCAP Ses. 5) to: \n1. View the Spectrogram \n2. Hear the Audio\n3. Compare the Prediction and Latency of your key models.")

# Display the widget
display(sample_slider)

‚úÖ Test Dataset Loaded: 1241 samples.



Select an index (from 0 to 1240) from the test set (IEMOCAP Ses. 5) to: 
1. View the Spectrogram 
2. Hear the Audio
3. Compare the Prediction and Latency of your key models.


interactive(children=(IntSlider(value=620, description='row_index', max=1240), Output()), _dom_classes=('widge‚Ä¶