# IWAE vs VAE: Analysis Notebook

This notebook aggregates the analysis results for the comparison between VAE (K=1), IWAE (K=5), and IWAE (K=20).

## 1. Setup & Imports

In [None]:
import torch
import os
import sys
import matplotlib.pyplot as plt

# Add src to path
sys.path.append(os.path.abspath('.'))

from src.models.vae import VAE
from src.models.iwae import IWAE
from src.data.mnist_loader import get_dataloaders
from src.analysis.visualize import plot_reconstruction, plot_samples
from src.analysis.active_units import calc_active_units, plot_kl_stats
from src.analysis.evaluate_likelihood import evaluate_model
from src.analysis.gradient_variance import compute_gradient_variance

# Config
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Paths
vae_path = 'checkpoints/vae_k1_epochs50_seed42.pt'
iwae_k5_path = 'checkpoints/iwae_k5_epochs50_seed42.pt'
iwae_k20_path = 'checkpoints/iwae_k20_epochs50_seed42.pt'
iwae_k50_path = 'checkpoints/iwae_k50_epochs50_seed42.pt'
iwae_k100_path = 'checkpoints/iwae_k100_epochs50_seed42.pt'

output_dir = './notebook_results'
os.makedirs(output_dir, exist_ok=True)

Using device: cuda


## 2. Load Models and Data

In [2]:
input_size = 784
hidden_size = 200
latent_size = 50
output_size = 784

# Load VAE
vae = VAE(input_size, hidden_size, latent_size, output_size)
vae.load_state_dict(torch.load(vae_path, map_location=device))
vae.to(device)
print("VAE Loaded.")

# Load IWAE K=5
iwae_k5 = IWAE(5, input_size, hidden_size, latent_size, output_size)
iwae_k5.load_state_dict(torch.load(iwae_k5_path, map_location=device))
iwae_k5.to(device)
print("IWAE (K=5) Loaded.")

# Load IWAE K=20
iwae_k20 = IWAE(20, input_size, hidden_size, latent_size, output_size)
iwae_k20.load_state_dict(torch.load(iwae_k20_path, map_location=device))
iwae_k20.to(device)
print("IWAE (K=20) Loaded.")

# Data
train_loader, val_loader, test_loader = get_dataloaders(batch_size=32)

VAE Loaded.
IWAE (K=5) Loaded.
IWAE (K=20) Loaded.
Data Loaded: Train 50000, Val 10000, Test 10000


## 3. Qualitative Comparison: Reconstructions & Samples

In [3]:
# VAE Visuals
plot_reconstruction(vae, test_loader, device, f"{output_dir}/vae_recon.png")
plot_samples(vae, device, f"{output_dir}/vae_samples.png")

# IWAE K=5 Visuals
plot_reconstruction(iwae_k5, test_loader, device, f"{output_dir}/iwae_k5_recon.png")
plot_samples(iwae_k5, device, f"{output_dir}/iwae_k5_samples.png")

# IWAE K=20 Visuals
plot_reconstruction(iwae_k20, test_loader, device, f"{output_dir}/iwae_k20_recon.png")
plot_samples(iwae_k20, device, f"{output_dir}/iwae_k20_samples.png")

# Display (if running in interactive notebook)
print("Visualizations saved to ./notebook_results")
# plt.imshow(plt.imread(f"{output_dir}/vae_recon.png"))
# plt.show()

Reconstructions saved to ./notebook_results/vae_recon.png
Samples saved to ./notebook_results/vae_samples.png
Reconstructions saved to ./notebook_results/iwae_k5_recon.png
Samples saved to ./notebook_results/iwae_k5_samples.png
Reconstructions saved to ./notebook_results/iwae_k20_recon.png
Samples saved to ./notebook_results/iwae_k20_samples.png
Visualizations saved to ./notebook_results


## 4. Quantitative Comparison: Log-Likelihood (IWAE bound, K=5000)

In [None]:
k_eval = 5000
print(f"Estimating LL with K={k_eval}...")

# We strictly use IWAE logic for evaluation (even for VAE weights)
evaluator = IWAE(k_eval, input_size, hidden_size, latent_size, output_size).to(device)

# Evaluate VAE
evaluator.load_state_dict(vae.state_dict())
vae_ll = evaluate_model(evaluator, test_loader, device, k_eval)
print(f"VAE LL: {vae_ll:.4f}")

# Evaluate IWAE K=5
evaluator.load_state_dict(iwae_k5.state_dict())
iwae_k5_ll = evaluate_model(evaluator, test_loader, device, k_eval)
print(f"IWAE (K=5) LL: {iwae_k5_ll:.4f}")

# Evaluate IWAE K=20
evaluator.load_state_dict(iwae_k20.state_dict())
iwae_k20_ll = evaluate_model(evaluator, test_loader, device, k_eval)
print(f"IWAE (K=20) LL: {iwae_k20_ll:.4f}")


Estimating LL with K=5000...


                                                                         

VAE LL: -81.0650


                                                                         

IWAE (K=5) LL: -78.3691


                                                                         

IWAE (K=20) LL: -77.3391
Improvement (K=1->5): 2.6959 nats
Improvement (K=5->20): 1.0300 nats




In [7]:
print(f"Improvement (K=1->5): {iwae_k5_ll - vae_ll:.4f} nats")
print(f"Improvement (K=5->20): {iwae_k20_ll - iwae_k5_ll:.4f} nats")
print(f"Improvement (K=1->20): {iwae_k20_ll - vae_ll:.4f} nats")

Improvement (K=1->5): 2.6959 nats
Improvement (K=5->20): 1.0300 nats
Improvement (K=1->20): 3.7259 nats


## 5. Posterior Collapse Analysis: Effective KL

In [5]:
print("Analyzing Active Units...")

n_vae, vae_kls = calc_active_units(vae, test_loader, device)
n_iwae_k5, iwae_k5_kls = calc_active_units(iwae_k5, test_loader, device)
n_iwae_k20, iwae_k20_kls = calc_active_units(iwae_k20, test_loader, device)

print(f"VAE Active Units: {n_vae}")
print(f"IWAE (K=5) Active Units: {n_iwae_k5}")
print(f"IWAE (K=20) Active Units: {n_iwae_k20}")

plot_kl_stats(vae_kls, "VAE", f"{output_dir}/vae_kl.png")
plot_kl_stats(iwae_k5_kls, "IWAE K=5", f"{output_dir}/iwae_k5_kl.png")
plot_kl_stats(iwae_k20_kls, "IWAE K=20", f"{output_dir}/iwae_k20_kl.png")

Analyzing Active Units...


                                                                    

VAE Active Units: 15
IWAE (K=5) Active Units: 22
IWAE (K=20) Active Units: 26
Plot saved to ./notebook_results/vae_kl.png
Plot saved to ./notebook_results/iwae_k5_kl.png
Plot saved to ./notebook_results/iwae_k20_kl.png


## 6. Gradient Variance Analysis (SNR)

In [6]:
# Get a batch
data, _ = next(iter(train_loader))
data = data.to(device)

print("Comparing Gradient SNR...")

var_vae, snr_vae = compute_gradient_variance(vae, data, n_runs=50)
print(f"VAE SNR: {snr_vae:.4f}")

var_iwae_k5, snr_iwae_k5 = compute_gradient_variance(iwae_k5, data, n_runs=50)
print(f"IWAE (K=5) SNR: {snr_iwae_k5:.4f}")

var_iwae_k20, snr_iwae_k20 = compute_gradient_variance(iwae_k20, data, n_runs=50)
print(f"IWAE (K=20) SNR: {snr_iwae_k20:.4f}")

Comparing Gradient SNR...
Collecting gradients over 50 runs...
VAE SNR: 0.5119
Collecting gradients over 50 runs...
IWAE (K=5) SNR: 0.2917
Collecting gradients over 50 runs...
IWAE (K=20) SNR: 0.2487
