# c-VAE Model Training for MALDI Data

This notebook demonstrates the training process for the Conditional Variational Autoencoder (c-VAE) on MALDI mass spectrometry imaging data.

## Training Pipeline Overview:
1. Configure model hyperparameters
2. Load preprocessed data
3. Initialise model and optimiser
4. Train and save the model
5. Analyse the learned latent space
6. Save model's configuration and losses

In [1]:
# Import required libraries
import os
import torch
import numpy as np
from pathlib import Path
import logging

os.chdir('/home/pasco/sdsc_mlibra/JupyterNotebooks/cleaned/brain_lipid_cvae_pcorso') # Replace with the path where you git-cloned the repo

# Import our modules
from models.cvae import ConditionalVAE
from utils.dataloader import MALDIDataset, create_dataloader
from utils.visualisation import TrainingVisualizer, ReconstructionVisualizer

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Configuration

Set up model hyperparameters and training configuration.

In [2]:
# Model configuration
config = {
    'hidden_dim': 512,      # Dimension of hidden layers
    'latent_dim': 256,      # Dimension of latent space
    'beta': 0.1,            # Weight of KL divergence term
    'batch_size': 128,      # Batch size for training
    'learning_rate': 5e-4,  # Learning rate
    'weight_decay': 0.01,   # Weight decay for regularization
    'dropout_rate': 0.1,    # Dropout rate
    'epochs': 60,           # Number of training epochs
    'device': 'cpu'         # 'cpu', 'cuda' (if NVIDIA GPU available) or 'mps' if Apple Silicon CPU/GPU
}

logger.info(f"Using device: {config['device']}")

INFO:__main__:Using device: cpu


## 2. Load data

Load the preprocessed data and create data loaders.

In [3]:
# Set data paths
data_dir = Path('data/H5')
maldi_path = data_dir / 'maldi_processed.h5'
coords_path = data_dir / 'coords_spherical.h5'
ref_path = data_dir / 'reference_data.h5'

# Detect compute resources
device=config['device']
num_cpus = os.cpu_count()
has_gpu = torch.cuda.is_available()
print(f"Number of CPUs available and whether GPU computation is possible: ", num_cpus, has_gpu)
    
# Configure number of workers
if has_gpu and device != "cpu":
    # For GPU training, num_workers = num_cpus is often optimal
    num_workers = num_cpus
else:
    # For CPU training, leave one core free for system processes
    num_workers = max(1, num_cpus - 1)
    
# Create data loader
dataloader, scalers = create_dataloader(
    maldi_path=maldi_path,
    ccf_path=coords_path,
    ref_path=ref_path,
    batch_size=config['batch_size'],
    num_workers=num_workers,
    device=config['device'],
    downsample_factor=1,
    chunk_size=500
)

# Get data dimensions
sample_batch = next(iter(dataloader))
n_spatial_points = sample_batch[0].size(1)
coords_dim = sample_batch[1].size(1)

logger.info(f"Number of spatial points: {n_spatial_points}")
logger.info(f"Coordinate dimensions: {coords_dim}")

Number of CPUs available and whether GPU computation is possible:  20 True


INFO:root:Processed rows 0 to 500 of 2000
INFO:root:Processed rows 500 to 1000 of 2000
INFO:root:Processed rows 1000 to 1500 of 2000
INFO:root:Processed rows 1500 to 2000 of 2000
INFO:root:Number of points: 47968
INFO:root:Matrix shape: (47968,)
INFO:root:Matrix data type: float64
INFO:root:Dataset initialized with shapes:
INFO:root:MALDI shape: torch.Size([2000, 47968])
INFO:root:CCF shape: torch.Size([3, 47968])
INFO:root:Reference data shape: torch.Size([1, 47968])
INFO:__main__:Number of spatial points: 47968
INFO:__main__:Coordinate dimensions: 3


## 3. Initialise Model

Set up the c-VAE model and optimiser.

In [4]:
# Initialise model
model = ConditionalVAE(
    maldi_dim=n_spatial_points,
    ccf_dim=coords_dim,
    hidden_dim=config['hidden_dim'],
    latent_dim=config['latent_dim'],
    beta=config['beta'],
    device=config['device']
).to(config['device'])

# Initialise optimizer with gradient clipping
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['learning_rate'],
    weight_decay=config['weight_decay'],
    betas=(0.9, 0.999)
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=config['learning_rate'],
    epochs=config['epochs'],
    steps_per_epoch=len(dataloader),
    pct_start=0.1,  # Warm-up period
    anneal_strategy='cos'
)

# Apply gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

tensor(0.)

## 4. Training

Train the model while monitoring progress.

In [5]:
# Train model
metrics, avg_loss, run_id = model.train_model(
    dataloader=dataloader,
    optimizer=optimizer,
    epochs=config['epochs'],
    print_every=50,
    use_mixed_precision=True,
    scheduler=scheduler,
    save_dir='checkpoints'
)

logger.info(f"Training completed. Run ID: {run_id}")

  scaler = GradScaler() if use_mixed_precision and torch.cuda.is_available() else None



Mixed precision training disabled (not supported on CPU)


Epoch 1/60: 100%|█| 15/15 [00:25<00:00,  1.72s/it, loss=0.2356, recon=0.2277, kl
Epoch 2/60: 100%|█| 15/15 [00:26<00:00,  1.80s/it, loss=0.7003, recon=0.1261, kl
Epoch 3/60: 100%|█| 15/15 [00:27<00:00,  1.80s/it, loss=1.6896, recon=0.0022, kl
Epoch 4/60: 100%|█| 15/15 [00:27<00:00,  1.81s/it, loss=3.6516, recon=0.0024, kl
Epoch 5/60: 100%|█| 15/15 [00:27<00:00,  1.82s/it, loss=6.2283, recon=0.0008, kl
Epoch 6/60: 100%|█| 15/15 [00:27<00:00,  1.82s/it, loss=8.9549, recon=0.0003, kl
Epoch 7/60: 100%|█| 15/15 [00:27<00:00,  1.81s/it, loss=11.8293, recon=0.0001, k
Epoch 8/60: 100%|█| 15/15 [00:27<00:00,  1.82s/it, loss=14.6933, recon=0.0000, k
Epoch 9/60: 100%|█| 15/15 [00:27<00:00,  1.82s/it, loss=17.4305, recon=0.0000, k
Epoch 10/60: 100%|█| 15/15 [00:27<00:00,  1.83s/it, loss=19.9488, recon=0.0000, 
Epoch 11/60: 100%|█| 15/15 [00:27<00:00,  1.85s/it, loss=22.1645, recon=0.0000, 
Epoch 12/60: 100%|█| 15/15 [00:27<00:00,  1.84s/it, loss=24.0343, recon=0.0000, 
Epoch 13/60: 100%|█| 15/15 [

## 5. Latent Space Analysis

Analyse the learned latent space representation.

In [None]:
from utils.visualisation import LatentSpaceVisualizer

# Initialize latent space visualizer
latent_visualizer = LatentSpaceVisualizer()

# Encode test data
with torch.no_grad():
    z_mu, z_logvar = model.encode(test_maldi, test_coords)

# Plot latent space
latent_visualizer.plot_latent_space_2d(
    z=z_mu,
    save_path='evaluation/latent_space.png'
)

# Plot latent traversal
latent_visualizer.plot_latent_traversal(
    decoder=model.decoder,
    coordinates=test_coords[0:1],
    dim=0,  # First latent dimension
    save_path='evaluation/latent_traversal.png'
)

## 6. Save configuration and metrics info

Save the information on the training parameters and losses

In [6]:
import json

# Saving all dicts to the same file
dict_save_dir = Path('checkpoints',run_id)

# Create directory if it doesn't exist
os.makedirs(dict_save_dir, exist_ok=True)
file_path = dict_save_dir / 'info_training.txt'

with open(file_path, 'w') as f:
    json.dump(metrics, f)
    f.write('\n')
    json.dump(config, f)
    f.write('\n')

runID_path = Path('training_plots')
os.makedirs(runID_path, exist_ok=True)
runID_filepath = runID_path / 'runID.txt'
with open(runID_filepath, 'w') as f:
    json.dump(run_id, f)
    f.write('\n')

logger.info(f"Dictionaries saved to {file_path} and {runID_path}")

INFO:__main__:Dictionaries saved to checkpoints/CVAE_b128_e60_z256_20250130_124426/info_training.txt and training_plots


## Summary

Training steps completed:
1. ✓ Configure model hyperparameters
2. ✓ Load preprocessed data
3. ✓ Initialise model and optimiser
4. ✓ Train and save the model
5. ✓ Analyse the learned latent space
6. ✓ Save model's configuration and losses

The trained model is now ready for inference and analysis.