# Conditional Disentangled VAE Interactive Notebook 🧠🔀📊

This notebook provides an interactive interface for exploring a **Conditional Disentangled VAE** trained on multi-modal medical imaging data.

## Advanced Features:
- 🎯 **Multi-modal reconstruction** across different medical imaging types
- 🔄 **Conditional generation** by modality (ChestMNIST, PathMNIST, etc.)
- 🎨 **Class-conditional generation** within each modality
- 🔀 **Disentangled latent space** exploration (shared vs modality-specific)
- 📊 **Cross-modal comparisons** and analysis
- 🧪 **Modality transfer** experiments

---

## 1. Import Required Libraries

In [1]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output
import warnings
warnings.filterwarnings('ignore')

# Add the project root to the Python path
project_root = Path().resolve().parent  # Go up one level from notebooks to project root
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

# Import project modules
from src.models import ConditionalVAE, DisentangledConditionalVAE
from src.data.multi_modal_datamodule import MultiModalDataModule
from src.utils import compute_reconstruction_metrics

print("✅ All libraries imported successfully!")
print(f"🔧 PyTorch version: {torch.__version__}")
print(f"🎯 Device available: {'GPU' if torch.cuda.is_available() else 'CPU'}")

ModuleNotFoundError: No module named 'ipywidgets'

## 2. Configuration and Setup

In [None]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥️  Using device: {device}")

# Model configuration for Conditional/Disentangled VAE
DISENTANGLED_CONFIG = {
    "num_modalities": 5,
    "shared_latent_dim": 8,
    "modality_latent_dim": 8,
    "resolution": 28,
    "hidden_channels": 128,
    "ch_mult": (1, 2, 4, 8),
    "num_res_blocks": 2,
    "attn_resolutions": [16],
    "modality_separation_weight": 0.1,
    "contrastive_weight": 0.05,
}

CONDITIONAL_CONFIG = {
    "modalities": ["chestmnist", "pathmnist", "octmnist", "pneumoniamnist", "dermamnist"],
    "condition_dim": 5,
    "condition_method": "concat",
    "input_channels": 3,  # Max channels across modalities
    "latent_dim": 128,
    "hidden_channels": 128,
    "resolution": 28,
    "ch_mult": (1, 2, 4, 8),
    "num_res_blocks": 2,
    "attn_resolutions": [16],
}

# Modality information
MODALITIES = {
    0: {"name": "ChestMNIST", "channels": 1, "description": "Chest X-Ray Images", "classes": ["Normal", "Abnormal"]},
    1: {"name": "PathMNIST", "channels": 3, "description": "Colon Pathology Images", "classes": ["ADI", "BACK", "DEB", "LYM", "MUC", "MUS", "NORM", "STR", "TUM"]},
    2: {"name": "OCTMNIST", "channels": 3, "description": "Retinal OCT Images", "classes": ["CNV", "DME", "DRUSEN", "NORMAL"]},
    3: {"name": "PneumoniaMNIST", "channels": 1, "description": "Pneumonia X-Ray Images", "classes": ["Normal", "Pneumonia"]},
    4: {"name": "DermaMNIST", "channels": 3, "description": "Dermatoscope Images", "classes": ["ACK", "BCC", "MEL", "NEV", "PIH", "SEK", "UNK"]},
}

# Paths
CHECKPOINTS_DIR = Path("logs/checkpoints")
DATA_DIR = Path("data")

print("⚙️ Configuration loaded successfully!")
print(f"📁 Checkpoints directory: {CHECKPOINTS_DIR}")
print(f"🏥 Available modalities: {[info['name'] for info in MODALITIES.values()]}")

## 3. Load Pre-trained Model

In [None]:
def load_conditional_model(checkpoint_path=None, model_type="disentangled"):
    """Load a conditional VAE model with optional checkpoint weights."""
    
    if model_type == "disentangled":
        model = DisentangledConditionalVAE(**DISENTANGLED_CONFIG)
        config_used = DISENTANGLED_CONFIG
    elif model_type == "conditional":
        model = ConditionalVAE(**CONDITIONAL_CONFIG)
        config_used = CONDITIONAL_CONFIG
    else:
        raise ValueError(f"Unknown model type: {model_type}. Choose 'disentangled' or 'conditional'")
    
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"📂 Loading checkpoint from: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        
        # Extract model weights from Lightning checkpoint
        model_state_dict = {}
        for key, value in checkpoint["state_dict"].items():
            if key.startswith("model."):
                model_state_dict[key[6:]] = value  # Remove "model." prefix
        
        try:
            model.load_state_dict(model_state_dict)
            print("✅ Model weights loaded successfully!")
        except Exception as e:
            print(f"⚠️ Error loading weights: {e}")
            print("Using randomly initialized weights...")
    else:
        print("⚠️ No checkpoint provided - using randomly initialized weights")
    
    model.to(device)
    model.eval()
    
    print(f"🧠 Model loaded: {model.__class__.__name__}")
    print(f"📊 Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    if model_type == "disentangled":
        print(f"🔀 Shared latent dim: {config_used['shared_latent_dim']}")
        print(f"🎯 Modality latent dim: {config_used['modality_latent_dim']}")
    
    return model

# Model selection widget
model_type_widget = widgets.Dropdown(
    options=[('Disentangled VAE', 'disentangled'), ('Conditional VAE', 'conditional')],
    value='disentangled',
    description='Model Type:'
)

checkpoint_path_widget = widgets.Text(
    value='',  # Update with your checkpoint path
    placeholder='Path to checkpoint file (.ckpt)',
    description='Checkpoint:',
    style={'description_width': 'initial'}
)

load_button = widgets.Button(description="Load Model", button_style='primary')
output_widget = widgets.Output()

def on_load_clicked(b):
    with output_widget:
        clear_output()
        global model, model_type
        model_type = model_type_widget.value
        checkpoint = checkpoint_path_widget.value if checkpoint_path_widget.value.strip() else None
        model = load_conditional_model(checkpoint, model_type)

load_button.on_click(on_load_clicked)

# Display widgets
display(widgets.VBox([
    widgets.HBox([model_type_widget, load_button]),
    checkpoint_path_widget,
    output_widget
]))

# Initialize with default model
model_type = 'disentangled'
model = load_conditional_model(None, model_type)

## 4. Load Multi-Modal Dataset

In [None]:
import numpy as np
from torch.utils.data import DataLoader
from datasets.multimodal_dataset import MultiModalDataset
import ipywidgets as widgets
from IPython.display import display, clear_output

# Available datasets
DATASETS = {
    'chestmnist': 'data/chestmnist.npz',
    'pathmnist': 'data/pathmnist.npz',
    'octmnist': 'data/octmnist.npz',
    'pneumoniamnist': 'data/pneumoniamnist.npz',
    'dermamnist': 'data/dermamnist.npz'
}

# Data loading widgets
dataset_selection = widgets.SelectMultiple(
    options=list(DATASETS.keys()),
    value=['chestmnist', 'pathmnist'],  # Default selection
    description='Datasets:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='300px', height='120px')
)

batch_size_widget = widgets.IntSlider(
    value=32,
    min=8,
    max=128,
    step=8,
    description='Batch Size:',
    style={'description_width': 'initial'}
)

num_samples_widget = widgets.IntSlider(
    value=100,
    min=50,
    max=500,
    step=50,
    description='Samples per Dataset:',
    style={'description_width': 'initial'}
)

load_data_button = widgets.Button(
    description='Load Selected Datasets',
    button_style='primary'
)

# Output area
data_output = widgets.Output()

def load_multimodal_data(b=None):
    with data_output:
        clear_output(wait=True)
        try:
            selected_datasets = list(dataset_selection.value)
            if not selected_datasets:
                print("Please select at least one dataset.")
                return
            
            print(f"Loading datasets: {selected_datasets}")
            
            # Load all selected datasets
            all_data = {}
            for dataset_name in selected_datasets:
                print(f"Loading {dataset_name}...")
                data_path = DATASETS[dataset_name]
                data = np.load(data_path)
                
                # Extract train/test data
                train_images = data['train_images']
                train_labels = data['train_labels']
                test_images = data['test_images']
                test_labels = data['test_labels']
                
                all_data[dataset_name] = {
                    'train_images': train_images,
                    'train_labels': train_labels,
                    'test_images': test_images,
                    'test_labels': test_labels,
                    'num_classes': len(np.unique(train_labels))
                }
                
                print(f"  - Train: {train_images.shape}, Test: {test_images.shape}")
                print(f"  - Classes: {len(np.unique(train_labels))}")
            
            # Create combined dataset
            global multimodal_dataset, dataloader, dataset_info
            multimodal_dataset = MultiModalDataset(all_data, num_samples=num_samples_widget.value)
            dataloader = DataLoader(multimodal_dataset, batch_size=batch_size_widget.value, shuffle=True)
            
            dataset_info = {
                'modalities': selected_datasets,
                'modality_to_idx': {mod: idx for idx, mod in enumerate(selected_datasets)},
                'total_samples': len(multimodal_dataset),
                'batch_size': batch_size_widget.value
            }
            
            print(f"\n✅ Successfully loaded multi-modal dataset!")
            print(f"Total samples: {len(multimodal_dataset)}")
            print(f"Modalities: {selected_datasets}")
            print(f"Batch size: {batch_size_widget.value}")
            
        except Exception as e:
            print(f"❌ Error loading data: {e}")
            import traceback
            traceback.print_exc()

load_data_button.on_click(load_multimodal_data)

# Display widgets
print("Select datasets and configure data loading:")
display(widgets.VBox([
    dataset_selection,
    batch_size_widget,
    num_samples_widget,
    load_data_button,
    data_output
]))

## 5. Interactive Multi-Modal Reconstruction

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

def reconstruct_samples():
    """Interactive reconstruction interface"""
    
    # Sample selection widgets
    modality_dropdown = widgets.Dropdown(
        options=[],
        description='Modality:',
        style={'description_width': 'initial'}
    )
    
    sample_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=99,
        description='Sample Index:',
        style={'description_width': 'initial'}
    )
    
    reconstruct_button = widgets.Button(
        description='Reconstruct Sample',
        button_style='info'
    )
    
    reconstruction_output = widgets.Output()
    
    def update_modality_options():
        if 'dataset_info' in globals():
            modality_dropdown.options = dataset_info['modalities']
            if dataset_info['modalities']:
                modality_dropdown.value = dataset_info['modalities'][0]
    
    def perform_reconstruction(b=None):
        with reconstruction_output:
            clear_output(wait=True)
            
            if 'conditional_model' not in globals():
                print("❌ Please load a model first!")
                return
            
            if 'multimodal_dataset' not in globals():
                print("❌ Please load dataset first!")
                return
            
            try:
                # Get selected sample
                modality_name = modality_dropdown.value
                sample_idx = sample_slider.value
                modality_idx = dataset_info['modality_to_idx'][modality_name]
                
                # Get sample from dataset
                sample_data = multimodal_dataset[sample_idx]
                image = sample_data['image'].unsqueeze(0).to(device)  # Add batch dimension
                modality_condition = sample_data['modality'].unsqueeze(0).to(device)
                class_condition = sample_data['class'].unsqueeze(0).to(device)
                
                # Perform reconstruction
                conditional_model.eval()
                with torch.no_grad():
                    if hasattr(conditional_model, 'encode'):
                        # Get latent representation
                        mu, log_var = conditional_model.encode(image, modality_condition, class_condition)
                        z = conditional_model.reparameterize(mu, log_var)
                        
                        # Reconstruct
                        reconstruction = conditional_model.decode(z, modality_condition, class_condition)
                    else:
                        # Use forward pass
                        outputs = conditional_model(image, modality_condition, class_condition)
                        reconstruction = outputs['reconstruction']
                
                # Convert to numpy for visualization
                original_img = image.cpu().squeeze().numpy()
                reconstructed_img = reconstruction.cpu().squeeze().numpy()
                
                # Handle different channel configurations
                if original_img.ndim == 3 and original_img.shape[0] in [1, 3]:
                    original_img = np.transpose(original_img, (1, 2, 0))
                    reconstructed_img = np.transpose(reconstructed_img, (1, 2, 0))
                
                if original_img.shape[-1] == 1:
                    original_img = original_img.squeeze(-1)
                    reconstructed_img = reconstructed_img.squeeze(-1)
                
                # Create visualization
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))
                
                # Original image
                axes[0].imshow(original_img, cmap='gray' if len(original_img.shape) == 2 else None)
                axes[0].set_title(f'Original\\n{modality_name} - Class {class_condition.item()}')
                axes[0].axis('off')
                
                # Reconstructed image
                axes[1].imshow(reconstructed_img, cmap='gray' if len(reconstructed_img.shape) == 2 else None)
                axes[1].set_title(f'Reconstructed\\n{modality_name} - Class {class_condition.item()}')
                axes[1].axis('off')
                
                # Difference
                diff = np.abs(original_img - reconstructed_img)
                im = axes[2].imshow(diff, cmap='hot')
                axes[2].set_title('Absolute Difference')
                axes[2].axis('off')
                plt.colorbar(im, ax=axes[2])
                
                plt.tight_layout()
                plt.show()
                
                # Calculate metrics
                mse = F.mse_loss(reconstruction, image).item()
                print(f"\\n📊 Reconstruction Metrics:")
                print(f"MSE Loss: {mse:.6f}")
                print(f"Modality: {modality_name} (index: {modality_idx})")
                print(f"Class: {class_condition.item()}")
                
            except Exception as e:
                print(f"❌ Error during reconstruction: {e}")
                import traceback
                traceback.print_exc()
    
    reconstruct_button.on_click(perform_reconstruction)
    
    # Update options when called
    update_modality_options()
    
    # Display interface
    print("Select a sample to reconstruct:")
    display(widgets.VBox([
        modality_dropdown,
        sample_slider,
        reconstruct_button,
        reconstruction_output
    ]))

# Call the function to create the interface
reconstruct_samples()

## 6. Conditional Generation Interface

In [None]:
def conditional_generation_interface():
    """Interactive conditional generation interface"""
    
    # Generation parameter widgets
    gen_modality_dropdown = widgets.Dropdown(
        options=[],
        description='Target Modality:',
        style={'description_width': 'initial'}
    )
    
    gen_class_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=1,
        description='Target Class:',
        style={'description_width': 'initial'}
    )
    
    num_samples_gen = widgets.IntSlider(
        value=8,
        min=1,
        max=16,
        description='Number of Samples:',
        style={'description_width': 'initial'}
    )
    
    latent_dim_slider = widgets.IntSlider(
        value=128,
        min=64,
        max=512,
        step=64,
        description='Latent Dimension:',
        style={'description_width': 'initial'}
    )
    
    temperature_slider = widgets.FloatSlider(
        value=1.0,
        min=0.1,
        max=2.0,
        step=0.1,
        description='Temperature:',
        style={'description_width': 'initial'}
    )
    
    generate_button = widgets.Button(
        description='Generate Samples',
        button_style='success'
    )
    
    generation_output = widgets.Output()
    
    def update_generation_options():
        if 'dataset_info' in globals():
            gen_modality_dropdown.options = dataset_info['modalities']
            if dataset_info['modalities']:
                gen_modality_dropdown.value = dataset_info['modalities'][0]
                
                # Update class range based on dataset
                modality_name = gen_modality_dropdown.value
                if 'multimodal_dataset' in globals():
                    # Get max class from the dataset
                    max_class = 0
                    for sample_idx in range(min(100, len(multimodal_dataset))):
                        sample = multimodal_dataset[sample_idx]
                        max_class = max(max_class, sample['class'].item())
                    gen_class_slider.max = max_class
    
    def generate_conditional_samples(b=None):
        with generation_output:
            clear_output(wait=True)
            
            if 'conditional_model' not in globals():
                print("❌ Please load a model first!")
                return
            
            try:
                modality_name = gen_modality_dropdown.value
                target_class = gen_class_slider.value
                n_samples = num_samples_gen.value
                latent_size = latent_dim_slider.value
                temperature = temperature_slider.value
                
                modality_idx = dataset_info['modality_to_idx'][modality_name]
                
                print(f"🎨 Generating {n_samples} samples...")
                print(f"Modality: {modality_name} (index: {modality_idx})")
                print(f"Class: {target_class}")
                print(f"Temperature: {temperature}")
                
                conditional_model.eval()
                with torch.no_grad():
                    # Sample from prior
                    z = torch.randn(n_samples, latent_size).to(device) * temperature
                    
                    # Create condition tensors
                    modality_condition = torch.full((n_samples,), modality_idx).to(device)
                    class_condition = torch.full((n_samples,), target_class).to(device)
                    
                    # Generate samples
                    if hasattr(conditional_model, 'decode'):
                        generated_samples = conditional_model.decode(z, modality_condition, class_condition)
                    else:
                        # Use the model's generate method if available
                        generated_samples = conditional_model.generate(z, modality_condition, class_condition)
                    
                    # Convert to numpy
                    samples = generated_samples.cpu().numpy()
                    
                    # Create visualization grid
                    cols = min(4, n_samples)
                    rows = (n_samples + cols - 1) // cols
                    
                    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
                    if rows == 1:
                        axes = axes.reshape(1, -1)
                    if cols == 1:
                        axes = axes.reshape(-1, 1)
                    
                    for i in range(n_samples):
                        row = i // cols
                        col = i % cols
                        
                        # Handle different channel configurations
                        img = samples[i]
                        if img.ndim == 3 and img.shape[0] in [1, 3]:
                            img = np.transpose(img, (1, 2, 0))
                        if img.shape[-1] == 1:
                            img = img.squeeze(-1)
                        
                        axes[row, col].imshow(img, cmap='gray' if len(img.shape) == 2 else None)
                        axes[row, col].set_title(f'{modality_name}\\nClass {target_class}')
                        axes[row, col].axis('off')
                    
                    # Hide empty subplots
                    for i in range(n_samples, rows * cols):
                        row = i // cols
                        col = i % cols
                        axes[row, col].axis('off')
                    
                    plt.tight_layout()
                    plt.show()
                    
                    print(f"\\n✅ Generated {n_samples} samples successfully!")
                    
            except Exception as e:
                print(f"❌ Error during generation: {e}")
                import traceback
                traceback.print_exc()
    
    generate_button.on_click(generate_conditional_samples)
    
    # Update options when called
    update_generation_options()
    
    # Display interface
    print("Configure conditional generation parameters:")
    display(widgets.VBox([
        gen_modality_dropdown,
        gen_class_slider,
        num_samples_gen,
        latent_dim_slider,
        temperature_slider,
        generate_button,
        generation_output
    ]))

# Call the function to create the interface
conditional_generation_interface()

## 7. Disentangled Latent Space Exploration

In [None]:
def latent_space_exploration():
    """Explore disentangled latent space properties"""
    
    # Exploration widgets
    exploration_type = widgets.Dropdown(
        options=['Latent Interpolation', 'Dimension Analysis', 'Cross-Modal Transfer'],
        value='Latent Interpolation',
        description='Exploration Type:',
        style={'description_width': 'initial'}
    )
    
    source_modality = widgets.Dropdown(
        options=[],
        description='Source Modality:',
        style={'description_width': 'initial'}
    )
    
    target_modality = widgets.Dropdown(
        options=[],
        description='Target Modality:',
        style={'description_width': 'initial'}
    )
    
    source_class = widgets.IntSlider(
        value=0,
        min=0,
        max=1,
        description='Source Class:',
        style={'description_width': 'initial'}
    )
    
    target_class = widgets.IntSlider(
        value=0,
        min=0,
        max=1,
        description='Target Class:',
        style={'description_width': 'initial'}
    )
    
    num_steps = widgets.IntSlider(
        value=10,
        min=5,
        max=20,
        description='Interpolation Steps:',
        style={'description_width': 'initial'}
    )
    
    latent_dim_to_vary = widgets.IntSlider(
        value=0,
        min=0,
        max=127,
        description='Latent Dimension:',
        style={'description_width': 'initial'}
    )
    
    variation_range = widgets.FloatSlider(
        value=3.0,
        min=1.0,
        max=5.0,
        step=0.5,
        description='Variation Range:',
        style={'description_width': 'initial'}
    )
    
    explore_button = widgets.Button(
        description='Explore Latent Space',
        button_style='warning'
    )
    
    exploration_output = widgets.Output()
    
    def update_exploration_options():
        if 'dataset_info' in globals():
            modalities = dataset_info['modalities']
            source_modality.options = modalities
            target_modality.options = modalities
            if modalities:
                source_modality.value = modalities[0]
                target_modality.value = modalities[-1] if len(modalities) > 1 else modalities[0]
    
    def perform_exploration(b=None):
        with exploration_output:
            clear_output(wait=True)
            
            if 'conditional_model' not in globals():
                print("❌ Please load a model first!")
                return
            
            if 'multimodal_dataset' not in globals():
                print("❌ Please load dataset first!")
                return
            
            try:
                exploration_mode = exploration_type.value
                print(f"🔍 {exploration_mode}...")
                
                conditional_model.eval()
                
                if exploration_mode == 'Latent Interpolation':
                    # Interpolate between two samples
                    source_mod_idx = dataset_info['modality_to_idx'][source_modality.value]
                    target_mod_idx = dataset_info['modality_to_idx'][target_modality.value]
                    
                    # Get random samples from each modality/class
                    with torch.no_grad():
                        # Create condition tensors
                        source_mod_tensor = torch.tensor([source_mod_idx]).to(device)
                        source_class_tensor = torch.tensor([source_class.value]).to(device)
                        target_mod_tensor = torch.tensor([target_mod_idx]).to(device)
                        target_class_tensor = torch.tensor([target_class.value]).to(device)
                        
                        # Sample latent codes
                        latent_size = latent_dim_slider.value if 'latent_dim_slider' in globals() else 128
                        z1 = torch.randn(1, latent_size).to(device)
                        z2 = torch.randn(1, latent_size).to(device)
                        
                        # Interpolation
                        steps = num_steps.value
                        interpolations = []
                        
                        for i in range(steps):
                            alpha = i / (steps - 1)
                            z_interp = (1 - alpha) * z1 + alpha * z2
                            
                            # Generate with source modality first, then target
                            if i < steps // 2:
                                gen_sample = conditional_model.decode(z_interp, source_mod_tensor, source_class_tensor)
                                mod_name = source_modality.value
                            else:
                                gen_sample = conditional_model.decode(z_interp, target_mod_tensor, target_class_tensor)
                                mod_name = target_modality.value
                            
                            interpolations.append((gen_sample.cpu().numpy(), mod_name))
                        
                        # Visualize interpolation
                        fig, axes = plt.subplots(2, steps//2, figsize=(steps*2, 6))
                        
                        for i, (sample, mod_name) in enumerate(interpolations):
                            row = 0 if i < steps // 2 else 1
                            col = i % (steps // 2)
                            
                            img = sample.squeeze()
                            if img.ndim == 3 and img.shape[0] in [1, 3]:
                                img = np.transpose(img, (1, 2, 0))
                            if img.shape[-1] == 1:
                                img = img.squeeze(-1)
                            
                            axes[row, col].imshow(img, cmap='gray' if len(img.shape) == 2 else None)
                            axes[row, col].set_title(f'{mod_name}\\nStep {i+1}')
                            axes[row, col].axis('off')
                        
                        plt.tight_layout()
                        plt.show()
                
                elif exploration_mode == 'Dimension Analysis':
                    # Vary specific latent dimensions
                    mod_idx = dataset_info['modality_to_idx'][source_modality.value]
                    class_idx = source_class.value
                    
                    with torch.no_grad():
                        base_z = torch.randn(1, latent_dim_slider.value if 'latent_dim_slider' in globals() else 128).to(device)
                        mod_tensor = torch.tensor([mod_idx]).to(device)
                        class_tensor = torch.tensor([class_idx]).to(device)
                        
                        # Vary the selected dimension
                        dim_to_vary = latent_dim_to_vary.value
                        var_range = variation_range.value
                        variations = np.linspace(-var_range, var_range, num_steps.value)
                        
                        samples = []
                        for var_val in variations:
                            z_varied = base_z.clone()
                            z_varied[0, dim_to_vary] = var_val
                            
                            sample = conditional_model.decode(z_varied, mod_tensor, class_tensor)
                            samples.append(sample.cpu().numpy())
                        
                        # Visualize variations
                        fig, axes = plt.subplots(1, len(samples), figsize=(len(samples)*2, 3))
                        if len(samples) == 1:
                            axes = [axes]
                        
                        for i, sample in enumerate(samples):
                            img = sample.squeeze()
                            if img.ndim == 3 and img.shape[0] in [1, 3]:
                                img = np.transpose(img, (1, 2, 0))
                            if img.shape[-1] == 1:
                                img = img.squeeze(-1)
                            
                            axes[i].imshow(img, cmap='gray' if len(img.shape) == 2 else None)
                            axes[i].set_title(f'Dim {dim_to_vary}\\nVal: {variations[i]:.2f}')
                            axes[i].axis('off')
                        
                        plt.tight_layout()
                        plt.show()
                
                elif exploration_mode == 'Cross-Modal Transfer':
                    # Transfer latent code between modalities
                    source_mod_idx = dataset_info['modality_to_idx'][source_modality.value]
                    target_mod_idx = dataset_info['modality_to_idx'][target_modality.value]
                    
                    with torch.no_grad():
                        # Sample in source modality
                        z = torch.randn(1, latent_dim_slider.value if 'latent_dim_slider' in globals() else 128).to(device)
                        source_mod_tensor = torch.tensor([source_mod_idx]).to(device)
                        target_mod_tensor = torch.tensor([target_mod_idx]).to(device)
                        source_class_tensor = torch.tensor([source_class.value]).to(device)
                        target_class_tensor = torch.tensor([target_class.value]).to(device)
                        
                        # Generate in both modalities
                        source_sample = conditional_model.decode(z, source_mod_tensor, source_class_tensor)
                        target_sample = conditional_model.decode(z, target_mod_tensor, target_class_tensor)
                        
                        # Visualize transfer
                        fig, axes = plt.subplots(1, 2, figsize=(8, 4))
                        
                        for i, (sample, mod_name) in enumerate([(source_sample, source_modality.value), 
                                                              (target_sample, target_modality.value)]):
                            img = sample.cpu().numpy().squeeze()
                            if img.ndim == 3 and img.shape[0] in [1, 3]:
                                img = np.transpose(img, (1, 2, 0))
                            if img.shape[-1] == 1:
                                img = img.squeeze(-1)
                            
                            axes[i].imshow(img, cmap='gray' if len(img.shape) == 2 else None)
                            axes[i].set_title(f'{mod_name}\\nSame Latent Code')
                            axes[i].axis('off')
                        
                        plt.tight_layout()
                        plt.show()
                
                print(f"\\n✅ {exploration_mode} completed!")
                
            except Exception as e:
                print(f"❌ Error during exploration: {e}")
                import traceback
                traceback.print_exc()
    
    explore_button.on_click(perform_exploration)
    
    # Update options when called
    update_exploration_options()
    
    # Create dynamic interface based on exploration type
    def update_interface(*args):
        mode = exploration_type.value
        if mode == 'Latent Interpolation':
            interface_widgets = [exploration_type, source_modality, target_modality, 
                               source_class, target_class, num_steps]
        elif mode == 'Dimension Analysis':
            interface_widgets = [exploration_type, source_modality, source_class, 
                               latent_dim_to_vary, variation_range, num_steps]
        else:  # Cross-Modal Transfer
            interface_widgets = [exploration_type, source_modality, target_modality, 
                               source_class, target_class]
        
        interface_widgets.extend([explore_button, exploration_output])
        
        # Clear and redisplay
        exploration_output.clear_output()
        with exploration_output:
            print("Configure latent space exploration:")
            display(widgets.VBox(interface_widgets))
    
    exploration_type.observe(update_interface, names='value')
    
    # Initial display
    print("Configure latent space exploration:")
    display(widgets.VBox([
        exploration_type,
        source_modality,
        target_modality,
        source_class,
        target_class,
        num_steps,
        explore_button,
        exploration_output
    ]))

# Call the function to create the interface
latent_space_exploration()

## 8. Model Analysis & Comparison

This notebook provides comprehensive tools for exploring conditional and disentangled VAE models:

### Features:
- **Multi-Modal Data Loading**: Load and combine multiple MedMNIST datasets
- **Interactive Reconstruction**: Compare original vs reconstructed images across modalities
- **Conditional Generation**: Generate samples conditioned on modality and class
- **Latent Space Exploration**: Interpolate, analyze dimensions, and transfer between modalities
- **Disentanglement Analysis**: Explore how the latent space separates modalities and classes

### Usage Tips:
1. Start by loading a trained model using the model loading interface
2. Load your desired datasets (recommend starting with 2-3 modalities)
3. Use reconstruction to verify model quality
4. Experiment with conditional generation across different modalities
5. Explore latent space properties to understand disentanglement

### Next Steps:
- Compare with the vanilla VAE notebook for single-modality behavior
- Experiment with different temperature settings for generation
- Analyze cross-modal relationships in the latent space
- Use dimension analysis to understand which latent dimensions control specific features