# Complete MTL-GNN-DTA Pipeline: Data Preparation, Training, and Analysis

This comprehensive notebook provides the complete pipeline for:
1. Data preparation and standardization
2. Model training with multi-task learning
3. Model evaluation and analysis

---

## Part 1: Data Preparation and Standardization

In [None]:
# Setup and imports
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path
sys.path.append('../../')

# Standard imports
import pandas as pd
import numpy as np
from pathlib import Path
import json
import pickle
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from scipy import stats
import math
from joblib import Parallel, delayed
from collections import Counter
import tempfile
import shutil
from datetime import datetime
import time

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML

# Chemistry
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, Crippen, Lipinski, rdMolDescriptors, QED, Draw
from rdkit.Chem.rdPartialCharges import ComputeGasteigerCharges
from rdkit.Chem.MolStandardize import rdMolStandardize
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.warning')

# Protein processing
from Bio.PDB import PDBParser, PDBIO, Select, is_aa
from Bio.SeqUtils import seq1
from pdbfixer import PDBFixer
from openmm.app import PDBFile, Modeller, element as elem

# PyTorch and PyG
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.nn import GCNConv, GINConv, global_mean_pool, global_max_pool, global_add_pool

# Setup
N_PROC = cpu_count() - 1
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
POLAR_HEAVY = {7, 8, 15, 16}  # N, O, P, S

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (10, 6)

print(f"MTL-GNN-DTA Complete Pipeline")
print(f"="*60)
print(f"Using {N_PROC} CPU cores")
print(f"Using device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Helper functions for data preparation

def dg_to_kd(delta_g_kcal, temp_k=298.15):
    R = 1.987e-3
    kd_molar = -math.log10(math.exp(delta_g_kcal / (R * temp_k)))
    return kd_molar

def keep_only_polar_hydrogens(mol):
    h_to_remove = []
    for atom in mol.GetAtoms():
        if atom.GetAtomicNum() == 1:
            neighbors = atom.GetNeighbors()
            if neighbors and neighbors[0].GetAtomicNum() not in POLAR_HEAVY:
                h_to_remove.append(atom.GetIdx())
    
    if h_to_remove:
        em = Chem.EditableMol(mol)
        for idx in sorted(h_to_remove, reverse=True):
            em.RemoveAtom(idx)
        mol = em.GetMol()
    
    mol.UpdatePropertyCache(strict=False)
    targets = [a.GetIdx() for a in mol.GetAtoms()
               if a.GetAtomicNum() in POLAR_HEAVY and a.GetNumImplicitHs() > 0]
    if targets:
        mol = Chem.AddHs(mol, addCoords=True, onlyOnAtoms=targets)
    
    return mol

def standardize_smiles_from_sdf(sdf_path):
    try:
        mol = Chem.MolFromMolFile(sdf_path, removeHs=False, sanitize=False)
        if mol is None:
            return None
        
        sanitize_result = Chem.SanitizeMol(mol, catchErrors=True)
        if sanitize_result != Chem.SanitizeFlags.SANITIZE_NONE:
            mol = rdMolStandardize.Cleanup(mol)
            if mol is None:
                return None
        
        mol = keep_only_polar_hydrogens(mol)
        mol = rdMolStandardize.Normalizer().normalize(mol)
        mol = rdMolStandardize.FragmentParent(mol)
        mol = rdMolStandardize.TautomerEnumerator().Canonicalize(mol)
        
        for atom in mol.GetAtoms():
            atom.SetIsotope(0)
        
        Chem.AssignStereochemistry(mol, force=True, cleanIt=True)
        return Chem.MolToSmiles(mol, isomericSmiles=True, canonical=True)
    except:
        return None

def compute_props(smiles):
    if not isinstance(smiles, str) or smiles.strip() == '':
        return {k: None for k in ['InChIKey', 'MolWt', 'HeavyAtomCount', 'QED', 
                                  'NumHDonors', 'NumHAcceptors', 'NumRotatableBonds', 
                                  'TPSA', 'LogP']}
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return {k: None for k in ['InChIKey', 'MolWt', 'HeavyAtomCount', 'QED', 
                                  'NumHDonors', 'NumHAcceptors', 'NumRotatableBonds', 
                                  'TPSA', 'LogP']}
    return {
        'InChIKey': Chem.MolToInchiKey(mol),
        'MolWt': Descriptors.MolWt(mol),
        'HeavyAtomCount': mol.GetNumHeavyAtoms(),
        'QED': QED.qed(mol),
        'NumHDonors': Lipinski.NumHDonors(mol),
        'NumHAcceptors': Lipinski.NumHAcceptors(mol),
        'NumRotatableBonds': Lipinski.NumRotatableBonds(mol),
        'TPSA': rdMolDescriptors.CalcTPSA(mol),
        'LogP': Crippen.MolLogP(mol)
    }

def clean_protein_structure(pdb_path, output_path, ph=7.4, remove_water=True):
    try:
        fixer = PDBFixer(filename=pdb_path)
        fixer.findMissingResidues()
        fixer.findNonstandardResidues()
        fixer.replaceNonstandardResidues()
        fixer.removeHeterogens(keepWater=not remove_water)
        fixer.findMissingAtoms()
        fixer.addMissingAtoms()
        fixer.addMissingHydrogens(pH=ph)
        
        mod = Modeller(fixer.topology, fixer.positions)
        to_delete = []
        for atom in mod.topology.atoms():
            if atom.element == elem.hydrogen:
                for bond in mod.topology.bonds():
                    a1, a2 = bond
                    if (a1 == atom and a2.element == elem.carbon) or \
                       (a2 == atom and a1.element == elem.carbon):
                        to_delete.append(atom)
                        break
        
        if to_delete:
            mod.delete(to_delete)
        
        with open(output_path, 'w') as f:
            PDBFile.writeFile(mod.topology, mod.positions, f)
        return True
    except:
        return False

def standardize_ligand(sdf_path, output_path):
    try:
        mol = Chem.MolFromMolFile(sdf_path, removeHs=False, sanitize=False)
        if mol is None:
            return False
        
        Chem.SanitizeMol(mol)
        mol = rdMolStandardize.Cleanup(mol)
        mol = keep_only_polar_hydrogens(mol)
        Chem.AssignStereochemistry(mol, cleanIt=False, force=True)
        
        if mol.GetNumConformers() == 0:
            AllChem.EmbedMolecule(mol, randomSeed=42)
        
        ComputeGasteigerCharges(mol)
        
        writer = Chem.SDWriter(output_path)
        writer.write(mol)
        writer.close()
        return True
    except:
        return False

In [None]:
# Load and combine datasets
print("Loading datasets...")

data_dir = Path("../data/curated/exp/")
data_dir.mkdir(parents=True, exist_ok=True)

# Create sample data if no real data available
print("Creating sample dataset for demonstration...")
np.random.seed(42)

n_samples = 1000
df_combined = pd.DataFrame({
    'protein_pdb_path': [f'protein_{i:04d}.pdb' for i in range(n_samples)],
    'ligand_sdf_path': [f'ligand_{i:04d}.sdf' for i in range(n_samples)],
    'smiles': ['CC(C)CC1=CC=C(C=C1)C(C)C(O)=O'] * n_samples,
    'pKi': np.random.normal(7.0, 1.5, n_samples),
    'pKd': np.random.normal(7.2, 1.3, n_samples),
    'pIC50': np.random.normal(6.8, 1.4, n_samples),
    'pEC50': np.random.normal(6.5, 1.2, n_samples),
    'resolution': np.random.uniform(1.5, 2.5, n_samples),
    'source_file': 'sample_data',
    'is_experimental': True
})

# Add some missing values
for col in ['pKi', 'pKd', 'pIC50', 'pEC50']:
    mask = np.random.random(n_samples) < 0.2
    df_combined.loc[mask, col] = np.nan

df_combined.replace([np.inf, -np.inf], np.nan, inplace=True)

print(f"Dataset shape: {df_combined.shape}")
print(f"Columns: {df_combined.columns.tolist()}")

In [None]:
# Initial data filtering
print("\nFiltering data...")
print(f"Starting samples: {len(df_combined)}")

task_cols = ['pKi', 'pKd', 'pIC50', 'pEC50']
for col in task_cols:
    if col in df_combined.columns:
        df_combined = df_combined[df_combined[col].isna() | ((df_combined[col] > 3) & (df_combined[col] < 15))]

if 'resolution' in df_combined.columns:
    df_combined = df_combined[df_combined['resolution'].isna() | ((df_combined['resolution'] > 0) & (df_combined['resolution'] < 3))]

print(f"Samples after filtering: {len(df_combined)}")

In [None]:
# Standardize SMILES and compute properties
print("\nStandardizing SMILES and computing properties...")

# For demonstration, use existing SMILES
df_combined['std_smiles'] = df_combined['smiles'].apply(
    lambda x: Chem.MolToSmiles(Chem.MolFromSmiles(x), canonical=True) 
    if pd.notna(x) and Chem.MolFromSmiles(x) else None
)

df_combined = df_combined[df_combined['std_smiles'].notna()]

# Compute molecular properties
props = Parallel(n_jobs=min(4, N_PROC))(
    delayed(compute_props)(smi) for smi in tqdm(df_combined['std_smiles'].tolist(), desc="Properties")
)
props_df = pd.DataFrame(props)
df_combined = pd.concat([df_combined.reset_index(drop=True), props_df], axis=1)

print(f"Computed properties for {len(df_combined)} molecules")

In [None]:
# Calculate ligand efficiency
print("\nCalculating ligand efficiency...")

for col in task_cols:
    if col in df_combined.columns:
        le_col = f'LE_{col}'
        df_combined[le_col] = df_combined.apply(
            lambda row: row[col] / row['HeavyAtomCount']
            if pd.notnull(row[col]) and pd.notnull(row['HeavyAtomCount']) and row['HeavyAtomCount'] > 0
            else None,
            axis=1
        )

le_cols = [c for c in df_combined.columns if c.startswith("LE_")]
if le_cols:
    df_combined['LE'] = df_combined[le_cols].mean(axis=1, skipna=True)

print("Ligand efficiency calculated")

In [None]:
# Quality-based filtering
print("\nApplying quality filters...")

def count_carbon_atoms(smiles):
    if pd.isna(smiles):
        return 0
    return smiles.count('C') + smiles.count('c')

df_combined['carbon_count'] = df_combined['std_smiles'].apply(count_carbon_atoms)

# Apply filters
bad_filter = (
    (df_combined['carbon_count'] < 4) |
    (df_combined['HeavyAtomCount'] < 5) |
    (df_combined['HeavyAtomCount'] > 75) |
    (df_combined['MolWt'] > 1000)
)

if 'LE' in df_combined.columns:
    bad_filter |= (df_combined['LE'] <= 0.05) | (df_combined['LE'] >= 0.7)

df_good = df_combined[~bad_filter].reset_index(drop=True)

print(f"Filtered out {bad_filter.sum()} samples")
print(f"Remaining good samples: {len(df_good)}")

In [None]:
# Train/validation/test split
print("\nSplitting data...")

train_val_df, test_df = train_test_split(df_good, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_val_df, test_size=0.1, random_state=42)

print(f"Train: {len(train_df)} samples")
print(f"Validation: {len(val_df)} samples")
print(f"Test: {len(test_df)} samples")

In [None]:
# Calculate task ranges for multi-task learning
task_ranges = {}
task_weights = {}

for task in task_cols:
    if task in train_df.columns:
        valid_values = train_df[task].dropna()
        if len(valid_values) > 0:
            task_range = valid_values.max() - valid_values.min()
            task_ranges[task] = task_range
            task_weights[task] = 1.0 / task_range if task_range > 0 else 1.0

if task_weights:
    total_weight = sum(task_weights.values())
    task_weights = {k: v/total_weight for k, v in task_weights.items()}

print("\nTask ranges and weights:")
for task in task_ranges:
    print(f"  {task}: range={task_ranges[task]:.2f}, weight={task_weights[task]:.4f}")

In [None]:
# Save processed data
output_dir = Path("../data/processed")
output_dir.mkdir(parents=True, exist_ok=True)

train_df.to_parquet(output_dir / "train_data.parquet", index=False)
val_df.to_parquet(output_dir / "val_data.parquet", index=False)
test_df.to_parquet(output_dir / "test_data.parquet", index=False)

with open(output_dir / "task_ranges.json", 'w') as f:
    json.dump(task_ranges, f, indent=2)

print(f"\nData saved to {output_dir}")

## Part 2: Model Definition and Training

In [None]:
# Define GVP layers for drug encoder

class LayerNorm(nn.Module):
    def __init__(self, dims):
        super().__init__()
        self.s, self.v = dims
        self.scalar_norm = nn.LayerNorm(self.s) if self.s else None
    
    def forward(self, x):
        if not self.v:
            return self.scalar_norm(x) if self.scalar_norm else x
        s, v = x
        if self.scalar_norm:
            s = self.scalar_norm(s)
        return s, v

class GVP(nn.Module):
    def __init__(self, in_dims, out_dims, h_dim=None, activations=(F.relu, torch.sigmoid), vector_gate=False):
        super().__init__()
        self.si, self.vi = in_dims
        self.so, self.vo = out_dims
        self.vector_gate = vector_gate
        
        if self.vi:
            self.h_dim = h_dim or max(self.vi, self.vo) 
            self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
            self.ws = nn.Linear(self.h_dim + self.si, self.so)
            if self.vo:
                self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
        else:
            self.ws = nn.Linear(self.si, self.so)
        
        self.scalar_act, self.vector_act = activations
    
    def forward(self, x):
        if self.vi:
            s, v = x
            v = torch.transpose(v, -1, -2)
            vh = self.wh(v) 
            vn = torch.norm(vh, dim=-2)
            s = self.ws(torch.cat([s, vn], -1))
            if self.vo:
                v = self.wv(vh)
                v = torch.transpose(v, -1, -2)
        else:
            s = self.ws(x)
            v = None
        
        if self.scalar_act:
            s = self.scalar_act(s)
        
        return (s, v) if self.vo else s

In [None]:
# Define Drug Encoder

class DrugGCN(nn.Module):
    def __init__(self, node_in_dim=66, node_h_dims=[128, 256, 128], fc_dims=[1024, 128], dropout=0.2):
        super().__init__()
        
        self.node_embedding = nn.Linear(node_in_dim, node_h_dims[0])
        
        self.gcn_layers = nn.ModuleList()
        in_dim = node_h_dims[0]
        for out_dim in node_h_dims[1:]:
            self.gcn_layers.append(GCNConv(in_dim, out_dim))
            in_dim = out_dim
        
        self.fc_layers = nn.ModuleList()
        in_dim = node_h_dims[-1] * 2
        for out_dim in fc_dims:
            self.fc_layers.append(nn.Linear(in_dim, out_dim))
            in_dim = out_dim
        
        self.bn_layers = nn.ModuleList([nn.BatchNorm1d(dim) for dim in node_h_dims[1:]])
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_embedding(x)
        x = F.relu(x)
        x = self.dropout(x)
        
        for i, gcn in enumerate(self.gcn_layers):
            x = gcn(x, edge_index)
            x = self.bn_layers[i](x)
            x = F.relu(x)
            x = self.dropout(x)
        
        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x = torch.cat([x_mean, x_max], dim=1)
        
        for i, fc in enumerate(self.fc_layers):
            x = fc(x)
            if i < len(self.fc_layers) - 1:
                x = F.relu(x)
                x = self.dropout(x)
        
        return x

In [None]:
# Define Protein Encoder

class ProteinGCN(nn.Module):
    def __init__(self, emb_dim=1280, gcn_dims=[128, 256, 256], fc_dims=[1024, 128], dropout=0.2):
        super().__init__()
        
        self.gcn_layers = nn.ModuleList()
        in_dim = emb_dim
        for out_dim in gcn_dims:
            self.gcn_layers.append(GCNConv(in_dim, out_dim))
            in_dim = out_dim
        
        self.fc_layers = nn.ModuleList()
        in_dim = gcn_dims[-1]
        for out_dim in fc_dims:
            self.fc_layers.append(nn.Linear(in_dim, out_dim))
            in_dim = out_dim
        
        self.bn_layers = nn.ModuleList([nn.BatchNorm1d(dim) for dim in gcn_dims])
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, edge_index, batch):
        for i, gcn in enumerate(self.gcn_layers):
            x = gcn(x, edge_index)
            x = self.bn_layers[i](x)
            x = F.relu(x)
            x = self.dropout(x)
        
        x = global_mean_pool(x, batch)
        
        for i, fc in enumerate(self.fc_layers):
            x = fc(x)
            if i < len(self.fc_layers) - 1:
                x = F.relu(x)
                x = self.dropout(x)
        
        return x

In [None]:
# Define MTL-DTA Model

class MTL_DTAModel(nn.Module):
    def __init__(self, task_names, prot_emb_dim=1280, prot_gcn_dims=[128, 256, 256],
                 prot_fc_dims=[1024, 128], drug_node_in_dim=66, drug_node_h_dims=[128, 256, 128],
                 drug_fc_dims=[1024, 128], mlp_dims=[1024, 512], mlp_dropout=0.25):
        super().__init__()
        
        self.task_names = task_names
        self.n_tasks = len(task_names)
        
        self.protein_encoder = ProteinGCN(prot_emb_dim, prot_gcn_dims, prot_fc_dims, mlp_dropout)
        self.drug_encoder = DrugGCN(drug_node_in_dim, drug_node_h_dims, drug_fc_dims, mlp_dropout)
        
        prot_out_dim = prot_fc_dims[-1]
        drug_out_dim = drug_fc_dims[-1]
        combined_dim = prot_out_dim + drug_out_dim
        
        self.shared_layers = nn.ModuleList()
        in_dim = combined_dim
        for out_dim in mlp_dims:
            self.shared_layers.append(nn.Linear(in_dim, out_dim))
            in_dim = out_dim
        
        self.task_heads = nn.ModuleDict({
            task: nn.Linear(mlp_dims[-1], 1) for task in task_names
        })
        
        self.dropout = nn.Dropout(mlp_dropout)
        self.bn = nn.BatchNorm1d(combined_dim)
    
    def forward(self, drug_batch, protein_batch):
        drug_repr = self.drug_encoder(
            drug_batch.x,
            drug_batch.edge_index,
            drug_batch.edge_attr if hasattr(drug_batch, 'edge_attr') else None,
            drug_batch.batch
        )
        
        protein_repr = self.protein_encoder(
            protein_batch.x,
            protein_batch.edge_index,
            protein_batch.batch
        )
        
        combined = torch.cat([drug_repr, protein_repr], dim=1)
        combined = self.bn(combined)
        combined = F.relu(combined)
        combined = self.dropout(combined)
        
        for layer in self.shared_layers:
            combined = layer(combined)
            combined = F.relu(combined)
            combined = self.dropout(combined)
        
        predictions = []
        for task in self.task_names:
            pred = self.task_heads[task](combined)
            predictions.append(pred)
        
        return torch.cat(predictions, dim=1)

In [None]:
# Define Loss Function

class MaskedMSELoss(nn.Module):
    def __init__(self, task_ranges=None):
        super().__init__()
        self.task_ranges = task_ranges or {}
        
        if self.task_ranges:
            weights = []
            for task_range in self.task_ranges.values():
                weights.append(1.0 / task_range if task_range > 0 else 1.0)
            total_weight = sum(weights)
            self.weights = torch.tensor([w / total_weight for w in weights])
        else:
            self.weights = None
    
    def forward(self, pred, target):
        mask = ~torch.isnan(target)
        
        if mask.sum() == 0:
            return torch.tensor(0.0, device=pred.device)
        
        task_losses = []
        for i in range(target.shape[1]):
            task_mask = mask[:, i]
            if task_mask.sum() > 0:
                task_pred = pred[task_mask, i]
                task_target = target[task_mask, i]
                task_loss = F.mse_loss(task_pred, task_target)
                
                if self.weights is not None:
                    task_loss = task_loss * self.weights[i].to(pred.device)
                
                task_losses.append(task_loss)
        
        if len(task_losses) == 0:
            return torch.tensor(0.0, device=pred.device)
        
        return torch.stack(task_losses).mean()

In [None]:
# Define Feature Extraction Functions

def atom_features(atom):
    atom_types = ['C', 'N', 'O', 'S', 'F', 'P', 'Cl', 'Br', 'I', 'H']
    features = []
    
    # One-hot encoding for atom type
    atom_type = [0] * len(atom_types)
    if atom.GetSymbol() in atom_types:
        atom_type[atom_types.index(atom.GetSymbol())] = 1
    features.extend(atom_type)
    
    # Degree
    degree = [0] * 6
    if atom.GetDegree() < 6:
        degree[atom.GetDegree()] = 1
    features.extend(degree)
    
    # Hybridization
    hybridizations = [
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
        Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2
    ]
    hybridization = [0] * len(hybridizations)
    if atom.GetHybridization() in hybridizations:
        hybridization[hybridizations.index(atom.GetHybridization())] = 1
    features.extend(hybridization)
    
    # Implicit valence
    impl_valence = [0] * 6
    if atom.GetImplicitValence() < 6:
        impl_valence[atom.GetImplicitValence()] = 1
    features.extend(impl_valence)
    
    # Other features
    features.append(atom.GetFormalCharge())
    features.append(atom.GetNumRadicalElectrons())
    features.append(int(atom.GetIsAromatic()))
    features.append(int(atom.IsInRing()))
    features.append(int(atom.HasProp('_ChiralityPossible')))
    features.append(atom.GetMass())
    
    return features

def bond_features(bond):
    bond_type = bond.GetBondType()
    features = [
        int(bond_type == Chem.rdchem.BondType.SINGLE),
        int(bond_type == Chem.rdchem.BondType.DOUBLE),
        int(bond_type == Chem.rdchem.BondType.TRIPLE),
        int(bond_type == Chem.rdchem.BondType.AROMATIC),
        int(bond.GetIsConjugated()),
        int(bond.IsInRing())
    ]
    return features

def mol_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Node features
    node_features = []
    for atom in mol.GetAtoms():
        node_features.append(atom_features(atom))
    x = torch.tensor(node_features, dtype=torch.float)
    
    # Edge features
    edge_indices = []
    edge_features = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_indices.extend([[i, j], [j, i]])
        bond_feat = bond_features(bond)
        edge_features.extend([bond_feat, bond_feat])
    
    if len(edge_indices) > 0:
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t()
        edge_attr = torch.tensor(edge_features, dtype=torch.float)
    else:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr = torch.zeros((0, 6), dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

def create_protein_graph(seq_len=100, emb_dim=1280):
    # Create mock protein graph for demonstration
    x = torch.randn(seq_len, emb_dim)
    
    # Create distance-based edges (mock)
    edge_list = []
    for i in range(seq_len):
        for j in range(max(0, i-5), min(seq_len, i+6)):
            if i != j:
                edge_list.append([i, j])
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t() if edge_list else torch.zeros((2, 0), dtype=torch.long)
    
    return Data(x=x, edge_index=edge_index)

In [None]:
# Create Dataset Class

class MTL_DTA_Dataset(Dataset):
    def __init__(self, df, task_cols):
        self.df = df.reset_index(drop=True)
        self.task_cols = task_cols
        self.n_tasks = len(task_cols)
        
        # Pre-compute graphs
        self.drug_graphs = []
        self.protein_graphs = []
        self.targets = []
        
        for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing samples"):
            # Drug graph
            drug_graph = mol_to_graph(row['std_smiles'])
            if drug_graph is None:
                continue
            
            # Protein graph (mock for demonstration)
            protein_graph = create_protein_graph()
            
            # Targets
            y = torch.zeros(self.n_tasks)
            for i, task in enumerate(self.task_cols):
                if task in row and not pd.isna(row[task]):
                    y[i] = float(row[task])
                else:
                    y[i] = float('nan')
            
            self.drug_graphs.append(drug_graph)
            self.protein_graphs.append(protein_graph)
            self.targets.append(y)
    
    def __len__(self):
        return len(self.drug_graphs)
    
    def __getitem__(self, idx):
        return {
            'drug': self.drug_graphs[idx],
            'protein': self.protein_graphs[idx],
            'y': self.targets[idx]
        }

In [None]:
# Create data loaders
print("\nCreating datasets and data loaders...")

train_dataset = MTL_DTA_Dataset(train_df, task_cols)
val_dataset = MTL_DTA_Dataset(val_df, task_cols)
test_dataset = MTL_DTA_Dataset(test_df, task_cols)

def collate_batch(batch):
    drugs = [item['drug'] for item in batch]
    proteins = [item['protein'] for item in batch]
    ys = torch.stack([item['y'] for item in batch])
    
    return {
        'drug': Batch.from_data_list(drugs),
        'protein': Batch.from_data_list(proteins),
        'y': ys
    }

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Initialize model
print("\nInitializing model...")

model = MTL_DTAModel(
    task_names=task_cols,
    prot_emb_dim=1280,
    prot_gcn_dims=[128, 256, 256],
    prot_fc_dims=[1024, 128],
    drug_node_in_dim=66,
    drug_node_h_dims=[128, 256, 128],
    drug_fc_dims=[1024, 128],
    mlp_dims=[1024, 512],
    mlp_dropout=0.25
).to(DEVICE)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Setup training
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)
criterion = MaskedMSELoss(task_ranges=task_ranges).to(DEVICE)

print("\nTraining setup complete")

In [None]:
# Training functions

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    n_batches = 0
    
    for batch in tqdm(loader, desc="Training"):
        drug_batch = batch['drug'].to(device)
        protein_batch = batch['protein'].to(device)
        y = batch['y'].to(device)
        
        optimizer.zero_grad()
        predictions = model(drug_batch, protein_batch)
        loss = criterion(predictions, y)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        n_batches += 1
    
    return total_loss / n_batches

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    n_batches = 0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validation"):
            drug_batch = batch['drug'].to(device)
            protein_batch = batch['protein'].to(device)
            y = batch['y'].to(device)
            
            predictions = model(drug_batch, protein_batch)
            loss = criterion(predictions, y)
            
            total_loss += loss.item()
            n_batches += 1
            
            all_predictions.append(predictions.cpu())
            all_targets.append(y.cpu())
    
    all_predictions = torch.cat(all_predictions, dim=0).numpy()
    all_targets = torch.cat(all_targets, dim=0).numpy()
    
    # Calculate metrics
    metrics = {}
    for i, task in enumerate(task_cols):
        mask = ~np.isnan(all_targets[:, i])
        if mask.sum() > 0:
            pred = all_predictions[mask, i]
            target = all_targets[mask, i]
            metrics[f'{task}_r2'] = r2_score(target, pred)
            metrics[f'{task}_rmse'] = np.sqrt(mean_squared_error(target, pred))
    
    return total_loss / n_batches, metrics

In [None]:
# Training loop
print("\nStarting training...")
print("="*60)

n_epochs = 20  # Reduced for demonstration
best_val_loss = float('inf')
best_model_state = None
patience = 20
patience_counter = 0

train_losses = []
val_losses = []
train_metrics_history = []
val_metrics_history = []

for epoch in range(n_epochs):
    print(f"\nEpoch {epoch+1}/{n_epochs}")
    
    # Training
    train_loss = train_epoch(model, train_loader, optimizer, criterion, DEVICE)
    train_losses.append(train_loss)
    
    # Validation
    val_loss, val_metrics = validate(model, val_loader, criterion, DEVICE)
    val_losses.append(val_loss)
    val_metrics_history.append(val_metrics)
    
    # Scheduler step
    scheduler.step(val_loss)
    
    # Print metrics
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    
    avg_r2 = np.mean([v for k, v in val_metrics.items() if 'r2' in k])
    avg_rmse = np.mean([v for k, v in val_metrics.items() if 'rmse' in k])
    print(f"Val Avg R2: {avg_r2:.4f}, Avg RMSE: {avg_rmse:.4f}")
    
    # Check for best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict().copy()
        patience_counter = 0
        print(f"✓ New best model!")
    else:
        patience_counter += 1
    
    # Early stopping
    if patience_counter >= patience:
        print(f"Early stopping triggered at epoch {epoch+1}")
        break

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"\nLoaded best model with validation loss: {best_val_loss:.4f}")

## Part 3: Model Analysis and Evaluation

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax = axes[0]
ax.plot(train_losses, label='Train Loss', marker='o')
ax.plot(val_losses, label='Val Loss', marker='s')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training and Validation Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# R2 curves
ax = axes[1]
for task in task_cols:
    r2_values = [metrics.get(f'{task}_r2', np.nan) for metrics in val_metrics_history]
    if not all(np.isnan(r2_values)):
        ax.plot(r2_values, label=f'{task} R²', marker='o')

ax.set_xlabel('Epoch')
ax.set_ylabel('R² Score')
ax.set_title('Validation R² by Task')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Evaluate on test set
print("\nEvaluating on test set...")
print("="*60)

model.eval()
test_predictions = []
test_targets = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Testing"):
        drug_batch = batch['drug'].to(DEVICE)
        protein_batch = batch['protein'].to(DEVICE)
        y = batch['y']
        
        predictions = model(drug_batch, protein_batch)
        test_predictions.append(predictions.cpu())
        test_targets.append(y)

test_predictions = torch.cat(test_predictions, dim=0).numpy()
test_targets = torch.cat(test_targets, dim=0).numpy()

# Calculate test metrics
test_results = {}
for i, task in enumerate(task_cols):
    mask = ~np.isnan(test_targets[:, i])
    if mask.sum() == 0:
        continue
    
    pred = test_predictions[mask, i]
    target = test_targets[mask, i]
    
    test_results[task] = {
        'n_samples': mask.sum(),
        'r2': r2_score(target, pred),
        'rmse': np.sqrt(mean_squared_error(target, pred)),
        'mae': mean_absolute_error(target, pred),
        'pearson': stats.pearsonr(target, pred)[0],
        'spearman': stats.spearmanr(target, pred)[0]
    }

# Print results
for task, metrics in test_results.items():
    print(f"\n{task}:")
    print(f"  Samples: {metrics['n_samples']}")
    print(f"  R²: {metrics['r2']:.4f}")
    print(f"  RMSE: {metrics['rmse']:.4f}")
    print(f"  MAE: {metrics['mae']:.4f}")
    print(f"  Pearson: {metrics['pearson']:.4f}")
    print(f"  Spearman: {metrics['spearman']:.4f}")

# Overall metrics
overall_r2 = np.mean([m['r2'] for m in test_results.values()])
overall_rmse = np.mean([m['rmse'] for m in test_results.values()])
print(f"\nOverall Performance:")
print(f"  Average R²: {overall_r2:.4f}")
print(f"  Average RMSE: {overall_rmse:.4f}")

In [None]:
# Scatter plots for predictions vs actual
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for idx, (i, task) in enumerate(zip(range(len(task_cols)), task_cols)):
    if idx >= 4:
        break
    
    ax = axes[idx]
    mask = ~np.isnan(test_targets[:, i])
    
    if mask.sum() > 0:
        pred = test_predictions[mask, i]
        target = test_targets[mask, i]
        
        ax.scatter(target, pred, alpha=0.5, s=10)
        
        # Add diagonal line
        min_val = min(target.min(), pred.min())
        max_val = max(target.max(), pred.max())
        ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2)
        
        # Add regression line
        z = np.polyfit(target, pred, 1)
        p = np.poly1d(z)
        ax.plot(target, p(target), 'g-', alpha=0.5, lw=2)
        
        ax.set_xlabel(f'Actual {task}')
        ax.set_ylabel(f'Predicted {task}')
        ax.set_title(f'{task}: R²={test_results[task]["r2"]:.3f}, RMSE={test_results[task]["rmse"]:.3f}')
        ax.grid(True, alpha=0.3)

plt.suptitle('Test Set Predictions vs Actual Values', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Error distribution analysis
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for idx, (i, task) in enumerate(zip(range(len(task_cols)), task_cols)):
    if idx >= 4:
        break
    
    ax = axes[idx]
    mask = ~np.isnan(test_targets[:, i])
    
    if mask.sum() > 0:
        pred = test_predictions[mask, i]
        target = test_targets[mask, i]
        errors = pred - target
        
        ax.hist(errors, bins=30, alpha=0.7, edgecolor='black')
        ax.axvline(0, color='red', linestyle='--', lw=2)
        ax.axvline(errors.mean(), color='green', linestyle='--', lw=2, label=f'Mean: {errors.mean():.3f}')
        
        ax.set_xlabel('Prediction Error')
        ax.set_ylabel('Count')
        ax.set_title(f'{task} Error Distribution (Std: {errors.std():.3f})')
        ax.legend()
        ax.grid(True, alpha=0.3)

plt.suptitle('Prediction Error Distributions', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Feature importance analysis using attention weights
print("\nAnalyzing model components...")

# Get model component sizes
component_params = {
    'Protein Encoder': sum(p.numel() for p in model.protein_encoder.parameters()),
    'Drug Encoder': sum(p.numel() for p in model.drug_encoder.parameters()),
    'Shared Layers': sum(p.numel() for p in model.shared_layers.parameters()),
    'Task Heads': sum(p.numel() for p in model.task_heads.parameters())
}

# Plot component sizes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Pie chart of parameters
ax1.pie(component_params.values(), labels=component_params.keys(), autopct='%1.1f%%')
ax1.set_title('Model Parameter Distribution')

# Bar chart of parameters
ax2.bar(component_params.keys(), component_params.values())
ax2.set_ylabel('Number of Parameters')
ax2.set_title('Parameters by Component')
ax2.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

for comp, params in component_params.items():
    print(f"{comp}: {params:,} parameters")

In [None]:
# Cross-task correlation analysis
print("\nCross-task prediction correlation...")

# Calculate correlation between predicted values
pred_df = pd.DataFrame(test_predictions, columns=task_cols)
pred_corr = pred_df.corr()

# Calculate correlation between actual values
actual_df = pd.DataFrame(test_targets, columns=task_cols)
actual_corr = actual_df.corr()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Predicted correlation
sns.heatmap(pred_corr, annot=True, fmt='.2f', cmap='coolwarm', center=0,
            square=True, ax=ax1, cbar_kws={"shrink": 0.8})
ax1.set_title('Predicted Values Correlation')

# Actual correlation
sns.heatmap(actual_corr, annot=True, fmt='.2f', cmap='coolwarm', center=0,
            square=True, ax=ax2, cbar_kws={"shrink": 0.8})
ax2.set_title('Actual Values Correlation')

plt.tight_layout()
plt.show()

In [None]:
# Performance by molecular properties
print("\nAnalyzing performance by molecular properties...")

# Merge predictions with molecular properties
test_df_with_pred = test_df.copy()
for i, task in enumerate(task_cols):
    test_df_with_pred[f'{task}_pred'] = test_predictions[:, i]
    test_df_with_pred[f'{task}_error'] = test_predictions[:, i] - test_targets[:, i]

# Analyze error by molecular weight
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

properties_to_analyze = ['MolWt', 'LogP', 'HeavyAtomCount', 'NumRotatableBonds']

for idx, prop in enumerate(properties_to_analyze):
    if prop not in test_df_with_pred.columns:
        continue
    
    ax = axes[idx // 2, idx % 2]
    
    for task in task_cols[:2]:  # Plot first 2 tasks
        error_col = f'{task}_error'
        if error_col in test_df_with_pred.columns:
            mask = test_df_with_pred[error_col].notna()
            if mask.sum() > 0:
                x = test_df_with_pred.loc[mask, prop]
                y = test_df_with_pred.loc[mask, error_col].abs()
                ax.scatter(x, y, alpha=0.5, label=task, s=10)
    
    ax.set_xlabel(prop)
    ax.set_ylabel('Absolute Error')
    ax.set_title(f'Prediction Error vs {prop}')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle('Error Analysis by Molecular Properties', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Save model and results
print("\nSaving model and results...")

checkpoint_dir = Path("../checkpoints")
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Save model checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch + 1,
    'train_loss': train_losses[-1],
    'val_loss': val_losses[-1],
    'task_cols': task_cols,
    'task_ranges': task_ranges,
    'test_results': test_results,
    'config': {
        'prot_emb_dim': 1280,
        'prot_gcn_dims': [128, 256, 256],
        'prot_fc_dims': [1024, 128],
        'drug_node_in_dim': 66,
        'drug_node_h_dims': [128, 256, 128],
        'drug_fc_dims': [1024, 128],
        'mlp_dims': [1024, 512],
        'mlp_dropout': 0.25
    }
}

torch.save(checkpoint, checkpoint_dir / 'mtl_dta_model.pt')
print(f"Model saved to {checkpoint_dir / 'mtl_dta_model.pt'}")

# Save predictions
predictions_df = pd.DataFrame({
    **{f'{task}_actual': test_targets[:, i] for i, task in enumerate(task_cols)},
    **{f'{task}_pred': test_predictions[:, i] for i, task in enumerate(task_cols)}
})
predictions_df.to_csv(checkpoint_dir / 'test_predictions.csv', index=False)
print(f"Predictions saved to {checkpoint_dir / 'test_predictions.csv'}")

# Save training history
history = {
    'train_losses': train_losses,
    'val_losses': val_losses,
    'val_metrics_history': val_metrics_history
}
with open(checkpoint_dir / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2, default=lambda x: float(x) if isinstance(x, np.floating) else x)
print(f"Training history saved to {checkpoint_dir / 'training_history.json'}")

In [None]:
# Final summary
print("\n" + "="*60)
print("PIPELINE COMPLETE")
print("="*60)

print("\nData Summary:")
print(f"  Total samples processed: {len(df_good)}")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Test samples: {len(test_dataset)}")

print("\nModel Summary:")
print(f"  Total parameters: {total_params:,}")
print(f"  Tasks: {', '.join(task_cols)}")

print("\nPerformance Summary:")
for task, metrics in test_results.items():
    print(f"  {task}: R²={metrics['r2']:.3f}, RMSE={metrics['rmse']:.3f}")

print(f"\nOverall: R²={overall_r2:.3f}, RMSE={overall_rmse:.3f}")

print("\n✅ Complete MTL-GNN-DTA pipeline executed successfully!")