# Extraction of model metrics

In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

## Load a model and data

In [4]:
import torch
from pathlib import Path
import yaml

from kvae.model.model import KVAE
from kvae.utils.config import KVAEConfig
from kvae.train.utils import parse_device, build_dataloaders
from kvae.train.train import evaluate
from kvae.train.imputation import impute_epoch


def load_checkpoint(checkpoint_path, device='auto'):
    """
    Load a trained KVAE model from checkpoint.
    
    Args:
        checkpoint_path: Path to checkpoint file (.pt)
        device: Device to load model on ('auto', 'cuda', 'cpu', 'mps')
    
    Returns:
        model: Loaded KVAE model
        checkpoint: Full checkpoint dictionary with training info
        device: The actual device being used
    """
    checkpoint_path = Path(checkpoint_path)
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    
    device = parse_device(device)
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Initialize model with config
    cfg = KVAEConfig()
    model = KVAE(cfg).to(device)
    
    # Load weights
    model.load_state_dict(checkpoint['model_state'])
    model.eval()
    
    print(f"✓ Loaded checkpoint from epoch {checkpoint['epoch']}")
    print(f"  Train loss: {checkpoint['train_loss']:.6f}")
    print(f"  Val loss: {checkpoint['val_loss']:.6f}")
    
    return model, checkpoint, device


# Example usage:
# runs/20251204-214705
# runs/20251212-201237
# runs/20251212-205425
runs_path = Path("../runs/20251212-205425")
checkpoint_path = runs_path / "checkpoints/kvae-best.pt"

model, ckpt, device = load_checkpoint(checkpoint_path, device='cpu')
print(f"\nModel loaded on: {device}")


✓ Loaded checkpoint from epoch 6
  Train loss: 7.899650
  Val loss: 7.172936

Model loaded on: cpu


In [5]:
# Load config from the run directory
config_path = runs_path / "config.yaml"

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

# Load data
train_loader, val_loader = build_dataloaders(
    config['dataset'], 
    batch_size=config['training']['batch_size']
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Using device: {device}")


Configuration:
  dataset: {'kwargs': {'load_in_memory': True, 'normalize': False, 'seq_len': 20}, 'num_workers': 6, 'path': '/Users/rodrigopaganini/master/data/pgm/kvae/box.npz', 'type': 'pymunk', 'val_split': 0.2}
  training: {'add_imputation_plots': True, 'batch_size': 32, 'ckpt_every': 5, 'device': 'mps', 'gpus': 1, 'logdir': 'runs', 'lr': 0.007, 'max_epochs': 100, 'pretrain_vae_epochs': 0, 'seed': 10, 'warmup_epochs': 5}
  transforms: {'add_noise_std': 0.0}
Train batches: 125
Val batches: 32
Using device: cpu


## Define metrics and evaluate

In [None]:
def fraction_of_incorrect_pixels(y_true, y_pred):
    """
    Compute the fraction of incorrect pixels between two binary images.
    Args:
        y_true: Ground truth binary image (numpy array or torch tensor)
        y_pred: Predicted binary image (numpy array or torch tensor)
    Returns:
        Fraction of incorrect pixels (float)
    """
    if isinstance(y_true, torch.Tensor):
        y_true = y_true.cpu().numpy()
    if isinstance(y_pred, torch.Tensor):
        y_pred = y_pred.cpu().numpy()
    
    incorrect_pixels = np.sum(y_true != y_pred)
    total_pixels = y_true.size
    
    return incorrect_pixels / total_pixels if total_pixels > 0 else 0.0
metrics_functions = {
    "fraction_of_incorrect_pixels": fraction_of_incorrect_pixels,
}

def evaluate_metrics(model, loader, device, metrics_functions, mask_sampling_fn, output_key="x_recon"):
    model.eval()
    n_batches  = 0

    total_metrics = {key: 0.0 for key in metrics_functions.keys()}
    for batch in tqdm(loader, desc="Evaluating:"):
        model.kalman_filter.dyn_params.reset_state()

        x = batch["images"].float().to(device)
        B, T = x.shape[:2]

        # Fully observed evaluation (no masking)
        mask = mask_sampling_fn(B, T).to(device)
        with torch.no_grad():
            outputs = model(x, mask=mask)
            for metric_name, metric_fn in metrics_functions.items():
                total_metrics[metric_name] += metric_fn(x, outputs[output_key])
            del outputs
        n_batches += 1

    denom = max(n_batches, 1)
    mean_metrics = {k: v / denom for k, v in total_metrics.items()}

    return mean_metrics


def evaluate_impute_metrics(model, loader, device, metrics_functions, mask_sampling_fn, output_key="x_recon"):
    model.eval()
    n_batches  = 0

    total_metrics = {key: 0.0 for key in metrics_functions.keys()}
    for batch in tqdm(loader, desc="Evaluating:"):
        model.kalman_filter.dyn_params.reset_state()

        x = batch["images"].float().to(device)
        B, T = x.shape[:2]

        # Fully observed evaluation (no masking)
        mask = mask_sampling_fn(B, T).to(device)
        print("mask percentage:", mask.mean().item())
        with torch.no_grad():
            outputs = model.impute(x, mask)
            for metric_name, metric_fn in metrics_functions.items():
                total_metrics[metric_name] += metric_fn(x, outputs[output_key])
            del outputs
        n_batches += 1


    denom = max(n_batches, 1)
    mean_metrics = {k: v / denom for k, v in total_metrics.items()}

    return mean_metrics

full_mask_sampling = lambda B, T: torch.ones(B, T, device=device, dtype=torch.float32)

In [None]:
metrics = evaluate_metrics(model, val_loader, device, metrics_functions, full_mask_sampling)
print("Evaluation Metrics:")
for metric_name, metric_value in metrics.items():
    print(f"  {metric_name}: {metric_value:.6f}")

In [None]:
dropout_rate = 0.5
dropout_mask_sampling = lambda B, T: (torch.rand(B, T, device=device) < (1 - dropout_rate)).float()

metrics = evaluate_impute_metrics(model, val_loader, device, metrics_functions, dropout_mask_sampling, "x_filtered")
print("Evaluation Metrics:")
for metric_name, metric_value in metrics.items():
    print(f"  {metric_name}: {metric_value:.6f}")

In [None]:
dropout_metrics = {dr: {}  for dr in torch.linspace(0.2, 1.0, steps=5)}
for drop_rate in dropout_metrics.keys():
    dropout_mask_sampling = lambda B, T: (torch.rand(B, T, device=device) < (1 - drop_rate)).float()
    dropout_metrics[drop_rate] = evaluate_metrics(model, val_loader, device, metrics_functions, dropout_mask_sampling)

In [None]:
plt.figure(figsize=(8, 6))
for metric_name in metrics_functions.keys():
    values = [dropout_metrics[dr][metric_name] for dr in dropout_metrics.keys()]
    plt.plot(list(dropout_metrics.keys()), values, marker='o', label=metric_name)
plt.xlabel("Dropout Rate")
plt.ylabel("Metric Value")
plt.title("Metrics vs Dropout Rate")
plt.legend()
plt.grid()
plt.show()