# Interactive Model Testing

This notebook provides an interactive interface for testing trained speech separation models with dropdown widgets.

## Features

- Auto-discovers available checkpoints from the new hierarchical structure
- Dropdown selectors for model, task, variant, and sample
- Recognizes 'latest' symlinks for easy access
- Single-button model loading and testing
- Audio playback and waveform visualization

In [5]:
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from IPython.display import Audio, display
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual
import yaml
import sys

# Add project root to path
sys.path.insert(0, str(Path.cwd()))

from config import Config
from models import get_model
from datasets import get_dataset
from torchmetrics.audio import ScaleInvariantSignalDistortionRatio

## 1. Checkpoint Discovery

Scan the checkpoint directory structure to find all available models.

In [6]:
def discover_checkpoints(checkpoint_root='checkpoints'):
    """Discover all available checkpoints in hierarchical structure.
    
    Supports both patterns:
    - Standard: checkpoints/{model}/{task}/{run_id}/best_model.pt
    - Legacy/misconfigured: checkpoints/{model}/{model}/{task}/{run_id}/best_model.pt
    
    Returns dict: {display_name: checkpoint_path}
    """
    checkpoint_root = Path(checkpoint_root)
    checkpoints = {}
    
    if not checkpoint_root.exists():
        print(f"Warning: Checkpoint directory {checkpoint_root} does not exist")
        return checkpoints
    
    def scan_for_checkpoints(base_path, model_type_prefix=""):
        """Recursively scan for checkpoint files."""
        for item in base_path.iterdir():
            if not item.is_dir():
                continue
            
            # Check if this is a checkpoint file location
            checkpoint_file = item / 'best_model.pt'
            if checkpoint_file.exists():
                # This is a run directory, extract info from path
                path_parts = item.relative_to(checkpoint_root).parts
                
                # Handle both structures:
                # Standard: model/task/run_id
                # Legacy: model/model/task/run_id
                if len(path_parts) >= 3:
                    if path_parts[0] == path_parts[1]:
                        # Legacy structure: model/model/task/run_id
                        model_type = path_parts[0]
                        task = path_parts[2]
                        run_id = path_parts[3] if len(path_parts) > 3 else path_parts[2]
                    else:
                        # Standard structure: model/task/run_id
                        model_type = path_parts[0]
                        task = path_parts[1]
                        run_id = path_parts[2] if len(path_parts) > 2 else path_parts[1]
                    
                    # Create display name
                    if run_id == 'latest' or item.name == 'latest':
                        display_name = f"{model_type}/{task}/[LATEST]"
                    else:
                        display_name = f"{model_type}/{task}/{run_id}"
                    
                    checkpoints[display_name] = str(checkpoint_file)
            else:
                # Recurse into subdirectories
                scan_for_checkpoints(item, model_type_prefix)
    
    # Start scanning
    scan_for_checkpoints(checkpoint_root)
    
    return checkpoints

# Discover available checkpoints
available_checkpoints = discover_checkpoints()

if available_checkpoints:
    print(f"Found {len(available_checkpoints)} checkpoint(s):")
    for name in sorted(available_checkpoints.keys()):
        print(f"  - {name}")
else:
    print("No checkpoints found. Train a model first!")

Found 44 checkpoint(s):
  - dprnn/EB/[LATEST]
  - dprnn/EB/run_2025-12-13_01-08-35
  - dprnn/EB/run_2025-12-13_01-09-50
  - dprnn/EB/run_2025-12-13_01-11-09
  - dprnn/EB/run_2025-12-13_01-12-25
  - dprnn/ES/[LATEST]
  - dprnn/ES/run_2025-12-13_01-01-13
  - dprnn/ES/run_2025-12-13_01-02-30
  - dprnn/ES/run_2025-12-13_01-03-45
  - dprnn/ES/run_2025-12-13_01-06-16
  - dprnn/SB/[LATEST]
  - dprnn/SB/run_2025-12-12_22-47-32
  - dprnn/SB/run_2025-12-12_22-50-04
  - dprnn/SB/run_2025-12-12_22-52-34
  - dprnn/SB/run_2025-12-12_22-57-38
  - dprnn/SB/run_2025-12-12_23-00-09
  - dprnn/SB/run_2025-12-12_23-02-42
  - dprnn/SB/run_2025-12-12_23-05-13
  - dprnn/SB/run_2025-12-12_23-07-50
  - dprnn/SB/run_2025-12-12_23-10-30
  - dprnn/SB/run_2025-12-12_23-13-08
  - dprnn/SB/run_2025-12-12_23-15-48
  - dprnn/SB/run_2025-12-12_23-18-28
  - dprnn/SB/run_2025-12-12_23-21-03
  - dprnn/SB/run_2025-12-12_23-26-16
  - dprnn/SB/run_2025-12-12_23-28-52
  - dprnn/SB/run_2025-12-12_23-34-14
  - dprnn/SB/run_2025-

## 2. Configuration

Set up dataset paths and parameters.

In [7]:
# Dataset configuration
# Try to get data root from config.py defaults, or use manual override
try:
    from config import DataConfig
    config_defaults = DataConfig()
    DATA_ROOT = Path(config_defaults.polsess.data_root)
    print(f"Using dataset path from config.py: {DATA_ROOT}")
except Exception as e:
    # Fallback: User can manually set the path here
    DATA_ROOT = Path("F:/PolSMSE/EksperymentyMOWA/BAZY/MOWA/PolSESS_C_both/PolSESS_C_both")
    print(f"Using manual dataset path: {DATA_ROOT}")
    print(f"(Could not load from config: {e})")

SAMPLE_RATE = 8000

# Available variants for PolSESS
VARIANTS = {
    'C': 'Clean (no background)',
    'S': 'Speech only',
    'E': 'Event only',
    'R': 'Reverb only',
    'SE': 'Speech + Event',
    'SR': 'Speech + Reverb',
    'ER': 'Event + Reverb',
    'SER': 'Speech + Event + Reverb'
}

# Tasks
TASKS = {
    'ES': 'Enhance Single speaker',
    'EB': 'Enhance Both speakers',
    'SB': 'Separate Both speakers'
}

print("Configuration loaded.")
print(f"Sample rate: {SAMPLE_RATE} Hz")

# Verify dataset path exists
if not DATA_ROOT.exists():
    print(f"\n⚠️  WARNING: Dataset path does not exist: {DATA_ROOT}")
    print("Please update DATA_ROOT in this cell to point to your PolSESS dataset")
else:
    print(f"✓ Dataset path verified: {DATA_ROOT}")

Using dataset path from config.py: /home/user/datasets/PolSESS_C_both/PolSESS_C_both
Configuration loaded.
Sample rate: 8000 Hz
✓ Dataset path verified: /home/user/datasets/PolSESS_C_both/PolSESS_C_both


## 3. Interactive Model Testing

Use dropdown widgets to select model, task, variant, and sample, then test the model with a single button click.

In [None]:
class ModelTester:
    """Interactive model testing with dropdown widgets."""
    
    def __init__(self, checkpoints, data_root, sample_rate=16000):
        self.checkpoints = checkpoints
        self.data_root = Path(data_root)
        self.sample_rate = sample_rate
        self.model = None
        self.config = None
        self.dataset = None
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Create widgets
        self.checkpoint_dropdown = widgets.Dropdown(
            options=sorted(checkpoints.keys()),
            description='Checkpoint:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='600px')
        )
        
        self.task_dropdown = widgets.Dropdown(
            options=[(f"{k}: {v}", k) for k, v in TASKS.items()],
            description='Task:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='400px')
        )
        
        self.variant_dropdown = widgets.Dropdown(
            options=[(f"{k}: {v}", k) for k, v in VARIANTS.items()],
            description='Variant:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='400px')
        )
        
        self.sample_id_text = widgets.IntText(
            value=0,
            description='Sample ID:',
            style={'description_width': '120px'},
            layout=widgets.Layout(width='200px')
        )
        
        self.load_button = widgets.Button(
            description='Load Model',
            button_style='info',
            icon='download',
            layout=widgets.Layout(width='150px')
        )
        
        self.test_button = widgets.Button(
            description='Test Model',
            button_style='success',
            icon='play',
            layout=widgets.Layout(width='150px')
        )
        
        self.output = widgets.Output()
        
        # Button callbacks
        self.load_button.on_click(self._on_load_clicked)
        self.test_button.on_click(self._on_test_clicked)
        
        # Initially disable test button
        self.test_button.disabled = True
    
    def _on_load_clicked(self, b):
        """Load selected model checkpoint."""
        with self.output:
            self.output.clear_output()
            print("Loading model...")
            
            try:
                checkpoint_path = self.checkpoints[self.checkpoint_dropdown.value]
                checkpoint = torch.load(checkpoint_path, map_location=self.device)
                
                # Load config from checkpoint
                config_dict = checkpoint.get('config', {})
                
                # Also try loading config.yaml if available
                config_yaml_path = Path(checkpoint_path).parent / 'config.yaml'
                if config_yaml_path.exists():
                    with open(config_yaml_path, 'r') as f:
                        config_dict = yaml.safe_load(f)
                
                # Create config object
                from types import SimpleNamespace
                
                def dict_to_namespace(d):
                    if isinstance(d, dict):
                        return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
                    return d
                
                self.config = dict_to_namespace(config_dict)
                
                # Get model class and create instance
                model_type = self.config.model.model_type
                ModelClass = get_model(model_type)
                
                # Get model-specific params
                model_params = getattr(self.config.model, model_type, {})
                if hasattr(model_params, '__dict__'):
                    model_kwargs = vars(model_params)
                else:
                    model_kwargs = {}
                
                # Create model
                self.model = ModelClass(**model_kwargs)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.model.to(self.device)
                self.model.eval()
                
                print(f"✓ Model loaded: {model_type}")
                print(f"  Checkpoint: {checkpoint_path}")
                print(f"  Epoch: {checkpoint.get('epoch', 'N/A')}")
                print(f"  Val SI-SDR: {checkpoint.get('val_sisdr', 'N/A'):.2f} dB")
                print(f"  Device: {self.device}")
                
                # Enable test button
                self.test_button.disabled = False
                
            except Exception as e:
                print(f"✗ Error loading model: {e}")
                import traceback
                traceback.print_exc()
                self.test_button.disabled = True
    
    def _on_test_clicked(self, b):
        """Test model on selected sample."""
        with self.output:
            self.output.clear_output(wait=True)
            print("Testing model...")
            
            try:
                task = self.task_dropdown.value
                variant = self.variant_dropdown.value
                sample_id = self.sample_id_text.value
                
                # Create dataset
                DatasetClass = get_dataset('polsess')
                dataset = DatasetClass(
                    data_root=str(self.data_root),
                    subset='test',
                    task=task,
                    allowed_variants=[variant]
                )
                
                if sample_id >= len(dataset):
                    print(f"✗ Error: Sample ID {sample_id} out of range (max: {len(dataset)-1})")
                    return
                
                # Get sample
                sample = dataset[sample_id]
                mix = sample['mix'].unsqueeze(0).to(self.device)  # [1, T]
                clean = sample['clean']  # [C, T] where C is num speakers (1 or 2)
                
                # Handle different task types
                is_separation = task == 'SB'  # Separate Both speakers
                
                # Run inference
                with torch.no_grad():
                    mix_input = mix.unsqueeze(1)  # [1, 1, T]
                    estimate = self.model(mix_input)  # [1, C, T] for separation, [1, T] for enhancement
                    estimate = estimate.cpu()
                
                # Compute SI-SDR based on task type
                si_sdr_metric = ScaleInvariantSignalDistortionRatio()
                
                if is_separation:
                    # For separation: clean is [2, T], estimate should be [1, 2, T]
                    # Compare mix to sum of both speakers for baseline
                    clean_sum = clean.sum(dim=0, keepdim=True)  # [1, T]
                    si_sdr_mix = si_sdr_metric(mix.cpu(), clean_sum)
                    
                    # Compute SI-SDR for each separated speaker
                    if estimate.dim() == 3:  # [1, 2, T]
                        # Average SI-SDR across both speakers
                        si_sdr_spk1 = si_sdr_metric(estimate[:, 0:1, :], clean[0:1].unsqueeze(0))
                        si_sdr_spk2 = si_sdr_metric(estimate[:, 1:2, :], clean[1:2].unsqueeze(0))
                        si_sdr_estimate = (si_sdr_spk1 + si_sdr_spk2) / 2
                        
                        print("\n" + "="*60)
                        print(f"Task: {task} | Variant: {variant} | Sample: {sample_id}")
                        print("="*60)
                        print(f"SI-SDR (Mix):           {si_sdr_mix.item():>8.2f} dB")
                        print(f"SI-SDR (Speaker 1):     {si_sdr_spk1.item():>8.2f} dB")
                        print(f"SI-SDR (Speaker 2):     {si_sdr_spk2.item():>8.2f} dB")
                        print(f"SI-SDR (Average):       {si_sdr_estimate.item():>8.2f} dB")
                        print(f"Improvement (Avg):      {(si_sdr_estimate - si_sdr_mix).item():>8.2f} dB")
                        print("="*60 + "\n")
                    else:
                        raise ValueError(f"Expected 3D output [1, 2, T] for separation, got {estimate.shape}")
                else:
                    # For enhancement: clean is [T], estimate is [1, T]
                    if clean.dim() > 1:
                        clean = clean[0]  # Take first speaker if multi-channel
                    
                    si_sdr_mix = si_sdr_metric(mix.cpu(), clean.unsqueeze(0))
                    si_sdr_estimate = si_sdr_metric(estimate, clean.unsqueeze(0))
                    improvement = si_sdr_estimate - si_sdr_mix
                    
                    print("\n" + "="*60)
                    print(f"Task: {task} | Variant: {variant} | Sample: {sample_id}")
                    print("="*60)
                    print(f"SI-SDR (Mix):      {si_sdr_mix.item():>8.2f} dB")
                    print(f"SI-SDR (Estimate): {si_sdr_estimate.item():>8.2f} dB")
                    print(f"Improvement:       {improvement.item():>8.2f} dB")
                    print("="*60 + "\n")
                
                # Visualize waveforms
                if is_separation and estimate.dim() == 3:
                    self._plot_separation_waveforms(
                        mix.squeeze(0).cpu(),
                        clean,
                        estimate.squeeze(0)  # [2, T]
                    )
                    
                    # Audio playback for separation
                    print("\nAudio Playback:")
                    print("Mix:")
                    display(Audio(mix.squeeze(0).cpu().numpy(), rate=self.sample_rate))
                    print("\nClean Speaker 1:")
                    display(Audio(clean[0].numpy(), rate=self.sample_rate))
                    print("\nClean Speaker 2:")
                    display(Audio(clean[1].numpy(), rate=self.sample_rate))
                    print("\nEstimated Speaker 1:")
                    display(Audio(estimate[0, 0].numpy(), rate=self.sample_rate))
                    print("\nEstimated Speaker 2:")
                    display(Audio(estimate[0, 1].numpy(), rate=self.sample_rate))
                else:
                    self._plot_waveforms(
                        mix.squeeze(0).cpu(),
                        clean.squeeze(0) if clean.dim() > 1 else clean,
                        estimate.squeeze(0)
                    )
                    
                    # Audio playback for enhancement
                    print("\nAudio Playback:")
                    print("Mix:")
                    display(Audio(mix.squeeze(0).cpu().numpy(), rate=self.sample_rate))
                    print("\nClean (Target):")
                    display(Audio(clean.squeeze(0).numpy() if clean.dim() > 1 else clean.numpy(), rate=self.sample_rate))
                    print("\nEstimate (Output):")
                    display(Audio(estimate.squeeze(0).numpy(), rate=self.sample_rate))
                
            except Exception as e:
                print(f"✗ Error during testing: {e}")
                import traceback
                traceback.print_exc()
    
    def _plot_waveforms(self, mix, clean, estimate):
        """Plot waveforms for comparison (enhancement tasks)."""
        fig, axes = plt.subplots(3, 1, figsize=(12, 6), sharex=True)
        
        time = np.arange(len(mix)) / self.sample_rate
        
        axes[0].plot(time, mix.numpy(), linewidth=0.5)
        axes[0].set_ylabel('Mix')
        axes[0].set_title('Waveform Comparison')
        axes[0].grid(True, alpha=0.3)
        
        axes[1].plot(time, clean.numpy(), linewidth=0.5, color='green')
        axes[1].set_ylabel('Clean (Target)')
        axes[1].grid(True, alpha=0.3)
        
        axes[2].plot(time, estimate.numpy(), linewidth=0.5, color='orange')
        axes[2].set_ylabel('Estimate')
        axes[2].set_xlabel('Time (s)')
        axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def _plot_separation_waveforms(self, mix, clean, estimate):
        """Plot waveforms for speaker separation tasks."""
        fig, axes = plt.subplots(5, 1, figsize=(12, 10), sharex=True)
        
        time = np.arange(len(mix)) / self.sample_rate
        
        axes[0].plot(time, mix.numpy(), linewidth=0.5)
        axes[0].set_ylabel('Mix')
        axes[0].set_title('Speaker Separation Results')
        axes[0].grid(True, alpha=0.3)
        
        axes[1].plot(time, clean[0].numpy(), linewidth=0.5, color='green')
        axes[1].set_ylabel('Clean Spk 1')
        axes[1].grid(True, alpha=0.3)
        
        axes[2].plot(time, clean[1].numpy(), linewidth=0.5, color='darkgreen')
        axes[2].set_ylabel('Clean Spk 2')
        axes[2].grid(True, alpha=0.3)
        
        axes[3].plot(time, estimate[0].numpy(), linewidth=0.5, color='orange')
        axes[3].set_ylabel('Est. Spk 1')
        axes[3].grid(True, alpha=0.3)
        
        axes[4].plot(time, estimate[1].numpy(), linewidth=0.5, color='darkorange')
        axes[4].set_ylabel('Est. Spk 2')
        axes[4].set_xlabel('Time (s)')
        axes[4].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def display(self):
        """Display the interactive UI."""
        # Layout
        ui = widgets.VBox([
            widgets.HTML("<h3>Model Testing Interface</h3>"),
            self.checkpoint_dropdown,
            widgets.HBox([self.task_dropdown, self.variant_dropdown]),
            self.sample_id_text,
            widgets.HBox([self.load_button, self.test_button]),
            widgets.HTML("<hr>"),
            self.output
        ])
        
        display(ui)

# Create and display tester
if available_checkpoints:
    tester = ModelTester(available_checkpoints, DATA_ROOT, SAMPLE_RATE)
    tester.display()
else:
    print("No checkpoints available. Please train a model first.")

VBox(children=(HTML(value='<h3>Model Testing Interface</h3>'), Dropdown(description='Checkpoint:', layout=Layo…

## Instructions

1. **Select Checkpoint**: Choose a trained model from the dropdown (use `[LATEST]` for most recent)
2. **Click Load Model**: Loads the model and displays configuration info
3. **Configure Test**: Select task, variant, and sample ID
4. **Click Test Model**: Runs inference and displays:
   - SI-SDR metrics (mix, estimate, improvement)
   - Waveform comparison plots
   - Audio playback for mix, clean, and estimate

## Tips

- Use the `latest` symlink to always test the most recent model
- Start with sample_id=0 and explore different samples
- Compare different variants (C, S, E, R, SE, SR, ER, SER) to see model robustness
- SI-SDR improvement > 0 dB indicates successful separation