# Vanilla VAE Interactive Notebook 🧠📊

This notebook provides an interactive interface for exploring a vanilla Variational Autoencoder (VAE) trained on medical imaging data.

## Features:
- 🔄 **Reconstruct samples** from the dataset
- 🎨 **Generate new images** from the latent space  
- 📊 **Interactive visualizations** with widgets
- 🔍 **Compare** original vs reconstructed images
- 🎯 **Single modality** focus for detailed analysis

---

## 1. Import Required Libraries

In [4]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output
import warnings
warnings.filterwarnings('ignore')

# Add the project root to the Python path
# Get the notebook directory and go up one level to the project root
notebook_dir = Path().resolve()
project_root = notebook_dir.parent  # Go up one level from notebooks/ to project root
print(f"📁 Notebook directory: {notebook_dir}")
print(f"📁 Project root: {project_root}")

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
    print(f"✅ Added project root to Python path")

# Import project modules
from src.models import BaseVAE, BetaVAE
from src.data import MedMNISTDataModule  # Corrected import path
from src.utils import compute_reconstruction_metrics

print("✅ All libraries imported successfully!")
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"🎯 Device available: {'GPU' if torch.cuda.is_available() else 'CPU'}")

📁 Notebook directory: /Users/parsa/Projects/TUDa/DGM/medvae-disentangled-multimodal/notebooks
📁 Project root: /Users/parsa/Projects/TUDa/DGM/medvae-disentangled-multimodal
✅ All libraries imported successfully!
🔧 PyTorch version: 2.7.1
🎯 Device available: CPU
✅ All libraries imported successfully!
🔧 PyTorch version: 2.7.1
🎯 Device available: CPU


## 2. Configuration and Setup

In [5]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥️  Using device: {device}")

# Model configuration for vanilla VAE
MODEL_CONFIG = {
    "input_channels": 1,      # Grayscale images (ChestMNIST)
    "latent_dim": 128,        # Latent space dimension
    "hidden_channels": 128,   # Hidden layer channels
    "resolution": 28,         # Image resolution
    "ch_mult": (1, 2, 4, 8),  # Channel multipliers
    "num_res_blocks": 2,      # Number of residual blocks
    "attn_resolutions": [16], # Attention at specific resolutions
}

# Paths (update these based on your trained models)
CHECKPOINTS_DIR = Path("logs/checkpoints")
DATA_DIR = Path("data")

# Available datasets for vanilla VAE (single modality)
DATASETS = {
    "chestmnist": {
        "name": "ChestMNIST", 
        "channels": 1, 
        "description": "Chest X-Ray Images"
    },
    "pneumoniamnist": {
        "name": "PneumoniaMNIST", 
        "channels": 1, 
        "description": "Pneumonia X-Ray Images"
    }
}

print("⚙️ Configuration loaded successfully!")
print(f"📁 Checkpoints directory: {CHECKPOINTS_DIR}")
print(f"📊 Available datasets: {list(DATASETS.keys())}")

🖥️  Using device: cpu
⚙️ Configuration loaded successfully!
📁 Checkpoints directory: logs/checkpoints
📊 Available datasets: ['chestmnist', 'pneumoniamnist']


## 3. Load Pre-trained Model

In [9]:
def load_vanilla_vae_model(checkpoint_path=None, model_type="base"):
    """Load a vanilla VAE model with optional checkpoint weights."""
    
    if model_type == "base":
        model = BaseVAE(**MODEL_CONFIG)
    elif model_type == "beta":
        model = BetaVAE(beta=4.0, **MODEL_CONFIG)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"📂 Loading checkpoint from: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        
        # Extract model weights from Lightning checkpoint
        model_state_dict = {}
        for key, value in checkpoint["state_dict"].items():
            if key.startswith("model."):
                model_state_dict[key[6:]] = value  # Remove "model." prefix
        
        model.load_state_dict(model_state_dict)
        print("✅ Model weights loaded successfully!")
    else:
        print("⚠️ No checkpoint provided - using randomly initialized weights")
    
    model.to(device)
    model.eval()
    
    print(f"🧠 Model loaded: {model.__class__.__name__}")
    print(f"📊 Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    return model

# Initialize model (you can update the checkpoint path)
model = load_vanilla_vae_model(
    checkpoint_path=f'{project_root}/logs/checkpoints/chest_base_vae_quick-epoch=04-val/loss=0.040.ckpt',  # Update this with your checkpoint path
    model_type="base"      # "base" or "beta"
)

📂 Loading checkpoint from: /Users/parsa/Projects/TUDa/DGM/medvae-disentangled-multimodal/logs/checkpoints/chest_base_vae_quick-epoch=04-val/loss=0.040.ckpt


RuntimeError: Error(s) in loading state_dict for BaseVAE:
	Missing key(s) in state_dict: "encoder.down.0.block.1.norm1.weight", "encoder.down.0.block.1.norm1.bias", "encoder.down.0.block.1.conv1.weight", "encoder.down.0.block.1.conv1.bias", "encoder.down.0.block.1.norm2.weight", "encoder.down.0.block.1.norm2.bias", "encoder.down.0.block.1.conv2.weight", "encoder.down.0.block.1.conv2.bias", "encoder.down.1.block.1.norm1.weight", "encoder.down.1.block.1.norm1.bias", "encoder.down.1.block.1.conv1.weight", "encoder.down.1.block.1.conv1.bias", "encoder.down.1.block.1.norm2.weight", "encoder.down.1.block.1.norm2.bias", "encoder.down.1.block.1.conv2.weight", "encoder.down.1.block.1.conv2.bias", "encoder.down.2.block.1.norm1.weight", "encoder.down.2.block.1.norm1.bias", "encoder.down.2.block.1.conv1.weight", "encoder.down.2.block.1.conv1.bias", "encoder.down.2.block.1.norm2.weight", "encoder.down.2.block.1.norm2.bias", "encoder.down.2.block.1.conv2.weight", "encoder.down.2.block.1.conv2.bias", "encoder.down.2.downsample.conv.weight", "encoder.down.2.downsample.conv.bias", "encoder.down.3.block.0.norm1.weight", "encoder.down.3.block.0.norm1.bias", "encoder.down.3.block.0.conv1.weight", "encoder.down.3.block.0.conv1.bias", "encoder.down.3.block.0.norm2.weight", "encoder.down.3.block.0.norm2.bias", "encoder.down.3.block.0.conv2.weight", "encoder.down.3.block.0.conv2.bias", "encoder.down.3.block.0.nin_shortcut.weight", "encoder.down.3.block.0.nin_shortcut.bias", "encoder.down.3.block.1.norm1.weight", "encoder.down.3.block.1.norm1.bias", "encoder.down.3.block.1.conv1.weight", "encoder.down.3.block.1.conv1.bias", "encoder.down.3.block.1.norm2.weight", "encoder.down.3.block.1.norm2.bias", "encoder.down.3.block.1.conv2.weight", "encoder.down.3.block.1.conv2.bias", "decoder.up.0.block.2.norm1.weight", "decoder.up.0.block.2.norm1.bias", "decoder.up.0.block.2.conv1.weight", "decoder.up.0.block.2.conv1.bias", "decoder.up.0.block.2.norm2.weight", "decoder.up.0.block.2.norm2.bias", "decoder.up.0.block.2.conv2.weight", "decoder.up.0.block.2.conv2.bias", "decoder.up.1.block.2.norm1.weight", "decoder.up.1.block.2.norm1.bias", "decoder.up.1.block.2.conv1.weight", "decoder.up.1.block.2.conv1.bias", "decoder.up.1.block.2.norm2.weight", "decoder.up.1.block.2.norm2.bias", "decoder.up.1.block.2.conv2.weight", "decoder.up.1.block.2.conv2.bias", "decoder.up.2.block.0.nin_shortcut.weight", "decoder.up.2.block.0.nin_shortcut.bias", "decoder.up.2.block.2.norm1.weight", "decoder.up.2.block.2.norm1.bias", "decoder.up.2.block.2.conv1.weight", "decoder.up.2.block.2.conv1.bias", "decoder.up.2.block.2.norm2.weight", "decoder.up.2.block.2.norm2.bias", "decoder.up.2.block.2.conv2.weight", "decoder.up.2.block.2.conv2.bias", "decoder.up.3.block.0.norm1.weight", "decoder.up.3.block.0.norm1.bias", "decoder.up.3.block.0.conv1.weight", "decoder.up.3.block.0.conv1.bias", "decoder.up.3.block.0.norm2.weight", "decoder.up.3.block.0.norm2.bias", "decoder.up.3.block.0.conv2.weight", "decoder.up.3.block.0.conv2.bias", "decoder.up.3.block.1.norm1.weight", "decoder.up.3.block.1.norm1.bias", "decoder.up.3.block.1.conv1.weight", "decoder.up.3.block.1.conv1.bias", "decoder.up.3.block.1.norm2.weight", "decoder.up.3.block.1.norm2.bias", "decoder.up.3.block.1.conv2.weight", "decoder.up.3.block.1.conv2.bias", "decoder.up.3.block.2.norm1.weight", "decoder.up.3.block.2.norm1.bias", "decoder.up.3.block.2.conv1.weight", "decoder.up.3.block.2.conv1.bias", "decoder.up.3.block.2.norm2.weight", "decoder.up.3.block.2.norm2.bias", "decoder.up.3.block.2.conv2.weight", "decoder.up.3.block.2.conv2.bias", "decoder.up.3.upsample.conv.weight", "decoder.up.3.upsample.conv.bias". 
	size mismatch for encoder.conv_in.weight: copying a param with shape torch.Size([32, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 1, 3, 3]).
	size mismatch for encoder.conv_in.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder.down.0.block.0.norm1.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder.down.0.block.0.norm1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder.down.0.block.0.conv1.weight: copying a param with shape torch.Size([32, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for encoder.down.0.block.0.conv1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder.down.0.block.0.norm2.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder.down.0.block.0.norm2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder.down.0.block.0.conv2.weight: copying a param with shape torch.Size([32, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for encoder.down.0.block.0.conv2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder.down.0.downsample.conv.weight: copying a param with shape torch.Size([32, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for encoder.down.0.downsample.conv.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder.down.1.block.0.norm1.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder.down.1.block.0.norm1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for encoder.down.1.block.0.conv1.weight: copying a param with shape torch.Size([64, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
	size mismatch for encoder.down.1.block.0.conv1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.0.norm2.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.0.norm2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.0.conv2.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for encoder.down.1.block.0.conv2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.block.0.nin_shortcut.weight: copying a param with shape torch.Size([64, 32, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1]).
	size mismatch for encoder.down.1.block.0.nin_shortcut.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.1.downsample.conv.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for encoder.down.1.downsample.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.2.block.0.norm1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.2.block.0.norm1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for encoder.down.2.block.0.conv1.weight: copying a param with shape torch.Size([128, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 256, 3, 3]).
	size mismatch for encoder.down.2.block.0.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.0.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.0.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.0.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for encoder.down.2.block.0.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.down.2.block.0.nin_shortcut.weight: copying a param with shape torch.Size([128, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
	size mismatch for encoder.down.2.block.0.nin_shortcut.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for encoder.mid.block_1.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_1.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_1.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for encoder.mid.block_1.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_1.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_1.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_1.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for encoder.mid.block_1.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.attn_1.norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.attn_1.norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.attn_1.q.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 1, 1]).
	size mismatch for encoder.mid.attn_1.q.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.attn_1.k.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 1, 1]).
	size mismatch for encoder.mid.attn_1.k.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.attn_1.v.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 1, 1]).
	size mismatch for encoder.mid.attn_1.v.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.attn_1.proj_out.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 1, 1]).
	size mismatch for encoder.mid.attn_1.proj_out.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_2.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_2.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_2.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for encoder.mid.block_2.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_2.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_2.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.mid.block_2.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for encoder.mid.block_2.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.norm_out.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.norm_out.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for encoder.conv_out.weight: copying a param with shape torch.Size([32, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 1024, 3, 3]).
	size mismatch for encoder.conv_out.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.conv_in.weight: copying a param with shape torch.Size([128, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 128, 3, 3]).
	size mismatch for decoder.conv_in.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_1.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_1.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_1.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for decoder.mid.block_1.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_1.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_1.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_1.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for decoder.mid.block_1.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.attn_1.norm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.attn_1.norm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.attn_1.q.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 1, 1]).
	size mismatch for decoder.mid.attn_1.q.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.attn_1.k.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 1, 1]).
	size mismatch for decoder.mid.attn_1.k.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.attn_1.v.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 1, 1]).
	size mismatch for decoder.mid.attn_1.v.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.attn_1.proj_out.weight: copying a param with shape torch.Size([128, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 1, 1]).
	size mismatch for decoder.mid.attn_1.proj_out.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_2.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_2.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_2.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for decoder.mid.block_2.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_2.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_2.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.mid.block_2.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([1024, 1024, 3, 3]).
	size mismatch for decoder.mid.block_2.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.up.0.block.0.norm1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.0.block.0.norm1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.0.block.0.conv1.weight: copying a param with shape torch.Size([32, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 256, 3, 3]).
	size mismatch for decoder.up.0.block.0.conv1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.0.block.0.norm2.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.0.block.0.norm2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.0.block.0.conv2.weight: copying a param with shape torch.Size([32, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for decoder.up.0.block.0.conv2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.0.block.0.nin_shortcut.weight: copying a param with shape torch.Size([32, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).
	size mismatch for decoder.up.0.block.0.nin_shortcut.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.0.block.1.norm1.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.0.block.1.norm1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.0.block.1.conv1.weight: copying a param with shape torch.Size([32, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for decoder.up.0.block.1.conv1.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.0.block.1.norm2.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.0.block.1.norm2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.0.block.1.conv2.weight: copying a param with shape torch.Size([32, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for decoder.up.0.block.1.conv2.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.up.1.block.0.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.1.block.0.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.1.block.0.conv1.weight: copying a param with shape torch.Size([64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 512, 3, 3]).
	size mismatch for decoder.up.1.block.0.conv1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.0.norm2.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.0.norm2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.0.conv2.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder.up.1.block.0.conv2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.0.nin_shortcut.weight: copying a param with shape torch.Size([64, 128, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 512, 1, 1]).
	size mismatch for decoder.up.1.block.0.nin_shortcut.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.norm1.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.norm1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.conv1.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder.up.1.block.1.conv1.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.norm2.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.norm2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.block.1.conv2.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder.up.1.block.1.conv2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.1.upsample.conv.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for decoder.up.1.upsample.conv.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for decoder.up.2.block.0.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.up.2.block.0.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1024]).
	size mismatch for decoder.up.2.block.0.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 1024, 3, 3]).
	size mismatch for decoder.up.2.block.0.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.0.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.0.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.0.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.block.0.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.norm1.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.norm1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.conv1.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.block.1.conv1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.norm2.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.norm2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.block.1.conv2.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.block.1.conv2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.up.2.upsample.conv.weight: copying a param with shape torch.Size([128, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for decoder.up.2.upsample.conv.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for decoder.norm_out.weight: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.norm_out.bias: copying a param with shape torch.Size([32]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for decoder.conv_out.weight: copying a param with shape torch.Size([1, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([1, 128, 3, 3]).

## 4. Load Dataset

In [None]:
def load_dataset(dataset_name="chestmnist", batch_size=32):
    """Load a specific MedMNIST dataset."""
    
    if dataset_name not in DATASETS:
        raise ValueError(f"Dataset {dataset_name} not available. Choose from: {list(DATASETS.keys())}")
    
    print(f"📊 Loading {DATASETS[dataset_name]['name']} dataset...")
    
    # Create data module
    datamodule = MedMNISTDataModule(
        dataset_name=dataset_name,
        data_dir=str(DATA_DIR),
        batch_size=batch_size,
        num_workers=2
    )
    
    # Setup data
    datamodule.setup()
    
    # Get dataloaders
    train_loader = datamodule.train_dataloader()
    val_loader = datamodule.val_dataloader()
    test_loader = datamodule.test_dataloader()
    
    print(f"✅ Dataset loaded successfully!")
    print(f"🔢 Train samples: {len(datamodule.train_dataset)}")
    print(f"🔢 Val samples: {len(datamodule.val_dataset)}")
    print(f"🔢 Test samples: {len(datamodule.test_dataset)}")
    
    return datamodule, train_loader, val_loader, test_loader

# Load dataset
dataset_name = "chestmnist"  # Change this to "pneumoniamnist" if you prefer
datamodule, train_loader, val_loader, test_loader = load_dataset(dataset_name)

# Get a sample batch for exploration
sample_batch = next(iter(val_loader))
if len(sample_batch) >= 2:
    sample_images, sample_labels = sample_batch[0], sample_batch[1]
    print(f"🖼️ Sample batch shape: {sample_images.shape}")
    print(f"🏷️ Sample labels shape: {sample_labels.shape}")
else:
    sample_images = sample_batch
    sample_labels = None
    print(f"🖼️ Sample batch shape: {sample_images.shape}")

## 5. Interactive Reconstruction Interface

In [None]:
def reconstruct_images(model, images):
    """Reconstruct images using the VAE model."""
    model.eval()
    with torch.no_grad():
        images = images.to(device)
        outputs = model(images)
        reconstructions = outputs["reconstruction"]
        return reconstructions.cpu()

def plot_reconstruction_comparison(original, reconstructed, num_samples=8):
    """Plot original vs reconstructed images side by side."""
    num_samples = min(num_samples, original.shape[0])
    
    fig, axes = plt.subplots(2, num_samples, figsize=(2*num_samples, 4))
    if num_samples == 1:
        axes = axes.reshape(2, 1)
    
    for i in range(num_samples):
        # Original image
        img_orig = original[i].squeeze()
        axes[0, i].imshow(img_orig, cmap='gray')
        axes[0, i].set_title(f'Original {i+1}')
        axes[0, i].axis('off')
        
        # Reconstructed image
        img_recon = reconstructed[i].squeeze()
        axes[1, i].imshow(img_recon, cmap='gray')
        axes[1, i].set_title(f'Reconstructed {i+1}')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Interactive reconstruction widget
@widgets.interact
def interactive_reconstruction(
    batch_index=widgets.IntSlider(min=0, max=9, step=1, value=0, description='Batch:'),
    num_samples=widgets.IntSlider(min=1, max=8, step=1, value=4, description='Samples:')
):
    """Interactive reconstruction interface."""
    
    # Get a batch of images
    val_iter = iter(val_loader)
    for _ in range(batch_index + 1):
        try:
            batch = next(val_iter)
        except StopIteration:
            val_iter = iter(val_loader)
            batch = next(val_iter)
    
    images = batch[0] if isinstance(batch, (list, tuple)) else batch
    
    # Reconstruct images
    reconstructions = reconstruct_images(model, images)
    
    # Plot comparison
    plot_reconstruction_comparison(images, reconstructions, num_samples)
    
    # Compute and display metrics
    if len(images) > 0:
        metrics = compute_reconstruction_metrics(images, reconstructions)
        print("📊 Reconstruction Metrics:")
        for key, value in metrics.items():
            print(f"  {key}: {value:.4f}")

print("🎮 Interactive reconstruction interface ready!")

## 6. Image Generation Interface

In [None]:
def generate_images(model, num_samples=8, seed=None):
    """Generate new images from the latent space."""
    if seed is not None:
        torch.manual_seed(seed)
    
    model.eval()
    with torch.no_grad():
        generated = model.sample(num_samples, device)
        return generated.cpu()

def plot_generated_images(images, title="Generated Images"):
    """Plot a grid of generated images."""
    num_images = images.shape[0]
    cols = min(4, num_images)
    rows = (num_images + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))
    if num_images == 1:
        axes = [axes]
    elif rows == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for i in range(num_images):
        img = images[i].squeeze()
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f'Sample {i+1}')
        axes[i].axis('off')
    
    # Hide empty subplots
    for i in range(num_images, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    plt.show()

# Interactive generation widget
@widgets.interact
def interactive_generation(
    num_samples=widgets.IntSlider(min=1, max=16, step=1, value=8, description='Samples:'),
    seed=widgets.IntSlider(min=0, max=1000, step=1, value=42, description='Seed:'),
    randomize=widgets.Checkbox(value=False, description='Random seed')
):
    """Interactive generation interface."""
    
    # Use random seed if requested
    actual_seed = None if randomize else seed
    
    # Generate images
    generated_images = generate_images(model, num_samples, actual_seed)
    
    # Plot generated images
    plot_generated_images(generated_images, f"Generated Images (seed: {'random' if randomize else seed})")

print("🎨 Interactive generation interface ready!")

## 7. Latent Space Exploration

In [None]:
def interpolate_in_latent_space(model, image1, image2, steps=8):
    """Interpolate between two images in latent space."""
    model.eval()
    with torch.no_grad():
        # Encode both images
        image1, image2 = image1.to(device), image2.to(device)
        
        mu1, logvar1 = model.encode(image1.unsqueeze(0))
        mu2, logvar2 = model.encode(image2.unsqueeze(0))
        
        # Sample from the latent distributions
        z1 = model.reparameterize(mu1, logvar1)
        z2 = model.reparameterize(mu2, logvar2)
        
        # Interpolate
        interpolations = []
        for i in range(steps):
            alpha = i / (steps - 1)
            z_interp = (1 - alpha) * z1 + alpha * z2
            
            # Decode interpolated latent
            recon = model.decode(z_interp)
            interpolations.append(recon.cpu())
        
        return torch.cat(interpolations, dim=0)

def plot_interpolation(original1, original2, interpolations):
    """Plot interpolation sequence."""
    num_interp = len(interpolations)
    
    fig, axes = plt.subplots(1, num_interp + 2, figsize=(2*(num_interp + 2), 3))
    
    # Plot first original
    axes[0].imshow(original1.squeeze(), cmap='gray')
    axes[0].set_title('Original 1')
    axes[0].axis('off')
    
    # Plot interpolations
    for i, img in enumerate(interpolations):
        axes[i + 1].imshow(img.squeeze(), cmap='gray')
        axes[i + 1].set_title(f'Step {i + 1}')
        axes[i + 1].axis('off')
    
    # Plot second original
    axes[-1].imshow(original2.squeeze(), cmap='gray')
    axes[-1].set_title('Original 2')
    axes[-1].axis('off')
    
    plt.suptitle('Latent Space Interpolation', fontsize=16)
    plt.tight_layout()
    plt.show()

# Interactive latent interpolation
@widgets.interact
def interactive_interpolation(
    image1_idx=widgets.IntSlider(min=0, max=31, step=1, value=0, description='Image 1:'),
    image2_idx=widgets.IntSlider(min=0, max=31, step=1, value=15, description='Image 2:'),
    steps=widgets.IntSlider(min=3, max=10, step=1, value=6, description='Steps:')
):
    """Interactive latent space interpolation."""
    
    # Get sample images
    images = sample_images[:32]  # Use first 32 images
    
    if image1_idx >= len(images) or image2_idx >= len(images):
        print("❌ Image index out of range!")
        return
    
    image1 = images[image1_idx]
    image2 = images[image2_idx]
    
    # Perform interpolation
    interpolations = interpolate_in_latent_space(model, image1, image2, steps)
    
    # Plot results
    plot_interpolation(image1, image2, interpolations)

print("🔀 Interactive latent space interpolation ready!")

## 8. Summary & Tips

### 🎯 What you can do with this notebook:

1. **🔄 Reconstruct images**: Use the interactive widget to see how well your VAE reconstructs real medical images
2. **🎨 Generate new images**: Sample from the latent space to create entirely new synthetic medical images  
3. **🔀 Explore latent space**: Interpolate between images to understand the learned representations
4. **📊 Analyze performance**: View reconstruction metrics like MSE, PSNR, and SSIM

### 💡 Tips for better results:

- **Load trained weights**: Update the checkpoint path in the model loading section to use your trained model
- **Try different seeds**: Use the randomize option to explore diverse generations
- **Experiment with interpolation**: Choose different image pairs to see interesting transitions
- **Monitor metrics**: Lower MSE and higher PSNR/SSIM indicate better reconstructions

### 🔧 Customization options:

- Change `dataset_name` to "pneumoniamnist" for different medical images
- Modify `MODEL_CONFIG` to match your trained model architecture
- Adjust `batch_size` and `num_samples` based on your system capabilities

---

**🎉 Happy exploring with your Vanilla VAE!**