# Vanilla VAE Interactive Notebook 🧠📊

This notebook provides an interactive interface for exploring a vanilla Variational Autoencoder (VAE) trained on medical imaging data.

## Features:
- 🔄 **Reconstruct samples** from the dataset
- 🎨 **Generate new images** from the latent space  
- 📊 **Interactive visualizations** with widgets
- 🔍 **Compare** original vs reconstructed images
- 🎯 **Single modality** focus for detailed analysis

---

## 1. Import Required Libraries

In [2]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output
import warnings
warnings.filterwarnings('ignore')

# Add the project root to the Python path
# Get the notebook directory and go up one level to the project root
notebook_dir = Path().resolve()
project_root = notebook_dir.parent  # Go up one level from notebooks/ to project root
print(f"📁 Notebook directory: {notebook_dir}")
print(f"📁 Project root: {project_root}")

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
    print(f"✅ Added project root to Python path")

# Import project modules
from src.models import BaseVAE, BetaVAE
from src.data import MedMNISTDataModule  # Corrected import path
from src.utils import compute_reconstruction_metrics

print("✅ All libraries imported successfully!")
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"🎯 Device available: {'GPU' if torch.cuda.is_available() else 'CPU'}")

📁 Notebook directory: /Users/parsa/Projects/TUDa/DGM/medvae-disentangled-multimodal/notebooks
📁 Project root: /Users/parsa/Projects/TUDa/DGM/medvae-disentangled-multimodal
✅ All libraries imported successfully!
🔧 PyTorch version: 2.7.1
🎯 Device available: CPU
✅ All libraries imported successfully!
🔧 PyTorch version: 2.7.1
🎯 Device available: CPU


## 2. Configuration and Setup

In [3]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥️  Using device: {device}")

# Model configuration for vanilla VAE
MODEL_CONFIG = {
    "input_channels": 1,      # Grayscale images (ChestMNIST)
    "latent_dim": 128,        # Latent space dimension (modern)
    "hidden_channels": 128,   # Hidden layer channels (modern)
    "resolution": 28,         # Image resolution
    "ch_mult": (1, 2, 4, 8),  # Channel multipliers (modern)
    "num_res_blocks": 2,      # Number of residual blocks (modern)
    "attn_resolutions": [16], # Attention at specific resolutions (modern)
}

# Legacy configuration (used by older checkpoints)
LEGACY_CONFIG = {
    "input_channels": 1,      # Grayscale images
    "latent_dim": 16,         # Smaller latent space (becomes 32 with double_z=True)
    "hidden_channels": 32,    # Fewer hidden channels
    "resolution": 28,         # Image resolution
    "ch_mult": (1, 2, 4),     # Fewer downsampling stages
    "num_res_blocks": 1,      # Single ResNet block per stage
    "attn_resolutions": [],   # No attention to save parameters
    "dropout": 0.1,
    "use_linear_attn": False,
    "attn_type": "vanilla",
    "double_z": True,
}

# Paths (update these based on your trained models)
CHECKPOINTS_DIR = Path("logs/checkpoints")
DATA_DIR = Path("data")

# Available datasets for vanilla VAE (single modality)
DATASETS = {
    "chestmnist": {
        "name": "ChestMNIST", 
        "channels": 1, 
        "description": "Chest X-Ray Images"
    },
    "pneumoniamnist": {
        "name": "PneumoniaMNIST", 
        "channels": 1, 
        "description": "Pneumonia X-Ray Images"
    }
}

print("⚙️ Configuration loaded successfully!")
print(f"📁 Checkpoints directory: {CHECKPOINTS_DIR}")
print(f"📊 Available datasets: {list(DATASETS.keys())}")
print(f"🔧 Modern config: {MODEL_CONFIG['hidden_channels']} hidden channels, {MODEL_CONFIG['latent_dim']} latent dim")
print(f"🔧 Legacy config: {LEGACY_CONFIG['hidden_channels']} hidden channels, {LEGACY_CONFIG['latent_dim']} latent dim")

🖥️  Using device: cpu
⚙️ Configuration loaded successfully!
📁 Checkpoints directory: logs/checkpoints
📊 Available datasets: ['chestmnist', 'pneumoniamnist']
🔧 Modern config: 128 hidden channels, 128 latent dim
🔧 Legacy config: 32 hidden channels, 16 latent dim


## 3. Load Pre-trained Model

In [4]:
class VanillaVAE(BaseVAE):
    """
    Vanilla VAE with the original architecture parameters.
    This matches the older checkpoints that were saved with different parameters.
    """
    def __init__(self, **kwargs):
        # Override default parameters to match the original architecture
        vanilla_config = {
            "input_channels": 1,
            "latent_dim": 16,       # This becomes 32 total with double_z=True (16*2)
            "hidden_channels": 32,  # Original smaller architecture
            "ch_mult": (1, 2, 4),   # Fewer downsampling stages
            "num_res_blocks": 1,    # Single ResNet block per stage
            "attn_resolutions": [], # No attention to save parameters
            "dropout": 0.1,
            "resolution": 28,       # Smaller resolution
            "use_linear_attn": False,
            "attn_type": "vanilla",
            "double_z": True,       # This doubles the latent_dim output
        }
        # Update with any provided kwargs
        vanilla_config.update(kwargs)
        super().__init__(**vanilla_config)

def load_vanilla_vae_model(checkpoint_path=None, model_type="base", use_legacy=False):
    """Load a vanilla VAE model with optional checkpoint weights."""
    
    if use_legacy:
        # Use the original smaller architecture for legacy checkpoints
        model = VanillaVAE()
        print("🔄 Using legacy VanillaVAE architecture (hidden_channels=32, latent_dim=16)")
    elif model_type == "base":
        model = BaseVAE(**MODEL_CONFIG)
    elif model_type == "beta":
        model = BetaVAE(beta=4.0, **MODEL_CONFIG)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"📂 Loading checkpoint from: {checkpoint_path}")
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            
            # Extract model weights from Lightning checkpoint
            model_state_dict = {}
            for key, value in checkpoint["state_dict"].items():
                if key.startswith("model."):
                    model_state_dict[key[6:]] = value  # Remove "model." prefix
            
            model.load_state_dict(model_state_dict)
            print("✅ Model weights loaded successfully!")
        except RuntimeError as e:
            if "size mismatch" in str(e):
                print("⚠️ Architecture mismatch detected! Trying legacy architecture...")
                # Try loading with legacy architecture
                model = VanillaVAE()
                model.to(device)  # Move to device before loading
                
                model_state_dict = {}
                for key, value in checkpoint["state_dict"].items():
                    if key.startswith("model."):
                        model_state_dict[key[6:]] = value
                
                model.load_state_dict(model_state_dict)
                print("✅ Model weights loaded successfully with legacy architecture!")
            else:
                raise e
    else:
        print("⚠️ No checkpoint provided - using randomly initialized weights")
    
    model.to(device)
    model.eval()
    
    print(f"🧠 Model loaded: {model.__class__.__name__}")
    print(f"📊 Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    return model

# Initialize model (you can update the checkpoint path)
model = load_vanilla_vae_model(
    checkpoint_path=f'{project_root}/logs/checkpoints/chest_base_vae_quick-epoch=04-val/loss=0.040.ckpt',  # Update this with your checkpoint path
    model_type="base",      # "base" or "beta"
    use_legacy=False        # Set to True if you know it's a legacy checkpoint
)

📂 Loading checkpoint from: /Users/parsa/Projects/TUDa/DGM/medvae-disentangled-multimodal/logs/checkpoints/chest_base_vae_quick-epoch=04-val/loss=0.040.ckpt
⚠️ Architecture mismatch detected! Trying legacy architecture...
✅ Model weights loaded successfully with legacy architecture!
🧠 Model loaded: VanillaVAE
📊 Total parameters: 2,742,337


## 4. Load Dataset

In [5]:
def load_dataset(dataset_name="chestmnist", batch_size=32):
    """Load a specific MedMNIST dataset."""
    
    if dataset_name not in DATASETS:
        raise ValueError(f"Dataset {dataset_name} not available. Choose from: {list(DATASETS.keys())}")
    
    print(f"📊 Loading {DATASETS[dataset_name]['name']} dataset...")
    
    # Create data module - note that it expects a list of dataset names
    data_path = project_root / "data"  # Use absolute path
    print(f"📁 Data path: {data_path}")
    
    datamodule = MedMNISTDataModule(
        dataset_names=[dataset_name],  # Pass as a list
        batch_size=batch_size,
        num_workers=2,
        size=28,  # Use 28x28 to match the model configuration
        root=str(data_path),  # Use absolute path
        normalize=True,
        augment_train=False  # Disable augmentation for cleaner visualization
    )
    
    # Setup data
    datamodule.setup()
    
    # Get dataloaders
    train_loader = datamodule.train_dataloader()
    val_loader = datamodule.val_dataloader()
    test_loader = datamodule.test_dataloader()
    
    print(f"✅ Dataset loaded successfully!")
    print(f"🔢 Train samples: {len(datamodule.train_dataset)}")
    print(f"🔢 Val samples: {len(datamodule.val_dataset)}")
    print(f"🔢 Test samples: {len(datamodule.test_dataset)}")
    
    return datamodule, train_loader, val_loader, test_loader

# Load dataset
dataset_name = "chestmnist"  # Change this to "pneumoniamnist" if you prefer
datamodule, train_loader, val_loader, test_loader = load_dataset(dataset_name)

# Get a sample batch for exploration
sample_batch = next(iter(val_loader))
if len(sample_batch) >= 2:
    sample_images, sample_labels = sample_batch[0], sample_batch[1]
    print(f"🖼️ Sample batch shape: {sample_images.shape}")
    print(f"🏷️ Sample labels shape: {sample_labels.shape}")
else:
    sample_images = sample_batch
    sample_labels = None
    print(f"🖼️ Sample batch shape: {sample_images.shape}")

📊 Loading ChestMNIST dataset...
📁 Data path: /Users/parsa/Projects/TUDa/DGM/medvae-disentangled-multimodal/data
Loading modality info for chestmnist...
Loading modality info for chestmnist...
  chestmnist: 1 channels
  chestmnist: 1 channels
✅ Dataset loaded successfully!
🔢 Train samples: 78468
🔢 Val samples: 11219
🔢 Test samples: 22433
✅ Dataset loaded successfully!
🔢 Train samples: 78468
🔢 Val samples: 11219
🔢 Test samples: 22433
🖼️ Sample batch shape: torch.Size([32, 1, 28, 28])
🏷️ Sample labels shape: torch.Size([32, 1])
🖼️ Sample batch shape: torch.Size([32, 1, 28, 28])
🏷️ Sample labels shape: torch.Size([32, 1])


## 5. Interactive Reconstruction Interface

In [6]:
def reconstruct_images(model, images):
    """Reconstruct images using the VAE model."""
    model.eval()
    with torch.no_grad():
        images = images.to(device)
        outputs = model(images)
        reconstructions = outputs["reconstruction"]
        return reconstructions.cpu()

def plot_reconstruction_comparison(original, reconstructed, num_samples=8):
    """Plot original vs reconstructed images side by side."""
    num_samples = min(num_samples, original.shape[0])
    
    fig, axes = plt.subplots(2, num_samples, figsize=(2*num_samples, 4))
    if num_samples == 1:
        axes = axes.reshape(2, 1)
    
    for i in range(num_samples):
        # Original image
        img_orig = original[i].squeeze()
        axes[0, i].imshow(img_orig, cmap='gray')
        axes[0, i].set_title(f'Original {i+1}')
        axes[0, i].axis('off')
        
        # Reconstructed image
        img_recon = reconstructed[i].squeeze()
        axes[1, i].imshow(img_recon, cmap='gray')
        axes[1, i].set_title(f'Reconstructed {i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Interactive reconstruction widget
@widgets.interact
def interactive_reconstruction(
    batch_index=widgets.IntSlider(min=0, max=9, step=1, value=0, description='Batch:'),
    num_samples=widgets.IntSlider(min=1, max=8, step=1, value=4, description='Samples:')
):
    """Interactive reconstruction interface."""
    
    # Get a batch of images
    val_iter = iter(val_loader)
    for _ in range(batch_index + 1):
        try:
            batch = next(val_iter)
        except StopIteration:
            val_iter = iter(val_loader)
            batch = next(val_iter)
    
    images = batch[0] if isinstance(batch, (list, tuple)) else batch
    
    # Reconstruct images
    reconstructions = reconstruct_images(model, images)
    
    # Plot comparison
    plot_reconstruction_comparison(images, reconstructions, num_samples)
    
    # Compute and display metrics
    if len(images) > 0:
        metrics = compute_reconstruction_metrics(images, reconstructions)
        print("📊 Reconstruction Metrics:")
        for key, value in metrics.items():
            print(f"  {key}: {value:.4f}")

print("🎮 Interactive reconstruction interface ready!")

interactive(children=(IntSlider(value=0, description='Batch:', max=9), IntSlider(value=4, description='Samples…

🎮 Interactive reconstruction interface ready!


## 6. Image Generation Interface

In [7]:
def generate_images(model, num_samples=8, seed=None):
    """Generate new images from the latent space."""
    if seed is not None:
        torch.manual_seed(seed)
    
    model.eval()
    with torch.no_grad():
        generated = model.sample(num_samples, device)
        return generated.cpu()

def plot_generated_images(images, title="Generated Images"):
    """Plot a grid of generated images."""
    num_images = images.shape[0]
    cols = min(4, num_images)
    rows = (num_images + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))
    if num_images == 1:
        axes = [axes]
    elif rows == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for i in range(num_images):
        img = images[i].squeeze()
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f'Sample {i+1}')
        axes[i].axis('off')
    
    # Hide empty subplots
    for i in range(num_images, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

# Interactive generation widget
@widgets.interact
def interactive_generation(
    num_samples=widgets.IntSlider(min=1, max=16, step=1, value=8, description='Samples:'),
    seed=widgets.IntSlider(min=0, max=1000, step=1, value=42, description='Seed:'),
    randomize=widgets.Checkbox(value=False, description='Random seed')
):
    """Interactive generation interface."""
    
    # Use random seed if requested
    actual_seed = None if randomize else seed
    
    # Generate images
    generated_images = generate_images(model, num_samples, actual_seed)
    
    # Plot generated images
    plot_generated_images(generated_images, f"Generated Images (seed: {'random' if randomize else seed})")

print("🎨 Interactive generation interface ready!")

interactive(children=(IntSlider(value=8, description='Samples:', max=16, min=1), IntSlider(value=42, descripti…

🎨 Interactive generation interface ready!


## 7. Latent Space Exploration

In [8]:
def interpolate_in_latent_space(model, image1, image2, steps=8):
    """Interpolate between two images in latent space."""
    model.eval()
    with torch.no_grad():
        # Encode both images
        image1, image2 = image1.to(device), image2.to(device)
        
        mu1, logvar1 = model.encode(image1.unsqueeze(0))
        mu2, logvar2 = model.encode(image2.unsqueeze(0))
        
        # Sample from the latent distributions
        z1 = model.reparameterize(mu1, logvar1)
        z2 = model.reparameterize(mu2, logvar2)
        
        # Interpolate
        interpolations = []
        for i in range(steps):
            alpha = i / (steps - 1)
            z_interp = (1 - alpha) * z1 + alpha * z2
            
            # Decode interpolated latent
            recon = model.decode(z_interp)
            interpolations.append(recon.cpu())
        
        return torch.cat(interpolations, dim=0)

def plot_interpolation(original1, original2, interpolations):
    """Plot interpolation sequence."""
    num_interp = len(interpolations)
    
    fig, axes = plt.subplots(1, num_interp + 2, figsize=(2*(num_interp + 2), 3))
    
    # Plot first original
    axes[0].imshow(original1.squeeze(), cmap='gray')
    axes[0].set_title('Original 1')
    axes[0].axis('off')
    
    # Plot interpolations
    for i, img in enumerate(interpolations):
        axes[i + 1].imshow(img.squeeze(), cmap='gray')
        axes[i + 1].set_title(f'Step {i + 1}')
        axes[i + 1].axis('off')
    
    # Plot second original
    axes[-1].imshow(original2.squeeze(), cmap='gray')
    axes[-1].set_title('Original 2')
    axes[-1].axis('off')
    
    plt.suptitle('Latent Space Interpolation', fontsize=16)
    plt.tight_layout()
    plt.show()

# Interactive latent interpolation
@widgets.interact
def interactive_interpolation(
    image1_idx=widgets.IntSlider(min=0, max=31, step=1, value=0, description='Image 1:'),
    image2_idx=widgets.IntSlider(min=0, max=31, step=1, value=15, description='Image 2:'),
    steps=widgets.IntSlider(min=3, max=10, step=1, value=6, description='Steps:')
):
    """Interactive latent space interpolation."""
    
    # Get sample images
    images = sample_images[:32]  # Use first 32 images
    
    if image1_idx >= len(images) or image2_idx >= len(images):
        print("❌ Image index out of range!")
        return
    
    image1 = images[image1_idx]
    image2 = images[image2_idx]
    
    # Perform interpolation
    interpolations = interpolate_in_latent_space(model, image1, image2, steps)
    
    # Plot results
    plot_interpolation(image1, image2, interpolations)

print("🔀 Interactive latent space interpolation ready!")

interactive(children=(IntSlider(value=0, description='Image 1:', max=31), IntSlider(value=15, description='Ima…

🔀 Interactive latent space interpolation ready!


## 8. Summary

### 🎯 What you can do with this notebook:

1. **🔄 Reconstruct images**: Use the interactive widget to see how well your VAE reconstructs real medical images
2. **🎨 Generate new images**: Sample from the latent space to create entirely new synthetic medical images  
3. **🔀 Explore latent space**: Interpolate between images to understand the learned representations
4. **📊 Analyze performance**: View reconstruction metrics like MSE, PSNR, and SSIM

---