# Interactive Model Testing

This notebook provides an interactive interface for testing trained speech separation models with dropdown widgets.

## Features

- Auto-discovers available checkpoints from the hierarchical structure
- Dropdown selectors for model, task, variant, and sample
- Single-button model loading and testing
- Audio playback and waveform visualization

In [None]:
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import Audio, display
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
import yaml
import sys

# Add project root to path
sys.path.insert(0, str(Path.cwd()))

from config import Config
from models import get_model
from datasets import get_dataset
from torchmetrics.audio import ScaleInvariantSignalDistortionRatio
from evaluate import load_model_from_checkpoint

## 1. Checkpoint Discovery

Scan the checkpoint directory structure to find all available models.

In [None]:
def discover_checkpoints(checkpoint_root='checkpoints'):
    """Discover all available checkpoints in hierarchical structure.
    
    Structure: checkpoints/{model_name}/{task}/{run_name}/*.pt
    
    Returns dict: {display_name: checkpoint_path}
    """
    checkpoint_root = Path(checkpoint_root)
    checkpoints = {}
    
    if not checkpoint_root.exists():
        print(f"Warning: Checkpoint directory {checkpoint_root} does not exist")
        return checkpoints
    
    # Scan for all .pt files
    for pt_file in checkpoint_root.glob('**/*.pt'):
        # Get relative path parts
        path_parts = pt_file.relative_to(checkpoint_root).parts
        
        # Expected structure: model_name/task/run_name/checkpoint.pt
        if len(path_parts) >= 4:
            model_name = path_parts[0]
            task = path_parts[1]
            run_name = path_parts[2]
            checkpoint_filename = path_parts[3]
            
            # Create display name with checkpoint filename
            display_name = f"{model_name}/{task}/{run_name}/{checkpoint_filename}"
            checkpoints[display_name] = str(pt_file)
        # Also handle: model_name/task/checkpoint.pt (if no run_name folder)
        elif len(path_parts) == 3:
            model_name = path_parts[0]
            task = path_parts[1]
            checkpoint_filename = path_parts[2]
            
            display_name = f"{model_name}/{task}/{checkpoint_filename}"
            checkpoints[display_name] = str(pt_file)
    
    return checkpoints

# Discover available checkpoints
available_checkpoints = discover_checkpoints()

if available_checkpoints:
    print(f"Found {len(available_checkpoints)} checkpoint(s):")
    for name in sorted(available_checkpoints.keys()):
        print(f"  - {name}")
else:
    print("No checkpoints found. Train a model first!")

## 2. Configuration

Set up dataset paths and parameters.

In [None]:
# Dataset configuration
# Try to get data root from config.py defaults, or use manual override
try:
    from config import DataConfig
    config_defaults = DataConfig()
    DATA_ROOT = Path(config_defaults.polsess.data_root)
    print(f"Using dataset path from config.py: {DATA_ROOT}")
except Exception as e:
    # Fallback: User can manually set the path here
    DATA_ROOT = Path("F:/PolSMSE/EksperymentyMOWA/BAZY/MOWA/PolSESS_C_both/PolSESS_C_both")
    print(f"Using manual dataset path: {DATA_ROOT}")
    print(f"(Could not load from config: {e})")

SAMPLE_RATE = 8000

# Available variants for PolSESS
VARIANTS = {
    'C': 'Clean (no background)',
    'S': 'Scene only',
    'E': 'Event only',
    'R': 'Reverb only',
    'SE': 'Scene + Event',
    'SR': 'Scene + Reverb',
    'ER': 'Event + Reverb',
    'SER': 'Scene + Event + Reverb'
}

# Tasks
TASKS = {
    'ES': 'Enhance Single speaker',
    'EB': 'Enhance Both speakers',
    'SB': 'Separate Both speakers'
}

print("Configuration loaded.")
print(f"Sample rate: {SAMPLE_RATE} Hz")

# Verify dataset path exists
if not DATA_ROOT.exists():
    print(f"\n⚠️  WARNING: Dataset path does not exist: {DATA_ROOT}")
    print("Please update DATA_ROOT in this cell to point to your PolSESS dataset")
else:
    print(f"✓ Dataset path verified: {DATA_ROOT}")

## 3. Interactive Model Testing

Use dropdown widgets to select model, task, variant, and sample, then test the model with a single button click.

In [None]:
class ModelTester:
    """Interactive model testing with dropdown widgets."""
    
    def __init__(self, checkpoints, data_root, sample_rate=16000):
        self.checkpoints = checkpoints
        self.data_root = Path(data_root)
        self.sample_rate = sample_rate
        self.model = None
        self.config = None
        self.dataset = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Create widgets
        self.checkpoint_dropdown = widgets.Dropdown(
            options=sorted(checkpoints.keys()),
            description='Checkpoint:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='600px')
        )
        
        self.task_dropdown = widgets.Dropdown(
            options=[(f"{k}: {v}", k) for k, v in TASKS.items()],
            description='Task:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='400px')
        )
        
        self.variant_dropdown = widgets.Dropdown(
            options=[(f"{k}: {v}", k) for k, v in VARIANTS.items()],
            description='Variant:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='400px')
        )
        
        self.sample_id_text = widgets.IntText(
            value=0,
            description='Sample ID:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='200px')
        )
        
        self.show_plots_checkbox = widgets.Checkbox(
            value=False,
            description='Show waveform plots',
            style={'description_width': '1px'},
            layout=widgets.Layout(width='200px')
        )
        
        self.load_button = widgets.Button(
            description='Load Model',
            button_style='info',
            icon='download',
            layout=widgets.Layout(width='150px')
        )
        
        self.test_button = widgets.Button(
            description='Test Model',
            button_style='success',
            icon='play',
            layout=widgets.Layout(width='150px')
        )
        
        self.output = widgets.Output()
        
        # Button callbacks
        self.load_button.on_click(self._on_load_clicked)
        self.test_button.on_click(self._on_test_clicked)
        
        # Initially disable test button
        self.test_button.disabled = True
    
    def _compute_pit_siSDR(self, estimate, clean, si_sdr_metric):
        """Compute permutation-invariant SI-SDR for speaker separation.
        
        Args:
            estimate: [1, 2, T] - estimated speaker signals
            clean: [2, T] - ground truth speaker signals
            si_sdr_metric: SI-SDR metric function
            
        Returns:
            best_sisdr_avg: Average SI-SDR for best permutation
            best_sisdr1: SI-SDR for first speaker in best permutation
            best_sisdr2: SI-SDR for second speaker in best permutation
            best_perm: Best permutation index (0 or 1)
        """
        # Permutation 1: est[0]->clean[0], est[1]->clean[1]
        sisdr_perm1_spk1 = si_sdr_metric(estimate[:, 0:1, :], clean[0:1].unsqueeze(0))
        sisdr_perm1_spk2 = si_sdr_metric(estimate[:, 1:2, :], clean[1:2].unsqueeze(0))
        sisdr_perm1_avg = (sisdr_perm1_spk1 + sisdr_perm1_spk2) / 2
        
        # Permutation 2: est[0]->clean[1], est[1]->clean[0]
        sisdr_perm2_spk1 = si_sdr_metric(estimate[:, 0:1, :], clean[1:2].unsqueeze(0))
        sisdr_perm2_spk2 = si_sdr_metric(estimate[:, 1:2, :], clean[0:1].unsqueeze(0))
        sisdr_perm2_avg = (sisdr_perm2_spk1 + sisdr_perm2_spk2) / 2
        
        # Choose best permutation
        if sisdr_perm1_avg >= sisdr_perm2_avg:
            return sisdr_perm1_avg, sisdr_perm1_spk1, sisdr_perm1_spk2, 0
        else:
            return sisdr_perm2_avg, sisdr_perm2_spk1, sisdr_perm2_spk2, 1
    
    def _on_load_clicked(self, b):
        """Load selected model checkpoint."""
        with self.output:
            self.output.clear_output()
            print("Loading model...")
            
            try:
                checkpoint_path = self.checkpoints[self.checkpoint_dropdown.value]
                
                # Use the new evaluation loading helper
                self.model = load_model_from_checkpoint(
                    checkpoint_path, 
                    config=None,  # Loads config from checkpoint
                    device=self.device
                )
                
                # Load checkpoint again to get metadata (model is already loaded)
                checkpoint = torch.load(checkpoint_path, map_location=self.device)
                
                # Extract config for internal use
                config_dict = checkpoint.get('config', {})
                config_yaml_path = Path(checkpoint_path).parent / 'config.yaml'
                if config_yaml_path.exists():
                    with open(config_yaml_path, 'r') as f:
                        config_dict = yaml.safe_load(f)
                
                from types import SimpleNamespace
                
                def dict_to_namespace(d):
                    if isinstance(d, dict):
                        return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
                    return d
                
                self.config = dict_to_namespace(config_dict)
                
                # Get model type for display
                model_type = self.config.model.model_type
                
                print(f"✓ Model loaded: {model_type}")
                print(f"  Checkpoint: {checkpoint_path}")
                print(f"  Epoch: {checkpoint.get('epoch', 'N/A')}")
                print(f"  Val SI-SDR: {checkpoint.get('val_sisdr', 'N/A'):.2f} dB")
                print(f"  Device: {self.device}")
                
                # Enable test button
                self.test_button.disabled = False
                
            except Exception as e:
                print(f"✗ Error loading model: {e}")
                import traceback
                traceback.print_exc()
                self.test_button.disabled = True
    
    def _on_test_clicked(self, b):
        """Test model on selected sample."""
        with self.output:
            self.output.clear_output(wait=True)
            print("Testing model...")
            
            try:
                task = self.task_dropdown.value
                variant = self.variant_dropdown.value
                sample_id = self.sample_id_text.value
                
                # Create dataset
                DatasetClass = get_dataset('polsess')
                dataset = DatasetClass(
                    data_root=str(self.data_root),
                    subset='test',
                    task=task,
                    allowed_variants=[variant]
                )
                
                if sample_id >= len(dataset):
                    print(f"✗ Error: Sample ID {sample_id} out of range (max: {len(dataset)-1})")
                    return
                
                # Get sample
                sample = dataset[sample_id]
                mix = sample['mix'].unsqueeze(0).to(self.device)  # [1, T]
                clean = sample['clean']  # [C, T] where C is num speakers (1 or 2)

                # Get the mix file path from dataset metadata
                mix_file_name = dataset.metadata.iloc[sample_id]['mixFile']
                mix_file_path = dataset.data_root / 'test' / 'mix' / mix_file_name

                print(f"\nMix file: {mix_file_name}")
                print(f"Full path: {mix_file_path}")
                
                # Handle different task types
                is_separation = task == 'SB'  # Separate Both speakers
                
                # Run inference
                with torch.no_grad():
                    mix_input = mix.unsqueeze(1)  # [1, 1, T]
                    estimate = self.model(mix_input)  # [1, C, T] for separation, [1, T] for enhancement
                    estimate = estimate.cpu()
                
                # Compute SI-SDR based on task type
                si_sdr_metric = ScaleInvariantSignalDistortionRatio()
                
                if is_separation:
                    # For separation: use permutation-invariant SI-SDR
                    if estimate.dim() == 3:  # [1, 2, T]
                        # Compute PIT SI-SDR
                        sisdr_avg, sisdr_spk1, sisdr_spk2, best_perm = self._compute_pit_siSDR(
                            estimate, clean, si_sdr_metric
                        )
                        
                        print("\n" + "="*60)
                        print(f"Task: {task} | Variant: {variant} | Sample: {sample_id}")
                        print("="*60)
                        print(f"SI-SDR (Speaker 1):     {sisdr_spk1.item():>8.2f} dB")
                        print(f"SI-SDR (Speaker 2):     {sisdr_spk2.item():>8.2f} dB")
                        print(f"SI-SDR (Average):       {sisdr_avg.item():>8.2f} dB")
                        if best_perm == 1:
                            print(f"Note: Best permutation swapped speaker order")
                        print("="*60 + "\n")
                    else:
                        raise ValueError(f"Expected 3D output [1, 2, T] for separation, got {estimate.shape}")
                else:
                    # For enhancement: clean is [T], estimate is [1, T]
                    if clean.dim() > 1:
                        clean = clean[0]  # Take first speaker if multi-channel
                    
                    si_sdr_mix = si_sdr_metric(mix.cpu(), clean.unsqueeze(0))
                    si_sdr_estimate = si_sdr_metric(estimate, clean.unsqueeze(0))
                    improvement = si_sdr_estimate - si_sdr_mix
                    
                    print("\n" + "="*60)
                    print(f"Task: {task} | Variant: {variant} | Sample: {sample_id}")
                    print("="*60)
                    print(f"SI-SDR (Mix):      {si_sdr_mix.item():>8.2f} dB")
                    print(f"SI-SDR (Estimate): {si_sdr_estimate.item():>8.2f} dB")
                    print(f"Improvement:       {improvement.item():>8.2f} dB")
                    print("="*60 + "\n")
                
                # Visualize waveforms (only if checkbox is checked)
                if self.show_plots_checkbox.value:
                    if is_separation and estimate.dim() == 3:
                        self._plot_separation_waveforms(
                            mix.squeeze(0).cpu(),
                            clean,
                            estimate.squeeze(0)  # [2, T]
                        )
                    else:
                        self._plot_waveforms(
                            mix.squeeze(0).cpu(),
                            clean.squeeze(0) if clean.dim() > 1 else clean,
                            estimate.squeeze(0)
                        )
                
                # Audio playback
                if is_separation:
                    print("\nAudio Playback:")
                    print("Mix:")
                    display(Audio(mix.squeeze(0).cpu().numpy(), rate=self.sample_rate))
                    print("\nClean Speaker 1:")
                    display(Audio(clean[0].numpy(), rate=self.sample_rate))
                    print("\nClean Speaker 2:")
                    display(Audio(clean[1].numpy(), rate=self.sample_rate))
                    print("\nEstimated Speaker 1:")
                    display(Audio(estimate[0, 0].numpy(), rate=self.sample_rate))
                    print("\nEstimated Speaker 2:")
                    display(Audio(estimate[0, 1].numpy(), rate=self.sample_rate))
                else:
                    print("\nAudio Playback:")
                    print("Mix:")
                    display(Audio(mix.squeeze(0).cpu().numpy(), rate=self.sample_rate))
                    print("\nClean (Target):")
                    display(Audio(clean.squeeze(0).numpy() if clean.dim() > 1 else clean.numpy(), rate=self.sample_rate))
                    print("\nEstimate (Output):")
                    display(Audio(estimate.squeeze(0).numpy(), rate=self.sample_rate))
                
            except Exception as e:
                print(f"✗ Error during testing: {e}")
                import traceback
                traceback.print_exc()
    
    def _plot_waveforms(self, mix, clean, estimate):
        """Plot waveforms for comparison (enhancement tasks)."""
        fig, axes = plt.subplots(3, 1, figsize=(12, 6), sharex=True)
        
        time = np.arange(len(mix)) / self.sample_rate
        
        axes[0].plot(time, mix.numpy(), linewidth=0.5)
        axes[0].set_ylabel('Mix')
        axes[0].set_title('Waveform Comparison')
        axes[0].grid(True, alpha=0.3)
        
        axes[1].plot(time, clean.numpy(), linewidth=0.5, color='green')
        axes[1].set_ylabel('Clean (Target)')
        axes[1].grid(True, alpha=0.3)
        
        axes[2].plot(time, estimate.numpy(), linewidth=0.5, color='orange')
        axes[2].set_ylabel('Estimate')
        axes[2].set_xlabel('Time (s)')
        axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def _plot_separation_waveforms(self, mix, clean, estimate):
        """Plot waveforms for speaker separation tasks."""
        fig, axes = plt.subplots(5, 1, figsize=(12, 10), sharex=True)
        
        time = np.arange(len(mix)) / self.sample_rate
        
        axes[0].plot(time, mix.numpy(), linewidth=0.5)
        axes[0].set_ylabel('Mix')
        axes[0].set_title('Speaker Separation Results')
        axes[0].grid(True, alpha=0.3)
        
        axes[1].plot(time, clean[0].numpy(), linewidth=0.5, color='green')
        axes[1].set_ylabel('Clean Spk 1')
        axes[1].grid(True, alpha=0.3)
        
        axes[2].plot(time, clean[1].numpy(), linewidth=0.5, color='darkgreen')
        axes[2].set_ylabel('Clean Spk 2')
        axes[2].grid(True, alpha=0.3)
        
        axes[3].plot(time, estimate[0].numpy(), linewidth=0.5, color='orange')
        axes[3].set_ylabel('Est. Spk 1')
        axes[3].grid(True, alpha=0.3)
        
        axes[4].plot(time, estimate[1].numpy(), linewidth=0.5, color='darkorange')
        axes[4].set_ylabel('Est. Spk 2')
        axes[4].set_xlabel('Time (s)')
        axes[4].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def display(self):
        """Display the interactive UI."""
        # Layout
        ui = widgets.VBox([
            widgets.HTML("<h3>Model Testing Interface</h3>"),
            self.checkpoint_dropdown,
            widgets.HBox([self.task_dropdown, self.variant_dropdown]),
            self.sample_id_text,
            self.show_plots_checkbox,
            widgets.HBox([self.load_button, self.test_button]),
            widgets.HTML("<hr>"),
            self.output
        ])
        
        display(ui)

# Create and display tester
if available_checkpoints:
    tester = ModelTester(available_checkpoints, DATA_ROOT, SAMPLE_RATE)
    tester.display()
else:
    print("No checkpoints available. Please train a model first.")