In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import sys
import matplotlib.pyplot as plt

import numpy as np
import math

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# Utils functions used across the project

In [None]:
def discover_checkpoints(checkpoint_dir: str = 'checkpoints') -> List[Dict[str, Any]]:
    """
    Discover all checkpoint files and parse metadata from filenames.

    Expected filename format: {model}_{k}{K}_epochs{E}_seed{S}.pt
    Examples:
        - vae_k1_epochs50_seed42.pt
        - iwae_k5_epochs50_seed42.pt
        - iwae_k20_epochs50_seed42.pt

    Returns:
        List of dicts with keys: name, path, type, k, epochs, seed
    """
    checkpoints = []
    pattern = os.path.join(checkpoint_dir, '*.pt')

    # Regex to parse checkpoint filenames
    filename_regex = re.compile(
        r'^(?P<model_type>vae|iwae)_k(?P<k>\d+)_epochs(?P<epochs>\d+)_seed(?P<seed>\d+)\.pt$'
    )

    for filepath in sorted(glob(pattern)):
        filename = os.path.basename(filepath)
        match = filename_regex.match(filename)

        if match:
            model_type = match.group('model_type')
            k = int(match.group('k'))
            epochs = int(match.group('epochs'))
            seed = int(match.group('seed'))

            # Create human-readable name
            if model_type == 'vae':
                name = f'VAE (K={k})'
            else:
                name = f'IWAE (K={k})'

            checkpoints.append({
                'name': name,
                'path': filepath,
                'type': model_type,
                'k': k,
                'epochs': epochs,
                'seed': seed
            })
        else:
            # Non-standard filename, include with minimal info
            print(f"Warning: Could not parse checkpoint filename: {filename}")
            checkpoints.append({
                'name': filename.replace('.pt', ''),
                'path': filepath,
                'type': 'unknown',
                'k': None,
                'epochs': None,
                'seed': None
            })

    return checkpoints


def get_model_key(checkpoint_path: str) -> str:
    """Get a unique key for a model from its checkpoint path."""
    return os.path.basename(checkpoint_path).replace('.pt', '')


def load_results(results_path: str = 'results/evaluations.yaml') -> Dict[str, Any]:
    """
    Load evaluation results from YAML file.

    Returns:
        Dict mapping model keys to their evaluation metrics
    """
    if not os.path.exists(results_path):
        return {}

    with open(results_path, 'r') as f:
        data = yaml.safe_load(f)

    return data if data else {}


def _convert_to_native(obj: Any) -> Any:
    """
    Recursively convert numpy types to native Python types for YAML serialization.
    """
    import numpy as np

    if isinstance(obj, dict):
        return {k: _convert_to_native(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [_convert_to_native(v) for v in obj]
    elif isinstance(obj, (np.integer,)):
        return int(obj)
    elif isinstance(obj, (np.floating,)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj


def save_results(
    results_path: str,
    model_key: str,
    metrics: Dict[str, Any],
    model_info: Optional[Dict[str, Any]] = None
) -> None:
    """
    Save or update evaluation results for a model in the YAML file.

    Args:
        results_path: Path to the YAML file
        model_key: Unique identifier for the model (typically filename without .pt)
        metrics: Dict of metric names to values (e.g., {'log_likelihood': -81.06})
        model_info: Optional dict with model metadata (type, k, path, etc.)
    """
    # Load existing results
    results = load_results(results_path)

    # Initialize entry if doesn't exist
    if model_key not in results:
        results[model_key] = {}

    # Update with model info if provided (convert numpy types)
    if model_info:
        results[model_key].update(_convert_to_native(model_info))

    # Update with new metrics (convert numpy types)
    results[model_key].update(_convert_to_native(metrics))

    # Ensure directory exists
    os.makedirs(os.path.dirname(results_path), exist_ok=True)

    # Save back to YAML
    with open(results_path, 'w') as f:
        yaml.dump(results, f, default_flow_style=False, sort_keys=False)


def get_models_config(
    checkpoint_dir: str = 'checkpoints',
    results_path: str = 'results/evaluations.yaml'
) -> List[Dict[str, Any]]:
    """
    Get full model configurations by merging checkpoint discovery with stored results.

    This is the main function that analysis scripts should use to get model configs
    instead of hardcoding them.

    Returns:
        List of model configs with all available metrics
    """
    # Discover checkpoints
    checkpoints = discover_checkpoints(checkpoint_dir)

    # Load stored results
    results = load_results(results_path)

    # Merge checkpoint info with stored results
    models_config = []
    for checkpoint in checkpoints:
        model_key = get_model_key(checkpoint['path'])
        config = checkpoint.copy()

        # Add stored metrics if available
        if model_key in results:
            stored = results[model_key]
            # Add metrics that aren't already in config
            for key, value in stored.items():
                if key not in config:
                    config[key] = value

        models_config.append(config)

    return models_config

# Functions for Model Analysis and Evaluation

In [None]:
def compute_gradient_variance(model, data, n_runs=100):
    """
    Computes the variance of the gradients for the encoder parameters
    over multiple runs on the SAME batch of data (varying random sampling).
    """
    model.train()

    # Target the first encoder layer
    target_layer_param = model.encoder[0].weight
    grads = []

    for _ in range(n_runs):
        model.zero_grad()
        recon_x, mu, logvar, z = model(data)
        loss = model.compute_loss(data, recon_x, mu, logvar, z)
        loss.backward()

        if target_layer_param.grad is not None:
            grads.append(target_layer_param.grad.clone().cpu().numpy())

    grads = np.array(grads)

    var_per_param = np.var(grads, axis=0)
    mean_grad = np.mean(grads, axis=0)
    avg_variance = np.mean(var_per_param)

    std_per_param = np.std(grads, axis=0) + 1e-10
    snr = np.mean(np.abs(mean_grad) / std_per_param)

    return avg_variance, snr


def analyze_single_checkpoint(
    checkpoint_path: str,
    model_type: str,
    k: int,
    device: str,
    data: torch.Tensor,
    n_runs: int = 50,
    results_path: str = None,
    hidden_size: int = 200,
    latent_size: int = 50
) -> tuple:
    """Analyze gradient variance for a single checkpoint."""
    input_size = 784
    output_size = 784

    if model_type == 'vae':
        model = VAE(input_size, hidden_size, latent_size, output_size)
    else:
        model = IWAE(k, input_size, hidden_size, latent_size, output_size)

    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device)

    variance, snr = compute_gradient_variance(model, data, n_runs=n_runs)

    # Save to YAML if path provided
    if results_path:
        model_key = get_model_key(checkpoint_path)
        save_results(
            results_path=results_path,
            model_key=model_key,
            metrics={
                'gradient_variance': float(f"{variance:.2e}"),
                'gradient_snr': round(snr, 4)
            }
        )

    return variance, snr


def analyze_all_checkpoints(
    checkpoint_dir: str,
    device: str,
    data: torch.Tensor,
    n_runs: int = 50,
    results_path: str = 'results/evaluations.yaml'
):
    """Discover and analyze all checkpoints."""
    checkpoints = discover_checkpoints(checkpoint_dir)

    if not checkpoints:
        print(f"No checkpoints found in {checkpoint_dir}")
        return

    print(f"Found {len(checkpoints)} checkpoints")
    print("-" * 60)

    results = []
    for cp in checkpoints:
        if cp['k'] is None:
            print(f"Skipping {cp['name']} - could not parse K value")
            continue

        print(f"\nAnalyzing {cp['name']}...")

        variance, snr = analyze_single_checkpoint(
            checkpoint_path=cp['path'],
            model_type=cp['type'],
            k=cp['k'],
            device=device,
            data=data,
            n_runs=n_runs,
            results_path=results_path
        )

        results.append((cp['name'], variance, snr))
        print(f"  Variance: {variance:.2e}, SNR: {snr:.4f}")

    print("-" * 60)
    print("\nSummary:")
    print(f"{'Model':<20} | {'Variance':<12} | {'SNR':<10}")
    print("-" * 50)
    for name, var, snr in results:
        print(f"{name:<20} | {var:<12.2e} | {snr:<10.4f}")

    if results_path:
        print(f"\nResults saved to {results_path}")

In [None]:

def get_dataloaders(batch_size=128, data_dir='./data'):
    """
    Returns (train_loader, val_loader, test_loader) for MNIST.

    Preprocessing:
    - Train: Dynamic Binarization (sampling from pixel intensities)
    - Val/Test: Fixed Binarization (rounding at 0.5)

    Split:
    - Train: 50,000 samples
    - Val: 10,000 samples (from the original training set)
    - Test: 10,000 samples (original test set)
    """

    # Transforms
    # Dynamic binarization: interpret pixel value as probability p, sample x ~ Bern(p)
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: torch.bernoulli(x))
    ])

    # Fixed binarization: round to nearest integer (0 or 1)
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: torch.round(x))
    ])

    # Download and load datasets
    # We load train dataset twice to apply different transforms for train vs val
    full_train_dynamic = datasets.MNIST(data_dir, train=True, download=True, transform=train_transform)
    full_train_fixed = datasets.MNIST(data_dir, train=True, download=True, transform=test_transform)
    test_dataset = datasets.MNIST(data_dir, train=False, download=True, transform=test_transform)

    # Create Split Indices (fixed seed for reproducibility of split)
    generator = torch.Generator().manual_seed(42)
    indices = torch.randperm(len(full_train_dynamic), generator=generator)

    train_indices = indices[:50000]
    val_indices = indices[50000:]

    # Create Subsets
    train_dataset = Subset(full_train_dynamic, train_indices)
    val_dataset = Subset(full_train_fixed, val_indices)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    print(f"Data Loaded: Train {len(train_dataset)}, Val {len(val_dataset)}, Test {len(test_dataset)}")

    return train_loader, val_loader, test_loader

In [None]:
def plot_reconstruction(model, dataloader, device, save_path):
    model.eval()
    data, _ = next(iter(dataloader))
    data = data.to(device)

    with torch.no_grad():
        recon, _, _, _ = model(data)
        # If IWAE, recon is (K, B, 784), take mean
        if len(recon.shape) == 3:
            recon = recon.mean(dim=0)

    # Randomly select 8 samples
    batch_size = data.size(0)
    num_samples = min(8, batch_size)
    indices = torch.randperm(batch_size)[:num_samples]
    
    # Reshape
    input_imgs = data[indices].view(-1, 28, 28).cpu()
    recon_imgs = recon[indices].view(-1, 28, 28).cpu()

    fig, axes = plt.subplots(2, 8, figsize=(12, 3))
    for i in range(8):
        # Original
        axes[0, i].imshow(input_imgs[i], cmap='gray')
        axes[0, i].axis('off')
        if i == 0: axes[0, i].set_title("Original")

        # Recon
        axes[1, i].imshow(recon_imgs[i], cmap='gray')
        axes[1, i].axis('off')
        if i == 0: axes[1, i].set_title("Recon")

    plt.tight_layout()
    plt.savefig(save_path)
    print(f"Reconstructions saved to {save_path}")
    plt.close()

def plot_samples(model, device, save_path, n=64):
    model.eval()
    # Sample from prior p(z) ~ N(0, I)
    z = torch.randn(n, model.latent_size).to(device)

    with torch.no_grad():
        samples = model.decoder(z)

    samples = samples.view(-1, 28, 28).cpu()

    # Grid size ( sqrt(n) )
    grid_size = int(n**0.5)
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(8, 8))

    for i, ax in enumerate(axes.flatten()):
        ax.imshow(samples[i], cmap='gray')
        ax.axis('off')

    plt.tight_layout()
    plt.savefig(save_path)
    print(f"Samples saved to {save_path}")
    plt.close()

In [None]:
def compute_batch_kl(mu, logvar):
    """
    Computes individual KL divergence for each dimension z_j.
    Formula: KL = -0.5 * (1 + logvar - mu^2 - exp(logvar))

    Args:
        mu: (BATCH_SIZE, LATENT_SIZE)
        logvar: (BATCH_SIZE, LATENT_SIZE)

    Returns:
        kl_per_dim: (LATENT_SIZE,) - Average KL for each dimension across the batch
    """
    kl_elementwise = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
    kl_per_dim = kl_elementwise.mean(dim=0)
    return kl_per_dim


def calc_active_units(model, dataloader, device, threshold=0.01):
    """
    Runs over the dataset and computes the number of active units.
    A unit is active if Avg_KL(z_j) > threshold.
    """
    model.eval()
    total_kl = 0
    num_batches = 0

    with torch.no_grad():
        for data, _ in tqdm(dataloader, desc="Checking Latents", leave=False):
            data = data.to(device)

            if isinstance(model, IWAE):
                x_flat = data.view(data.size(0), -1)
                h = model.encoder(x_flat)
                mu, logvar = h.chunk(2, dim=1)
            else:
                _, mu, logvar, _ = model(data)

            batch_kl = compute_batch_kl(mu, logvar)
            total_kl += batch_kl
            num_batches += 1

    final_avg_kl = total_kl / num_batches
    num_active = (final_avg_kl > threshold).sum().item()

    return num_active, final_avg_kl


def plot_kl_stats(avg_kl_values, model_name, save_path):
    """Plots the KL value for each dimension sorted."""
    sorted_kl, _ = torch.sort(avg_kl_values, descending=True)
    sorted_kl = sorted_kl.cpu().numpy()

    plt.figure(figsize=(8, 4))
    plt.bar(range(len(sorted_kl)), sorted_kl)
    plt.axhline(y=0.01, color='r', linestyle='--', label='Threshold (0.01)')
    plt.xlabel('Latent Dimensions (Sorted)')
    plt.ylabel('Average KL Divergence (nats)')
    plt.title(f'Effective KL per Dimension - {model_name}')
    plt.legend()
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"Plot saved to {save_path}")
    plt.close()


def analyze_single_checkpoint(
    checkpoint_path: str,
    model_type: str,
    k: int,
    device: str,
    test_loader,
    threshold: float = 0.01,
    output_dir: str = './results',
    results_path: str = None,
    hidden_size: int = 200,
    latent_size: int = 50
) -> int:
    """Analyze a single checkpoint and optionally save results."""
    input_size = 784
    output_size = 784

    if model_type == 'vae':
        model = VAE(input_size, hidden_size, latent_size, output_size)
    else:
        model = IWAE(k, input_size, hidden_size, latent_size, output_size)

    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device)

    n_active, avg_kls = calc_active_units(model, test_loader, device, threshold)

    # Save to YAML if path provided
    if results_path:
        model_key = get_model_key(checkpoint_path)
        save_results(
            results_path=results_path,
            model_key=model_key,
            metrics={'active_units': n_active}
        )

    # Plot
    os.makedirs(output_dir, exist_ok=True)
    base_name = os.path.basename(checkpoint_path).replace('.pt', '')
    plot_kl_stats(
        avg_kls,
        f"{model_type.upper()} (Active: {n_active})",
        os.path.join(output_dir, f"{base_name}_kl.png")
    )

    return n_active


def analyze_all_checkpoints(
    checkpoint_dir: str,
    device: str,
    test_loader,
    threshold: float = 0.01,
    output_dir: str = './results',
    results_path: str = 'results/evaluations.yaml'
):
    """Discover and analyze all checkpoints."""
    checkpoints = discover_checkpoints(checkpoint_dir)

    if not checkpoints:
        print(f"No checkpoints found in {checkpoint_dir}")
        return

    print(f"Found {len(checkpoints)} checkpoints")
    print("-" * 50)

    results = []
    for cp in checkpoints:
        if cp['k'] is None:
            print(f"Skipping {cp['name']} - could not parse K value")
            continue

        print(f"\nAnalyzing {cp['name']}...")

        n_active = analyze_single_checkpoint(
            checkpoint_path=cp['path'],
            model_type=cp['type'],
            k=cp['k'],
            device=device,
            test_loader=test_loader,
            threshold=threshold,
            output_dir=output_dir,
            results_path=results_path
        )

        results.append((cp['name'], n_active))
        print(f"  Active Units: {n_active}/50")

    print("-" * 50)
    print("\nSummary:")
    for name, n_active in results:
        print(f"  {name}: {n_active} active units")

    if results_path:
        print(f"\nResults saved to {results_path}")

In [None]:
def evaluate_model(model, dataloader, device, k_eval):
    """
    Evaluates model using IWAE bound with k_eval samples.
    Returns average Log-Likelihood (nats).
    """
    model.eval()
    total_ll = 0
    total_recon_error = 0
    total_weighted_recon_error = 0
    total_samples = 0

    # Check if model has 'K' attribute and update it temporarily
    original_k = getattr(model, 'K', 1)
    if hasattr(model, 'K'):
        model.K = k_eval

    with torch.no_grad():
        for data, _ in tqdm(dataloader, desc=f"Evaluating with K={k_eval}", leave=False):
            data = data.to(device)
            recon_x, mu, logvar, z = model(data)
            loss = model.compute_loss(data, recon_x, mu, logvar, z)

            # Calculate Reconstruction Error (BCE)
            # Shape of recon_x: (K, B, L) for IWAE/VAE(wrapped), or (B, L) for VAE(native)
            # We standardize to (K, B, L) for computation
            x_flat = data.view(data.size(0), -1)

            if len(recon_x.shape) == 2: # (B, L)
                 # VAE standard
                 bce = torch.nn.functional.binary_cross_entropy(recon_x, x_flat, reduction='sum')
                 recon_error = bce.item()
                 weighted_recon_error = recon_error
            else: # (K, B, L)
                 # IWAE: Compute expected reconstruction error over K samples
                 # Sum over Batch & Latent, Mean over K
                 x_expanded = x_flat.unsqueeze(0).expand_as(recon_x) # (K, B, L)
                 # BCE per sample per batch: (K, B)
                 bce_per_sample = torch.nn.functional.binary_cross_entropy(recon_x, x_expanded, reduction='none').sum(dim=2)

                 # 1. Unweighted Reconstruction Error (Expected Value)
                 recon_error = bce_per_sample.mean(dim=0).sum().item()

                 # 2. Weighted Reconstruction Error (Importance Sampled)
                 log_p_x_given_z = -bce_per_sample # (K, B)

                 # q(z|x) and p(z)
                 if mu.dim() == 2:
                     mu_k = mu.unsqueeze(0).expand_as(z)
                     logvar_k = logvar.unsqueeze(0).expand_as(z)
                 else:
                     mu_k, logvar_k = mu, logvar

                 log_q_z_given_x = -0.5 * (torch.log(2 * torch.tensor(math.pi)) + logvar_k + (z - mu_k).pow(2) / torch.exp(logvar_k))
                 log_q_z_given_x = log_q_z_given_x.sum(dim=2) # (K, B)

                 log_p_z = -0.5 * (torch.log(2 * torch.tensor(math.pi)) + z.pow(2))
                 log_p_z = log_p_z.sum(dim=2) # (K, B)

                 log_w = log_p_x_given_z + log_p_z - log_q_z_given_x # (K, B)
                 w_tilde = torch.softmax(log_w, dim=0) # (K, B)

                 weighted_bce = (w_tilde * bce_per_sample).sum(dim=0) # (B,)
                 weighted_recon_error = weighted_bce.sum().item()

            batch_img_count = data.size(0)
            total_ll += -loss.item() * batch_img_count
            total_samples += batch_img_count
            total_recon_error += recon_error
            total_weighted_recon_error += weighted_recon_error

    # Restore K just in case
    if hasattr(model, 'K'):
        model.K = original_k

    return total_ll / total_samples, total_recon_error / total_samples, total_weighted_recon_error / total_samples


def evaluate_single_checkpoint(
    checkpoint_path: str,
    model_type: str,
    k_train: int,
    k_eval: int,
    device: str,
    test_loader,
    hidden_size: int = 200,
    latent_size: int = 50,
    results_path: str = None
) -> tuple:
    """
    Evaluate a single checkpoint and optionally save results.

    Returns:
        Log-likelihood estimate
    """
    input_size = 784
    output_size = 784

    # Load weights into IWAE evaluator (works for both VAE and IWAE checkpoints)
    evaluator = IWAE(
        K=k_eval,
        input_size=input_size,
        hidden_size=hidden_size,
        latent_size=latent_size,
        output_size=output_size
    )
    evaluator.load_state_dict(torch.load(checkpoint_path, map_location=device))
    evaluator.to(device)

    ll, recon_error, weighted_recon_error = evaluate_model(evaluator, test_loader, device, k_eval)

    # Save to YAML if path provided
    if results_path:
        model_key = get_model_key(checkpoint_path)

        metrics = {
            'log_likelihood': round(ll, 2),
            'reconstruction_error': round(recon_error, 2)
        }

        # Add weighted recon error only for IWAE (K>1)
        if model_type == 'iwae' and k_train > 1:
            metrics['weighted_reconstruction_error'] = round(weighted_recon_error, 2)

        save_results(
            results_path=results_path,
            model_key=model_key,
            metrics=metrics,
            model_info={
                'path': checkpoint_path,
                'type': model_type,
                'k': k_train
            }
        )

    return ll, recon_error, weighted_recon_error


def evaluate_all_checkpoints(
    checkpoint_dir: str,
    k_eval: int,
    device: str,
    test_loader,
    hidden_size: int = 200,
    latent_size: int = 50,
    results_path: str = 'results/evaluations.yaml'
):
    """
    Discover and evaluate all checkpoints in directory.
    """
    checkpoints = discover_checkpoints(checkpoint_dir)

    if not checkpoints:
        print(f"No checkpoints found in {checkpoint_dir}")
        return

    print(f"Found {len(checkpoints)} checkpoints")
    print(f"Evaluation on Test Set (10k images) using IWAE bound with K={k_eval}")
    print("-" * 60)

    results = []
    for cp in checkpoints:
        print(f"\nEvaluating {cp['name']}...")

        ll, recon_error, weighted_recon = evaluate_single_checkpoint(
            checkpoint_path=cp['path'],
            model_type=cp['type'],
            k_train=cp['k'] if cp['k'] else 1,
            k_eval=k_eval,
            device=device,
            test_loader=test_loader,
            hidden_size=hidden_size,
            latent_size=latent_size,
            results_path=results_path
        )

        results.append((cp['name'], ll, recon_error, weighted_recon, cp.get('type'), cp.get('k')))

        if cp.get('type') == 'iwae' and (cp.get('k') or 0) > 1:
            print(f"  {cp['name']} Log-Likelihood: {ll:.2f} nats | Recon: {recon_error:.2f} | Weighted Recon: {weighted_recon:.2f}")
        else:
            print(f"  {cp['name']} Log-Likelihood: {ll:.2f} nats | Recon: {recon_error:.2f}")

    print("-" * 60)
    print("\nSummary:")
    # Sort by Log Likelihood
    for name, ll, recon, weighted, mtype, k in sorted(results, key=lambda x: x[1], reverse=True):
        if mtype == 'iwae' and k and k > 1:
            print(f"  {name:<15}: LL={ll:.2f} | Recon={recon:.2f} | Weighted={weighted:.2f}")
        else:
            print(f"  {name:<15}: LL={ll:.2f} | Recon={recon:.2f}")

    print(f"\nResults saved to {results_path}")

# IWAE & VAE Implementations

In [None]:
class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, output_size):
        """
        Initialise l'architecture du VAE.
        
        Args:
            input_size (int): Dimension de l'entrée (ex: 784 pour MNIST 28x28).
            hidden_size (int): Nombre de neurones dans les couches cachées.
            latent_size (int): Dimension de l'espace latent (le 'goulot d'étranglement').
            output_size (int): Dimension de la sortie reconstruite (généralement = input_size).
        """
        super().__init__()
        
        # --- ENCODEUR ---
        # Compresse l'entrée vers l'espace latent
        self.encoder = nn.Sequential(
            nn.Linear(in_features=input_size, out_features=hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.Tanh(),
            nn.Linear(in_features=hidden_size, out_features=hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.Tanh(),
            # La sortie finale de l'encodeur est 2 * latent_size car on doit
            # prédire à la fois la moyenne (mu) et le log-variance (logvar).
            nn.Linear(in_features=hidden_size, out_features=2 * latent_size)
        )

        # --- DÉCODEUR ---
        # Reconstruit l'image à partir d'un point de l'espace latent
        self.decoder = nn.Sequential(
            nn.Linear(in_features=latent_size, out_features=hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.Tanh(),
            nn.Linear(in_features=hidden_size, out_features=hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.Tanh(),
            nn.Linear(in_features=hidden_size, out_features=output_size),
            # Sigmoid assure que la sortie est entre 0 et 1 (idéal pour des pixels normalisés).
            nn.Sigmoid() 
        )

        self.latent_size = latent_size

    def forward(self, x):
        """
        Passage vers l'avant (Forward pass).
        """
        # Aplatit l'entrée au cas où elle arrive sous forme d'image (B, C, H, W) -> (B, D)
        x = x.view(x.size(0), -1)
        
        # 1. ENCODAGE : Extraction des paramètres de la distribution gaussienne
        h = self.encoder(x)
        # Sépare le vecteur en deux : mu (moyenne) et logvar (log-variance)
        mu, logvar = h.chunk(2, dim=1)

        # 2. REPARAMÉTRISATION (The Reparameterization Trick)
        # On ne peut pas faire de backpropagation à travers un échantillonnage aléatoire.
        # On utilise donc : z = mu + sigma * epsilon, où epsilon ~ N(0, 1)
        std = torch.exp(0.5 * logvar) # Calcul de l'écart-type (sigma)
        eps = torch.randn_like(std)   # Échantillonnage d'un bruit blanc
        z = mu + std * eps            # Échantillonnage de l'espace latent

        # 3. DÉCODAGE : Reconstruction à partir de z
        recon_x = self.decoder(z)
        
        return recon_x, mu, logvar, z

    def compute_loss(self, x, recon_x, mu, logvar):
        """
        Calcule la perte ELBO (Evidence Lower Bound).
        Loss = Erreur de reconstruction + Divergence KL
        """
        x = x.view(x.size(0), -1)
        
        # A. Perte de Reconstruction (BCE) : Compare l'original et la copie
        # On utilise la somme (reduction='sum') pour rester cohérent avec la formule mathématique de la KL.
        bce = F.binary_cross_entropy(recon_x, x, reduction='sum')

        # B. Divergence de Kullback-Leibler (KLD)
        # Force la distribution latente apprise à être proche d'une distribution normale standard N(0,1).
        # Formule fermée : -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return bce + kld

In [None]:

class IWAE(VAE):
    def __init__(self, K, input_size, hidden_size, latent_size, output_size):
        """
        Initialise l'IWAE.
        Args:
            K (int): Nombre d'échantillons d'importance par donnée.
        """
        super().__init__(input_size=input_size, hidden_size=hidden_size, 
                         latent_size=latent_size, output_size=output_size)
        self.K = K

    def forward(self, x):
        """
        Passage vers l'avant avec échantillonnage multiple.
        """
        # Aplatissement de l'entrée : (Batch, ...) -> (Batch, Input_size)
        x = x.view(x.size(0), -1)
        
        # Encodage pour obtenir les paramètres de la distribution latente
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=1) # Dimensions : (Batch, Latent)
        std = torch.exp(0.5 * logvar)
        
        # --- ÉCHANTILLONNAGE MULTIPLE (K échantillons) ---
        # On génère K bruits blancs pour chaque élément du batch.
        # Dimension finale de eps : (K, Batch, Latent)
        eps = torch.randn(self.K, x.size(0), self.latent_size).to(x.device)

        # Reparamétrisation : z a maintenant une dimension (K, Batch, Latent)
        z = mu + eps * std

        # --- DÉCODAGE MASSIVE ---
        # On aplatit z pour passer tous les échantillons (K * Batch) dans le décodeur d'un coup
        z_flat = z.view(-1, self.latent_size)
        recon_x_flat = self.decoder(z_flat)
        
        # On redonne à la reconstruction sa forme multi-échantillons : (K, Batch, Output_size)
        recon_x = recon_x_flat.view(self.K, x.size(0), -1)

        return recon_x, mu, logvar, z

    def compute_loss(self, x, recon_x, mu, logvar, z):
        """
        Calcule la perte IWAE basée sur le Log-Sum-Exp.
        L'objectif est de maximiser la borne : E[log( (1/K) * sum(p(x,z)/q(z|x)) )]
        """
        x = x.view(x.size(0), -1)
        # Répète x pour correspondre aux K échantillons de recon_x : (K, Batch, Input_size)
        x_k = x.unsqueeze(0).repeat(self.K, 1, 1)

        # 1. Log p(x | z) : Log-vraisemblance de la reconstruction (Log-Bernoulli)
        # On somme sur la dimension des pixels (dim=2)
        log_p_x_given_z = -F.binary_cross_entropy(recon_x, x_k, reduction="none").sum(dim=2)

        # 2. Log q(z | x) : Log-densité de la distribution de l'encodeur (Gaussienne)
        # Formule : log N(z; mu, sigma^2)
        log_q_z_given_x = -0.5 * (torch.log(2 * torch.tensor(np.pi)) + logvar + (z - mu).pow(2) / torch.exp(logvar))
        log_q_z_given_x = log_q_z_given_x.sum(dim=2) # Somme sur les dimensions latentes -> (K, Batch)

        # 3. Log p(z) : Log-densité de la priorité (Prior) N(0, 1)
        log_p_z = -0.5 * (torch.log(2 * torch.tensor(np.pi)) + z.pow(2))
        log_p_z = log_p_z.sum(dim=2) # Somme sur les dimensions latentes -> (K, Batch)

        # --- CALCUL DES POIDS D'IMPORTANCE ---
        # log_w = log( p(x,z) / q(z|x) ) = log p(x|z) + log p(z) - log q(z|x)
        log_w = log_p_x_given_z + log_p_z - log_q_z_given_x

        # --- LOG-SUM-EXP TRICK ---
        # Pour éviter les instabilités numériques, on utilise logsumexp pour calculer le log de la moyenne
        # loss = - Moyenne_sur_batch ( log ( (1/K) * sum(exp(log_w)) ) )
        loss = - (torch.logsumexp(log_w, dim=0) - torch.log(torch.tensor(float(self.K)))).mean()

        return loss

# IWAE vs VAE: Analysis Notebook

This notebook aggregates the analysis results for the comparison between VAE (K=1), IWAE (K=5), and IWAE (K=20).

## 1. Setup & Imports

In [None]:
import torch
import os
import sys
import matplotlib.pyplot as plt

# Add src to path
sys.path.append(os.path.abspath('.'))

from src.analysis.evaluate_likelihood import evaluate_model

# Config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Paths
vae_path = 'checkpoints/vae_k1_epochs50_seed42.pt'
iwae_k5_path = 'checkpoints/iwae_k5_epochs50_seed42.pt'
iwae_k20_path = 'checkpoints/iwae_k20_epochs50_seed42.pt'
iwae_k50_path = 'checkpoints/iwae_k50_epochs50_seed42.pt'
iwae_k100_path = 'checkpoints/iwae_k100_epochs50_seed42.pt'

output_dir = './notebook_results'
os.makedirs(output_dir, exist_ok=True)

Using device: cuda


## 2. Load Models and Data

In [2]:
input_size = 784
hidden_size = 200
latent_size = 50
output_size = 784

# Load VAE
vae = VAE(input_size, hidden_size, latent_size, output_size)
vae.load_state_dict(torch.load(vae_path, map_location=device))
vae.to(device)
print("VAE Loaded.")

# Load IWAE K=5
iwae_k5 = IWAE(5, input_size, hidden_size, latent_size, output_size)
iwae_k5.load_state_dict(torch.load(iwae_k5_path, map_location=device))
iwae_k5.to(device)
print("IWAE (K=5) Loaded.")

# Load IWAE K=20
iwae_k20 = IWAE(20, input_size, hidden_size, latent_size, output_size)
iwae_k20.load_state_dict(torch.load(iwae_k20_path, map_location=device))
iwae_k20.to(device)
print("IWAE (K=20) Loaded.")

# Load IWAE K=50
iwae_k50 = IWAE(50, input_size, hidden_size, latent_size, output_size)
iwae_k50.load_state_dict(torch.load(iwae_k50_path, map_location=device))
iwae_k50.to(device)
print("IWAE (K=50) Loaded.")

# Load IWAE K=100
iwae_k100 = IWAE(100, input_size, hidden_size, latent_size, output_size)
iwae_k100.load_state_dict(torch.load(iwae_k100_path, map_location=device))
iwae_k100.to(device)
print("IWAE (K=100) Loaded.")


# Data
train_loader, val_loader, test_loader = get_dataloaders(batch_size=32)

VAE Loaded.
IWAE (K=5) Loaded.
IWAE (K=20) Loaded.
IWAE (K=50) Loaded.
IWAE (K=100) Loaded.
Data Loaded: Train 50000, Val 10000, Test 10000


## 3. Qualitative Comparison: Reconstructions & Samples

In [3]:
# VAE Visuals
plot_reconstruction(vae, test_loader, device, f"{output_dir}/vae_recon.png")
plot_samples(vae, device, f"{output_dir}/vae_samples.png")

# IWAE K=5 Visuals
plot_reconstruction(iwae_k5, test_loader, device, f"{output_dir}/iwae_k5_recon.png")
plot_samples(iwae_k5, device, f"{output_dir}/iwae_k5_samples.png")

# IWAE K=20 Visuals
plot_reconstruction(iwae_k20, test_loader, device, f"{output_dir}/iwae_k20_recon.png")
plot_samples(iwae_k20, device, f"{output_dir}/iwae_k20_samples.png")

# IWAE K=50 Visuals
plot_reconstruction(iwae_k50, test_loader, device, f"{output_dir}/iwae_k50_recon.png")
plot_samples(iwae_k50, device, f"{output_dir}/iwae_k50_samples.png")

# IWAE K=100 Visuals
plot_reconstruction(iwae_k100, test_loader, device, f"{output_dir}/iwae_k100_recon.png")
plot_samples(iwae_k100, device, f"{output_dir}/iwae_k100_samples.png")

# Display (if running in interactive notebook)
print("Visualizations saved to ./notebook_results")
# plt.imshow(plt.imread(f"{output_dir}/vae_recon.png"))
# plt.show()

Reconstructions saved to ./notebook_results/vae_recon.png
Samples saved to ./notebook_results/vae_samples.png
Reconstructions saved to ./notebook_results/iwae_k5_recon.png
Samples saved to ./notebook_results/iwae_k5_samples.png
Reconstructions saved to ./notebook_results/iwae_k20_recon.png
Samples saved to ./notebook_results/iwae_k20_samples.png
Reconstructions saved to ./notebook_results/iwae_k50_recon.png
Samples saved to ./notebook_results/iwae_k50_samples.png
Reconstructions saved to ./notebook_results/iwae_k100_recon.png
Samples saved to ./notebook_results/iwae_k100_samples.png
Visualizations saved to ./notebook_results


## 4. Quantitative Comparison: Log-Likelihood (IWAE bound, K=5000)

In [None]:
k_eval = 5000
print(f"Estimating LL with K={k_eval}...")

# We strictly use IWAE logic for evaluation (even for VAE weights)
evaluator = IWAE(k_eval, input_size, hidden_size, latent_size, output_size).to(device)

# Evaluate VAE
evaluator.load_state_dict(vae.state_dict())
vae_ll = evaluate_model(evaluator, test_loader, device, k_eval)
print(f"VAE LL: {vae_ll:.4f}")

# Evaluate IWAE K=5
evaluator.load_state_dict(iwae_k5.state_dict())
iwae_k5_ll = evaluate_model(evaluator, test_loader, device, k_eval)
print(f"IWAE (K=5) LL: {iwae_k5_ll:.4f}")

# Evaluate IWAE K=20
evaluator.load_state_dict(iwae_k20.state_dict())
iwae_k20_ll = evaluate_model(evaluator, test_loader, device, k_eval)
print(f"IWAE (K=20) LL: {iwae_k20_ll:.4f}")


Estimating LL with K=5000...


                                                                         

VAE LL: -81.0650


                                                                         

IWAE (K=5) LL: -78.3691


                                                                         

IWAE (K=20) LL: -77.3391
Improvement (K=1->5): 2.6959 nats
Improvement (K=5->20): 1.0300 nats




In [7]:
print(f"Improvement (K=1->5): {iwae_k5_ll - vae_ll:.4f} nats")
print(f"Improvement (K=5->20): {iwae_k20_ll - iwae_k5_ll:.4f} nats")
print(f"Improvement (K=1->20): {iwae_k20_ll - vae_ll:.4f} nats")

Improvement (K=1->5): 2.6959 nats
Improvement (K=5->20): 1.0300 nats
Improvement (K=1->20): 3.7259 nats


## 5. Posterior Collapse Analysis: Effective KL

In [5]:
print("Analyzing Active Units...")

n_vae, vae_kls = calc_active_units(vae, test_loader, device)
n_iwae_k5, iwae_k5_kls = calc_active_units(iwae_k5, test_loader, device)
n_iwae_k20, iwae_k20_kls = calc_active_units(iwae_k20, test_loader, device)

print(f"VAE Active Units: {n_vae}")
print(f"IWAE (K=5) Active Units: {n_iwae_k5}")
print(f"IWAE (K=20) Active Units: {n_iwae_k20}")

plot_kl_stats(vae_kls, "VAE", f"{output_dir}/vae_kl.png")
plot_kl_stats(iwae_k5_kls, "IWAE K=5", f"{output_dir}/iwae_k5_kl.png")
plot_kl_stats(iwae_k20_kls, "IWAE K=20", f"{output_dir}/iwae_k20_kl.png")

Analyzing Active Units...


                                                                    

VAE Active Units: 15
IWAE (K=5) Active Units: 22
IWAE (K=20) Active Units: 26
Plot saved to ./notebook_results/vae_kl.png
Plot saved to ./notebook_results/iwae_k5_kl.png
Plot saved to ./notebook_results/iwae_k20_kl.png


## 6. Gradient Variance Analysis (SNR)

In [6]:
# Get a batch
data, _ = next(iter(train_loader))
data = data.to(device)

print("Comparing Gradient SNR...")

var_vae, snr_vae = compute_gradient_variance(vae, data, n_runs=50)
print(f"VAE SNR: {snr_vae:.4f}")

var_iwae_k5, snr_iwae_k5 = compute_gradient_variance(iwae_k5, data, n_runs=50)
print(f"IWAE (K=5) SNR: {snr_iwae_k5:.4f}")

var_iwae_k20, snr_iwae_k20 = compute_gradient_variance(iwae_k20, data, n_runs=50)
print(f"IWAE (K=20) SNR: {snr_iwae_k20:.4f}")

Comparing Gradient SNR...
Collecting gradients over 50 runs...
VAE SNR: 0.5119
Collecting gradients over 50 runs...
IWAE (K=5) SNR: 0.2917
Collecting gradients over 50 runs...
IWAE (K=20) SNR: 0.2487
