In [None]:
"""
Example Notebook 2: Model Training for MTL-GNN-DTA
This notebook demonstrates how to train the multi-task learning model
"""

# %% [markdown]
# # Model Training for MTL-GNN-DTA
# 
# This notebook demonstrates:
# 1. Loading prepared data
# 2. Initializing the MTL-DTA model
# 3. Training with multi-task learning
# 4. Monitoring training progress
# 5. Saving trained models

# %% [markdown]
# ## 1. Setup and Imports

# %%
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path
sys.path.append('../../')

# Import MTL-GNN-DTA modules
from mtl_gnn_dta import (
    Config, 
    MTL_DTAModel,
    Trainer,
    create_data_loaders
)
from mtl_gnn_dta.data import MTL_DTA
from mtl_gnn_dta.features import ProteinFeaturizer, DrugFeaturizer
from mtl_gnn_dta.models import MaskedMSELoss
from mtl_gnn_dta.training import EarlyStopping
from mtl_gnn_dta.utils import setup_logging, plot_training_history

import torch
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import json
from tqdm import tqdm

# Setup
setup_logging()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# %% [markdown]
# ## 2. Load Configuration and Data

# %%
# Load configuration
config = Config()
print("Configuration loaded")

# Load prepared data
data_dir = Path(config.data.processed_dir)
train_data = pd.read_parquet(data_dir / 'train_data.parquet')
val_data = pd.read_parquet(data_dir / 'val_data.parquet')
test_data = pd.read_parquet(data_dir / 'test_data.parquet')

# Load task ranges
with open(data_dir / 'task_ranges.json', 'r') as f:
    task_ranges = json.load(f)

print(f"Data loaded:")
print(f"  Train: {len(train_data)} samples")
print(f"  Validation: {len(val_data)} samples")
print(f"  Test: {len(test_data)} samples")
print(f"  Tasks: {list(task_ranges.keys())}")

# %% [markdown]
# ## 3. Create Data Loaders

# %%
# For demonstration, we'll create mock data loaders
# In practice, these would load actual protein structures and drug molecules

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader

# Create mock data for demonstration
def create_mock_data(df, task_cols):
    """Create mock graph data for demonstration"""
    data_list = []
    
    for idx, row in df.iterrows():
        # Mock protein graph
        protein_data = Data(
            x=torch.randn(100, 1280),  # 100 residues, ESM-2 embedding dim
            edge_index=torch.randint(0, 100, (2, 300))  # Random edges
        )
        
        # Mock drug graph
        drug_data = Data(
            x=torch.randn(20, 66),  # 20 atoms, feature dim
            edge_index=torch.randint(0, 20, (2, 40)),  # Random edges
            edge_attr=torch.randn(40, 6)  # Edge features
        )
        
        # Target values
        y = torch.zeros(len(task_cols))
        for i, task in enumerate(task_cols):
            if task in row and not pd.isna(row[task]):
                y[i] = float(row[task])
            else:
                y[i] = float('nan')
        
        data_list.append({
            'protein': protein_data,
            'drug': drug_data,
            'y': y
        })
    
    return data_list

# Create datasets
task_cols = config.model.task_names
train_dataset = create_mock_data(train_data.head(100), task_cols)  # Use subset for demo
val_dataset = create_mock_data(val_data.head(20), task_cols)
test_dataset = create_mock_data(test_data.head(20), task_cols)

print(f"Created datasets with {len(train_dataset)} train, {len(val_dataset)} val, {len(test_dataset)} test samples")

# Custom collate function
def collate_batch(batch):
    proteins = [item['protein'] for item in batch]
    drugs = [item['drug'] for item in batch]
    ys = torch.stack([item['y'] for item in batch])
    
    return {
        'protein': Batch.from_data_list(proteins),
        'drug': Batch.from_data_list(drugs),
        'y': ys
    }

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_batch)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_batch)

print(f"Created data loaders with batch size {32}")

# %% [markdown]
# ## 4. Initialize Model

# %%
# Create model
model = MTL_DTAModel(
    task_names=task_cols,
    prot_emb_dim=config.model.prot_emb_dim,
    prot_gcn_dims=config.model.prot_gcn_dims,
    prot_fc_dims=config.model.prot_fc_dims,
    drug_node_in_dim=config.model.drug_node_in_dim,
    drug_node_h_dims=config.model.drug_node_h_dims,
    drug_fc_dims=config.model.drug_fc_dims,
    mlp_dims=config.model.mlp_dims,
    mlp_dropout=config.model.mlp_dropout
)

model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model initialized:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Tasks: {task_cols}")

# %% [markdown]
# ## 5. Setup Training

# %%
# Initialize optimizer
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=config.training.learning_rate,
    weight_decay=config.training.weight_decay
)

# Initialize learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=config.training.scheduler_factor,
    patience=config.training.scheduler_patience,
    verbose=True
)

# Initialize loss function
criterion = MaskedMSELoss(task_ranges=task_ranges).to(device)

# Initialize early stopping
early_stopping = EarlyStopping(
    patience=config.training.patience,
    min_delta=config.training.min_delta
)

print("Training setup complete:")
print(f"  Optimizer: {config.training.optimizer}")
print(f"  Learning rate: {config.training.learning_rate}")
print(f"  Loss function: MaskedMSELoss with task weighting")
print(f"  Early stopping patience: {config.training.patience}")

# %% [markdown]
# ## 6. Training Loop

# %%
# Training history
train_losses = []
val_losses = []
best_val_loss = float('inf')
best_model_state = None

# Training parameters
n_epochs = 20  # Reduced for demonstration
print(f"\nStarting training for {n_epochs} epochs...")
print("="*60)

for epoch in range(n_epochs):
    # Training phase
    model.train()
    train_loss = 0
    train_batches = 0
    
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs} - Training"):
        # Move data to device
        protein_batch = batch['protein'].to(device)
        drug_batch = batch['drug'].to(device)
        y_batch = batch['y'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        predictions = model(drug_batch, protein_batch)
        loss = criterion(predictions, y_batch)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_batches += 1
    
    avg_train_loss = train_loss / train_batches if train_batches > 0 else 0
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0
    val_batches = 0
    
    with torch.no_grad():
        for batch in val_loader:
            protein_batch = batch['protein'].to(device)
            drug_batch = batch['drug'].to(device)
            y_batch = batch['y'].to(device)
            
            predictions = model(drug_batch, protein_batch)
            loss = criterion(predictions, y_batch)
            
            val_loss += loss.item()
            val_batches += 1
    
    avg_val_loss = val_loss / val_batches if val_batches > 0 else 0
    val_losses.append(avg_val_loss)
    
    # Update learning rate
    scheduler.step(avg_val_loss)
    
    # Check for best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_state = model.state_dict().copy()
    
    # Early stopping
    early_stopping(avg_val_loss)
    
    # Print epoch results
    print(f"Epoch {epoch+1}/{n_epochs}: "
          f"Train Loss: {avg_train_loss:.4f}, "
          f"Val Loss: {avg_val_loss:.4f}, "
          f"Best Val Loss: {best_val_loss:.4f}")
    
    if early_stopping.early_stop:
        print(f"Early stopping triggered at epoch {epoch+1}")
        break

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"\nLoaded best model with validation loss: {best_val_loss:.4f}")

# %% [markdown]
# ## 7. Plot Training History

# %%
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(val_losses, label='Validation Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training History')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(val_losses, label='Validation Loss', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Loss (log scale)')
plt.title('Training History (Log Scale)')
plt.yscale('log')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# %% [markdown]
# ## 8. Evaluate on Test Set

# %%
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error

# Evaluate on test set
model.eval()
test_predictions = []
test_targets = []

with torch.no_grad():
    for batch in test_loader:
        protein_batch = batch['protein'].to(device)
        drug_batch = batch['drug'].to(device)
        y_batch = batch['y']
        
        predictions = model(drug_batch, protein_batch)
        
        test_predictions.append(predictions.cpu())
        test_targets.append(y_batch)

# Concatenate all predictions
test_predictions = torch.cat(test_predictions, dim=0).numpy()
test_targets = torch.cat(test_targets, dim=0).numpy()

# Calculate metrics per task
print("\nTest Set Performance:")
print("="*60)

for i, task in enumerate(task_cols):
    # Get valid (non-NaN) values
    mask = ~np.isnan(test_targets[:, i])
    if mask.sum() == 0:
        continue
    
    task_preds = test_predictions[mask, i]
    task_targets = test_targets[mask, i]
    
    r2 = r2_score(task_targets, task_preds)
    rmse = np.sqrt(mean_squared_error(task_targets, task_preds))
    mae = mean_absolute_error(task_targets, task_preds)
    
    print(f"{task:10s}: R²={r2:.3f}, RMSE={rmse:.3f}, MAE={mae:.3f} (n={mask.sum()})")

# %% [markdown]
# ## 9. Save Trained Model

# %%
# Create checkpoint directory
checkpoint_dir = Path(config.training.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Save model checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch + 1,
    'train_loss': train_losses[-1],
    'val_loss': val_losses[-1],
    'task_cols': task_cols,
    'task_ranges': task_ranges,
    'config': config.to_dict()
}

checkpoint_path = checkpoint_dir / 'model_checkpoint.pt'
torch.save(checkpoint, checkpoint_path)
print(f"\nModel saved to {checkpoint_path}")

# Also save training history
history_path = checkpoint_dir / 'training_history.json'
with open(history_path, 'w') as f:
    json.dump({
        'train_losses': train_losses,
        'val_losses': val_losses,
        'best_val_loss': best_val_loss
    }, f, indent=2)
print(f"Training history saved to {history_path}")

# %% [markdown]
# ## 10. Next Steps
# 
# Now that the model is trained, you can:
# 1. Move to `03_analysis.ipynb` for detailed analysis
# 2. Use the trained model for predictions on new data
# 3. Fine-tune hyperparameters for better performance
# 4. Implement cross-validation for more robust evaluation
# 
# The trained model can be loaded and used for predictions with:
# ```python
# from mtl_gnn_dta import AffinityPredictor
# predictor = AffinityPredictor(model_path='path/to/checkpoint.pt')
# predictions = predictor.predict_from_files(protein_path, ligand_path)
# ```