# Model Training for Lifespan Prediction

This notebook demonstrates how to:
1. Load configuration from YAML
2. Create datasets and dataloaders
3. Initialize model and trainer
4. Run training loop
5. Visualize results

This uses the refactored `lifespan_predictor` package for clean, modular training.

## 1. Setup and Imports

In [None]:
import os
import sys
import torch
import numpy as np
import pandas as pd
from pathlib import Path

# Import from the refactored package
from lifespan_predictor.config import Config
from lifespan_predictor.data.dataset import LifespanDataset
from lifespan_predictor.models.predictor import LifespanPredictor
from lifespan_predictor.training.trainer import Trainer
from lifespan_predictor.training.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler
from lifespan_predictor.training.metrics import AUC, Accuracy, F1Score, RMSE, MAE, R2Score
from lifespan_predictor.utils.logging import setup_logger
from lifespan_predictor.utils.visualization import plot_training_curves, plot_predictions

from torch_geometric.loader import DataLoader

# Setup logging
logger = setup_logger("training", level="INFO")
logger.info("Starting model training notebook")

## 2. Load Configuration

Load and validate configuration from YAML file.

In [None]:
# Load configuration
config_path = "../lifespan_predictor/config/default_config.yaml"
config = Config.from_yaml(config_path)

# Validate configuration
config.validate()

# Display key parameters
logger.info(f"Configuration loaded from: {config_path}")
logger.info(f"Task: {config.training.task}")
logger.info(f"Batch size: {config.training.batch_size}")
logger.info(f"Max epochs: {config.training.max_epochs}")
logger.info(f"Learning rate: {config.training.learning_rate}")
logger.info(f"Device: {config.device.use_cuda and 'cuda' or 'cpu'}")

# Set device
device = torch.device('cuda' if config.device.use_cuda and torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {device}")

## 3. Load Preprocessed Data

Load the data that was preprocessed in notebook 01.

In [None]:
# Load training data
train_data_dir = os.path.join(config.data.output_dir, "train")

logger.info(f"Loading training data from: {train_data_dir}")

# Load CSV for SMILES
train_df = pd.read_csv(os.path.join(train_data_dir, "processed_data.csv"))
train_smiles = train_df[config.data.smiles_column].tolist()

# Load preprocessed features
train_adj = np.load(os.path.join(train_data_dir, "adj.npy"))
train_features = np.load(os.path.join(train_data_dir, "features.npy"))
train_labels = np.load(os.path.join(train_data_dir, "labels.npy"))
train_fp_hashed = np.load(os.path.join(train_data_dir, "fp_hashed.npy"))
train_fp_nonhashed = np.load(os.path.join(train_data_dir, "fp_nonhashed.npy"))

logger.info(f"Loaded {len(train_smiles)} training molecules")
logger.info(f"Features shape: adj={train_adj.shape}, features={train_features.shape}")
logger.info(f"Fingerprints shape: hashed={train_fp_hashed.shape}, non-hashed={train_fp_nonhashed.shape}")
logger.info(f"Labels shape: {train_labels.shape}")

## 4. Create Datasets and DataLoaders

Create PyTorch Geometric datasets and dataloaders for training and validation.

In [None]:
# Prepare graph features tuple
graph_features = (train_adj, train_features)
fingerprints = (train_fp_hashed, train_fp_nonhashed)

# Create dataset
logger.info("Creating PyTorch Geometric dataset...")
full_dataset = LifespanDataset(
    smiles_list=train_smiles,
    graph_features=graph_features,
    fingerprints=fingerprints,
    labels=train_labels
)

logger.info(f"Dataset created with {len(full_dataset)} samples")

In [None]:
# Split into train and validation
from sklearn.model_selection import train_test_split

train_indices, val_indices = train_test_split(
    range(len(full_dataset)),
    test_size=config.training.val_split,
    random_state=config.random_seed,
    stratify=train_labels if config.training.stratify and config.training.task == "classification" else None
)

train_dataset = full_dataset[train_indices]
val_dataset = full_dataset[val_indices]

logger.info(f"Split: {len(train_dataset)} training, {len(val_dataset)} validation")

In [None]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.training.batch_size,
    shuffle=True,
    num_workers=0  # Set to 0 for Windows compatibility
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.training.batch_size,
    shuffle=False,
    num_workers=0
)

logger.info(f"Created dataloaders: {len(train_loader)} train batches, {len(val_loader)} val batches")

## 5. Initialize Model

Create the LifespanPredictor model with the specified configuration.

In [None]:
# Initialize model
logger.info("Initializing model...")
model = LifespanPredictor(config)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

logger.info(f"Model initialized with {total_params:,} total parameters")
logger.info(f"Trainable parameters: {trainable_params:,}")

# Display model architecture
print("\nModel Architecture:")
print(model)

## 6. Setup Training Components

Configure optimizer, callbacks, and metrics.

In [None]:
# Setup optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=config.training.learning_rate,
    weight_decay=config.training.weight_decay
)

# Setup learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max' if config.training.task == 'classification' else 'min',
    factor=0.5,
    patience=5,
    verbose=True
)

logger.info("Optimizer and scheduler configured")

In [None]:
# Setup callbacks
callbacks = [
    EarlyStopping(
        patience=config.training.patience,
        metric_name=config.training.main_metric,
        mode='max' if config.training.task == 'classification' else 'min'
    ),
    ModelCheckpoint(
        save_dir=config.data.output_dir,
        metric_name=config.training.main_metric,
        mode='max' if config.training.task == 'classification' else 'min'
    ),
    LearningRateScheduler(scheduler)
]

logger.info(f"Configured {len(callbacks)} callbacks")

## 7. Initialize Trainer and Start Training

Create the Trainer object and run the training loop.

In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    config=config,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    callbacks=callbacks,
    device=device
)

logger.info("Trainer initialized")

In [None]:
# Start training
logger.info("Starting training...")
print("\n" + "="*60)
print("TRAINING STARTED")
print("="*60 + "\n")

history = trainer.train()

print("\n" + "="*60)
print("TRAINING COMPLETED")
print("="*60)

## 8. Visualize Training Results

Plot training curves and analyze model performance.

In [None]:
# Plot training curves
plot_save_path = os.path.join(config.data.output_dir, "training_curves.png")
plot_training_curves(history, save_path=plot_save_path)

logger.info(f"Training curves saved to: {plot_save_path}")

In [None]:
# Display training summary
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"\nTotal epochs: {len(history['train_loss'])}")
print(f"Best epoch: {history.get('best_epoch', 'N/A')}")
print(f"\nFinal Training Loss: {history['train_loss'][-1]:.4f}")
print(f"Final Validation Loss: {history['val_loss'][-1]:.4f}")

if config.training.task == 'classification':
    print(f"\nBest Validation AUC: {max(history.get('val_AUC', [0])):.4f}")
    print(f"Best Validation Accuracy: {max(history.get('val_Accuracy', [0])):.4f}")
    print(f"Best Validation F1: {max(history.get('val_F1', [0])):.4f}")
else:
    print(f"\nBest Validation RMSE: {min(history.get('val_RMSE', [float('inf')])):.4f}")
    print(f"Best Validation MAE: {min(history.get('val_MAE', [float('inf')])):.4f}")
    print(f"Best Validation R2: {max(history.get('val_R2', [0])):.4f}")

print("\n" + "="*60)

## 9. Evaluate on Validation Set

Load the best model and evaluate on validation data.

In [None]:
# Load best model
best_model_path = os.path.join(config.data.output_dir, "best_model.pt")
if os.path.exists(best_model_path):
    logger.info(f"Loading best model from: {best_model_path}")
    checkpoint = torch.load(best_model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    logger.info("Best model loaded successfully")
else:
    logger.warning("Best model checkpoint not found, using current model")

In [None]:
# Make predictions on validation set
val_predictions = []
val_targets = []

model.eval()
with torch.no_grad():
    for batch in val_loader:
        batch = batch.to(device)
        outputs = model(batch)
        
        if config.training.task == 'classification':
            outputs = torch.sigmoid(outputs)
        
        val_predictions.append(outputs.cpu().numpy())
        val_targets.append(batch.y.cpu().numpy())

val_predictions = np.concatenate(val_predictions, axis=0)
val_targets = np.concatenate(val_targets, axis=0)

logger.info(f"Generated predictions for {len(val_predictions)} validation samples")

In [None]:
# Plot predictions
pred_plot_path = os.path.join(config.data.output_dir, "validation_predictions.png")
plot_predictions(
    y_true=val_targets,
    y_pred=val_predictions,
    save_path=pred_plot_path,
    task=config.training.task
)

logger.info(f"Prediction plot saved to: {pred_plot_path}")

## 10. Save Training Configuration and Results

In [None]:
# Save configuration used for training
config_save_path = os.path.join(config.data.output_dir, "training_config.yaml")
config.save(config_save_path)
logger.info(f"Configuration saved to: {config_save_path}")

# Save training history
history_df = pd.DataFrame(history)
history_save_path = os.path.join(config.data.output_dir, "training_history.csv")
history_df.to_csv(history_save_path, index=False)
logger.info(f"Training history saved to: {history_save_path}")

print("\n" + "="*60)
print("All results saved successfully!")
print(f"Output directory: {config.data.output_dir}")
print("="*60)