# TNBC Drug Cytotoxicity Prediction Model

## Abstract

This notebook implements a three-phase transfer learning pipeline for predicting drug cytotoxicity in Triple-Negative Breast Cancer (TNBC) cell lines. The model employs Graph Neural Networks (GNNs) to encode drug molecular structures and pathway-based features to represent cell line characteristics.

**Methodology:**
- Phase 1: Pan-cancer pre-training with 5-fold cross-validation
- Phase 2: Breast cancer fine-tuning with 5-fold cross-validation
- Phase 3: TNBC-specific fine-tuning (single split due to limited sample size)
- Cell-line based data splitting to prevent data leakage
- Evaluation metrics with confidence intervals


## 1. Setup and Imports

Import required libraries and configure paths for reproducibility.


In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import time
import pubchempy as pcp
import torch
import torch.nn as nn
from torch_geometric.data import Data, Batch
from torch_geometric.nn import TransformerConv, global_mean_pool, global_max_pool
from rdkit import Chem
from rdkit.Chem import AllChem
import pickle
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import os
import subprocess
import sys
from sklearn.model_selection import train_test_split, KFold
from scipy import stats
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
import shap
import json
from datetime import datetime

plt.rcParams.update({
    'figure.dpi': 300,  # High resolution for publication
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.1,
    'font.size': 10,
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'],
    'axes.linewidth': 1.0,
    'axes.spines.top': False,
    'axes.spines.right': False,
    'axes.labelsize': 11,
    'axes.titlesize': 12,
    'xtick.labelsize': 9,
    'ytick.labelsize': 9,
    'legend.fontsize': 9,
    'legend.frameon': False,
    'legend.loc': 'best',
    'lines.linewidth': 1.5,
    'lines.markersize': 4,
    'patch.linewidth': 0.5,
    'grid.alpha': 0.3,
    'grid.linewidth': 0.5,
    'xtick.major.width': 1.0,
    'ytick.major.width': 1.0,
    'xtick.minor.width': 0.5,
    'ytick.minor.width': 0.5,
})

NATURE_COLORS = {
    'black': '#000000',
    'dark_gray': '#404040',
    'medium_gray': '#808080',
    'light_gray': '#C0C0C0',
    'white': '#FFFFFF'
}

plt.rcParams['image.cmap'] = 'gray'


# Use current directory as root
project_root = Path.cwd()
data_dir = project_root / "data" / "raw"
output_dir = project_root / "results"
models_dir = project_root / "models"
splits_dir = project_root / "data_splits"
prebatched_dir = project_root / "prebatched_data"

for dir_path in [output_dir, models_dir, splits_dir, prebatched_dir]:
    dir_path.mkdir(parents=True, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


## 2. Data Loading and Preprocessing

Load pathway scores, drug response data, and cell line metadata from GDSC2 dataset.


In [None]:
pathway_scores_raw = pd.read_csv(data_dir / "cell_ge.csv", index_col=0)
gdsc2_df = pd.read_excel(data_dir / "GDSC2 Fitted Dose Response Oct 27 2023.xlsx")
model_df = pd.read_csv(data_dir / "DepMap Model Data.csv")
drug_smiles = pd.read_csv(data_dir / "drugs_with_smiles.csv")

rmse_col = [col for col in gdsc2_df.columns if 'RMSE' in col.upper()][0]
gdsc2_filtered = gdsc2_df[gdsc2_df[rmse_col] < 0.3].copy()

drug_response = gdsc2_filtered[['DRUG_NAME', 'CELL_LINE_NAME', 'LN_IC50', 'COSMIC_ID']].copy()
drug_response.columns = ['DrugName', 'CellLineName', 'LN_IC50', 'COSMICID']

pathway_names = pathway_scores_raw.index.tolist()
pathway_data = pathway_scores_raw.T.reset_index()
pathway_data.columns = ['CellLineName'] + pathway_names

cell_name_to_modelid = dict(zip(
    model_df['StrippedCellLineName'].str.upper().str.replace('-', '').str.replace('_', ''),
    model_df['ModelID']
))
cosmic_to_modelid = model_df.drop_duplicates(subset='COSMICID', keep='first').set_index('COSMICID')['ModelID'].to_dict()

pathway_data['ModelID'] = pathway_data['CellLineName'].apply(
    lambda x: cell_name_to_modelid.get(str(x).upper().replace('-', '').replace('_', ''), None)
)
pathway_data = pathway_data[pathway_data['ModelID'].notna()].copy()

drug_response['ModelID'] = drug_response['COSMICID'].apply(lambda x: cosmic_to_modelid.get(x, None))
unmapped = drug_response[drug_response['ModelID'].isna()]
if len(unmapped) > 0:
    drug_response.loc[drug_response['ModelID'].isna(), 'ModelID'] = drug_response.loc[drug_response['ModelID'].isna(), 'CellLineName'].apply(
        lambda x: cell_name_to_modelid.get(str(x).upper().replace('-', '').replace('_', ''), None)
    )
drug_response = drug_response[drug_response['ModelID'].notna()].drop(columns=['COSMICID'])

pan_cancer_pathway = drug_response.merge(
    pathway_data[['ModelID'] + pathway_names],
    on='ModelID',
    how='inner'
)

drug_smiles_renamed = drug_smiles.rename(columns={'DRUG_NAME': 'DrugName'})
pan_cancer_pathway = pan_cancer_pathway.merge(
    drug_smiles_renamed[['DrugName', 'SMILES']],
    on='DrugName',
    how='inner'
)

pan_cancer_pathway = pan_cancer_pathway[pan_cancer_pathway['LN_IC50'].notna()].copy()

pan_cancer_pathway = pan_cancer_pathway.merge(
    model_df[['ModelID', 'StrippedCellLineName', 'OncotreeLineage', 'OncotreePrimaryDisease']],
    on='ModelID',
    how='left'
)

drug_name_col = 'DrugName'
ln_ic50_col = 'LN_IC50'
pathway_cols = pathway_names
drugs_with_smiles = drug_smiles_renamed[['DrugName', 'SMILES']].copy()



### 2.1 Drug Molecular Graph Construction

Convert SMILES strings to molecular graphs using RDKit. Each drug is represented as a graph where nodes are atoms and edges are bonds.


In [None]:
# SMILES to Graph Conversion

def get_atom_features(atom):
    """
    Extract atom features for GNN.
    
    Args:
        atom: RDKit atom object
        
    Returns:
        List of atom features
    """
    features = [
        atom.GetAtomicNum(),
        atom.GetDegree(),
        atom.GetFormalCharge(),
        atom.GetHybridization().real,
        atom.GetIsAromatic(),
        atom.GetTotalNumHs(),
        atom.GetNumRadicalElectrons(),
        atom.IsInRing(),
        atom.GetChiralTag().real,
    ]
    return features

def get_bond_features(bond):
    """
    Extract bond features for GNN.
    
    Args:
        bond: RDKit bond object
        
    Returns:
        List of bond features
    """
    features = [
        bond.GetBondTypeAsDouble(),
        bond.GetIsConjugated(),
        bond.IsInRing(),
        bond.GetStereo().real,
    ]
    return features

def smiles_to_graph(smiles_string):
    """
    Convert SMILES string to PyTorch Geometric graph.
    
    Args:
        smiles_string: SMILES representation of molecule
        
    Returns:
        torch_geometric.data.Data object or None if invalid
    """
    try:
        mol = Chem.MolFromSmiles(smiles_string)
        if mol is None:
            return None
        
        mol = Chem.AddHs(mol)
        
        # Extract node features
        node_features = []
        for atom in mol.GetAtoms():
            node_features.append(get_atom_features(atom))
        
        if len(node_features) == 0:
            return None
            
        x = torch.tensor(node_features, dtype=torch.float)
        
        # Extract edge indices and features
        edge_indices = []
        edge_features = []
        
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            
            edge_indices.append([i, j])
            edge_indices.append([j, i])
            
            bond_feats = get_bond_features(bond)
            edge_features.append(bond_feats)
            edge_features.append(bond_feats)
        
        if len(edge_indices) == 0:
            edge_indices = [[0, 0]]
            edge_features = [[0.0] * 4]
        
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_features, dtype=torch.float)
        
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        
    except Exception:
        return None

In [None]:
# Preprocess and Cache Drug Graphs

def preprocess_drugs(drugs_df, drug_name_col, smiles_col, save_path):
    """
    Convert all drugs to graphs and cache.
    
    Args:
        drugs_df: DataFrame with drug names and SMILES
        drug_name_col: Column name for drug names
        smiles_col: Column name for SMILES strings
        save_path: Path to save cached graphs
        
    Returns:
        Dictionary mapping drug names to graph data
    """
    drug_graphs = {}
    failed_drugs = []
    
    for idx, row in tqdm(drugs_df.iterrows(), total=len(drugs_df), desc="Converting SMILES"):
        drug_name = row[drug_name_col]
        smiles = row[smiles_col]
        
        graph = smiles_to_graph(smiles)
        
        if graph is not None:
            drug_graphs[drug_name] = {
                'drug_name': drug_name,
                'smiles': smiles,
                'graph_data': graph,
                'node_dim': graph.x.shape[1],
                'edge_dim': graph.edge_attr.shape[1] if graph.edge_attr is not None else 0,
                'num_atoms': graph.x.shape[0],
                'num_bonds': graph.edge_index.shape[1]
            }
        else:
            failed_drugs.append(drug_name)
    
    with open(save_path, 'wb') as f:
        pickle.dump(drug_graphs, f)
    
    
    return drug_graphs

drug_graphs_path = project_root / "data" / "processed" / "drug_graphs.pkl"
drug_graphs = preprocess_drugs(drugs_with_smiles, drugs_with_smiles.columns[0], 'SMILES', drug_graphs_path)

# Filter pan_cancer_pathway to only include drugs that successfully converted to graphs
valid_drugs = set(drug_graphs.keys())
before_count = len(pan_cancer_pathway)
before_drugs = set(pan_cancer_pathway['DrugName'].unique())
pan_cancer_pathway = pan_cancer_pathway[pan_cancer_pathway['DrugName'].isin(valid_drugs)].copy()
after_count = len(pan_cancer_pathway)
removed_drugs = before_drugs - valid_drugs
if len(removed_drugs) > 0:
    pass


### 2.2 Dataset Creation

Create pan-cancer, breast cancer, and TNBC-specific datasets. TNBC dataset excludes HER2+ and ER+ cell lines.


In [None]:
breast_cancer_pathway = pan_cancer_pathway[
    pan_cancer_pathway['OncotreeLineage'] == 'Breast'
].copy()

her2_positive = ['SKBR3', 'HCC1419', 'HCC1954', 'HCC1569', 'AU565', 'JIMT1', 'BT474', 
                 'MDA-MB-453', 'UACC812', 'ZR7530', 'HCC2218', 'MDA-MB-361', 'EFM19']
er_positive = ['MCF7', 'T47D', 'ZR751', 'BT483', 'CAMA1', 'HCC1428', 'MDA-MB-415', 
               'MDA-MB-175VII', 'MDA-MB-134VI']

exclude_cells = set()
for cell_name in her2_positive + er_positive:
    matching = breast_cancer_pathway[
        breast_cancer_pathway['StrippedCellLineName'].str.upper().str.replace('-', '').str.replace('_', '') == 
        cell_name.upper().replace('-', '').replace('_', '')
    ]['ModelID'].unique()
    exclude_cells.update(matching)

tnbc_pathway = breast_cancer_pathway[~breast_cancer_pathway['ModelID'].isin(exclude_cells)].copy()


### 2.3 Data Splitting Strategy

**Cell-line based splitting:** Ensures no data leakage by splitting at the cell-line level rather than sample level. This prevents the same cell line from appearing in multiple splits.

Split ratios: 80% train, 10% validation, 10% test (by cell lines)


In [None]:
# PHASE 1.4: Data Splitting Configuration
# Set split_type to 'cell_line' for cell-line based split (no data leakage)
# Set split_type to 'random' for random split (standard train/test split)
SPLIT_TYPE = 'cell_line'  # Options: 'cell_line' or 'random'

def cell_line_split(dataframe, train_ratio=0.8, val_ratio=0.1, random_seed=42):
    """
    Split by cell lines (GPDRP method).
    No cell line appears in multiple splits.
    
    Args:
        dataframe: DataFrame with ModelID column
        train_ratio: Proportion of cell lines for training
        val_ratio: Proportion of cell lines for validation
        random_seed: Random seed for reproducibility
        
    Returns:
        train_idx, val_idx, test_idx: Index arrays for each split
    """
    np.random.seed(random_seed)
    
    cell_lines = dataframe['ModelID'].unique()
    n_cells = len(cell_lines)
    
    # Shuffle cell lines
    shuffled_cells = np.random.permutation(cell_lines)
    
    # Split cell lines
    n_train = int(n_cells * train_ratio)
    n_val = int(n_cells * val_ratio)
    
    train_cells = set(shuffled_cells[:n_train])
    val_cells = set(shuffled_cells[n_train:n_train + n_val])
    test_cells = set(shuffled_cells[n_train + n_val:])
    
    # Assign pairs based on cell line
    train_idx = dataframe[dataframe['ModelID'].isin(train_cells)].index.values
    val_idx = dataframe[dataframe['ModelID'].isin(val_cells)].index.values
    test_idx = dataframe[dataframe['ModelID'].isin(test_cells)].index.values
    
    # Verify no overlap
    train_cell_set = set(dataframe.loc[train_idx, 'ModelID'].unique())
    val_cell_set = set(dataframe.loc[val_idx, 'ModelID'].unique())
    test_cell_set = set(dataframe.loc[test_idx, 'ModelID'].unique())
    
    assert len(train_cell_set & val_cell_set) == 0, "Cell line overlap between train and val!"
    assert len(train_cell_set & test_cell_set) == 0, "Cell line overlap between train and test!"
    assert len(val_cell_set & test_cell_set) == 0, "Cell line overlap between val and test!"
    
    return train_idx, val_idx, test_idx

def random_split(dataframe, train_ratio=0.8, val_ratio=0.1, random_seed=42):
    """
    Random split of samples (standard train/test split).
    Note: This allows data leakage as the same cell line can appear in multiple splits.
    
    Args:
        dataframe: DataFrame with samples
        train_ratio: Proportion of samples for training
        val_ratio: Proportion of samples for validation
        random_seed: Random seed for reproducibility
        
    Returns:
        train_idx, val_idx, test_idx: Index arrays for each split
    """
    np.random.seed(random_seed)
    
    indices = np.arange(len(dataframe))
    np.random.shuffle(indices)
    
    n_train = int(len(indices) * train_ratio)
    n_val = int(len(indices) * val_ratio)
    
    train_idx = dataframe.index[indices[:n_train]].values
    val_idx = dataframe.index[indices[n_train:n_train + n_val]].values
    test_idx = dataframe.index[indices[n_train + n_val:]].values
    
    return train_idx, val_idx, test_idx

if SPLIT_TYPE == 'cell_line':
    pan_train_idx, pan_val_idx, pan_test_idx = cell_line_split(pan_cancer_pathway, random_seed=42)
elif SPLIT_TYPE == 'random':
    pan_train_idx, pan_val_idx, pan_test_idx = random_split(pan_cancer_pathway, random_seed=42)
else:
    raise ValueError(f"Unknown SPLIT_TYPE: {SPLIT_TYPE}. Must be 'cell_line' or 'random'")

# Extract cell lines for each global split (for cell-line split)
if SPLIT_TYPE == 'cell_line':
    pan_train_cells = set(pan_cancer_pathway.loc[pan_train_idx, 'ModelID'].unique())
    pan_val_cells = set(pan_cancer_pathway.loc[pan_val_idx, 'ModelID'].unique())
    pan_test_cells = set(pan_cancer_pathway.loc[pan_test_idx, 'ModelID'].unique())
    
    # Filter breast cancer data to only include cell lines from appropriate global split
    breast_train_data = breast_cancer_pathway[breast_cancer_pathway['ModelID'].isin(pan_train_cells)].copy()
    breast_val_data = breast_cancer_pathway[breast_cancer_pathway['ModelID'].isin(pan_val_cells)].copy()
    breast_test_data = breast_cancer_pathway[breast_cancer_pathway['ModelID'].isin(pan_test_cells)].copy()
    
    # Filter TNBC data similarly
    tnbc_train_data = tnbc_pathway[tnbc_pathway['ModelID'].isin(pan_train_cells)].copy()
    tnbc_val_data = tnbc_pathway[tnbc_pathway['ModelID'].isin(pan_val_cells)].copy()
    tnbc_test_data = tnbc_pathway[tnbc_pathway['ModelID'].isin(pan_test_cells)].copy()
    
    breast_train_idx = breast_train_data.index.values
    breast_val_idx = breast_val_data.index.values
    breast_test_idx = breast_test_data.index.values
    
    tnbc_train_idx = tnbc_train_data.index.values
    tnbc_val_idx = tnbc_val_data.index.values
    tnbc_test_idx = tnbc_test_data.index.values
    
    # Verify no data leakage: check that cell lines don't appear in conflicting splits
    breast_train_cells = set(breast_train_data['ModelID'].unique())
    breast_val_cells = set(breast_val_data['ModelID'].unique())
    breast_test_cells = set(breast_test_data['ModelID'].unique())
    
    tnbc_train_cells = set(tnbc_train_data['ModelID'].unique())
    tnbc_val_cells = set(tnbc_val_data['ModelID'].unique())
    tnbc_test_cells = set(tnbc_test_data['ModelID'].unique())
    
    # Verify: breast/TNBC train cells should only be in pan train, not pan val/test
    assert len(breast_train_cells & pan_val_cells) == 0, "Data leakage: breast train cells in pan val!"
    assert len(breast_train_cells & pan_test_cells) == 0, "Data leakage: breast train cells in pan test!"
    assert len(tnbc_train_cells & pan_val_cells) == 0, "Data leakage: TNBC train cells in pan val!"
    assert len(tnbc_train_cells & pan_test_cells) == 0, "Data leakage: TNBC train cells in pan test!"

    # Phase 2 (Breast) test vs Phase 1 (Pan) train/val
    assert len(breast_test_cells & pan_train_cells) == 0, "Data leakage: breast test cells in pan train!"
    assert len(breast_test_cells & pan_val_cells) == 0, "Data leakage: breast test cells in pan val!"
    assert len(breast_val_cells & pan_test_cells) == 0, "Data leakage: breast val cells in pan test!"

    # Phase 3 (TNBC) test vs Phase 1 (Pan) train/val
    assert len(tnbc_test_cells & pan_train_cells) == 0, "Data leakage: TNBC test cells in pan train!"
    assert len(tnbc_test_cells & pan_val_cells) == 0, "Data leakage: TNBC test cells in pan val!"
    assert len(tnbc_val_cells & pan_test_cells) == 0, "Data leakage: TNBC val cells in pan test!"

    # Phase 3 (TNBC) test vs Phase 2 (Breast) train/val
    assert len(tnbc_test_cells & breast_train_cells) == 0, "Data leakage: TNBC test cells in breast train!"
    assert len(tnbc_test_cells & breast_val_cells) == 0, "Data leakage: TNBC test cells in breast val!"

    # Phase 2 (Breast) test vs Phase 2 train/val (internal independence)
    assert len(breast_test_cells & breast_train_cells) == 0, "Data leakage: breast test cells in breast train!"
    assert len(breast_test_cells & breast_val_cells) == 0, "Data leakage: breast test cells in breast val!"

    # Phase 3 (TNBC) test vs Phase 3 train/val (internal independence)
    assert len(tnbc_test_cells & tnbc_train_cells) == 0, "Data leakage: TNBC test cells in TNBC train!"
    assert len(tnbc_test_cells & tnbc_val_cells) == 0, "Data leakage: TNBC test cells in TNBC val!"

    assert len(tnbc_test_cells & breast_test_cells) == len(tnbc_test_cells), "TNBC test cells must be subset of breast test cells"

    # Verify test sets are disjoint from each other (should be same cells, but check)
    assert len(pan_test_cells & breast_test_cells) == len(breast_test_cells), "Breast test cells must be subset of pan test cells"
    assert len(pan_test_cells & tnbc_test_cells) == len(tnbc_test_cells), "TNBC test cells must be subset of pan test cells"

    
elif SPLIT_TYPE == 'random':
    # For random split, directly split breast cancer and TNBC data
    breast_train_idx, breast_val_idx, breast_test_idx = random_split(breast_cancer_pathway, random_seed=42)
    tnbc_train_idx, tnbc_val_idx, tnbc_test_idx = random_split(tnbc_pathway, random_seed=42)
    

# Save split indices to ensure reproducibility during evaluation
# splits_dir already defined in Cell 1
splits_file = splits_dir / f"data_splits_{SPLIT_TYPE}.json"
splits_data = {
    'split_type': SPLIT_TYPE,
    'pan_cancer': {
        'train': [int(x) for x in pan_train_idx],
        'val': [int(x) for x in pan_val_idx],
        'test': [int(x) for x in pan_test_idx]
    },
    'breast_cancer': {
        'train': [int(x) for x in breast_train_idx],
        'val': [int(x) for x in breast_val_idx],
        'test': [int(x) for x in breast_test_idx]
    },
    'tnbc': {
        'train': [int(x) for x in tnbc_train_idx],
        'val': [int(x) for x in tnbc_val_idx],
        'test': [int(x) for x in tnbc_test_idx]
    }
}

with open(splits_file, 'w') as f:
    json.dump(splits_data, f, indent=2)



## 3. Model Architecture

### 3.1 Dataset Class

Custom PyTorch Dataset class that handles pathway-based features with per-sample z-score normalization.


In [None]:

class DrugResponsePathwayDataset(Dataset):
    """
    Dataset class for drug response prediction using GSVA pathway scores.
    Performs per-sample z-score normalization of pathway scores.
    """
    
    def __init__(self, dataframe, drug_graphs_dict, drug_col='DRUG_NAME', pathway_cols=None):
        """
        Initialize pathway-based dataset.
        
        Args:
            dataframe: DataFrame with pathway scores, ModelID, DRUG_NAME, LN_IC50
            drug_graphs_dict: Dictionary mapping drug names to graph data
            drug_col: Column name for drug names
            pathway_cols: List of pathway column names (if None, will infer)
        """
        self.original_data = dataframe.copy()
        self.data = dataframe.reset_index(drop=True)
        self.drug_graphs = drug_graphs_dict
        self.drug_col = drug_col
        
        # Identify pathway columns
        if pathway_cols is not None:
            self.pathway_cols = [c for c in pathway_cols if c in dataframe.columns]
        else:
            # Infer pathway columns: exclude metadata columns
            exclude_cols = ['ModelID', 'COSMICID', 'StrippedCellLineName', 'OncotreeLineage', 
                           'OncotreePrimaryDisease', drug_col, 'SMILES', 'LN_IC50', 'CellLineName']
            # Only include columns that are numeric and not in exclude list
            numeric_cols = dataframe.select_dtypes(include=[np.number]).columns.tolist()
            self.pathway_cols = [c for c in numeric_cols if c not in exclude_cols]
        
        # Verify pathway columns are numeric
        for col in self.pathway_cols:
            if not pd.api.types.is_numeric_dtype(dataframe[col]):
                raise ValueError(f"Pathway column '{col}' is not numeric")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        drug_name = row[self.drug_col]
        
        if drug_name not in self.drug_graphs:
            raise ValueError(f"Drug {drug_name} not found in drug_graphs")
        drug_graph = self.drug_graphs[drug_name]['graph_data']
        
        pathway_scores = row[self.pathway_cols].values.astype(np.float32)
        
        # Per-sample z-score normalization
        mean = pathway_scores.mean()
        std = pathway_scores.std()
        if std > 1e-8:
            pathway_scores = (pathway_scores - mean) / std
        else:
            pathway_scores = np.zeros_like(pathway_scores)
        
        pathway_tensor = torch.tensor(pathway_scores, dtype=torch.float32)
        ic50 = torch.tensor([row['LN_IC50']], dtype=torch.float32)
        
        return {
            'drug_graph': drug_graph,
            'cell_expr': pathway_tensor,  # Keep same key name for compatibility
            'ic50': ic50,
            'drug_name': drug_name,
            'cell_id': row['ModelID']
        }


### 3.2 Drug Encoder

Graph Neural Network encoder using TransformerConv layers to encode molecular structures.


In [None]:
# Drug Encoder Architecture

class DrugEncoder(nn.Module):
    """
    Graph Neural Network encoder for molecular SMILES structures.
    Uses TransformerConv layers to generate 256-dim drug embeddings.
    """
    
    def __init__(self, node_feature_dim=9, edge_feature_dim=4, hidden_dim=256, dropout=0.3):
        """
        Initialize DrugEncoder.
        
        Args:
            node_feature_dim: Number of atom features
            edge_feature_dim: Number of bond features
            hidden_dim: Hidden layer dimension (output is 256)
            dropout: Dropout rate
        """
        super(DrugEncoder, self).__init__()
        
        # Device selection
        if torch.backends.mps.is_available():
            self.device = torch.device('mps')
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        
        # TransformerConv layers
        self.conv1 = TransformerConv(node_feature_dim, 128, heads=4, dropout=dropout, edge_dim=edge_feature_dim)
        self.bn1 = nn.BatchNorm1d(128 * 4)
        
        self.conv2 = TransformerConv(128 * 4, 256, heads=8, dropout=dropout, edge_dim=edge_feature_dim)
        self.bn2 = nn.BatchNorm1d(256 * 8)
        
        self.conv3 = TransformerConv(256 * 8, 256, heads=8, dropout=dropout, edge_dim=edge_feature_dim)
        self.bn3 = nn.BatchNorm1d(256 * 8)
        
        self.conv4 = TransformerConv(256 * 8, 256, heads=4, concat=False, dropout=dropout, edge_dim=edge_feature_dim)
        self.bn4 = nn.BatchNorm1d(256)
        
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        
        # Attention weights for pooling
        self.attention_weights = nn.Linear(256, 1)
        
        self.to(self.device)
    
    def forward(self, data):
        """
        Forward pass through drug encoder.
        
        Args:
            data: PyTorch Geometric Data/Batch object
            
        Returns:
            Drug embeddings of shape (batch_size, 256)
        """
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = x.to(self.device, non_blocking=True)
        edge_index = edge_index.to(self.device, non_blocking=True)
        edge_attr = edge_attr.to(self.device, non_blocking=True)
        batch = batch.to(self.device, non_blocking=True)
        
        # Layer 1
        x = self.conv1(x, edge_index, edge_attr)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # Layer 2
        x = self.conv2(x, edge_index, edge_attr)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # Layer 3
        x = self.conv3(x, edge_index, edge_attr)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # Layer 4
        x = self.conv4(x, edge_index, edge_attr)
        x = self.bn4(x)
        x = self.relu(x)
        
        # Attention-weighted pooling
        attention_scores = self.attention_weights(x)
        attention_scores = torch.softmax(attention_scores, dim=0)
        
        x_weighted = x * attention_scores
        x_mean = global_mean_pool(x_weighted, batch)
        x_max = global_max_pool(x, batch)
        
        embedding = (x_mean + x_max) / 2
        
        return embedding

# Cell Encoder Architecture

class CellEncoder(nn.Module):
    """
    Feedforward neural network to encode cell line pathway scores.
    Uses skip connection to preserve direct pathway signal.
    """
    
    def __init__(self, input_dim=1329, hidden_dim=512, output_dim=256, dropout1=0.4, dropout2=0.3):
        """
        Initialize CellEncoder.
        
        Args:
            input_dim: Number of pathway features
            hidden_dim: First hidden layer dimension
            output_dim: Output embedding dimension
            dropout1: Dropout rate for first layer
            dropout2: Dropout rate for second layer
        """
        super(CellEncoder, self).__init__()
        
        # Device selection
        if torch.backends.mps.is_available():
            self.device = torch.device('mps')
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        
        # Main pathway
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout1 = nn.Dropout(dropout1)
        
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.bn2 = nn.BatchNorm1d(output_dim)
        self.dropout2 = nn.Dropout(dropout2)
        
        self.skip = nn.Linear(input_dim, output_dim)
        
        self.relu = nn.ReLU()
        
        self.to(self.device)
    
    def forward(self, x):
        """
        Forward pass through cell encoder.
        
        Args:
            x: Pathway scores tensor of shape (batch_size, input_dim)
            
        Returns:
            Cell embeddings of shape (batch_size, 256)
        """
        x = x.to(self.device, non_blocking=True)
        
        # Main pathway
        out = self.fc1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout1(out)
        
        out = self.fc2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.dropout2(out)
        
        skip_out = self.skip(x)
        
        # Residual connection
        embedding = out + skip_out
        
        return embedding


### 3.3 Full Model Architecture

**DrugResponsePathwayGNN** combines:
- Drug encoder: GNN for molecular structure
- Cell encoder: MLP for pathway features
- Multi-task learning: IC50 regression + binary classification + reconstruction


In [None]:
# PHASE 4: DrugResponsePathwayGNN Model

class DrugResponsePathwayGNN(nn.Module):
    """
    GNN model for drug response prediction using pathway activity scores.
    Uses 1,329 pathway features instead of raw gene expression.
    """
    
    def __init__(self, drug_node_dim=9, drug_edge_dim=4, cell_input_dim=1329, 
                 hidden_dim=256, dropout=0.3):
        """
        Initialize DrugResponsePathwayGNN.
        
        Args:
            drug_node_dim: Number of atom features (9)
            drug_edge_dim: Number of bond features (4)
            cell_input_dim: Number of pathway features (1329)
            hidden_dim: Embedding dimension (256)
            dropout: Dropout rate
        """
        super(DrugResponsePathwayGNN, self).__init__()
        
        # Device selection
        if torch.backends.mps.is_available():
            self.device = torch.device('mps')
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        
        # Drug encoder (same as before)
        self.drug_encoder = DrugEncoder(
            node_feature_dim=drug_node_dim,
            edge_feature_dim=drug_edge_dim,
            hidden_dim=hidden_dim,
            dropout=dropout
        )
        
        # Cell encoder for pathway scores (input_dim=1329)
        self.cell_encoder = CellEncoder(
            input_dim=cell_input_dim,
            hidden_dim=512,
            output_dim=hidden_dim,
            dropout1=0.4,
            dropout2=0.3
        )
        
        # Attention mechanism for integration
        self.attention_query = nn.Linear(hidden_dim, hidden_dim)
        self.attention_key = nn.Linear(hidden_dim, hidden_dim)
        self.attention_value = nn.Linear(hidden_dim, hidden_dim)
        
        # Combined embedding dimension
        combined_dim = hidden_dim * 2  # 512
        
        # Head A: IC50 Regression (primary task)
        self.ic50_head = nn.Sequential(
            nn.Linear(combined_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )
        
        # Head B: Sensitivity Classification (auxiliary task)
        self.classification_head = nn.Sequential(
            nn.Linear(combined_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )
        
        # Head C: Similarity Reconstruction (auxiliary task)
        self.reconstruction_head = nn.Sequential(
            nn.Linear(combined_dim, 256),
            nn.ReLU(),
            nn.Linear(256, combined_dim)
        )
        
        self.to(self.device)
    
    def integrate_with_attention(self, drug_emb, cell_emb):
        """
        Integrate drug and cell embeddings using attention mechanism.
        
        Args:
            drug_emb: Drug embeddings (batch_size, 256)
            cell_emb: Cell embeddings (batch_size, 256)
            
        Returns:
            Combined embedding (batch_size, 512)
        """
        # Drug embedding as query
        query = self.attention_query(drug_emb)  # (batch, 256)
        key = self.attention_key(cell_emb)      # (batch, 256)
        value = self.attention_value(cell_emb)  # (batch, 256)
        
        # Compute attention scores
        attention_scores = torch.matmul(query.unsqueeze(1), key.unsqueeze(2))  # (batch, 1, 1)
        attention_scores = attention_scores / (drug_emb.size(-1) ** 0.5)
        attention_weights = torch.softmax(attention_scores, dim=-1)
        
        attended_cell = attention_weights.squeeze(-1) * value  # (batch, 1) * (batch, 256) -> (batch, 256)
        
        # Concatenate drug and attended cell embeddings
        combined = torch.cat([drug_emb, attended_cell], dim=1)  # (batch, 512)
        
        return combined
    
    def forward(self, drug_batch, cell_batch, return_embeddings=False):
        """
        Forward pass through complete model.
        
        Args:
            drug_batch: PyTorch Geometric batch of molecular graphs
            cell_batch: Pathway scores tensor (batch_size, 1329)
            return_embeddings: Whether to return intermediate embeddings
            
        Returns:
            Dictionary with predictions and optionally embeddings
        """
        # Encode drug and cell
        drug_emb = self.drug_encoder(drug_batch)  # (batch_size, 256)
        cell_emb = self.cell_encoder(cell_batch)   # (batch_size, 256)
        
        # Integrate with attention
        combined = self.integrate_with_attention(drug_emb, cell_emb)  # (batch_size, 512)
        
        # Multi-task predictions
        ic50_pred = self.ic50_head(combined)
        class_pred = self.classification_head(combined)
        recon_pred = self.reconstruction_head(combined)
        
        results = {
            'ic50': ic50_pred,
            'classification': class_pred,
            'reconstruction': recon_pred
        }
        
        if return_embeddings:
            results['embeddings'] = {
                'drug': drug_emb,
                'cell': cell_emb,
                'combined': combined
            }
        
        return results


## 4. Cross-Validation Framework

### 4.1 Cell-Line Based K-Fold Splitting

Implements k-fold cross-validation at the cell-line level to maintain data leakage prevention while enabling robust performance estimation.


In [None]:
# PHASE 5.5: Cross-Validation Support

def cell_line_kfold_split(dataframe, n_splits=5, random_seed=42):
    """
    Create k-fold splits at the cell-line level (no data leakage).
    
    Args:
        dataframe: DataFrame with ModelID column
        n_splits: Number of folds
        random_seed: Random seed for reproducibility
        
    Returns:
        List of (train_idx, val_idx) tuples for each fold
    """
    np.random.seed(random_seed)
    
    cell_lines = dataframe['ModelID'].unique()
    n_cells = len(cell_lines)
    
    # Shuffle cell lines
    shuffled_cells = np.random.permutation(cell_lines)
    
    kf = KFold(n_splits=n_splits, shuffle=False)
    fold_splits = []
    
    cell_line_indices = np.arange(n_cells)
    
    for fold_idx, (train_cell_idx, val_cell_idx) in enumerate(kf.split(cell_line_indices)):
        train_cells = set(shuffled_cells[train_cell_idx])
        val_cells = set(shuffled_cells[val_cell_idx])
        
        train_idx = dataframe[dataframe['ModelID'].isin(train_cells)].index.values
        val_idx = dataframe[dataframe['ModelID'].isin(val_cells)].index.values
        
        fold_splits.append((train_idx, val_idx))
    
    return fold_splits

def train_phase_pathway_cv(
    dataset,
    train_idx_all,
    val_idx_all,
    test_idx,
    model_class,
    model_kwargs,
    device,
    phase_name,
    n_folds=5,
    num_epochs=50,
    lr=1e-3,
    weight_decay=1e-4,
    patience=10,
    scheduler_patience=3,
    batch_size=256,
    use_prebatched=False,  # CV folds have different splits - use regular batching
    random_seed=42
):
    """
    Train a phase with k-fold cross-validation.
    
    Args:
        dataset: DrugResponsePathwayDataset
        train_idx_all: All training indices (will be split into CV folds)
        val_idx_all: All validation indices (will be split into CV folds)
        test_idx: Test indices (held out completely - NOT used during CV)
        model_class: Model class to instantiate
        model_kwargs: Keyword arguments for model initialization
        device: Device to train on
        phase_name: Name of the phase
        n_folds: Number of CV folds
        num_epochs: Maximum epochs per fold
        lr: Learning rate
        weight_decay: Weight decay
        patience: Early stopping patience
        scheduler_patience: LR scheduler patience
        batch_size: Batch size
        use_prebatched: Whether to use pre-batched data
        random_seed: Random seed for reproducibility
        
    Returns:
        Dictionary with CV results, aggregated metrics, and best model
    """
    
    # Combine train and val for CV splitting
    train_val_data = dataset.original_data.loc[np.concatenate([train_idx_all, val_idx_all])]
    
    cv_splits = cell_line_kfold_split(train_val_data, n_splits=n_folds, random_seed=random_seed)
    
    cv_results = []
    fold_models = []
    fold_checkpoints = []
    
    for fold_idx, (fold_train_idx, fold_val_idx) in enumerate(cv_splits):
        
        fold_loaders = create_pathway_dataloaders(
            dataset, fold_train_idx, fold_val_idx, test_idx,
            batch_size=batch_size,
            use_prebatched=False,  # Regular batching for CV folds
            phase_name=f"{phase_name.lower().replace(' ', '_')}_fold{fold_idx+1}",
            output_dir=output_dir
        )
        
        model = model_class(**model_kwargs).to(device)
        
        fold_checkpoint = models_dir / f"{phase_name.lower().replace(' ', '_')}_fold{fold_idx+1}_{SPLIT_TYPE}.pt"
        fold_checkpoints.append(fold_checkpoint)
        
        # Train model
        model, history = train_phase_pathway(
            model=model,
            train_loader=fold_loaders['train'],
            val_loader=fold_loaders['val'],
            device=device,
            phase_name=f"{phase_name} - Fold {fold_idx+1}",
            num_epochs=num_epochs,
            lr=lr,
            weight_decay=weight_decay,
            patience=patience,
            scheduler_patience=scheduler_patience,
            checkpoint_path=fold_checkpoint,
            prebatched_datasets=fold_loaders.get('datasets', None)
        )
        
        val_results = evaluate_model(model, fold_loaders['val'], device, is_prebatched=False)
        val_metrics = val_results['metrics']
        
        
        fold_result = {
            'fold': fold_idx + 1,
            'val_metrics': val_metrics,
            'history': history,
            'checkpoint': str(fold_checkpoint)
        }
        cv_results.append(fold_result)
        fold_models.append(model)
        
    
    # Aggregate results across folds
    
    # Collect metrics across folds
    val_metrics_list = [r['val_metrics'] for r in cv_results]
    
    # Calculate statistics
    def calc_stats(metrics_list, metric_name):
        values = [m[metric_name] for m in metrics_list]
        mean_val = np.mean(values)
        std_val = np.std(values)
        sem = std_val / np.sqrt(len(values))  # Standard error of the mean
        # 95% confidence interval (t-distribution)
        ci_95 = stats.t.interval(0.95, len(values)-1, loc=mean_val, scale=sem)
        return {
            'mean': mean_val,
            'std': std_val,
            'sem': sem,
            'ci_95_lower': ci_95[0],
            'ci_95_upper': ci_95[1],
            'min': np.min(values),
            'max': np.max(values),
            'values': values
        }
    
    metric_names = ['r2', 'pearson', 'spearman', 'rmse', 'mae']
    aggregated = {
        'val': {name: calc_stats(val_metrics_list, name) for name in metric_names},
    }
    
    for metric in ['r2', 'pearson', 'rmse']:
        stats_dict = aggregated['val'][metric]
        print(f"  {metric.upper():10s}: {stats_dict['mean']:.4f} ± {stats_dict['std']:.4f} "
              f"[95% CI: {stats_dict['ci_95_lower']:.4f}, {stats_dict['ci_95_upper']:.4f}]")
    
    # Find best fold (by validation R²)
    best_fold_idx = np.argmax([r['val_metrics']['r2'] for r in cv_results])
    best_model = fold_models[best_fold_idx]
    best_checkpoint = fold_checkpoints[best_fold_idx]
    
    
    cv_results_file = output_dir / f"{phase_name.lower().replace(' ', '_')}_cv_results_{SPLIT_TYPE}.pkl"
    with open(cv_results_file, 'wb') as f:
        pickle.dump({
            'phase_name': phase_name,
            'n_folds': n_folds,
            'cv_results': cv_results,
            'aggregated': aggregated,
            'best_fold': best_fold_idx + 1,
            'best_checkpoint': str(best_checkpoint)
        }, f)
    
    
    return {
        'best_model': best_model,
        'best_checkpoint': best_checkpoint,
        'cv_results': cv_results,
        'aggregated': aggregated,
        'all_models': fold_models
    }


### 4.2 Training and Evaluation Functions

Training functions with early stopping, learning rate scheduling, and comprehensive evaluation metrics.


In [None]:
# PHASE 5: Training Functions for Pathway-Based Model

def collate_fn(batch):
    """Custom collate function for drug graphs."""
    drug_graphs = [item['drug_graph'] for item in batch]
    cell_exprs = torch.stack([item['cell_expr'] for item in batch])
    ic50s = torch.stack([item['ic50'] for item in batch])
    drug_batch = Batch.from_data_list(drug_graphs)
    return {
        'drug_batch': drug_batch,
        'cell_batch': cell_exprs,
        'ic50': ic50s
    }

def prebatched_collate_fn(batch):
    """
    Collate function for pre-batched data.
    Since batch_size=1, each item is already a complete batch, just return it.
    """
    return batch[0]  # batch_size=1 means batch is a list with one item

class PrebatchedDataset(Dataset):
    """
    Dataset for pre-batched data with epoch-level batch shuffling.
    """
    def __init__(self, prebatched_batches):
        """
        Initialize pre-batched dataset.
        
        Args:
            prebatched_batches: List of pre-batched dictionaries
        """
        self.batches = prebatched_batches
        self.current_order = list(range(len(self.batches)))
    
    def __len__(self):
        return len(self.batches)
    
    def __getitem__(self, idx):
        return self.batches[self.current_order[idx]]
    
    def shuffle_batches(self, random_seed=None):
        """
        Shuffle batch order for epoch-level randomization.
        
        Args:
            random_seed: Optional random seed for reproducibility
        """
        if random_seed is not None:
            np.random.seed(random_seed)
        np.random.shuffle(self.current_order)

def create_prebatched_data(dataset, batch_size, split_name, phase_name, output_dir, shuffle_samples=True):
    """
    Pre-compute and save batched data to disk.
    
    Args:
        dataset: DrugResponsePathwayDataset
        batch_size: Batch size
        split_name: Name of split (train/val/test)
        phase_name: Name of training phase (phase1/phase2/phase3)
        output_dir: Directory to save pre-batched data
        shuffle_samples: Whether to shuffle samples before batching
        
    Returns:
        Path to saved batch file
    """
    output_dir = Path(output_dir)
    # Include split type in path to avoid overwriting when switching between random and cell_line splits
    split_type = globals().get('SPLIT_TYPE', 'unknown')
    prebatched_dir = output_dir / "prebatched_data" / f"{phase_name}_{split_type}"
    prebatched_dir.mkdir(parents=True, exist_ok=True)
    
    batch_file = prebatched_dir / f"{split_name}_batches.pkl"
    
    if batch_file.exists():
        return batch_file
    
    
    indices = np.arange(len(dataset))
    if shuffle_samples:
        np.random.seed(42)
        np.random.shuffle(indices)
    
    batches = []
    
    for i in tqdm(range(0, len(indices), batch_size), desc=f"Pre-batching {split_name}"):
        batch_indices = indices[i:i+batch_size]
        
        batch_items = [dataset[idx] for idx in batch_indices]
        
        # Extract and batch components
        drug_graphs = [item['drug_graph'] for item in batch_items]
        cell_exprs = torch.stack([item['cell_expr'] for item in batch_items])
        ic50s = torch.stack([item['ic50'] for item in batch_items])
        
        # Batch graphs once
        drug_batch = Batch.from_data_list(drug_graphs)
        
        # Store pre-batched data
        batched_data = {
            'drug_batch': drug_batch,
            'cell_batch': cell_exprs,
            'ic50': ic50s
        }
        batches.append(batched_data)
    
    with open(batch_file, 'wb') as f:
        pickle.dump(batches, f)
    
    return batch_file

def compute_loss(predictions, targets, median_ic50, loss_weights=(1.0, 0.3, 0.3)):
    """
    Compute multi-task loss.
    
    Args:
        predictions: Dictionary with 'ic50', 'classification', 'reconstruction'
        targets: Dictionary with 'ic50', 'embeddings'
        median_ic50: Threshold for classification
        loss_weights: Weights for (IC50, classification, reconstruction)
        
    Returns:
        Total loss and individual losses
    """
    ic50_pred = predictions['ic50'].squeeze()
    ic50_target = targets['ic50'].squeeze()
    
    mse_loss = nn.MSELoss()(ic50_pred, ic50_target)
    
    class_pred = predictions['classification'].squeeze()
    class_target = (ic50_target > median_ic50).float()
    bce_loss = nn.BCEWithLogitsLoss()(class_pred, class_target)
    
    recon_pred = predictions['reconstruction']
    recon_target = targets['embeddings']
    recon_loss = nn.MSELoss()(recon_pred, recon_target)
    
    total_loss = loss_weights[0] * mse_loss + loss_weights[1] * bce_loss + loss_weights[2] * recon_loss
    
    return total_loss, {
        'mse': mse_loss.item(),
        'bce': bce_loss.item(),
        'recon': recon_loss.item()
    }

def evaluate_model(model, dataloader, device, median_ic50=None, is_prebatched=False):
    """
    Evaluate model on a dataloader.
    
    Args:
        model: Trained model
        dataloader: DataLoader to evaluate on
        device: Device to run on
        median_ic50: Classification threshold (if None, computed from dataloader)
        is_prebatched: Whether using pre-batched data (batch_size=1)
        
    Returns:
        Dictionary with metrics, predictions, and targets
    """
    model.eval()
    all_preds = []
    all_targets = []
    total_loss = 0
    
    if median_ic50 is None:
        all_ic50s = []
        for batch in dataloader:
            # Batch is already unwrapped by prebatched_collate_fn if using pre-batched data
            all_ic50s.append(batch['ic50'].cpu().numpy())
        if len(all_ic50s) == 0:
            raise ValueError("Cannot compute median_ic50: dataloader is empty")
        median_ic50 = np.median(np.concatenate(all_ic50s))
    
    with torch.no_grad():
        for batch in dataloader:
            # Batch is already unwrapped by prebatched_collate_fn if using pre-batched data
            drug_batch = batch['drug_batch'].to(device)
            cell_batch = batch['cell_batch'].to(device)
            ic50_target = batch['ic50'].to(device)
            
            outputs = model(drug_batch, cell_batch, return_embeddings=True)
            combined_emb = outputs['embeddings']['combined']
            
            targets = {
                'ic50': ic50_target,
                'embeddings': combined_emb
            }
            
            loss, _ = compute_loss(outputs, targets, median_ic50, loss_weights=(1.0, 0.3, 0.3))
            total_loss += loss.item()
            
            all_preds.append(outputs['ic50'].cpu().numpy())
            all_targets.append(ic50_target.cpu().numpy())
    
    # Handle empty dataloader case
    if len(all_preds) == 0:
        return {
            'metrics': {
                'loss': 0.0,
                'pearson': np.nan,
                'spearman': np.nan,
                'r2': np.nan,
                'rmse': np.nan,
                'mae': np.nan
            },
            'predictions': np.array([]),
            'targets': np.array([])
        }
    
    all_preds = np.concatenate(all_preds).flatten()
    all_targets = np.concatenate(all_targets).flatten()
    
    pearson_r, _ = pearsonr(all_targets, all_preds)
    spearman_r, _ = spearmanr(all_targets, all_preds)
    r2 = r2_score(all_targets, all_preds)
    rmse = np.sqrt(mean_squared_error(all_targets, all_preds))
    mae = mean_absolute_error(all_targets, all_preds)
    
    return {
        'metrics': {
            'loss': total_loss / len(dataloader) if len(dataloader) > 0 else 0.0,
            'pearson': pearson_r,
            'spearman': spearman_r,
            'r2': r2,
            'rmse': rmse,
            'mae': mae
        },
        'predictions': all_preds,
        'targets': all_targets
    }

def create_pathway_dataloaders(dataset, train_idx, val_idx, test_idx, batch_size=128, 
                                use_prebatched=False, phase_name=None, output_dir=None):
    """
    Create dataloaders using pre-defined cell-line based splits.
    
    Args:
        dataset: DrugResponsePathwayDataset
        train_idx: Training indices (from original dataframe)
        val_idx: Validation indices (from original dataframe)
        test_idx: Test indices (from original dataframe)
        batch_size: Batch size
        use_prebatched: Whether to use pre-batched data
        phase_name: Name of training phase (required if use_prebatched=True)
        output_dir: Directory for pre-batched data (required if use_prebatched=True)
        
    Returns:
        Dictionary with train/val/test loaders and datasets (if pre-batched)
    """
    train_data = dataset.original_data.loc[train_idx].reset_index(drop=True)
    val_data = dataset.original_data.loc[val_idx].reset_index(drop=True)
    test_data = dataset.original_data.loc[test_idx].reset_index(drop=True)
    
    train_dataset = DrugResponsePathwayDataset(train_data, dataset.drug_graphs, dataset.drug_col, pathway_cols=dataset.pathway_cols)
    val_dataset = DrugResponsePathwayDataset(val_data, dataset.drug_graphs, dataset.drug_col, pathway_cols=dataset.pathway_cols)
    test_dataset = DrugResponsePathwayDataset(test_data, dataset.drug_graphs, dataset.drug_col, pathway_cols=dataset.pathway_cols)
    
    if use_prebatched:
        if phase_name is None or output_dir is None:
            raise ValueError("phase_name and output_dir required when use_prebatched=True")
        
        train_batch_file = create_prebatched_data(train_dataset, batch_size, 'train', phase_name, output_dir, shuffle_samples=True)
        val_batch_file = create_prebatched_data(val_dataset, batch_size, 'val', phase_name, output_dir, shuffle_samples=False)
        test_batch_file = create_prebatched_data(test_dataset, batch_size, 'test', phase_name, output_dir, shuffle_samples=False)
        
        with open(train_batch_file, 'rb') as f:
            train_batches = pickle.load(f)
        with open(val_batch_file, 'rb') as f:
            val_batches = pickle.load(f)
        with open(test_batch_file, 'rb') as f:
            test_batches = pickle.load(f)
        
        train_prebatched = PrebatchedDataset(train_batches)
        val_prebatched = PrebatchedDataset(val_batches)
        test_prebatched = PrebatchedDataset(test_batches)
        
        train_loader = DataLoader(
            train_prebatched,
            batch_size=1,
            shuffle=False,
            collate_fn=prebatched_collate_fn,
            num_workers=4 if torch.cuda.is_available() else 0,
            pin_memory=torch.cuda.is_available(),
        )
        
        val_loader = DataLoader(
            val_prebatched,
            batch_size=1,
            shuffle=False,
            collate_fn=prebatched_collate_fn,
            num_workers=4 if torch.cuda.is_available() else 0,
            pin_memory=torch.cuda.is_available(),
        )
        
        test_loader = DataLoader(
            test_prebatched,
            batch_size=1,
            shuffle=False,
            collate_fn=prebatched_collate_fn,
            num_workers=4 if torch.cuda.is_available() else 0,
            pin_memory=torch.cuda.is_available(),
        )
        
        return {
            'train': train_loader,
            'val': val_loader,
            'test': test_loader,
            'datasets': {
                'train': train_prebatched,
                'val': val_prebatched,
                'test': test_prebatched
            }
        }
    else:
        # Standard dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=4 if torch.cuda.is_available() else 0,
            pin_memory=torch.cuda.is_available(),
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=4 if torch.cuda.is_available() else 0,
            pin_memory=torch.cuda.is_available(),
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=4 if torch.cuda.is_available() else 0,
            pin_memory=torch.cuda.is_available(),
        )
        
        return {
            'train': train_loader,
            'val': val_loader,
            'test': test_loader
        }

def train_phase_pathway(
    model,
    train_loader,
    val_loader,
    device,
    phase_name,
    num_epochs=50,
    lr=1e-3,
    weight_decay=1e-4,
    patience=10,
    scheduler_patience=3,
    checkpoint_path=None,
    prebatched_datasets=None
):
    """
    Train a phase of the pathway-based model.
    
    Args:
        model: DrugResponsePathwayGNN model
        train_loader: Training DataLoader
        val_loader: Validation DataLoader
        device: Device to train on
        phase_name: Name of the phase
        num_epochs: Maximum epochs
        lr: Learning rate
        weight_decay: Weight decay
        patience: Early stopping patience
        scheduler_patience: LR scheduler patience
        checkpoint_path: Path to save best model
        prebatched_datasets: Dictionary with 'train' PrebatchedDataset (for epoch-level shuffling)
        
    Returns:
        Trained model and training history
    """
    model = model.to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=scheduler_patience)  # Less aggressive: 0.7 instead of 0.5
    
    is_prebatched = prebatched_datasets is not None and 'train' in prebatched_datasets
    
    all_ic50s = []
    for batch in train_loader:
        # Batch is already unwrapped by prebatched_collate_fn if using pre-batched data
        all_ic50s.append(batch['ic50'].cpu().numpy())
    median_ic50 = np.median(np.concatenate(all_ic50s))
    
    best_val_r2 = float('-inf')
    patience_counter = 0
    history = {'train_loss': [], 'val_loss': [], 'val_r2': [], 'val_pearson': []}
    
    for epoch in range(num_epochs):
        # Shuffle batches at start of each epoch (for pre-batched data)
        if is_prebatched:
            prebatched_datasets['train'].shuffle_batches(random_seed=epoch)
        
        model.train()
        train_loss = 0
        for batch in train_loader:
            # Batch is already unwrapped by prebatched_collate_fn if using pre-batched data
            drug_batch = batch['drug_batch'].to(device)
            cell_batch = batch['cell_batch'].to(device)
            ic50_target = batch['ic50'].to(device)
            
            outputs = model(drug_batch, cell_batch, return_embeddings=True)
            combined_emb = outputs['embeddings']['combined']
            
            targets = {
                'ic50': ic50_target,
                'embeddings': combined_emb.detach()
            }
            
            loss, _ = compute_loss(outputs, targets, median_ic50, loss_weights=(1.0, 0.3, 0.3))
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        eval_result = evaluate_model(model, val_loader, device, median_ic50, is_prebatched=is_prebatched)
        val_metrics = eval_result['metrics']
        val_r2 = val_metrics['r2']
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_metrics['loss'])
        history['val_r2'].append(val_r2)
        history['val_pearson'].append(val_metrics['pearson'])
        
        scheduler.step(val_r2)
        
        
        if val_r2 > best_val_r2:
            best_val_r2 = val_r2
            patience_counter = 0
            if checkpoint_path:
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'val_metrics': val_metrics,
                    'classification_threshold': median_ic50,
                    'epoch': epoch + 1
                }, checkpoint_path)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                break
    
    if checkpoint_path and Path(checkpoint_path).exists():
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Final Training Results:")
        best_val_metrics = checkpoint.get('val_metrics', {})
        if best_val_metrics:
            pass
        else:
            pass
        
        # Evaluate classification performance
        from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
        model.eval()
        class_preds_list = []
        class_probs_list = []
        class_targets_list = []
        
        with torch.no_grad():
            for batch in val_loader:
                drug_batch = batch['drug_batch'].to(device)
                cell_batch = batch['cell_batch'].to(device)
                ic50_target = batch['ic50'].to(device)
                
                outputs = model(drug_batch, cell_batch)
                class_logits = outputs['classification'].squeeze()
                class_probs = torch.sigmoid(class_logits).cpu().numpy()
                class_preds = (class_probs > 0.5).astype(int)
                class_targets = (ic50_target.cpu().numpy() > median_ic50).astype(int)
                
                class_probs_list.extend(class_probs)
                class_preds_list.extend(class_preds)
                class_targets_list.extend(class_targets)
        
        class_targets = np.array(class_targets_list)
        class_preds = np.array(class_preds_list)
        class_probs = np.array(class_probs_list)
        
        try:
            roc_auc = roc_auc_score(class_targets, class_probs)
        except:
            pass
        best_val_metrics = checkpoint.get('val_metrics', {})
        print(f"Final Training Results:")
        print(f"{'='*60}")
        if best_val_metrics:
            print(f"  R²:        {best_val_metrics.get('r2', 0):.4f}")
            print(f"  Pearson:   {best_val_metrics.get('pearson', 0):.4f}")
            print(f"  Spearman:  {best_val_metrics.get('spearman', 0):.4f}")
            print(f"  RMSE:      {best_val_metrics.get('rmse', 0):.4f}")
            print(f"  MAE:       {best_val_metrics.get('mae', 0):.4f}")
        else:
            print(f"  R²:        {best_val_r2:.4f}")
        print(f"\nTraining Loss: {history['train_loss'][-1]:.4f}")
        print(f"Validation Loss: {history['val_loss'][-1]:.4f}")
        print(f"{'='*60}\n")
    
    return model, history


## 5. Model Training

### 5.1 Data Loader Preparation

Create data loaders for all three phases with pre-batched data for efficiency.


In [None]:
# PHASE 6: 3-Phase Transfer Learning Pipeline

device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
models_dir = output_dir / "models"
models_dir.mkdir(exist_ok=True)

split_type = globals().get('SPLIT_TYPE', 'unknown')
phase1_checkpoint = models_dir / f"trial2_phase1_pathway_{split_type}.pt"
phase2_checkpoint = models_dir / f"trial2_phase2_breast_pathway_{split_type}.pt"
phase3_checkpoint = models_dir / f"trial2_phase3_tnbc_pathway_{split_type}.pt"

actual_pathway_count = len(pathway_cols)

pan_dataset_pathway = DrugResponsePathwayDataset(pan_cancer_pathway, drug_graphs, drug_col=drug_name_col, pathway_cols=pathway_cols)
breast_dataset_pathway = DrugResponsePathwayDataset(breast_cancer_pathway, drug_graphs, drug_col=drug_name_col, pathway_cols=pathway_cols)
tnbc_dataset_pathway = DrugResponsePathwayDataset(tnbc_pathway, drug_graphs, drug_col=drug_name_col, pathway_cols=pathway_cols)

pan_loaders = create_pathway_dataloaders(
    pan_dataset_pathway, pan_train_idx, pan_val_idx, pan_test_idx, 
    batch_size=256, use_prebatched=True, phase_name="phase1", output_dir=output_dir
)
breast_loaders = create_pathway_dataloaders(
    breast_dataset_pathway, breast_train_idx, breast_val_idx, breast_test_idx, 
    batch_size=256, use_prebatched=True, phase_name="phase2", output_dir=output_dir
)
tnbc_loaders = create_pathway_dataloaders(
    tnbc_dataset_pathway, tnbc_train_idx, tnbc_val_idx, tnbc_test_idx, 
    batch_size=256, use_prebatched=True, phase_name="phase3", output_dir=output_dir
)


### 5.2 Phase 1: Pan-Cancer Pre-training

Train on pan-cancer data using 5-fold cross-validation. Results include mean ± standard deviation and 95% confidence intervals across folds.


In [None]:
# PHASE 1: Pan-Cancer Training (5-Fold Cross-Validation)

model_phase1 = DrugResponsePathwayGNN(cell_input_dim=actual_pathway_count).to(device)

cv_results_file = output_dir / f"phase_1_pan-cancer_cv_results_{SPLIT_TYPE}.pkl"

if cv_results_file.exists():
    with open(cv_results_file, 'rb') as f:
        phase1_cv_data = pickle.load(f)
    best_checkpoint_path = Path(phase1_cv_data['best_checkpoint'])

    if not best_checkpoint_path.exists():
        phase1_cv_results = None
    else:
        checkpoint = torch.load(best_checkpoint_path, map_location=device, weights_only=False)
        model_phase1.load_state_dict(checkpoint['model_state_dict'])
        phase1_cv_results = phase1_cv_data
elif phase1_checkpoint.exists():
    checkpoint = torch.load(phase1_checkpoint, map_location=device, weights_only=False)
    model_phase1.load_state_dict(checkpoint['model_state_dict'])
else:
    # Run 5-fold cross-validation
    phase1_cv_results = train_phase_pathway_cv(
        dataset=pan_dataset_pathway,
        train_idx_all=pan_train_idx,
        val_idx_all=pan_val_idx,
        test_idx=pan_test_idx,
        model_class=DrugResponsePathwayGNN,
        model_kwargs={'cell_input_dim': actual_pathway_count},
        device=device,
        phase_name="Phase 1: Pan-Cancer",
        n_folds=5,
        num_epochs=50,
        lr=1.5e-3,
        weight_decay=1e-4,
        patience=10,
        scheduler_patience=5,
        batch_size=256,
        random_seed=42
    )
    
    model_phase1 = phase1_cv_results['best_model']
    
    torch.save({
        'model_state_dict': model_phase1.state_dict(),
        'val_metrics': phase1_cv_results['cv_results'][phase1_cv_results['best_fold']-1]['val_metrics'],
        'epoch': len(phase1_cv_results['cv_results'][phase1_cv_results['best_fold']-1]['history']['train_loss']),
        'cv_best_fold': phase1_cv_results['best_fold']
    }, phase1_checkpoint)


### 5.3 Phase 2: Breast Cancer Fine-tuning

Fine-tune Phase 1 model on breast cancer data using 5-fold cross-validation. Model weights are transferred from Phase 1.


In [None]:
# PHASE 2: Breast Cancer Training (5-Fold Cross-Validation)

model_phase2 = DrugResponsePathwayGNN(cell_input_dim=actual_pathway_count).to(device)
# Transfer weights from Phase 1
model_phase2.load_state_dict(model_phase1.state_dict(), strict=False)

cv_results_file = output_dir / f"phase_2_breast_cancer_cv_results_{SPLIT_TYPE}.pkl"

if cv_results_file.exists():
    with open(cv_results_file, 'rb') as f:
        phase2_cv_data = pickle.load(f)
    best_checkpoint_path = Path(phase2_cv_data['best_checkpoint'])

    if not best_checkpoint_path.exists():
        phase2_cv_results = None
    else:
        checkpoint = torch.load(best_checkpoint_path, map_location=device, weights_only=False)
        model_phase2.load_state_dict(checkpoint['model_state_dict'])
        phase2_cv_results = phase2_cv_data
elif phase2_checkpoint.exists():
    checkpoint = torch.load(phase2_checkpoint, map_location=device, weights_only=False)
    model_phase2.load_state_dict(checkpoint['model_state_dict'])
else:
    # Run 5-fold cross-validation
    phase2_cv_results = train_phase_pathway_cv(
        dataset=breast_dataset_pathway,
        train_idx_all=breast_train_idx,
        val_idx_all=breast_val_idx,
        test_idx=breast_test_idx,
        model_class=DrugResponsePathwayGNN,
        model_kwargs={'cell_input_dim': actual_pathway_count},
        device=device,
        phase_name="Phase 2: Breast Cancer",
        n_folds=5,
        num_epochs=50,
        lr=1e-4,
        weight_decay=1e-5,
        patience=15,
        scheduler_patience=5,
        batch_size=256,
        random_seed=42
    )
    
    model_phase2 = phase2_cv_results['best_model']
    
    torch.save({
        'model_state_dict': model_phase2.state_dict(),
        'val_metrics': phase2_cv_results['cv_results'][phase2_cv_results['best_fold']-1]['val_metrics'],
        'epoch': len(phase2_cv_results['cv_results'][phase2_cv_results['best_fold']-1]['history']['train_loss']),
        'cv_best_fold': phase2_cv_results['best_fold']
    }, phase2_checkpoint)


### 5.4 Phase 3: TNBC-Specific Fine-tuning

Final fine-tuning on TNBC data. Uses single train/val/test split due to limited sample size (n=24 cell lines).


In [None]:
# PHASE 3: TNBC Final Fine-tuning
model_phase3 = DrugResponsePathwayGNN(cell_input_dim=actual_pathway_count).to(device)

if phase2_checkpoint.exists():
    checkpoint = torch.load(phase2_checkpoint, map_location=device, weights_only=False)
    model_phase3.load_state_dict(checkpoint['model_state_dict'])

if phase3_checkpoint.exists():
    checkpoint = torch.load(phase3_checkpoint, map_location=device, weights_only=False)
    model_phase3.load_state_dict(checkpoint['model_state_dict'])
else:
    model_phase3, phase3_history = train_phase_pathway(
        model=model_phase3,
        train_loader=tnbc_loaders['train'],
        val_loader=tnbc_loaders['val'],
        device=device,
        phase_name="Phase 3: TNBC",
        num_epochs=50,
        lr=5e-5,
        weight_decay=1e-5,
        patience=15,
        scheduler_patience=5,
        checkpoint_path=phase3_checkpoint,
        prebatched_datasets=tnbc_loaders.get('datasets', None)
    )

## 6. Model Evaluation

### 6.1 Performance Metrics

Evaluate models on test sets and save comprehensive performance metrics.


In [None]:
split_suffix = f"_{SPLIT_TYPE}_split" if 'SPLIT_TYPE' in globals() else ""
performance_csv_path = output_dir / f"performance_metrics{split_suffix}.csv"
performance_pkl_path = output_dir / f"performance_metrics{split_suffix}.pkl"

results_df.to_csv(performance_csv_path, index=False)

with open(performance_pkl_path, 'wb') as f:
    pickle.dump({
        'results_df': results_df,
        'phase1_results': phase1_results,
        'phase2_results': phase2_results,
        'phase3_results': phase3_results,
        'split_type': SPLIT_TYPE if 'SPLIT_TYPE' in globals() else 'unknown',
        'timestamp': datetime.now().isoformat()
    }, f)

classification_metrics = {}
if 'class_results_phase1' in globals():
    classification_metrics['phase1'] = {
        'accuracy': accuracy_score(class_results_phase1['ground_truth'], class_results_phase1['predictions']),
        'precision': precision_score(class_results_phase1['ground_truth'], class_results_phase1['predictions'], zero_division=0),
        'recall': recall_score(class_results_phase1['ground_truth'], class_results_phase1['predictions'], zero_division=0),
        'f1_score': f1_score(class_results_phase1['ground_truth'], class_results_phase1['predictions'], zero_division=0),
        'roc_auc': roc_auc_score(class_results_phase1['ground_truth'], class_results_phase1['probabilities']),
        'n_samples': len(class_results_phase1['ground_truth']),
        'n_sensitive': int(np.sum(class_results_phase1['ground_truth'] == 0)),
        'n_resistant': int(np.sum(class_results_phase1['ground_truth'] == 1))
    }
if 'class_results_phase2' in globals():
    classification_metrics['phase2'] = {
        'accuracy': accuracy_score(class_results_phase2['ground_truth'], class_results_phase2['predictions']),
        'precision': precision_score(class_results_phase2['ground_truth'], class_results_phase2['predictions'], zero_division=0),
        'recall': recall_score(class_results_phase2['ground_truth'], class_results_phase2['predictions'], zero_division=0),
        'f1_score': f1_score(class_results_phase2['ground_truth'], class_results_phase2['predictions'], zero_division=0),
        'roc_auc': roc_auc_score(class_results_phase2['ground_truth'], class_results_phase2['probabilities']),
        'n_samples': len(class_results_phase2['ground_truth']),
        'n_sensitive': int(np.sum(class_results_phase2['ground_truth'] == 0)),
        'n_resistant': int(np.sum(class_results_phase2['ground_truth'] == 1))
    }
if 'class_results_phase3' in globals():
    classification_metrics['phase3'] = {
        'accuracy': accuracy_score(class_results_phase3['ground_truth'], class_results_phase3['predictions']),
        'precision': precision_score(class_results_phase3['ground_truth'], class_results_phase3['predictions'], zero_division=0),
        'recall': recall_score(class_results_phase3['ground_truth'], class_results_phase3['predictions'], zero_division=0),
        'f1_score': f1_score(class_results_phase3['ground_truth'], class_results_phase3['predictions'], zero_division=0),
        'roc_auc': roc_auc_score(class_results_phase3['ground_truth'], class_results_phase3['probabilities']),
        'n_samples': len(class_results_phase3['ground_truth']),
        'n_sensitive': int(np.sum(class_results_phase3['ground_truth'] == 0)),
        'n_resistant': int(np.sum(class_results_phase3['ground_truth'] == 1))
    }

if classification_metrics:
    classification_csv_path = output_dir / f"classification_metrics{split_suffix}.csv"
    classification_df = pd.DataFrame(classification_metrics).T
    classification_df.index.name = 'phase'
    classification_df.to_csv(classification_csv_path)
    
    classification_pkl_path = output_dir / f"classification_metrics{split_suffix}.pkl"
    with open(classification_pkl_path, 'wb') as f:
        pickle.dump(classification_metrics, f)

### 6.2 Classification Head Evaluation

Evaluate binary classification performance (sensitive vs resistant) using median IC50 as threshold.


In [None]:
# Evaluate Classification Head Function

from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, classification_report

def evaluate_classification_head(model_path, test_loader, device, phase_name, median_ic50=None):
    """
    Evaluate classification head predictions.
    
    Args:
        model_path: Path to model checkpoint
        test_loader: Test DataLoader
        device: Device to run on
        phase_name: Name of the phase
        median_ic50: Classification threshold (if None, computed from test data)
        
    Returns:
        Dictionary with predictions, probabilities, and ground truth
    """
    if isinstance(test_loader.dataset, PrebatchedDataset):
        pathway_dim = len(pathway_cols)
    else:
        pathway_dim = len(test_loader.dataset.pathway_cols)
    
    model = DrugResponsePathwayGNN(cell_input_dim=pathway_dim).to(device)
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if median_ic50 is None:
        checkpoint_median = checkpoint.get('classification_threshold', None)
        if checkpoint_median is not None:
            median_ic50 = checkpoint_median
        else:
            # Compute from test data
            all_ic50s = []
            for batch in test_loader:
                all_ic50s.append(batch['ic50'].cpu().numpy())
            median_ic50 = np.median(np.concatenate(all_ic50s))
    
    is_prebatched = isinstance(test_loader.dataset, PrebatchedDataset) or test_loader.batch_size == 1
    
    model.eval()
    all_class_logits = []
    all_ic50_preds = []
    all_ic50_targets = []
    
    with torch.no_grad():
        for batch in test_loader:
            drug_batch = batch['drug_batch'].to(device)
            cell_batch = batch['cell_batch'].to(device)
            ic50_target = batch['ic50'].to(device)
            
            outputs = model(drug_batch, cell_batch, return_embeddings=False)
            
            all_class_logits.append(outputs['classification'].cpu().numpy())
            all_ic50_preds.append(outputs['ic50'].cpu().numpy())
            all_ic50_targets.append(ic50_target.cpu().numpy())
    
    class_logits = np.concatenate(all_class_logits).flatten()
    class_probs = torch.sigmoid(torch.tensor(class_logits)).numpy()
    ic50_preds = np.concatenate(all_ic50_preds).flatten()
    ic50_targets = np.concatenate(all_ic50_targets).flatten()
    
    # Ground truth labels: 1 if IC50 > median (resistant), 0 if IC50 <= median (sensitive)
    y_true = (ic50_targets > median_ic50).astype(int)
    
    # Binary predictions using 0.5 threshold
    y_pred = (class_probs >= 0.5).astype(int)
    
    return {
        'phase_name': phase_name,
        'logits': class_logits,
        'probabilities': class_probs,
        'predictions': y_pred,
        'ground_truth': y_true,
        'ic50_targets': ic50_targets,
        'ic50_preds': ic50_preds,
        'median_ic50': median_ic50
    }

### 6.3 Visualization Helper Functions

Load saved splits to ensure reproducibility.


In [None]:
def load_saved_splits(split_type='cell_line'):
    """Load saved train/val/test split indices."""
    splits_file = splits_dir / f"data_splits_{split_type}.json"
    if not splits_file.exists():
        raise FileNotFoundError(f"Saved splits not found: {splits_file}. Please run data splitting cell first.")
    
    with open(splits_file, 'r') as f:
        splits_data = json.load(f)
    
    return {
        'pan_cancer': {
            'train': np.array(splits_data['pan_cancer']['train']),
            'val': np.array(splits_data['pan_cancer']['val']),
            'test': np.array(splits_data['pan_cancer']['test'])
        },
        'breast_cancer': {
            'train': np.array(splits_data['breast_cancer']['train']),
            'val': np.array(splits_data['breast_cancer']['val']),
            'test': np.array(splits_data['breast_cancer']['test'])
        },
        'tnbc': {
            'train': np.array(splits_data['tnbc']['train']),
            'val': np.array(splits_data['tnbc']['val']),
            'test': np.array(splits_data['tnbc']['test'])
        }
    }


## 7. Results Visualization

Generate comprehensive visualizations comparing model performance across phases.


In [None]:
import seaborn as sns

SPLIT_TYPE = 'cell_line'

if 'pan_loaders' not in globals() or 'breast_loaders' not in globals() or 'tnbc_loaders' not in globals():
    splits = load_saved_splits(SPLIT_TYPE)
    pan_train_idx = splits['pan_cancer']['train']
    pan_val_idx = splits['pan_cancer']['val']
    pan_test_idx = splits['pan_cancer']['test']
    breast_train_idx = splits['breast_cancer']['train']
    breast_val_idx = splits['breast_cancer']['val']
    breast_test_idx = splits['breast_cancer']['test']
    tnbc_train_idx = splits['tnbc']['train']
    tnbc_val_idx = splits['tnbc']['val']
    tnbc_test_idx = splits['tnbc']['test']
    
    pan_dataset_pathway = DrugResponsePathwayDataset(pan_cancer_pathway, drug_graphs, drug_col=drug_name_col, pathway_cols=pathway_cols)
    breast_dataset_pathway = DrugResponsePathwayDataset(breast_cancer_pathway, drug_graphs, drug_col=drug_name_col, pathway_cols=pathway_cols)
    tnbc_dataset_pathway = DrugResponsePathwayDataset(tnbc_pathway, drug_graphs, drug_col=drug_name_col, pathway_cols=pathway_cols)
    
    pan_loaders = create_pathway_dataloaders(
        pan_dataset_pathway, pan_train_idx, pan_val_idx, pan_test_idx, 
        batch_size=256, use_prebatched=True, phase_name="phase1", output_dir=output_dir
    )
    breast_loaders = create_pathway_dataloaders(
        breast_dataset_pathway, breast_train_idx, breast_val_idx, breast_test_idx, 
        batch_size=256, use_prebatched=True, phase_name="phase2", output_dir=output_dir
    )
    tnbc_loaders = create_pathway_dataloaders(
        tnbc_dataset_pathway, tnbc_train_idx, tnbc_val_idx, tnbc_test_idx, 
        batch_size=256, use_prebatched=True, phase_name="phase3", output_dir=output_dir
    )

split_type = SPLIT_TYPE if 'SPLIT_TYPE' in globals() else 'unknown'

def find_checkpoint(phase_num, split_type):
    """Find checkpoint file for cell-line split models."""
    models_dir = output_dir / "models"
    
    if phase_num == 1:
        name = "trial2_phase1_pathway_cellsplit.pt"
    elif phase_num == 2:
        name = "trial2_phase2_breast_pathway_cellsplit.pt"
    elif phase_num == 3:
        name = "trial2_phase3_tnbc_pathway_cellsplit.pt"
    
    return models_dir / name

def evaluate_phase_pathway(model_path, test_loader, phase_name, device):
    """
    Evaluate a phase on its test set.
    
    Args:
        model_path: Path to model checkpoint
        test_loader: Test DataLoader
        phase_name: Name of the phase
        device: Device to run on
        
    Returns:
        Dictionary with metrics and predictions
    """
    n_samples = len(test_loader.dataset)
    
    if n_samples == 0:
        print(f"Warning: {phase_name} test set is empty - returning NaN metrics")
        return {
            'phase_name': phase_name,
            'metrics': {
                'loss': 0.0,
                'pearson': np.nan,
                'spearman': np.nan,
                'r2': np.nan,
                'rmse': np.nan,
                'mae': np.nan
            },
            'predictions': np.array([]),
            'targets': np.array([]),
            'n_samples': 0,
            'n_cells': 0
        }
    
    if isinstance(test_loader.dataset, PrebatchedDataset):
        pathway_dim = len(pathway_cols)
    else:
        pathway_dim = len(test_loader.dataset.pathway_cols)
    
    model = DrugResponsePathwayGNN(cell_input_dim=pathway_dim).to(device)
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    is_prebatched = isinstance(test_loader.dataset, PrebatchedDataset) or test_loader.batch_size == 1
    results = evaluate_model(model, test_loader, device, is_prebatched=is_prebatched)
    
    if isinstance(test_loader.dataset, PrebatchedDataset):
        n_cells = None
    else:
        n_cells = test_loader.dataset.data['ModelID'].nunique()
    
    return {
        'phase_name': phase_name,
        'metrics': results['metrics'],
        'predictions': results['predictions'],
        'targets': results['targets'],
        'n_samples': n_samples,
        'n_cells': n_cells
    }

phase1_checkpoint = find_checkpoint(1, split_type)
phase2_checkpoint = find_checkpoint(2, split_type)
phase3_checkpoint = find_checkpoint(3, split_type)

phase1_results = evaluate_phase_pathway(phase1_checkpoint, pan_loaders['test'], "Phase 1: Pan-Cancer", device)
phase2_results = evaluate_phase_pathway(phase2_checkpoint, breast_loaders['test'], "Phase 2: Breast Cancer", device)
phase3_results = evaluate_phase_pathway(phase3_checkpoint, tnbc_loaders['test'], "Phase 3: TNBC", device)

results_df = pd.DataFrame({
    'Phase': ['Phase 1: Pan-Cancer', 'Phase 2: Breast Cancer', 'Phase 3: TNBC'],
    'Pearson': [phase1_results['metrics']['pearson'], phase2_results['metrics']['pearson'], phase3_results['metrics']['pearson']],
    'Spearman': [phase1_results['metrics']['spearman'], phase2_results['metrics']['spearman'], phase3_results['metrics']['spearman']],
    'R²': [phase1_results['metrics']['r2'], phase2_results['metrics']['r2'], phase3_results['metrics']['r2']],
    'RMSE': [phase1_results['metrics']['rmse'], phase2_results['metrics']['rmse'], phase3_results['metrics']['rmse']],
    'MAE': [phase1_results['metrics']['mae'], phase2_results['metrics']['mae'], phase3_results['metrics']['mae']]
})

split_suffix = f"_{SPLIT_TYPE}_split"
performance_csv_path = output_dir / f"performance_metrics{split_suffix}.csv"
performance_pkl_path = output_dir / f"performance_metrics{split_suffix}.pkl"

results_df.to_csv(performance_csv_path, index=False)

with open(performance_pkl_path, 'wb') as f:
    pickle.dump({
        'results_df': results_df,
        'phase1_results': phase1_results,
        'phase2_results': phase2_results,
        'phase3_results': phase3_results,
        'split_type': SPLIT_TYPE
    }, f)

class_results_phase1 = evaluate_classification_head(phase1_checkpoint, pan_loaders['test'], device, "Phase 1: Pan-Cancer")
class_results_phase2 = evaluate_classification_head(phase2_checkpoint, breast_loaders['test'], device, "Phase 2: Breast Cancer")
class_results_phase3 = evaluate_classification_head(phase3_checkpoint, tnbc_loaders['test'], device, "Phase 3: TNBC")

fig = plt.figure(figsize=(7, 5.25))

results = class_results_phase3
y_true = results['ground_truth']
y_pred = results['predictions']
y_probs = results['probabilities']

# 1. Confusion Matrix
ax1 = plt.subplot(3, 3, 1)
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='gray', ax=ax1, cbar_kws={'label': 'Count'})
ax1.set_xlabel('Predicted (0=Sensitive, 1=Resistant)')
ax1.set_ylabel('True (0=Sensitive, 1=Resistant)')
ax1.set_title(f'{results["phase_name"]}\nConfusion Matrix')

# 2. ROC Curve
ax2 = plt.subplot(3, 3, 2)
fpr, tpr, _ = roc_curve(y_true, y_probs)
roc_auc = auc(fpr, tpr)
ax2.plot(fpr, tpr, color='black', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
ax2.plot([0, 1], [0, 1], color='black', lw=2, linestyle='--', label='Random')
ax2.set_xlabel('False Positive Rate')
ax2.set_ylabel('True Positive Rate')
ax2.set_title('ROC Curve')
ax2.legend(loc="lower right")
ax2.grid(True, alpha=0.3)

# 3. Precision-Recall Curve
ax3 = plt.subplot(3, 3, 3)
precision, recall, _ = precision_recall_curve(y_true, y_probs)
pr_auc = auc(recall, precision)
ax3.plot(recall, precision, color='black', lw=2, label=f'PR curve (AUC = {pr_auc:.3f})')
baseline = np.sum(y_true) / len(y_true)
ax3.axhline(y=baseline, color='black', linestyle='--', label=f'Baseline = {baseline:.3f}')
ax3.set_xlabel('Recall')
ax3.set_ylabel('Precision')
ax3.set_title('Precision-Recall Curve')
ax3.legend(loc="lower left")
ax3.grid(True, alpha=0.3)

# 4. Probability Distribution by Class
ax4 = plt.subplot(3, 3, 4)
sensitive_probs = y_probs[y_true == 0]
resistant_probs = y_probs[y_true == 1]
ax4.hist(sensitive_probs, bins=30, alpha=0.7, label='Sensitive (IC50 ≤ median)', color='green', density=True)
ax4.hist(resistant_probs, bins=30, alpha=0.7, label='Resistant (IC50 > median)', color='red', density=True)
ax4.axvline(x=0.5, color='black', linestyle='--', linewidth=2, label='Decision Threshold (0.5)')
ax4.set_xlabel('Predicted Probability')
ax4.set_ylabel('Density')
ax4.set_title('Probability Distribution by True Class')
ax4.legend()
ax4.grid(True, alpha=0.3)

# 5. Classification Report
ax5 = plt.subplot(3, 3, 5)
ax5.axis('off')
report_text = classification_report(y_true, y_pred, target_names=['Sensitive', 'Resistant'])
ax5.text(0.1, 0.5, report_text, fontsize=10, family='monospace', verticalalignment='center')
ax5.set_title('Classification Report')

# 6. IC50 vs Classification Probability
ax6 = plt.subplot(3, 3, 6)
scatter = ax6.scatter(results['ic50_targets'], y_probs, c=y_true, cmap='gray', alpha=0.7, edgecolors='black', linewidth=0.5)
ax6.axhline(y=0.5, color='black', linestyle='--', linewidth=2, label='Decision Threshold')
ax6.axvline(x=results['median_ic50'], color='blue', linestyle='--', linewidth=2, label=f'Median IC50 = {results["median_ic50"]:.2f}')
ax6.set_xlabel('True IC50 (LN_IC50)')
ax6.set_ylabel('Predicted Classification Probability')
ax6.set_title('IC50 vs Classification Probability')
ax6.legend()
ax6.grid(True, alpha=0.3)
cbar = plt.colorbar(scatter, ax=ax6)
cbar.set_label('True Label (0=Sensitive, 1=Resistant)')

# 7. ROC Comparison
ax7 = plt.subplot(3, 3, 7)
fpr_p1, tpr_p1, _ = roc_curve(class_results_phase1['ground_truth'], class_results_phase1['probabilities'])
roc_auc_p1 = auc(fpr_p1, tpr_p1)
fpr_p2, tpr_p2, _ = roc_curve(class_results_phase2['ground_truth'], class_results_phase2['probabilities'])
roc_auc_p2 = auc(fpr_p2, tpr_p2)
fpr_p3, tpr_p3, _ = roc_curve(y_true, y_probs)
roc_auc_p3 = auc(fpr_p3, tpr_p3)

ax7.plot(fpr_p1, tpr_p1, label=f'Phase 1 (AUC={roc_auc_p1:.3f})', lw=2)
ax7.plot(fpr_p2, tpr_p2, label=f'Phase 2 (AUC={roc_auc_p2:.3f})', lw=2)
ax7.plot(fpr_p3, tpr_p3, label=f'Phase 3 (AUC={roc_auc_p3:.3f})', lw=2)
ax7.plot([0, 1], [0, 1], 'k--', label='Random')
ax7.set_xlabel('False Positive Rate')
ax7.set_ylabel('True Positive Rate')
ax7.set_title('ROC Curves: All Phases')
ax7.legend()
ax7.grid(True, alpha=0.3)

# 8. Metrics Comparison
ax8 = plt.subplot(3, 3, 8)
phases = ['Phase 1', 'Phase 2', 'Phase 3']
all_results = [class_results_phase1, class_results_phase2, class_results_phase3]
accuracies = []
precisions = []
recalls = []
f1_scores = []

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
for res in all_results:
    y_t = res['ground_truth']
    y_p = res['predictions']
    accuracies.append(accuracy_score(y_t, y_p))
    precisions.append(precision_score(y_t, y_p, zero_division=0))
    recalls.append(recall_score(y_t, y_p, zero_division=0))
    f1_scores.append(f1_score(y_t, y_p, zero_division=0))

x = np.arange(len(phases))
width = 0.2
ax8.bar(x - 1.5*width, accuracies, width, label='Accuracy', alpha=0.8)
ax8.bar(x - 0.5*width, precisions, width, label='Precision', alpha=0.8)
ax8.bar(x + 0.5*width, recalls, width, label='Recall', alpha=0.8)
ax8.bar(x + 1.5*width, f1_scores, width, label='F1-Score', alpha=0.8)
ax8.set_xlabel('Phase')
ax8.set_ylabel('Score')
ax8.set_title('Classification Metrics Comparison')
ax8.set_xticks(x)
ax8.set_xticklabels(phases)
ax8.legend()
ax8.set_ylim([0, 1.1])
ax8.grid(True, alpha=0.3, axis='y')

# 9. Probability Calibration
ax9 = plt.subplot(3, 3, 9)
n_bins = 10
bin_boundaries = np.linspace(0, 1, n_bins + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]

fraction_of_positives = []
mean_predicted_value = []

for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
    mask = (y_probs > bin_lower) & (y_probs <= bin_upper)
    if mask.sum() > 0:
        fraction_of_positives.append(y_true[mask].mean())
        mean_predicted_value.append(y_probs[mask].mean())

ax9.plot([0, 1], [0, 1], 'k--', label='Perfect Calibration')
ax9.plot(mean_predicted_value, fraction_of_positives, 's-', label='Model Calibration', markersize=8)
ax9.set_xlabel('Mean Predicted Probability')
ax9.set_ylabel('Fraction of Positives')
ax9.set_title('Calibration Plot')
ax9.legend()
ax9.grid(True, alpha=0.3)

plt.tight_layout()
classification_plot_path = output_dir / 'classification_head_analysis.png'
plt.savefig(classification_plot_path, dpi=300, bbox_inches='tight')
plt.show()

from sklearn.metrics import roc_auc_score
for phase_name, res in zip(['Phase 1: Pan-Cancer', 'Phase 2: Breast Cancer', 'Phase 3: TNBC'], 
                           [class_results_phase1, class_results_phase2, class_results_phase3]):
    y_t = res['ground_truth']
    y_p = res['predictions']
    y_probs = res['probabilities']



### 7.1 Pan-Cancer vs Fine-tuned Model Comparison

Compare Phase 1 (pan-cancer) and Phase 3 (TNBC fine-tuned) models on TNBC test data.


In [None]:


splits = load_saved_splits(SPLIT_TYPE)
pan_train_idx = splits['pan_cancer']['train']
pan_val_idx = splits['pan_cancer']['val']
pan_test_idx = splits['pan_cancer']['test']
breast_train_idx = splits['breast_cancer']['train']
breast_val_idx = splits['breast_cancer']['val']
breast_test_idx = splits['breast_cancer']['test']
tnbc_train_idx = splits['tnbc']['train']
tnbc_val_idx = splits['tnbc']['val']
tnbc_test_idx = splits['tnbc']['test']

models_dir = output_dir / "models"
split_type = SPLIT_TYPE if 'SPLIT_TYPE' in globals() else 'cell_line'

# Use cellsplit models (cell-line split)
phase1_checkpoint = models_dir / "trial2_phase1_pathway_cellsplit.pt"
phase3_checkpoint = models_dir / "trial2_phase3_tnbc_pathway_cellsplit.pt"

phase1_on_tnbc = evaluate_phase_pathway(phase1_checkpoint, tnbc_loaders['test'], "Phase 1 (Pan-Cancer) on TNBC", device)
phase3_on_tnbc = evaluate_phase_pathway(phase3_checkpoint, tnbc_loaders['test'], "Phase 3 (TNBC Fine-tuned) on TNBC", device)

comparison_df = pd.DataFrame({
    'Model': ['Phase 1: Pan-Cancer', 'Phase 3: TNBC Fine-tuned', 'Improvement'],
    'Pearson': [
        phase1_on_tnbc['metrics']['pearson'],
        phase3_on_tnbc['metrics']['pearson'],
        phase3_on_tnbc['metrics']['pearson'] - phase1_on_tnbc['metrics']['pearson']
    ],
    'Spearman': [
        phase1_on_tnbc['metrics']['spearman'],
        phase3_on_tnbc['metrics']['spearman'],
        phase3_on_tnbc['metrics']['spearman'] - phase1_on_tnbc['metrics']['spearman']
    ],
    'R²': [
        phase1_on_tnbc['metrics']['r2'],
        phase3_on_tnbc['metrics']['r2'],
        phase3_on_tnbc['metrics']['r2'] - phase1_on_tnbc['metrics']['r2']
    ],
    'RMSE': [
        phase1_on_tnbc['metrics']['rmse'],
        phase3_on_tnbc['metrics']['rmse'],
        phase1_on_tnbc['metrics']['rmse'] - phase3_on_tnbc['metrics']['rmse']  # Lower is better
    ],
    'MAE': [
        phase1_on_tnbc['metrics']['mae'],
        phase3_on_tnbc['metrics']['mae'],
        phase1_on_tnbc['metrics']['mae'] - phase3_on_tnbc['metrics']['mae']  # Lower is better
    ]
})

fig = plt.figure(figsize=(18, 12))

# 1. Metrics Comparison Bar Chart
ax1 = plt.subplot(2, 3, 1)
metrics = ['Pearson', 'Spearman', 'R²']
metric_keys = {'Pearson': 'pearson', 'Spearman': 'spearman', 'R²': 'r2'}
phase1_vals = [phase1_on_tnbc['metrics'][metric_keys[m]] for m in metrics]
phase3_vals = [phase3_on_tnbc['metrics'][metric_keys[m]] for m in metrics]
x = np.arange(len(metrics))
width = 0.35
ax1.bar(x - width/2, phase1_vals, width, label='Phase 1: Pan-Cancer', alpha=0.8, color='black')
ax1.bar(x + width/2, phase3_vals, width, label='Phase 3: TNBC Fine-tuned', alpha=0.8, color='black')
ax1.set_xlabel('Metric')
ax1.set_ylabel('Score')
ax1.set_title('Performance Comparison: Correlation Metrics')
ax1.set_xticks(x)
ax1.set_xticklabels(metrics)
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')
ax1.set_ylim([0, 1.0])

# 2. Error Metrics Comparison
ax2 = plt.subplot(2, 3, 2)
error_metrics = ['RMSE', 'MAE']
phase1_errors = [phase1_on_tnbc['metrics'][m.lower()] for m in error_metrics]
phase3_errors = [phase3_on_tnbc['metrics'][m.lower()] for m in error_metrics]
x2 = np.arange(len(error_metrics))
ax2.bar(x2 - width/2, phase1_errors, width, label='Phase 1: Pan-Cancer', alpha=0.8, color='black')
ax2.bar(x2 + width/2, phase3_errors, width, label='Phase 3: TNBC Fine-tuned', alpha=0.8, color='black')
ax2.set_xlabel('Metric')
ax2.set_ylabel('Error')
ax2.set_title('Performance Comparison: Error Metrics')
ax2.set_xticks(x2)
ax2.set_xticklabels(error_metrics)
ax2.legend()
ax2.grid(True, alpha=0.3, axis='y')

# 3. Scatter Plot: Phase 1 Predictions vs True
ax3 = plt.subplot(2, 3, 3)
preds_p1 = phase1_on_tnbc['predictions']
targets = phase1_on_tnbc['targets']
ax3.scatter(targets, preds_p1, alpha=0.7, s=20, color='black', edgecolors='black', linewidth=0.3)
min_val = min(targets.min(), preds_p1.min())
max_val = max(targets.max(), preds_p1.max())
ax3.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect Prediction')
ax3.set_xlabel('True IC50 (LN_IC50)')
ax3.set_ylabel('Predicted IC50 (LN_IC50)')
ax3.set_title(f'Phase 1: Pan-Cancer\nPearson={phase1_on_tnbc["metrics"]["pearson"]:.4f}, R²={phase1_on_tnbc["metrics"]["r2"]:.4f}')
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. Scatter Plot: Phase 3 Predictions vs True
ax4 = plt.subplot(2, 3, 4)
preds_p3 = phase3_on_tnbc['predictions']
ax4.scatter(targets, preds_p3, alpha=0.7, s=20, color='black', edgecolors='black', linewidth=0.3)
min_val = min(targets.min(), preds_p3.min())
max_val = max(targets.max(), preds_p3.max())
ax4.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect Prediction')
ax4.set_xlabel('True IC50 (LN_IC50)')
ax4.set_ylabel('Predicted IC50 (LN_IC50)')
ax4.set_title(f'Phase 3: TNBC Fine-tuned\nPearson={phase3_on_tnbc["metrics"]["pearson"]:.4f}, R²={phase3_on_tnbc["metrics"]["r2"]:.4f}')
ax4.legend()
ax4.grid(True, alpha=0.3)

# 5. Residual Comparison
ax5 = plt.subplot(2, 3, 5)
residuals_p1 = targets - preds_p1
residuals_p3 = targets - preds_p3
ax5.hist(residuals_p1, bins=50, alpha=0.7, label='Phase 1: Pan-Cancer', color='black', density=True)
ax5.hist(residuals_p3, bins=50, alpha=0.7, label='Phase 3: TNBC Fine-tuned', color='black', density=True)
ax5.axvline(x=0, color='red', linestyle='--', linewidth=2, label='Zero Error')
ax5.set_xlabel('Residual (True - Predicted)')
ax5.set_ylabel('Density')
ax5.set_title('Residual Distribution Comparison')
ax5.legend()
ax5.grid(True, alpha=0.3)

# 6. Improvement Summary
ax6 = plt.subplot(2, 3, 6)
ax6.axis('off')
improvement_text = f"""
Performance Improvement from Fine-tuning:

Pearson Correlation: {phase3_on_tnbc['metrics']['pearson'] - phase1_on_tnbc['metrics']['pearson']:+.4f}
Spearman Correlation: {phase3_on_tnbc['metrics']['spearman'] - phase1_on_tnbc['metrics']['spearman']:+.4f}
R² Score: {phase3_on_tnbc['metrics']['r2'] - phase1_on_tnbc['metrics']['r2']:+.4f}
RMSE Reduction: {phase1_on_tnbc['metrics']['rmse'] - phase3_on_tnbc['metrics']['rmse']:+.4f}
MAE Reduction: {phase1_on_tnbc['metrics']['mae'] - phase3_on_tnbc['metrics']['mae']:+.4f}

Test Samples: {len(targets)}
"""
ax6.text(0.1, 0.5, improvement_text, fontsize=12, family='monospace', verticalalignment='center',
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.7))
ax6.set_title('Fine-tuning Improvement Summary', fontsize=14, fontweight='bold')

plt.tight_layout()
comparison_plot_path = output_dir / 'pan_cancer_vs_tnbc_comparison.png'
plt.savefig(comparison_plot_path, dpi=300, bbox_inches='tight')
plt.show()

comparison_csv_path = output_dir / f'pan_cancer_vs_tnbc_comparison_{split_type}.csv'
comparison_df.to_csv(comparison_csv_path, index=False)



### 7.2 Classification Performance Comparison

Compare binary classification performance between pan-cancer and fine-tuned models.


In [None]:
# Compare Classification Performance: Pan-Cancer Model vs Fine-tuned TNBC Model on TNBC Test Data



splits = load_saved_splits(SPLIT_TYPE)
pan_train_idx = splits['pan_cancer']['train']
pan_val_idx = splits['pan_cancer']['val']
pan_test_idx = splits['pan_cancer']['test']
breast_train_idx = splits['breast_cancer']['train']
breast_val_idx = splits['breast_cancer']['val']
breast_test_idx = splits['breast_cancer']['test']
tnbc_train_idx = splits['tnbc']['train']
tnbc_val_idx = splits['tnbc']['val']
tnbc_test_idx = splits['tnbc']['test']

models_dir = output_dir / "models"
split_type = SPLIT_TYPE if 'SPLIT_TYPE' in globals() else 'cell_line'

# Use cellsplit models (cell-line split)
phase1_checkpoint = models_dir / "trial2_phase1_pathway_cellsplit.pt"
phase3_checkpoint = models_dir / "trial2_phase3_tnbc_pathway_cellsplit.pt"


if 'tnbc_loaders' not in globals():
    if 'tnbc_test_idx' not in globals():
        pass

# Compute median_ic50 from TNBC test data to ensure fair comparison
all_tnbc_ic50s = []
for batch in tnbc_loaders['test']:
    all_tnbc_ic50s.append(batch['ic50'].cpu().numpy())
tnbc_test_median_ic50 = np.median(np.concatenate(all_tnbc_ic50s))

# Evaluate Phase 1 (Pan-Cancer) classification head on TNBC test data
class_phase1_on_tnbc = evaluate_classification_head(phase1_checkpoint, tnbc_loaders['test'], device, "Phase 1 (Pan-Cancer) on TNBC", median_ic50=tnbc_test_median_ic50)

# Evaluate Phase 3 (TNBC Fine-tuned) classification head on TNBC test data
class_phase3_on_tnbc = evaluate_classification_head(phase3_checkpoint, tnbc_loaders['test'], device, "Phase 3 (TNBC Fine-tuned) on TNBC", median_ic50=tnbc_test_median_ic50)

# Calculate classification metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

y_true = class_phase1_on_tnbc['ground_truth']
y_pred_p1 = class_phase1_on_tnbc['predictions']
y_probs_p1 = class_phase1_on_tnbc['probabilities']
y_pred_p3 = class_phase3_on_tnbc['predictions']
y_probs_p3 = class_phase3_on_tnbc['probabilities']

metrics_p1 = {
    'accuracy': accuracy_score(y_true, y_pred_p1),
    'precision': precision_score(y_true, y_pred_p1, zero_division=0),
    'recall': recall_score(y_true, y_pred_p1, zero_division=0),
    'f1': f1_score(y_true, y_pred_p1, zero_division=0),
    'roc_auc': roc_auc_score(y_true, y_probs_p1)
}

metrics_p3 = {
    'accuracy': accuracy_score(y_true, y_pred_p3),
    'precision': precision_score(y_true, y_pred_p3, zero_division=0),
    'recall': recall_score(y_true, y_pred_p3, zero_division=0),
    'f1': f1_score(y_true, y_pred_p3, zero_division=0),
    'roc_auc': roc_auc_score(y_true, y_probs_p3)
}

comparison_class_df = pd.DataFrame({
    'Model': ['Phase 1: Pan-Cancer', 'Phase 3: TNBC Fine-tuned', 'Improvement'],
    'Accuracy': [
        metrics_p1['accuracy'],
        metrics_p3['accuracy'],
        metrics_p3['accuracy'] - metrics_p1['accuracy']
    ],
    'Precision': [
        metrics_p1['precision'],
        metrics_p3['precision'],
        metrics_p3['precision'] - metrics_p1['precision']
    ],
    'Recall': [
        metrics_p1['recall'],
        metrics_p3['recall'],
        metrics_p3['recall'] - metrics_p1['recall']
    ],
    'F1-Score': [
        metrics_p1['f1'],
        metrics_p3['f1'],
        metrics_p3['f1'] - metrics_p1['f1']
    ],
    'ROC-AUC': [
        metrics_p1['roc_auc'],
        metrics_p3['roc_auc'],
        metrics_p3['roc_auc'] - metrics_p1['roc_auc']
    ]
})


fig = plt.figure(figsize=(18, 12))

# 1. Classification Metrics Comparison Bar Chart
ax1 = plt.subplot(2, 3, 1)
class_metrics = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC-AUC']
metric_keys = {'Accuracy': 'accuracy', 'Precision': 'precision', 'Recall': 'recall', 'F1-Score': 'f1', 'ROC-AUC': 'roc_auc'}
phase1_class_vals = [metrics_p1[metric_keys[m]] for m in class_metrics]
phase3_class_vals = [metrics_p3[metric_keys[m]] for m in class_metrics]
x = np.arange(len(class_metrics))
width = 0.35
ax1.bar(x - width/2, phase1_class_vals, width, label='Phase 1: Pan-Cancer', alpha=0.8, color='black')
ax1.bar(x + width/2, phase3_class_vals, width, label='Phase 3: TNBC Fine-tuned', alpha=0.8, color='black')
ax1.set_xlabel('Metric')
ax1.set_ylabel('Score')
ax1.set_title('Classification Metrics Comparison')
ax1.set_xticks(x)
ax1.set_xticklabels(class_metrics, rotation=45, ha='right')
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')
ax1.set_ylim([0, 1.1])

# 2. ROC Curve Comparison
ax2 = plt.subplot(2, 3, 2)
fpr_p1, tpr_p1, _ = roc_curve(y_true, y_probs_p1)
roc_auc_p1 = auc(fpr_p1, tpr_p1)
fpr_p3, tpr_p3, _ = roc_curve(y_true, y_probs_p3)
roc_auc_p3 = auc(fpr_p3, tpr_p3)

ax2.plot(fpr_p1, tpr_p1, label=f'Phase 1 (AUC={roc_auc_p1:.3f})', lw=2, color='black')
ax2.plot(fpr_p3, tpr_p3, label=f'Phase 3 (AUC={roc_auc_p3:.3f})', lw=2, color='black')
ax2.plot([0, 1], [0, 1], 'k--', label='Random', lw=1)
ax2.set_xlabel('False Positive Rate')
ax2.set_ylabel('True Positive Rate')
ax2.set_title('ROC Curves Comparison')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. Precision-Recall Curve Comparison
ax3 = plt.subplot(2, 3, 3)
precision_p1, recall_p1, _ = precision_recall_curve(y_true, y_probs_p1)
pr_auc_p1 = auc(recall_p1, precision_p1)
precision_p3, recall_p3, _ = precision_recall_curve(y_true, y_probs_p3)
pr_auc_p3 = auc(recall_p3, precision_p3)

baseline = np.sum(y_true) / len(y_true)
ax3.plot(recall_p1, precision_p1, label=f'Phase 1 (AUC={pr_auc_p1:.3f})', lw=2, color='black')
ax3.plot(recall_p3, precision_p3, label=f'Phase 3 (AUC={pr_auc_p3:.3f})', lw=2, color='black')
ax3.axhline(y=baseline, color='black', linestyle='--', label=f'Baseline={baseline:.3f}')
ax3.set_xlabel('Recall')
ax3.set_ylabel('Precision')
ax3.set_title('Precision-Recall Curves Comparison')
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. Confusion Matrix: Phase 1
ax4 = plt.subplot(2, 3, 4)
cm_p1 = confusion_matrix(y_true, y_pred_p1)
sns.heatmap(cm_p1, annot=True, fmt='d', cmap='gray', ax=ax4, cbar_kws={'label': 'Count'})
ax4.set_xlabel('Predicted (0=Sensitive, 1=Resistant)')
ax4.set_ylabel('True (0=Sensitive, 1=Resistant)')
ax4.set_title(f'Phase 1: Pan-Cancer\nAccuracy={metrics_p1["accuracy"]:.4f}')

# 5. Confusion Matrix: Phase 3
ax5 = plt.subplot(2, 3, 5)
cm_p3 = confusion_matrix(y_true, y_pred_p3)
sns.heatmap(cm_p3, annot=True, fmt='d', cmap='Greens', ax=ax5, cbar_kws={'label': 'Count'})
ax5.set_xlabel('Predicted (0=Sensitive, 1=Resistant)')
ax5.set_ylabel('True (0=Sensitive, 1=Resistant)')
ax5.set_title(f'Phase 3: TNBC Fine-tuned\nAccuracy={metrics_p3["accuracy"]:.4f}')

# 6. Probability Distribution Comparison
ax6 = plt.subplot(2, 3, 6)
sensitive_probs_p1 = y_probs_p1[y_true == 0]
resistant_probs_p1 = y_probs_p1[y_true == 1]
sensitive_probs_p3 = y_probs_p3[y_true == 0]
resistant_probs_p3 = y_probs_p3[y_true == 1]

ax6.hist(sensitive_probs_p1, bins=30, alpha=0.4, label='Phase 1 Sensitive', color='lightblue', density=True)
ax6.hist(resistant_probs_p1, bins=30, alpha=0.4, label='Phase 1 Resistant', color='lightcoral', density=True)
ax6.hist(sensitive_probs_p3, bins=30, alpha=0.4, label='Phase 3 Sensitive', color='lightgreen', density=True, histtype='step', linewidth=2)
ax6.hist(resistant_probs_p3, bins=30, alpha=0.4, label='Phase 3 Resistant', color='black', density=True, histtype='step', linewidth=2)
ax6.axvline(x=0.5, color='black', linestyle='--', linewidth=2, label='Decision Threshold')
ax6.set_xlabel('Predicted Probability')
ax6.set_ylabel('Density')
ax6.set_title('Probability Distribution Comparison')
ax6.legend(fontsize=8)
ax6.grid(True, alpha=0.3)

plt.tight_layout()
comparison_class_plot_path = output_dir / 'classification_pan_cancer_vs_tnbc_comparison.png'
plt.savefig(comparison_class_plot_path, dpi=300, bbox_inches='tight')
plt.show()

fig2 = plt.figure(figsize=(7, 4.67))

# Improvement metrics
ax_imp = plt.subplot(2, 2, 1)
improvements = {
    'Accuracy': metrics_p3['accuracy'] - metrics_p1['accuracy'],
    'Precision': metrics_p3['precision'] - metrics_p1['precision'],
    'Recall': metrics_p3['recall'] - metrics_p1['recall'],
    'F1-Score': metrics_p3['f1'] - metrics_p1['f1'],
    'ROC-AUC': metrics_p3['roc_auc'] - metrics_p1['roc_auc']
}
colors = ['green' if v > 0 else 'red' for v in improvements.values()]
ax_imp.barh(list(improvements.keys()), list(improvements.values()), color=colors, alpha=0.7)
ax_imp.axvline(x=0, color='black', linestyle='-', linewidth=1)
ax_imp.set_xlabel('Improvement')
ax_imp.set_title('Classification Performance Improvement\nfrom Fine-tuning')
ax_imp.grid(True, alpha=0.3, axis='x')

# Percentage improvement
ax_pct = plt.subplot(2, 2, 2)
pct_improvements = {
    'Accuracy': (metrics_p3['accuracy'] - metrics_p1['accuracy']) / metrics_p1['accuracy'] * 100 if metrics_p1['accuracy'] > 0 else 0,
    'Precision': (metrics_p3['precision'] - metrics_p1['precision']) / metrics_p1['precision'] * 100 if metrics_p1['precision'] > 0 else 0,
    'Recall': (metrics_p3['recall'] - metrics_p1['recall']) / metrics_p1['recall'] * 100 if metrics_p1['recall'] > 0 else 0,
    'F1-Score': (metrics_p3['f1'] - metrics_p1['f1']) / metrics_p1['f1'] * 100 if metrics_p1['f1'] > 0 else 0,
    'ROC-AUC': (metrics_p3['roc_auc'] - metrics_p1['roc_auc']) / metrics_p1['roc_auc'] * 100 if metrics_p1['roc_auc'] > 0 else 0
}
colors_pct = ['green' if v > 0 else 'red' for v in pct_improvements.values()]
ax_pct.barh(list(pct_improvements.keys()), list(pct_improvements.values()), color=colors_pct, alpha=0.7)
ax_pct.axvline(x=0, color='black', linestyle='-', linewidth=1)
ax_pct.set_xlabel('Percentage Improvement (%)')
ax_pct.set_title('Percentage Improvement from Fine-tuning')
ax_pct.grid(True, alpha=0.3, axis='x')

# Summary text
ax_summary = plt.subplot(2, 2, (3, 4))
ax_summary.axis('off')
summary_text = f"""
Classification Performance Improvement Summary:

Absolute Improvements:
  Accuracy:  {improvements['Accuracy']:+.4f} ({pct_improvements['Accuracy']:+.2f}%)
  Precision: {improvements['Precision']:+.4f} ({pct_improvements['Precision']:+.2f}%)
  Recall:    {improvements['Recall']:+.4f} ({pct_improvements['Recall']:+.2f}%)
  F1-Score:  {improvements['F1-Score']:+.4f} ({pct_improvements['F1-Score']:+.2f}%)
  ROC-AUC:   {improvements['ROC-AUC']:+.4f} ({pct_improvements['ROC-AUC']:+.2f}%)

Test Samples: {len(y_true)}
  Sensitive (IC50 ≤ median): {np.sum(y_true == 0)}
  Resistant (IC50 > median): {np.sum(y_true == 1)}

Phase 1 Performance:
  Accuracy:  {metrics_p1['accuracy']:.4f}
  Precision: {metrics_p1['precision']:.4f}
  Recall:    {metrics_p1['recall']:.4f}
  F1-Score:  {metrics_p1['f1']:.4f}
  ROC-AUC:   {metrics_p1['roc_auc']:.4f}

Phase 3 Performance:
  Accuracy:  {metrics_p3['accuracy']:.4f}
  Precision: {metrics_p3['precision']:.4f}
  Recall:    {metrics_p3['recall']:.4f}
  F1-Score:  {metrics_p3['f1']:.4f}
  ROC-AUC:   {metrics_p3['roc_auc']:.4f}
"""
ax_summary.text(0.05, 0.5, summary_text, fontsize=11, family='monospace', verticalalignment='center',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.7))

plt.tight_layout()
improvement_plot_path = output_dir / 'classification_improvement_summary.png'
plt.savefig(improvement_plot_path, dpi=300, bbox_inches='tight')
plt.show()

comparison_class_csv_path = output_dir / f'classification_pan_cancer_vs_tnbc_comparison_{split_type}.csv'
comparison_class_df.to_csv(comparison_class_csv_path, index=False)



## 8. Interpretability Analysis

### 8.1 SHAP Analysis: Pathway Importance

Use SHAP (SHapley Additive exPlanations) to identify which pathways are most important for drug response prediction.


In [None]:
# SHAP Analysis: Pathway Importance per Drug

from collections import defaultdict
xs = True

if xs == True:
    
    def create_pathway_predictor(model, drug_graph, drug_name, pathway_cols, normalization_stats=None):
        """
        Create a wrapper function for SHAP that fixes the drug and varies pathways.
        
        Args:
            model: Trained DrugResponsePathwayGNN model
            drug_graph: Fixed drug graph (PyTorch Geometric Data object)
            drug_name: Name of the drug (for reference)
            pathway_cols: List of pathway column names
            normalization_stats: Dict with 'means' and 'stds' for normalization (optional)
            
        Returns:
            Function that takes pathway scores array and returns IC50 predictions
        """
        model.eval()
        drug_graph = drug_graph.to(device)
        
        drug_batch = Batch.from_data_list([drug_graph]).to(device)
        
        def pathway_predictor(pathway_scores_array):
            """
            Predict IC50 given pathway scores.
            
            Args:
                pathway_scores_array: numpy array of shape (n_samples, n_pathways)
                
            Returns:
                numpy array of IC50 predictions (n_samples,)
            """
            # Convert to tensor
            if isinstance(pathway_scores_array, np.ndarray):
                pathway_tensor = torch.FloatTensor(pathway_scores_array).to(device)
            else:
                pathway_tensor = pathway_scores_array.to(device)
            
            if normalization_stats is not None:
                means = torch.FloatTensor(normalization_stats['means']).to(device)
                stds = torch.FloatTensor(normalization_stats['stds']).to(device)
                pathway_tensor = (pathway_tensor - means) / stds
            
            # Expand drug batch to match batch size
            batch_size = pathway_tensor.shape[0]
            expanded_drug_batch = Batch.from_data_list([drug_graph] * batch_size).to(device)
            
            with torch.no_grad():
                outputs = model(expanded_drug_batch, pathway_tensor)
                predictions = outputs['ic50'].cpu().numpy().flatten()
            
            return predictions
        
        return pathway_predictor
    
    def compute_shap_for_drug(model, drug_name, drug_graph, test_data, pathway_cols, 
                              normalization_stats=None, n_samples=100, n_background=50):
        """
        Compute SHAP values for a specific drug.
        
        Args:
            model: Trained model
            drug_name: Name of the drug
            drug_graph: Drug graph (PyTorch Geometric Data)
            test_data: DataFrame with test samples (filtered for this drug)
            pathway_cols: List of pathway column names
            normalization_stats: Normalization statistics
            n_samples: Number of samples to explain
            n_background: Number of background samples for SHAP
            
        Returns:
            Dictionary with SHAP values and pathway names
        """
        if len(test_data) == 0:
            return None
        
        # Sample test data
        if len(test_data) > n_samples:
            test_sample = test_data.sample(n=min(n_samples, len(test_data)), random_state=42)
        else:
            test_sample = test_data.copy()
        
        X_test = test_sample[pathway_cols].values.astype(np.float32)
        
        if SPLIT_TYPE == 'cell_line':
            background_data = pan_cancer_pathway.loc[pan_train_idx]
        else:
            background_data = pan_cancer_pathway.loc[pan_train_idx] if len(pan_train_idx) > 0 else pan_cancer_pathway
        
        drug_background = background_data[background_data['DrugName'] == drug_name]
        if len(drug_background) == 0:
            drug_background = background_data.sample(n=min(n_background, len(background_data)), random_state=42)
        else:
            drug_background = drug_background.sample(n=min(n_background, len(drug_background)), random_state=42)
        
        X_background = drug_background[pathway_cols].values.astype(np.float32)
        
        predictor = create_pathway_predictor(model, drug_graph, drug_name, pathway_cols, normalization_stats)
        
        # Compute SHAP values using KernelExplainer
        explainer = shap.KernelExplainer(predictor, X_background)
        shap_values = explainer.shap_values(X_test, nsamples=min(100, len(X_test)))
        
        mean_shap = np.abs(shap_values).mean(axis=0)
        
        results = {
            'drug_name': drug_name,
            'shap_values': shap_values,
            'mean_abs_shap': mean_shap,
            'pathway_names': pathway_cols,
            'test_samples': len(test_sample),
            'background_samples': len(X_background)
        }
        
        return results
    

    # Define checkpoint paths - use cellsplit models for cell_line split
    models_dir = output_dir / "models"
    split_type = SPLIT_TYPE if 'SPLIT_TYPE' in globals() else 'cell_line'
    
    # For cell_line split, use cellsplit models (old naming convention)
    if split_type == 'cell_line':
        phase2_checkpoint = models_dir / "trial2_phase2_breast_pathway_cellsplit.pt"
        phase3_checkpoint = models_dir / "trial2_phase3_tnbc_pathway_cellsplit.pt"
    else:
        phase2_checkpoint = models_dir / f"trial2_phase2_breast_pathway_{split_type}.pt"
        phase3_checkpoint = models_dir / f"trial2_phase3_tnbc_pathway_{split_type}.pt"
    
    
    if not phase3_checkpoint.exists():
        # Try cellsplit as fallback
        fallback_checkpoint = models_dir / "trial2_phase3_tnbc_pathway_cellsplit.pt"
        if fallback_checkpoint.exists():
            phase3_checkpoint = fallback_checkpoint
    
    shap_pathway_count = actual_pathway_count if 'actual_pathway_count' in globals() else len(pathway_cols)
    model_shap = DrugResponsePathwayGNN(cell_input_dim=shap_pathway_count).to(device)
    if phase3_checkpoint.exists():
        checkpoint = torch.load(phase3_checkpoint, map_location=device, weights_only=False)
        model_shap.load_state_dict(checkpoint['model_state_dict'])
    else:
        if phase2_checkpoint.exists():
            checkpoint = torch.load(phase2_checkpoint, map_location=device, weights_only=False)
            model_shap.load_state_dict(checkpoint['model_state_dict'])
    
    # Compute normalization stats from training data
    if SPLIT_TYPE == 'cell_line':
        train_data = pan_cancer_pathway.loc[pan_train_idx]
    else:
        train_data = pan_cancer_pathway.loc[pan_train_idx] if len(pan_train_idx) > 0 else pan_cancer_pathway
    
    normalization_stats = {
        'means': train_data[pathway_cols].values.mean(axis=0).astype(np.float32),
        'stds': train_data[pathway_cols].values.std(axis=0).astype(np.float32) + 1e-8
    }
    
    if phase3_results['n_samples'] > 0:
        if SPLIT_TYPE == 'cell_line':
            test_data_full = tnbc_pathway.loc[tnbc_test_idx] if len(tnbc_test_idx) > 0 else tnbc_pathway
        else:
            test_data_full = tnbc_pathway.loc[tnbc_test_idx] if len(tnbc_test_idx) > 0 else tnbc_pathway
    else:
        if SPLIT_TYPE == 'cell_line':
            test_data_full = breast_cancer_pathway.loc[breast_test_idx] if len(breast_test_idx) > 0 else breast_cancer_pathway
        else:
            test_data_full = breast_cancer_pathway.loc[breast_test_idx] if len(breast_test_idx) > 0 else breast_cancer_pathway
    
    unique_drugs = test_data_full['DrugName'].unique()
    
    # Define high-priority drugs for analysis
    high_priority_drugs = [
        'Paclitaxel',     # Microtubule - should show mitotic pathways
        'Cisplatin',      # DNA crosslinker - should show DNA repair
        'Olaparib',       # PARP inhibitor - should show DNA repair
        'Trametinib',     # MEK inhibitor - should show MAPK/ERK
        'Lapatinib',      # EGFR/HER2 - should show EGFR signaling
    ]
    
    # Prioritize high-priority drugs, then analyze all remaining drugs
    priority_drugs_in_set = [drug for drug in high_priority_drugs if drug in unique_drugs]
    other_drugs = [drug for drug in unique_drugs if drug not in high_priority_drugs]
    
    # Combine: high-priority first, then all others
    unique_drugs = priority_drugs_in_set + other_drugs
    
    if priority_drugs_in_set:
        pass
    
    # Compute SHAP values for each drug
    
    drug_shap_results = {}
    
    for drug_name in tqdm(unique_drugs, desc="Processing drugs"):
        # Filter test data for this drug
        drug_test_data = test_data_full[test_data_full['DrugName'] == drug_name].copy()
        
        if len(drug_test_data) == 0:
            continue
        
        if drug_name not in drug_graphs:
            continue
        
        drug_graph_data = drug_graphs[drug_name]['graph_data']
        
        # Compute SHAP values
        try:
            shap_result = compute_shap_for_drug(
                model_shap, drug_name, drug_graph_data, drug_test_data,
                pathway_cols, normalization_stats, n_samples=50, n_background=50
            )
            
            if shap_result is not None:
                drug_shap_results[drug_name] = shap_result
        except Exception as e:
            continue
    
    
    # Analyze and visualize results
    pathway_importance_aggregate = defaultdict(list)
    
    for drug_name, result in drug_shap_results.items():
        mean_shap = result['mean_abs_shap']
        pathway_names = result['pathway_names']
        
        for pathway, importance in zip(pathway_names, mean_shap):
            pathway_importance_aggregate[pathway].append(importance)
    
    # Compute mean importance per pathway
    pathway_mean_importance = {
        pathway: np.mean(importances) 
        for pathway, importances in pathway_importance_aggregate.items()
    }
    
    # Sort by mean importance
    sorted_pathways = sorted(pathway_mean_importance.items(), key=lambda x: x[1], reverse=True)
    
    for i, (pathway, importance) in enumerate(sorted_pathways[:20], 1):
        n_drugs = len(pathway_importance_aggregate[pathway])
    
    shap_results_path = output_dir / "shap_analysis_results.pkl"
    with open(shap_results_path, 'wb') as f:
        pickle.dump({
            'drug_shap_results': drug_shap_results,
            'pathway_mean_importance': pathway_mean_importance,
            'normalization_stats': normalization_stats
        }, f)
    
    
    
    # Select top pathways and drugs for visualization
    top_pathways_viz = [p[0] for p in sorted_pathways[:15]]
    top_drugs_viz = list(drug_shap_results.keys())[:15]
    
    heatmap_data = []
    for drug_name in top_drugs_viz:
        if drug_name in drug_shap_results:
            result = drug_shap_results[drug_name]
            pathway_names = result['pathway_names']
            mean_shap = result['mean_abs_shap']
            
            row = []
            for pathway in top_pathways_viz:
                if pathway in pathway_names:
                    idx = pathway_names.index(pathway)
                    row.append(mean_shap[idx])
                else:
                    row.append(0.0)
            heatmap_data.append(row)
    
    if len(heatmap_data) > 0:
        heatmap_array = np.array(heatmap_data)
        
        plt.figure(figsize=(7, 4.375))
        plt.imshow(heatmap_array, aspect='auto', cmap='gray', interpolation='nearest')
        plt.colorbar(label='Mean |SHAP| Value')
        plt.xlabel('Pathway', fontsize=10)
        plt.ylabel('Drug', fontsize=10)
        plt.title('SHAP Analysis: Pathway Importance per Drug\n(Top 15 Pathways × Top 15 Drugs)', fontsize=14, fontweight='bold')
        plt.xticks(range(len(top_pathways_viz)), top_pathways_viz, rotation=45, ha='right', fontsize=8)
        plt.yticks(range(len(top_drugs_viz)), top_drugs_viz, fontsize=9)
        plt.tight_layout()
        
        heatmap_path = output_dir / "shap_heatmap.png"
        plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
        plt.show()
    


## 9. Results Summary

Save all results, metrics, splits, and metadata for reproducibility.

In [None]:
def convert_numpy(obj):
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert_numpy(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy(item) for item in obj]
    return obj

splits_summary = {
    'split_type': SPLIT_TYPE,
    'pan_cancer': {
        'train_samples': len(pan_train_idx),
        'val_samples': len(pan_val_idx),
        'test_samples': len(pan_test_idx),
        'train_cell_lines': len(set(pan_cancer_pathway.loc[pan_train_idx]['ModelID'].unique())),
        'val_cell_lines': len(set(pan_cancer_pathway.loc[pan_val_idx]['ModelID'].unique())),
        'test_cell_lines': len(set(pan_cancer_pathway.loc[pan_test_idx]['ModelID'].unique()))
    },
    'breast_cancer': {
        'train_samples': len(breast_train_idx),
        'val_samples': len(breast_val_idx),
        'test_samples': len(breast_test_idx),
        'train_cell_lines': len(set(breast_cancer_pathway.loc[breast_train_idx]['ModelID'].unique())),
        'val_cell_lines': len(set(breast_cancer_pathway.loc[breast_val_idx]['ModelID'].unique())),
        'test_cell_lines': len(set(breast_cancer_pathway.loc[breast_test_idx]['ModelID'].unique()))
    },
    'tnbc': {
        'train_samples': len(tnbc_train_idx),
        'val_samples': len(tnbc_val_idx),
        'test_samples': len(tnbc_test_idx),
        'train_cell_lines': len(set(tnbc_pathway.loc[tnbc_train_idx]['ModelID'].unique())),
        'val_cell_lines': len(set(tnbc_pathway.loc[tnbc_val_idx]['ModelID'].unique())),
        'test_cell_lines': len(set(tnbc_pathway.loc[tnbc_test_idx]['ModelID'].unique()))
    },
    'timestamp': datetime.now().isoformat()
}

splits_summary_path = output_dir / f"data_splits_summary_{SPLIT_TYPE}.json"
with open(splits_summary_path, 'w') as f:
    json.dump(splits_summary, f, indent=2)

cv_summary = {}
if 'phase1_cv_results' in globals() and phase1_cv_results is not None:
    cv_summary['phase1'] = {
        'n_folds': len(phase1_cv_results.get('cv_results', [])),
        'best_fold': phase1_cv_results.get('best_fold', 1),
        'aggregated_metrics': phase1_cv_results.get('aggregated', {}),
        'fold_results': [
            {
                'fold': i+1,
                'val_r2': float(r['val_metrics']['r2']),
                'val_pearson': float(r['val_metrics']['pearson']),
                'val_rmse': float(r['val_metrics']['rmse'])
            }
            for i, r in enumerate(phase1_cv_results.get('cv_results', []))
        ]
    }

if 'phase2_cv_results' in globals() and phase2_cv_results is not None:
    cv_summary['phase2'] = {
        'n_folds': len(phase2_cv_results.get('cv_results', [])),
        'best_fold': phase2_cv_results.get('best_fold', 1),
        'aggregated_metrics': phase2_cv_results.get('aggregated', {}),
        'fold_results': [
            {
                'fold': i+1,
                'val_r2': float(r['val_metrics']['r2']),
                'val_pearson': float(r['val_metrics']['pearson']),
                'val_rmse': float(r['val_metrics']['rmse'])
            }
            for i, r in enumerate(phase2_cv_results.get('cv_results', []))
        ]
    }

if cv_summary:
    cv_summary_path = output_dir / f"cv_results_summary_{SPLIT_TYPE}.json"
    cv_summary_serializable = convert_numpy(cv_summary)
    with open(cv_summary_path, 'w') as f:
        json.dump(cv_summary_serializable, f, indent=2)

checkpoints_info = {
    'phase1': {
        'checkpoint_path': str(phase1_checkpoint),
        'exists': phase1_checkpoint.exists(),
        'cv_best_checkpoint': str(phase1_cv_results['best_checkpoint']) if 'phase1_cv_results' in globals() and phase1_cv_results is not None else None
    },
    'phase2': {
        'checkpoint_path': str(phase2_checkpoint),
        'exists': phase2_checkpoint.exists(),
        'cv_best_checkpoint': str(phase2_cv_results['best_checkpoint']) if 'phase2_cv_results' in globals() and phase2_cv_results is not None else None
    },
    'phase3': {
        'checkpoint_path': str(phase3_checkpoint),
        'exists': phase3_checkpoint.exists()
    },
    'timestamp': datetime.now().isoformat()
}

checkpoints_info_path = output_dir / f"model_checkpoints_info_{SPLIT_TYPE}.json"
with open(checkpoints_info_path, 'w') as f:
    json.dump(checkpoints_info, f, indent=2)

comprehensive_summary = {
    'experiment_info': {
        'split_type': SPLIT_TYPE,
        'timestamp': datetime.now().isoformat(),
        'device': str(device),
        'pathway_count': int(actual_pathway_count)
    },
    'data_splits': splits_summary,
    'regression_metrics': {
        'phase1': {
            'r2': float(phase1_results['metrics']['r2']),
            'pearson': float(phase1_results['metrics']['pearson']),
            'spearman': float(phase1_results['metrics']['spearman']),
            'rmse': float(phase1_results['metrics']['rmse']),
            'mae': float(phase1_results['metrics']['mae'])
        },
        'phase2': {
            'r2': float(phase2_results['metrics']['r2']),
            'pearson': float(phase2_results['metrics']['pearson']),
            'spearman': float(phase2_results['metrics']['spearman']),
            'rmse': float(phase2_results['metrics']['rmse']),
            'mae': float(phase2_results['metrics']['mae'])
        },
        'phase3': {
            'r2': float(phase3_results['metrics']['r2']),
            'pearson': float(phase3_results['metrics']['pearson']),
            'spearman': float(phase3_results['metrics']['spearman']),
            'rmse': float(phase3_results['metrics']['rmse']),
            'mae': float(phase3_results['metrics']['mae'])
        }
    },
    'classification_metrics': classification_metrics if 'classification_metrics' in globals() else {},
    'cv_summary': cv_summary,
    'files_saved': {
        'splits': str(splits_file),
        'performance_csv': str(performance_csv_path),
        'performance_pkl': str(performance_pkl_path),
        'classification_csv': str(classification_csv_path) if 'classification_csv_path' in globals() else None,
        'classification_pkl': str(classification_pkl_path) if 'classification_pkl_path' in globals() else None,
        'figures': [
            'classification_head_analysis.png',
            'pan_cancer_vs_tnbc_comparison.png',
            'classification_pan_cancer_vs_tnbc_comparison.png',
            'classification_improvement_summary.png'
        ]
    }
}

if 'comparison_df' in globals():
    comprehensive_summary['comparison_regression'] = comparison_df.to_dict('records')
if 'comparison_class_df' in globals():
    comprehensive_summary['comparison_classification'] = comparison_class_df.to_dict('records')

shap_data = None
if 'drug_shap_results' in globals():
    shap_data = drug_shap_results
elif 'shap_results_dict' in globals():
    shap_data = shap_results_dict

if shap_data is not None:
    shap_summary = {}
    for drug_name, result in shap_data.items():
        if isinstance(result, dict) and 'mean_abs_shap' in result:
            top_indices = np.argsort(result['mean_abs_shap'])[-10:][::-1]
            shap_summary[drug_name] = {
                'top_pathways': [result['pathway_names'][i] for i in top_indices],
                'top_shap_values': [float(result['mean_abs_shap'][i]) for i in top_indices],
                'n_test_samples': int(result.get('test_samples', 0)),
                'n_background_samples': int(result.get('background_samples', 0))
            }
    
    if shap_summary:
        comprehensive_summary['shap_analysis'] = shap_summary
        comprehensive_summary['files_saved']['shap_results'] = 'shap_analysis_results.pkl'

comprehensive_summary_path = output_dir / f"comprehensive_results_summary_{SPLIT_TYPE}.json"
comprehensive_summary_serializable = convert_numpy(comprehensive_summary)
with open(comprehensive_summary_path, 'w') as f:
    json.dump(comprehensive_summary_serializable, f, indent=2)