In [None]:
"""
Model Training Example Notebook
"""

# Cell 1: Setup
import sys
sys.path.append('..')

import torch
import pandas as pd
from gnn_dta_mtl import MTL_DTAModel, MTLTrainer, CrossValidator
from gnn_dta_mtl.datasets import build_mtl_dataset_optimized
from gnn_dta_mtl.features import StructureChunkLoader
from gnn_dta_mtl.utils import prepare_mtl_experiment

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Cell 2: Load Data
# Load your prepared dataframe
df = pd.read_parquet("../data/standardized/standardized_input.parquet")

# Define tasks
task_cols = ['pKi', 'pEC50', 'pKd', 'pIC50']

# Calculate task ranges
task_ranges = prepare_mtl_experiment(df, task_cols)

# Cell 3: Load Structure Chunks
# Initialize chunk loader for efficient structure loading
chunk_loader = StructureChunkLoader(
    chunk_dir="../data/structure_chunks/",
    cache_size=2
)

print(f"Loaded {len(chunk_loader.get_available_pdb_ids())} protein structures")

# Cell 4: Prepare Dataset
# Filter dataframe to available structures
available_pdb_ids = chunk_loader.get_available_pdb_ids()
df['protein_id'] = df['standardized_protein_pdb'].apply(
    lambda p: os.path.splitext(os.path.basename(p))[0]
)
df_clean = df[df['protein_id'].isin(available_pdb_ids)].reset_index(drop=True)

print(f"Dataset size: {len(df_clean)}")

# Cell 5: Split Data
from sklearn.model_selection import train_test_split

# Split into train/valid/test
train_df, test_df = train_test_split(df_clean, test_size=0.2, random_state=42)
train_df, valid_df = train_test_split(train_df, test_size=0.125, random_state=42)

print(f"Train: {len(train_df)}, Valid: {len(valid_df)}, Test: {len(test_df)}")

# Cell 6: Create Datasets
from gnn_dta_mtl.datasets import build_mtl_dataset_optimized
import torch_geometric

# Build datasets
train_dataset = build_mtl_dataset_optimized(train_df, chunk_loader, task_cols)
valid_dataset = build_mtl_dataset_optimized(valid_df, chunk_loader, task_cols)
test_dataset = build_mtl_dataset_optimized(test_df, chunk_loader, task_cols)

# Create data loaders
batch_size = 128

train_loader = torch_geometric.loader.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=0
)
valid_loader = torch_geometric.loader.DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=False, num_workers=0
)
test_loader = torch_geometric.loader.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, num_workers=0
)

# Cell 7: Initialize Model
model = MTL_DTAModel(
    task_names=task_cols,
    prot_emb_dim=1280,
    prot_gcn_dims=[128, 256, 256],
    prot_fc_dims=[1024, 128],
    drug_node_in_dim=[66, 1],
    drug_node_h_dims=[128, 64],
    drug_fc_dims=[1024, 128],
    mlp_dims=[1024, 512],
    mlp_dropout=0.25
)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Cell 8: Train Model
trainer = MTLTrainer(
    model=model,
    task_cols=task_cols,
    task_ranges=task_ranges,
    device=device,
    learning_rate=0.0005,
    batch_size=batch_size
)

# Train
history = trainer.train(
    train_loader=train_loader,
    valid_loader=valid_loader,
    n_epochs=100,
    patience=20,
    verbose=True
)

# Cell 9: Evaluate on Test Set
test_results = trainer.predict(test_loader)

# Print results
from gnn_dta_mtl.evaluation import calculate_metrics

print("\nTest Set Results:")
print("="*50)

for task, data in test_results.items():
    metrics = calculate_metrics(data['targets'], data['predictions'])
    print(f"{task}:")
    print(f"  RMSE: {metrics['rmse']:.3f}")
    print(f"  R²: {metrics['r2']:.3f}")
    print(f"  MAE: {metrics['mae']:.3f}")
    print(f"  Samples: {len(data['targets'])}")

# Cell 10: Visualize Results
from gnn_dta_mtl.evaluation import plot_predictions

# Plot predictions
fig = plot_predictions(test_results, task_cols)
plt.tight_layout()
plt.show()

# Cell 11: Save Model
torch.save({
    'model_state_dict': model.state_dict(),
    'task_cols': task_cols,
    'task_ranges': task_ranges,
    'model_config': {
        'prot_emb_dim': 1280,
        'prot_gcn_dims': [128, 256, 256],
        'prot_fc_dims': [1024, 128],
        'drug_node_in_dim': [66, 1],
        'drug_node_h_dims': [128, 64],
        'drug_fc_dims': [1024, 128],
        'mlp_dims': [1024, 512],
        'mlp_dropout': 0.25
    }
}, '../models/trained_model.pt')

print("Model saved successfully!")