# TNBC Drug Cytotoxicity Prediction Model
## Comprehensive Data Preparation Pipeline

This notebook creates hierarchical datasets for GNN-based drug response prediction:

**Pipeline Steps:**
1. Load cancer cell line data (DepMap, GDSC2, gene expression)
2. Identify BL1 TNBC cell lines using molecular criteria
3. Load/fetch SMILES structures for all drugs
4. Filter GDSC2 data for quality (RMSE < 0.3)
5. Merge drug response with SMILES and gene expression
6. Apply quality control filters
7. Create tissue-specific subsets:
   - Pan-cancer (all cell lines) → `pan_cancer_data.pkl`
   - Breast cancer only → `breast_cancer_data.pkl`
   - TNBC subset → `tnbc_data.pkl`
   - BL1 TNBC subset → `bl1_tnbc_data.pkl`
8. Generate summary report

**Training Strategy:**
Train on pan-cancer → Fine-tune on breast → Fine-tune on TNBC → Fine-tune on BL1 TNBC

In [57]:
import pandas as pd
import numpy as np
from pathlib import Path
import time
import pubchempy as pcp
import torch
import torch.nn as nn
from torch_geometric.data import Data, Batch
from torch_geometric.nn import TransformerConv, global_mean_pool, global_max_pool
from rdkit import Chem
from rdkit.Chem import AllChem
import pickle
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import os
import subprocess
import sys
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import matplotlib.pyplot as plt

# Data directory
data_dir = Path("/Users/tjalling/Desktop/Dev./Capstone/Datasets/Cytotoxicity Model")

In [None]:
# Load datasets

model_df = pd.read_csv(data_dir / "DepMap Model Data.csv")
expression_df = pd.read_csv(data_dir / "Omics Expression TPM Logp1 Human Protein Coding Genes.csv")
mutations_df = pd.read_csv(data_dir / "Omics Somatic Mutations.csv")
cn_df = pd.read_csv(data_dir / "Omics CN Gene WGS Data.csv")
compounds_df = pd.read_csv(data_dir / "Screened Compounds v8.5.csv")
gdsc2_df = pd.read_excel(data_dir / "GDSC2 Fitted Dose Response Oct 27 2023.xlsx")

print(f"Loaded {len(model_df)} cell lines, {len(compounds_df)} compounds, {len(gdsc2_df)} drug responses")

In [None]:
# BL1 Cell Line Identification Functions

def find_gene_columns(df, gene_symbols):
    """Find columns matching gene symbols (handles 'SYMBOL (EntrezID)' format)."""
    found = {}
    for symbol in gene_symbols:
        matches = [col for col in df.columns if col.startswith(f"{symbol} (")]
        if matches:
            found[symbol] = matches[0]
    return found

def normalize_cell_line_name(name):
    """Normalize cell line name for matching."""
    return str(name).replace('-', '').replace('_', '').upper()

def identify_bl1_cell_lines(model_df, mutations_df, expression_df, cn_df):
    """
    Identify BL1 TNBC cell lines using molecular criteria.
    
    Criteria: Breast + TP53 mutation + (High proliferation OR MYC amplification)
    Plus known reference lines: HCC1143, HCC1937, MDAMB468, HCC70, HCC1806, CAL51, BT549, SUM149PT
    """
    print("BL1 TNBC Cell Line Identification")
    print("="*80)
    
    known_bl1_names = ['HCC1143', 'HCC1937', 'MDAMB468', 'HCC70', 'HCC1806', 'CAL51', 'BT549', 'SUM149PT']
    
    # Filter for breast cancer lines
    breast_lines = model_df[model_df['OncotreeLineage'] == 'Breast']['ModelID'].values
    print(f"Breast cancer lines: {len(breast_lines)}")
    
    # TP53 mutations (non-silent)
    tp53_mutations = mutations_df[
        (mutations_df['HugoSymbol'] == 'TP53') & 
        (mutations_df['VariantType'] != 'Silent')
    ]['ModelID'].unique()
    print(f"TP53 mutated lines: {len(tp53_mutations)}")
    
    # High proliferation (top 25%)
    prolif_genes = ['MKI67', 'CCNB1', 'CCNE1']
    prolif_cols = find_gene_columns(expression_df, prolif_genes)
    
    if len(prolif_cols) > 0:
        gene_col_names = list(prolif_cols.values())
        prolif_data = expression_df[['ModelID'] + gene_col_names].copy()
        prolif_data['prolif_score'] = prolif_data[gene_col_names].mean(axis=1)
        threshold = prolif_data['prolif_score'].quantile(0.75)
        high_prolif_lines = prolif_data[prolif_data['prolif_score'] >= threshold]['ModelID'].values
        print(f"High proliferation lines: {len(high_prolif_lines)}")
    else:
        high_prolif_lines = []
    
    # MYC amplification (CN > 6)
    myc_cols = find_gene_columns(cn_df, ['MYC', 'MYCN'])
    if myc_cols:
        myc_col = myc_cols.get('MYC') or myc_cols.get('MYCN')
        myc_amp_lines = cn_df[cn_df[myc_col] > 6]['ModelID'].values
        print(f"MYC amplified lines: {len(myc_amp_lines)}")
    else:
        myc_amp_lines = []
    
    # Combine criteria
    criteria_bl1 = set(breast_lines) & set(tp53_mutations) & (set(high_prolif_lines) | set(myc_amp_lines))
    
    # Find known reference lines
    known_bl1_present = set()
    cell_name_col = 'StrippedCellLineName' if 'StrippedCellLineName' in model_df.columns else 'ModelID'
    
    for ref_name in known_bl1_names:
        normalized_ref = normalize_cell_line_name(ref_name)
        matching = model_df[
            model_df[cell_name_col].apply(lambda x: normalize_cell_line_name(x) == normalized_ref)
        ]['ModelID'].values
        if len(matching) > 0:
            known_bl1_present.update(matching)
    
    # Final BL1 list
    final_bl1 = criteria_bl1 | known_bl1_present
    
    print(f"\nFinal BL1 lines: {len(final_bl1)} (Criteria: {len(criteria_bl1)}, Reference: {len(known_bl1_present)})")
    print("="*80)
    
    # Display
    for model_id in sorted(final_bl1):
        cell_name = model_df[model_df['ModelID'] == model_id][cell_name_col].values
        name = cell_name[0] if len(cell_name) > 0 else 'Unknown'
        source = []
        if model_id in criteria_bl1:
            source.append("criteria")
        if model_id in known_bl1_present:
            source.append("reference")
        print(f"  {model_id:20s} {name:20s} [{', '.join(source)}]")
    
    return sorted(list(final_bl1))

# Run identification
bl1_model_ids = identify_bl1_cell_lines(model_df, mutations_df, expression_df, cn_df)

BL1 TNBC Cell Line Identification
Breast cancer lines: 96
TP53 mutated lines: 1187
High proliferation lines: 439
MYC amplified lines: 49

Final BL1 lines: 23 (Criteria: 19, Reference: 8)
  ACH-000017           SKBR3                [criteria]
  ACH-000111           HCC1187              [criteria]
  ACH-000117           EFM192A              [criteria]
  ACH-000196           HCC1599              [criteria]
  ACH-000223           HCC1937              [criteria, reference]
  ACH-000248           AU565                [criteria]
  ACH-000277           HCC1419              [criteria]
  ACH-000288           BT549                [criteria, reference]
  ACH-000374           HCC1143              [reference]
  ACH-000536           BT20                 [criteria]
  ACH-000621           MDAMB157             [criteria]
  ACH-000624           HCC1806              [reference]
  ACH-000668           HCC70                [criteria, reference]
  ACH-000699           HCC1395              [criteria]
  ACH-00

In [None]:
# Create Training Dataset

# Map BL1 ModelIDs to COSMIC IDs
bl1_cosmic_map = model_df[model_df['ModelID'].isin(bl1_model_ids)][
    ['ModelID', 'COSMICID', 'StrippedCellLineName']
].copy()

# Filter GDSC2 for BL1 cell lines
cell_id_col = 'COSMIC_ID' if 'COSMIC_ID' in gdsc2_df.columns else 'COSMICID'
bl1_cosmic_ids = bl1_cosmic_map['COSMICID'].values
gdsc2_bl1 = gdsc2_df[gdsc2_df[cell_id_col].isin(bl1_cosmic_ids)].copy()

# Remove missing IC50
ln_ic50_col = [col for col in gdsc2_df.columns if 'LN_IC50' in col.upper()][0]
gdsc2_bl1 = gdsc2_bl1[gdsc2_bl1[ln_ic50_col].notna()]

# Merge with cell line info
training_data = gdsc2_bl1.merge(
    bl1_cosmic_map,
    left_on=cell_id_col,
    right_on='COSMICID',
    how='left'
)

drug_col = [col for col in training_data.columns if 'DRUG' in col.upper() and 'NAME' in col.upper()][0]

print(f"Training Dataset: {len(training_data)} pairs | {training_data[drug_col].nunique()} drugs | {training_data['ModelID'].nunique()} cell lines")

Training Dataset: 5311 pairs | 286 drugs | 20 cell lines


In [None]:
# Load SMILES structures (use existing file if available)

smiles_file = data_dir / 'drugs_with_smiles.csv'

if smiles_file.exists():
    drugs_with_smiles = pd.read_csv(smiles_file)
else:
    def get_smiles_pubchempy(drug_name):
        """
        Get SMILES string for a drug using PubChemPy.
        
        Args:
            drug_name: Drug name to search
            
        Returns:
            SMILES string or None
        """
        try:
            compounds = pcp.get_compounds(drug_name, 'name')
            if compounds:
                return compounds[0].isomeric_smiles
            return None
        except Exception:
            return None
    
    unique_drugs = gdsc2_df[drug_col].unique()
    drugs_df = pd.DataFrame({drug_col: unique_drugs})
    
    smiles_list = []
    for drug_name in drugs_df[drug_col]:
        smiles_list.append(get_smiles_pubchempy(drug_name))
        time.sleep(0.2)
    
    drugs_df['SMILES'] = smiles_list
    drugs_df.to_csv(smiles_file, index=False)
    drugs_with_smiles = drugs_df

drugs_with_smiles = drugs_with_smiles[drugs_with_smiles['SMILES'].notna()].copy()
print(f"Drugs with valid SMILES: {len(drugs_with_smiles)}")

Drugs with valid SMILES: 228


In [None]:
# Filter GDSC2 data for quality (RMSE < 0.3)

def filter_gdsc2_quality(gdsc2_df, rmse_threshold=0.3):
    """
    Filter GDSC2 data for good curve fits.
    
    Args:
        gdsc2_df: GDSC2 dataframe
        rmse_threshold: Maximum RMSE to keep
        
    Returns:
        Filtered dataframe
    """
    rmse_col = [col for col in gdsc2_df.columns if 'RMSE' in col.upper()]
    if rmse_col:
        initial_count = len(gdsc2_df)
        gdsc2_filtered = gdsc2_df[gdsc2_df[rmse_col[0]] < rmse_threshold].copy()
        print(f"GDSC2 quality filter: {initial_count} → {len(gdsc2_filtered)} measurements (RMSE < {rmse_threshold})")
        return gdsc2_filtered
    return gdsc2_df.copy()

gdsc2_filtered = filter_gdsc2_quality(gdsc2_df)

GDSC2 quality filter: 242036 → 242036 measurements (RMSE < 0.3)


In [None]:
# Merge GDSC2 drug response with SMILES structures

# Identify column names
ln_ic50_col = [col for col in gdsc2_filtered.columns if 'LN_IC50' in col.upper()][0]
auc_col = [col for col in gdsc2_filtered.columns if 'AUC' in col.upper()][0]
drug_name_col = [col for col in gdsc2_filtered.columns if 'DRUG' in col.upper() and 'NAME' in col.upper()][0]
cosmic_id_col = [col for col in gdsc2_filtered.columns if 'COSMIC' in col.upper()][0]

# Merge on drug name
response_with_smiles = gdsc2_filtered.merge(
    drugs_with_smiles,
    left_on=drug_name_col,
    right_on=drugs_with_smiles.columns[0],
    how='inner'
)

# Remove rows with missing IC50
response_with_smiles = response_with_smiles[response_with_smiles[ln_ic50_col].notna()].copy()

print(f"Drug-cell pairs with SMILES: {len(response_with_smiles)}")
print(f"Unique drugs: {response_with_smiles[drug_name_col].nunique()}")
print(f"Unique cell lines: {response_with_smiles[cosmic_id_col].nunique()}")

Drug-cell pairs with SMILES: 199501
Unique drugs: 228
Unique cell lines: 969


In [None]:
# Create mapping between COSMIC_ID and ModelID

def create_id_mapping(model_df, expression_df):
    """
    Create mapping between COSMIC_ID and ModelID for cell lines with expression data.
    
    Args:
        model_df: Model metadata
        expression_df: Expression data
        
    Returns:
        DataFrame with COSMICID and ModelID mapping
    """
    # Get cell lines that have expression data
    expression_model_ids = set(expression_df['ModelID'])
    
    # Create mapping
    id_mapping = model_df[model_df['ModelID'].isin(expression_model_ids)][
        ['ModelID', 'COSMICID', 'StrippedCellLineName', 'OncotreeLineage', 'OncotreePrimaryDisease']
    ].copy()
    
    # Remove rows without COSMICID
    id_mapping = id_mapping[id_mapping['COSMICID'].notna()].copy()
    
    return id_mapping

id_mapping = create_id_mapping(model_df, expression_df)
print(f"Cell lines with expression data and COSMIC ID: {len(id_mapping)}")

Cell lines with expression data and COSMIC ID: 715


In [None]:
# Merge response with cell line IDs (without expression data yet)

# Step 1: Merge response with ID mapping to get ModelIDs
response_with_ids = response_with_smiles.merge(
    id_mapping,
    left_on=cosmic_id_col,
    right_on='COSMICID',
    how='inner'
)

print(f"After ID mapping: {len(response_with_ids)} pairs | {response_with_ids['ModelID'].nunique()} cell lines | {response_with_ids[drug_name_col].nunique()} drugs")

After ID mapping: 144921 pairs | 698 cell lines | 228 drugs


In [None]:
# Apply quality filters before merging

def apply_quality_filters_pre_merge(response_data, drug_col, min_tests_per_drug=10):
    """
    Filter drugs tested on too few cell lines.
    
    Args:
        response_data: Response data with ModelIDs
        drug_col: Drug name column
        min_tests_per_drug: Minimum tests required per drug
        
    Returns:
        Filtered dataset
    """
    drug_counts = response_data.groupby(drug_col).size()
    valid_drugs = drug_counts[drug_counts >= min_tests_per_drug].index
    data_clean = response_data[response_data[drug_col].isin(valid_drugs)].copy()
    
    print(f"Quality filter: {len(response_data)} → {len(data_clean)} pairs, {data_clean[drug_col].nunique()} drugs")
    
    return data_clean

response_clean = apply_quality_filters_pre_merge(response_with_ids, drug_name_col)

Quality filter: 144921 → 144921 pairs, 228 drugs


In [None]:
# Preprocess expression data (quality + dimensionality reduction)

def preprocess_expression_data(expression_df, response_clean, max_missing_pct=0.01, top_n_genes=3000):
    """
    Clean and reduce expression data before merging.
    
    Args:
        expression_df: Full gene expression data
        response_clean: Response data (to filter cell lines)
        max_missing_pct: Maximum missing % per gene (default 1%)
        top_n_genes: Number of top variance genes to keep
        
    Returns:
        Preprocessed expression dataframe
    """
    # Filter to only needed cell lines
    needed_model_ids = response_clean['ModelID'].unique()
    expr_subset = expression_df[expression_df['ModelID'].isin(needed_model_ids)].copy()
    
    # Get gene columns (numeric columns only, excluding ModelID)
    gene_cols = expr_subset.select_dtypes(include=['number']).columns.tolist()
    
    print(f"Starting: {len(expr_subset)} cell lines × {len(gene_cols)} genes")
    
    # Remove genes with >1% missing values
    missing_pct = expr_subset[gene_cols].isna().mean()
    valid_genes = missing_pct[missing_pct <= max_missing_pct].index.tolist()
    print(f"After removing genes with >{max_missing_pct*100}% missing: {len(valid_genes)} genes")
    
    # Select top variance genes
    gene_variances = expr_subset[valid_genes].var()
    top_genes = gene_variances.nlargest(min(top_n_genes, len(valid_genes))).index.tolist()
    print(f"Selected top {len(top_genes)} highest variance genes")
    
    # Create filtered dataset
    expr_filtered = expr_subset[['ModelID'] + top_genes].copy()
    
    # Impute remaining NaNs with gene median
    for gene in top_genes:
        if expr_filtered[gene].isna().any():
            expr_filtered[gene].fillna(expr_filtered[gene].median(), inplace=True)
    
    print(f"Final: {len(expr_filtered)} cell lines × {len(top_genes)} genes (NaNs imputed)")
    
    return expr_filtered

print("Preprocessing expression data...")
expression_clean = preprocess_expression_data(expression_df, response_clean)

# Merge with preprocessed expression
print("\nMerging response with expression...")
pan_cancer_data = response_clean.merge(
    expression_clean,
    on='ModelID',
    how='inner'
)

# Organize columns
meta_cols = ['COSMICID', 'ModelID', 'StrippedCellLineName', 'OncotreeLineage', 
             'OncotreePrimaryDisease', drug_name_col, 'SMILES', ln_ic50_col, auc_col]
gene_cols = [col for col in expression_clean.columns if col != 'ModelID']
pan_cancer_data = pan_cancer_data[meta_cols + gene_cols].copy()

print(f"Pan-cancer dataset: {len(pan_cancer_data)} pairs | {pan_cancer_data['ModelID'].nunique()} cell lines | {pan_cancer_data[drug_name_col].nunique()} drugs")

Preprocessing expression data...
Starting: 718 cell lines × 19216 genes
After removing genes with >1.0% missing: 19216 genes
Selected top 3000 highest variance genes
Final: 718 cell lines × 3000 genes (NaNs imputed)

Merging response with expression...
Pan-cancer dataset: 149093 pairs | 698 cell lines | 228 drugs


In [None]:
# Create tissue-specific subsets

def create_tissue_subset(data, tissue_type=None, primary_disease=None):
    """
    Filter data by tissue type and/or primary disease.
    
    Args:
        data: Full dataset
        tissue_type: OncotreeLineage value
        primary_disease: OncotreePrimaryDisease substring
        
    Returns:
        Filtered subset
    """
    subset = data.copy()
    
    if tissue_type:
        subset = subset[subset['OncotreeLineage'] == tissue_type]
    
    if primary_disease:
        subset = subset[subset['OncotreePrimaryDisease'].str.contains(primary_disease, na=False, case=False)]
    
    return subset

breast_cancer_data = create_tissue_subset(pan_cancer_data, tissue_type='Breast')
tnbc_data = create_tissue_subset(breast_cancer_data, primary_disease='Breast')
bl1_tnbc_data = tnbc_data[tnbc_data['ModelID'].isin(bl1_model_ids)].copy()

print(f"Breast: {len(breast_cancer_data)} pairs, {breast_cancer_data['ModelID'].nunique()} cell lines")
print(f"TNBC: {len(tnbc_data)} pairs, {tnbc_data['ModelID'].nunique()} cell lines")
print(f"BL1 TNBC: {len(bl1_tnbc_data)} pairs, {bl1_tnbc_data['ModelID'].nunique()} cell lines")

Breast: 9601 pairs, 45 cell lines
TNBC: 9601 pairs, 45 cell lines
BL1 TNBC: 4303 pairs, 20 cell lines


In [None]:
# Save processed datasets

output_dir = Path("/Users/tjalling/Desktop/Dev./Capstone/Model_Notebooks")

datasets = {
    'pan_cancer_data.pkl': pan_cancer_data,
    'breast_cancer_data.pkl': breast_cancer_data,
    'tnbc_data.pkl': tnbc_data,
    'bl1_tnbc_data.pkl': bl1_tnbc_data
}

for filename, dataset in datasets.items():
    output_path = output_dir / filename
    dataset.to_pickle(output_path)
    print(f"Saved {filename}: {len(dataset)} pairs, {dataset['ModelID'].nunique()} cell lines")

Saved pan_cancer_data.pkl: 149093 pairs, 698 cell lines
Saved breast_cancer_data.pkl: 9601 pairs, 45 cell lines
Saved tnbc_data.pkl: 9601 pairs, 45 cell lines
Saved bl1_tnbc_data.pkl: 4303 pairs, 20 cell lines


In [None]:
# Generate summary report

def generate_summary_report(datasets_dict, drug_col, output_file):
    """
    Generate summary report of processed datasets.
    
    Args:
        datasets_dict: Dictionary of dataset names and dataframes
        drug_col: Drug name column
        output_file: Path to save summary
    """
    summary_lines = []
    summary_lines.append("DATA PROCESSING SUMMARY")
    summary_lines.append("")
    
    pan_cancer = datasets_dict['pan_cancer_data.pkl']
    summary_lines.append("OVERALL STATISTICS:")
    summary_lines.append(f"  Total drugs: {pan_cancer[drug_col].nunique()}")
    summary_lines.append(f"  Total cell lines: {pan_cancer['ModelID'].nunique()}")
    summary_lines.append(f"  Gene features: {len([c for c in pan_cancer.columns if '(' in c and ')' in c])}")
    summary_lines.append("")
    
    summary_lines.append("DATASET STATISTICS:")
    for name, data in datasets_dict.items():
        summary_lines.append(f"\n  {name}:")
        summary_lines.append(f"    Pairs: {len(data)}")
        summary_lines.append(f"    Cell lines: {data['ModelID'].nunique()}")
        summary_lines.append(f"    Drugs: {data[drug_col].nunique()}")
    
    summary_lines.append("")
    
    bl1_data = datasets_dict['bl1_tnbc_data.pkl']
    if len(bl1_data) > 0:
        summary_lines.append("BL1 TNBC CELL LINES:")
        for model_id in sorted(bl1_data['ModelID'].unique()):
            cell_name = bl1_data[bl1_data['ModelID'] == model_id]['StrippedCellLineName'].iloc[0]
            n_drugs = len(bl1_data[bl1_data['ModelID'] == model_id])
            summary_lines.append(f"  {model_id}: {cell_name} ({n_drugs} drugs)")
    
    with open(output_file, 'w') as f:
        f.write('\n'.join(summary_lines))
    
    print('\n'.join(summary_lines))

summary_file = output_dir / 'data_summary.txt'
generate_summary_report(datasets, drug_name_col, summary_file)

DATA PROCESSING SUMMARY

OVERALL STATISTICS:
  Total drugs: 228
  Total cell lines: 698
  Gene features: 2999

DATASET STATISTICS:

  pan_cancer_data.pkl:
    Pairs: 149093
    Cell lines: 698
    Drugs: 228

  breast_cancer_data.pkl:
    Pairs: 9601
    Cell lines: 45
    Drugs: 228

  tnbc_data.pkl:
    Pairs: 9601
    Cell lines: 45
    Drugs: 228

  bl1_tnbc_data.pkl:
    Pairs: 4303
    Cell lines: 20
    Drugs: 228

BL1 TNBC CELL LINES:
  ACH-000111: HCC1187 (224 drugs)
  ACH-000117: EFM192A (223 drugs)
  ACH-000196: HCC1599 (182 drugs)
  ACH-000223: HCC1937 (225 drugs)
  ACH-000248: AU565 (224 drugs)
  ACH-000277: HCC1419 (220 drugs)
  ACH-000288: BT549 (226 drugs)
  ACH-000374: HCC1143 (224 drugs)
  ACH-000536: BT20 (224 drugs)
  ACH-000621: MDAMB157 (226 drugs)
  ACH-000624: HCC1806 (224 drugs)
  ACH-000668: HCC70 (226 drugs)
  ACH-000699: HCC1395 (158 drugs)
  ACH-000711: JIMT1 (225 drugs)
  ACH-000768: MDAMB231 (226 drugs)
  ACH-000849: MDAMB468 (225 drugs)
  ACH-000856: CAL

In [None]:
# SMILES to Graph Conversion

def get_atom_features(atom):
    """
    Extract atom features for GNN.
    
    Args:
        atom: RDKit atom object
        
    Returns:
        List of atom features
    """
    features = [
        atom.GetAtomicNum(),
        atom.GetDegree(),
        atom.GetFormalCharge(),
        atom.GetHybridization().real,
        atom.GetIsAromatic(),
        atom.GetTotalNumHs(),
        atom.GetNumRadicalElectrons(),
        atom.IsInRing(),
        atom.GetChiralTag().real,
    ]
    return features

def get_bond_features(bond):
    """
    Extract bond features for GNN.
    
    Args:
        bond: RDKit bond object
        
    Returns:
        List of bond features
    """
    features = [
        bond.GetBondTypeAsDouble(),
        bond.GetIsConjugated(),
        bond.IsInRing(),
        bond.GetStereo().real,
    ]
    return features

def smiles_to_graph(smiles_string):
    """
    Convert SMILES string to PyTorch Geometric graph.
    
    Args:
        smiles_string: SMILES representation of molecule
        
    Returns:
        torch_geometric.data.Data object or None if invalid
    """
    try:
        mol = Chem.MolFromSmiles(smiles_string)
        if mol is None:
            return None
        
        mol = Chem.AddHs(mol)
        
        # Extract node features
        node_features = []
        for atom in mol.GetAtoms():
            node_features.append(get_atom_features(atom))
        
        if len(node_features) == 0:
            return None
            
        x = torch.tensor(node_features, dtype=torch.float)
        
        # Extract edge indices and features
        edge_indices = []
        edge_features = []
        
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            
            edge_indices.append([i, j])
            edge_indices.append([j, i])
            
            bond_feats = get_bond_features(bond)
            edge_features.append(bond_feats)
            edge_features.append(bond_feats)
        
        if len(edge_indices) == 0:
            edge_indices = [[0, 0]]
            edge_features = [[0.0] * 4]
        
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_features, dtype=torch.float)
        
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        
    except Exception:
        return None

In [None]:
# Drug Encoder Architecture

class DrugEncoder(nn.Module):
    """
    Graph Neural Network encoder for molecular SMILES structures.
    Uses TransformerConv layers to generate 256-dim drug embeddings.
    """
    
    def __init__(self, node_feature_dim=9, edge_feature_dim=4, hidden_dim=256, dropout=0.3):
        """
        Initialize DrugEncoder.
        
        Args:
            node_feature_dim: Number of atom features
            edge_feature_dim: Number of bond features
            hidden_dim: Hidden layer dimension (output is 256)
            dropout: Dropout rate
        """
        super(DrugEncoder, self).__init__()
        
        # Device selection
        if torch.backends.mps.is_available():
            self.device = torch.device('mps')
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        
        # TransformerConv layers
        self.conv1 = TransformerConv(node_feature_dim, 128, heads=4, dropout=dropout, edge_dim=edge_feature_dim)
        self.bn1 = nn.BatchNorm1d(128 * 4)
        
        self.conv2 = TransformerConv(128 * 4, 256, heads=8, dropout=dropout, edge_dim=edge_feature_dim)
        self.bn2 = nn.BatchNorm1d(256 * 8)
        
        self.conv3 = TransformerConv(256 * 8, 256, heads=8, dropout=dropout, edge_dim=edge_feature_dim)
        self.bn3 = nn.BatchNorm1d(256 * 8)
        
        self.conv4 = TransformerConv(256 * 8, 256, heads=4, concat=False, dropout=dropout, edge_dim=edge_feature_dim)
        self.bn4 = nn.BatchNorm1d(256)
        
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        
        # Attention weights for pooling
        self.attention_weights = nn.Linear(256, 1)
        
        self.to(self.device)
    
    def forward(self, data):
        """
        Forward pass through drug encoder.
        
        Args:
            data: PyTorch Geometric Data/Batch object
            
        Returns:
            Drug embeddings of shape (batch_size, 256)
        """
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = x.to(self.device, non_blocking=True)
        edge_index = edge_index.to(self.device, non_blocking=True)
        edge_attr = edge_attr.to(self.device, non_blocking=True)
        batch = batch.to(self.device, non_blocking=True)
        
        # Layer 1
        x = self.conv1(x, edge_index, edge_attr)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # Layer 2
        x = self.conv2(x, edge_index, edge_attr)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # Layer 3
        x = self.conv3(x, edge_index, edge_attr)
        x = self.bn3(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        # Layer 4
        x = self.conv4(x, edge_index, edge_attr)
        x = self.bn4(x)
        x = self.relu(x)
        
        # Attention-weighted pooling
        attention_scores = self.attention_weights(x)
        attention_scores = torch.softmax(attention_scores, dim=0)
        
        x_weighted = x * attention_scores
        x_mean = global_mean_pool(x_weighted, batch)
        x_max = global_max_pool(x, batch)
        
        embedding = (x_mean + x_max) / 2
        
        return embedding

In [None]:
# Preprocess and Cache Drug Graphs

def preprocess_drugs(drugs_df, drug_name_col, smiles_col, save_path):
    """
    Convert all drugs to graphs and cache.
    
    Args:
        drugs_df: DataFrame with drug names and SMILES
        drug_name_col: Column name for drug names
        smiles_col: Column name for SMILES strings
        save_path: Path to save cached graphs
        
    Returns:
        Dictionary mapping drug names to graph data
    """
    drug_graphs = {}
    failed_drugs = []
    
    for idx, row in tqdm(drugs_df.iterrows(), total=len(drugs_df), desc="Converting SMILES"):
        drug_name = row[drug_name_col]
        smiles = row[smiles_col]
        
        graph = smiles_to_graph(smiles)
        
        if graph is not None:
            drug_graphs[drug_name] = {
                'drug_name': drug_name,
                'smiles': smiles,
                'graph_data': graph,
                'node_dim': graph.x.shape[1],
                'edge_dim': graph.edge_attr.shape[1] if graph.edge_attr is not None else 0,
                'num_atoms': graph.x.shape[0],
                'num_bonds': graph.edge_index.shape[1]
            }
        else:
            failed_drugs.append(drug_name)
    
    with open(save_path, 'wb') as f:
        pickle.dump(drug_graphs, f)
    
    print(f"Converted {len(drug_graphs)}/{len(drugs_df)} drugs to graphs")
    
    return drug_graphs

# Preprocess all drugs
output_dir = Path("/Users/tjalling/Desktop/Dev./Capstone/Model_Notebooks")
drug_graphs_path = output_dir / "drug_graphs.pkl"

drug_graphs = preprocess_drugs(
    drugs_with_smiles,
    drugs_with_smiles.columns[0],
    'SMILES',
    drug_graphs_path
)

Converting SMILES:   0%|          | 0/228 [00:00<?, ?it/s]

Converting SMILES: 100%|██████████| 228/228 [00:00<00:00, 363.80it/s]

Converted 228/228 drugs to graphs





In [None]:
# Test DrugEncoder

aspirin_smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
aspirin_graph = smiles_to_graph(aspirin_smiles)

if aspirin_graph is not None:
    encoder = DrugEncoder()
    encoder.eval()
    
    aspirin_graph.batch = torch.zeros(aspirin_graph.x.shape[0], dtype=torch.long)
    
    with torch.no_grad():
        embedding = encoder(aspirin_graph)
    
    assert embedding.shape == (1, 256), f"Expected shape (1, 256), got {embedding.shape}"
    assert not torch.isnan(embedding).any(), "Output contains NaN values"
    
    print(f"DrugEncoder test passed: output shape {embedding.shape}")

DrugEncoder test passed: output shape torch.Size([1, 256])


In [None]:
# Cell Encoder Architecture

class CellEncoder(nn.Module):
    """
    Feedforward neural network to encode cell line gene expression.
    Uses skip connection to preserve direct gene signal.
    """
    
    def __init__(self, input_dim=3000, hidden_dim=512, output_dim=256, dropout1=0.4, dropout2=0.3):
        """
        Initialize CellEncoder.
        
        Args:
            input_dim: Number of gene expression features
            hidden_dim: First hidden layer dimension
            output_dim: Output embedding dimension
            dropout1: Dropout rate for first layer
            dropout2: Dropout rate for second layer
        """
        super(CellEncoder, self).__init__()
        
        # Device selection
        if torch.backends.mps.is_available():
            self.device = torch.device('mps')
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        
        # Main pathway
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout1 = nn.Dropout(dropout1)
        
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.bn2 = nn.BatchNorm1d(output_dim)
        self.dropout2 = nn.Dropout(dropout2)
        
        # Skip connection
        self.skip = nn.Linear(input_dim, output_dim)
        
        self.relu = nn.ReLU()
        
        self.to(self.device)
    
    def forward(self, x):
        """
        Forward pass through cell encoder.
        
        Args:
            x: Gene expression tensor of shape (batch_size, 3000)
            
        Returns:
            Cell embeddings of shape (batch_size, 256)
        """
        x = x.to(self.device, non_blocking=True)
        
        # Main pathway
        out = self.fc1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout1(out)
        
        out = self.fc2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.dropout2(out)
        
        # Skip connection
        skip_out = self.skip(x)
        
        # Residual connection
        embedding = out + skip_out
        
        return embedding

In [None]:
# Input Normalization Utilities

def normalize_expression(expression_data, method='check'):
    """
    Check and normalize gene expression data if needed.
    
    Args:
        expression_data: Gene expression tensor or array (samples × genes)
        method: 'check' to analyze, 'zscore' to force normalization
        
    Returns:
        Normalized expression data
    """
    if isinstance(expression_data, pd.DataFrame):
        expression_data = expression_data.values
    
    if isinstance(expression_data, np.ndarray):
        expression_data = torch.tensor(expression_data, dtype=torch.float32)
    
    mean = expression_data.mean().item()
    std = expression_data.std().item()
    
    if method == 'check':
        if abs(mean) > 10 or std > 5:
            expression_data = (expression_data - expression_data.mean(dim=1, keepdim=True)) / (expression_data.std(dim=1, keepdim=True) + 1e-8)
            print(f"Applied row-wise normalization (mean={mean:.2f}, std={std:.2f} -> normalized)")
    
    elif method == 'zscore':
        expression_data = (expression_data - expression_data.mean(dim=1, keepdim=True)) / (expression_data.std(dim=1, keepdim=True) + 1e-8)
        print(f"Applied z-score normalization")
    
    return expression_data

In [None]:
# Test CellEncoder

np.random.seed(42)
sample_expression = np.random.normal(loc=2.5, scale=1.5, size=(10, 3000))
sample_expression = torch.tensor(sample_expression, dtype=torch.float32)

sample_expression = normalize_expression(sample_expression, method='check')

cell_encoder = CellEncoder(input_dim=3000, output_dim=256)
cell_encoder.eval()

with torch.no_grad():
    embeddings = cell_encoder(sample_expression)

assert embeddings.shape == (10, 256), f"Expected shape (10, 256), got {embeddings.shape}"
assert not torch.isnan(embeddings).any(), "Output contains NaN values"

print(f"CellEncoder test passed: output shape {embeddings.shape}")

CellEncoder test passed: output shape torch.Size([10, 256])


In [None]:
# Drug Response Prediction Model

class DrugResponseGNN(nn.Module):
    """
    Complete GNN model for drug response prediction with multi-task learning.
    Combines DrugEncoder and CellEncoder with three prediction heads.
    """
    
    def __init__(self, drug_node_dim=9, drug_edge_dim=4, cell_input_dim=3000, 
                 hidden_dim=256, dropout=0.3):
        """
        Initialize DrugResponseGNN.
        
        Args:
            drug_node_dim: Number of atom features
            drug_edge_dim: Number of bond features
            cell_input_dim: Number of gene expression features
            hidden_dim: Embedding dimension (256)
            dropout: Dropout rate
        """
        super(DrugResponseGNN, self).__init__()
        
        # Device selection
        if torch.backends.mps.is_available():
            self.device = torch.device('mps')
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
        
        # Encoders
        self.drug_encoder = DrugEncoder(
            node_feature_dim=drug_node_dim,
            edge_feature_dim=drug_edge_dim,
            hidden_dim=hidden_dim,
            dropout=dropout
        )
        
        self.cell_encoder = CellEncoder(
            input_dim=cell_input_dim,
            hidden_dim=512,
            output_dim=hidden_dim,
            dropout1=0.4,
            dropout2=0.3
        )
        
        # Attention mechanism for integration
        self.attention_query = nn.Linear(hidden_dim, hidden_dim)
        self.attention_key = nn.Linear(hidden_dim, hidden_dim)
        self.attention_value = nn.Linear(hidden_dim, hidden_dim)
        
        # Combined embedding dimension
        combined_dim = hidden_dim * 2  # 512
        
        # Head A: IC50 Regression (primary task)
        self.ic50_head = nn.Sequential(
            nn.Linear(combined_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )
        
        # Head B: Sensitivity Classification (auxiliary task)
        self.classification_head = nn.Sequential(
            nn.Linear(combined_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )
        
        # Head C: Similarity Reconstruction (auxiliary task)
        self.reconstruction_head = nn.Sequential(
            nn.Linear(combined_dim, 256),
            nn.ReLU(),
            nn.Linear(256, combined_dim)
        )
        
        self.to(self.device)
    
    def integrate_with_attention(self, drug_emb, cell_emb):
        """
        Integrate drug and cell embeddings using attention mechanism.
        
        Args:
            drug_emb: Drug embeddings (batch_size, 256)
            cell_emb: Cell embeddings (batch_size, 256)
            
        Returns:
            Combined embedding (batch_size, 512)
        """
        # Drug embedding as query
        query = self.attention_query(drug_emb)  # (batch, 256)
        key = self.attention_key(cell_emb)      # (batch, 256)
        value = self.attention_value(cell_emb)  # (batch, 256)
        
        # Compute attention scores
        attention_scores = torch.matmul(query.unsqueeze(1), key.unsqueeze(2))  # (batch, 1, 1)
        attention_scores = attention_scores / (drug_emb.size(-1) ** 0.5)
        attention_weights = torch.softmax(attention_scores, dim=-1)
        
        # Apply attention to cell embedding
        attended_cell = attention_weights.squeeze(-1) * value  # (batch, 1) * (batch, 256) -> (batch, 256)
        
        # Concatenate drug and attended cell embeddings
        combined = torch.cat([drug_emb, attended_cell], dim=1)  # (batch, 512)
        
        return combined
    
    def forward(self, drug_batch, cell_batch, return_embeddings=False):
        """
        Forward pass through complete model.
        
        Args:
            drug_batch: PyTorch Geometric batch of molecular graphs
            cell_batch: Gene expression tensor (batch_size, 3000)
            return_embeddings: Whether to return intermediate embeddings
            
        Returns:
            Dictionary with predictions and optionally embeddings
        """
        # Encode drug and cell
        drug_emb = self.drug_encoder(drug_batch)  # (batch_size, 256)
        cell_emb = self.cell_encoder(cell_batch)   # (batch_size, 256)
        
        # Integrate with attention
        combined = self.integrate_with_attention(drug_emb, cell_emb)  # (batch_size, 512)
        
        # Multi-task predictions
        ic50_pred = self.ic50_head(combined)
        class_pred = self.classification_head(combined)
        recon_pred = self.reconstruction_head(combined)
        
        results = {
            'ic50': ic50_pred,
            'classification': class_pred,
            'reconstruction': recon_pred
        }
        
        if return_embeddings:
            results['embeddings'] = {
                'drug': drug_emb,
                'cell': cell_emb,
                'combined': combined
            }
        
        return results

In [None]:
# Training Utilities

def compute_loss(predictions, targets, median_ic50, loss_weights=(1.0, 0.5, 0.3)):
    """
    Compute multi-task loss.
    
    Args:
        predictions: Dictionary with 'ic50', 'classification', 'reconstruction'
        targets: Dictionary with 'ic50', 'embeddings'
        median_ic50: Median IC50 value for classification
        loss_weights: Tuple of (ic50_weight, class_weight, recon_weight)
        
    Returns:
        total_loss: Combined loss
        loss_dict: Individual losses
    """
    # IC50 regression loss (MSE)
    ic50_loss = nn.functional.mse_loss(predictions['ic50'], targets['ic50'])
    
    # Classification loss (BCE)
    class_labels = (targets['ic50'] < median_ic50).float()
    class_loss = nn.functional.binary_cross_entropy_with_logits(
        predictions['classification'], 
        class_labels
    )
    
    # Reconstruction loss (MSE)
    target_embeddings = targets['embeddings']
    recon_loss = nn.functional.mse_loss(predictions['reconstruction'], target_embeddings)
    
    # Weighted combination
    total_loss = (loss_weights[0] * ic50_loss + 
                  loss_weights[1] * class_loss + 
                  loss_weights[2] * recon_loss)
    
    loss_dict = {
        'total': total_loss.item(),
        'ic50': ic50_loss.item(),
        'classification': class_loss.item(),
        'reconstruction': recon_loss.item()
    }
    
    return total_loss, loss_dict

def compute_metrics(y_true, y_pred):
    """
    Compute evaluation metrics.
    
    Args:
        y_true: True IC50 values (numpy array or tensor)
        y_pred: Predicted IC50 values (numpy array or tensor)
        
    Returns:
        Dictionary of metrics
    """
    # Handle both tensor and numpy inputs
    if hasattr(y_true, 'cpu'):
        y_true = y_true.cpu().numpy()
    if hasattr(y_pred, 'cpu'):
        y_pred = y_pred.cpu().numpy()
    
    y_true = np.asarray(y_true).flatten()
    y_pred = np.asarray(y_pred).flatten()
    
    pearson_corr, _ = pearsonr(y_true, y_pred)
    spearman_corr, _ = spearmanr(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    
    return {
        'pearson': pearson_corr,
        'spearman': spearman_corr,
        'r2': r2,
        'rmse': rmse,
        'mae': mae
    }

def print_model_summary(model):
    """
    Print model architecture summary.
    
    Args:
        model: DrugResponseGNN instance
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print("MODEL SUMMARY")
    print(f"Device: {model.device}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Memory footprint: ~{total_params * 4 / 1024**2:.2f} MB")
    print("\nArchitecture:")
    print("  DrugEncoder: SMILES → 256-dim")
    print("  CellEncoder: 3000 genes → 256-dim")
    print("  Attention Integration: 256 + 256 → 512-dim")
    print("  IC50 Head: 512 → 256 → 128 → 1")
    print("  Classification Head: 512 → 256 → 1")
    print("  Reconstruction Head: 512 → 256 → 512")

In [None]:
# Test Complete Pipeline

# Create model
model = DrugResponseGNN()
model.eval()

print(f"Model initialized on {model.device}\n")

# Test data
batch_size = 4

# Drug data: Create batch of aspirin graphs
aspirin_smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
test_graphs = []
for _ in range(batch_size):
    graph = smiles_to_graph(aspirin_smiles)
    if graph is not None:
        test_graphs.append(graph)

drug_batch = Batch.from_data_list(test_graphs)

# Cell data: Random expression
np.random.seed(42)
cell_data = np.random.normal(loc=2.5, scale=1.5, size=(batch_size, 3000))
cell_batch = torch.tensor(cell_data, dtype=torch.float32)

# Forward pass
with torch.no_grad():
    outputs = model(drug_batch, cell_batch, return_embeddings=True)

# Verify outputs
print("Output shapes:")
print(f"  IC50: {outputs['ic50'].shape}")
print(f"  Classification: {outputs['classification'].shape}")
print(f"  Reconstruction: {outputs['reconstruction'].shape}")
print(f"  Drug embedding: {outputs['embeddings']['drug'].shape}")
print(f"  Cell embedding: {outputs['embeddings']['cell'].shape}")
print(f"  Combined embedding: {outputs['embeddings']['combined'].shape}")

# Check for NaNs
has_nan = any(torch.isnan(outputs[key]).any() for key in ['ic50', 'classification', 'reconstruction'])
assert not has_nan, "Outputs contain NaN values"

# Verify shapes
assert outputs['ic50'].shape == (batch_size, 1), f"IC50 shape mismatch"
assert outputs['classification'].shape == (batch_size, 1), f"Classification shape mismatch"
assert outputs['reconstruction'].shape == (batch_size, 512), f"Reconstruction shape mismatch"

print("\nAll tests passed!")
print(f"IC50 predictions: {outputs['ic50'].squeeze().tolist()}")

# Model summary
print("\n" + "="*50)
print_model_summary(model)

Model initialized on mps

Output shapes:
  IC50: torch.Size([4, 1])
  Classification: torch.Size([4, 1])
  Reconstruction: torch.Size([4, 512])
  Drug embedding: torch.Size([4, 256])
  Cell embedding: torch.Size([4, 256])
  Combined embedding: torch.Size([4, 512])

All tests passed!
IC50 predictions: [0.04182592034339905, 0.007889561355113983, 0.05483219027519226, 0.030792269855737686]

MODEL SUMMARY
Device: mps
Total parameters: 31,053,827
Trainable parameters: 31,053,827
Memory footprint: ~118.46 MB

Architecture:
  DrugEncoder: SMILES → 256-dim
  CellEncoder: 3000 genes → 256-dim
  Attention Integration: 256 + 256 → 512-dim
  IC50 Head: 512 → 256 → 128 → 1
  Classification Head: 512 → 256 → 1
  Reconstruction Head: 512 → 256 → 512


In [None]:
# Test Complete Pipeline

# Create model
model = DrugResponseGNN()
model.eval()

print(f"Model initialized on {model.device}\n")

# Test data
batch_size = 4

# Drug data: Create batch of aspirin graphs
aspirin_smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"
test_graphs = []
for _ in range(batch_size):
    graph = smiles_to_graph(aspirin_smiles)
    if graph is not None:
        test_graphs.append(graph)

drug_batch = Batch.from_data_list(test_graphs)

# Cell data: Random expression
np.random.seed(42)
cell_data = np.random.normal(loc=2.5, scale=1.5, size=(batch_size, 3000))
cell_batch = torch.tensor(cell_data, dtype=torch.float32)

# Forward pass
with torch.no_grad():
    outputs = model(drug_batch, cell_batch, return_embeddings=True)

# Verify outputs
print("Output shapes:")
print(f"  IC50: {outputs['ic50'].shape}")
print(f"  Classification: {outputs['classification'].shape}")
print(f"  Reconstruction: {outputs['reconstruction'].shape}")
print(f"  Drug embedding: {outputs['embeddings']['drug'].shape}")
print(f"  Cell embedding: {outputs['embeddings']['cell'].shape}")
print(f"  Combined embedding: {outputs['embeddings']['combined'].shape}")

# Check for NaNs
has_nan = any(torch.isnan(outputs[key]).any() for key in ['ic50', 'classification', 'reconstruction'])
assert not has_nan, "Outputs contain NaN values"

# Verify shapes
assert outputs['ic50'].shape == (batch_size, 1), f"IC50 shape mismatch"
assert outputs['classification'].shape == (batch_size, 1), f"Classification shape mismatch"
assert outputs['reconstruction'].shape == (batch_size, 512), f"Reconstruction shape mismatch"

print("\nAll tests passed!")
print(f"IC50 predictions: {outputs['ic50'].squeeze().tolist()}")

# Model summary
print("\n" + "="*50)
print_model_summary(model)

Model initialized on mps

Output shapes:
  IC50: torch.Size([4, 1])
  Classification: torch.Size([4, 1])
  Reconstruction: torch.Size([4, 512])
  Drug embedding: torch.Size([4, 256])
  Cell embedding: torch.Size([4, 256])
  Combined embedding: torch.Size([4, 512])

All tests passed!
IC50 predictions: [-0.0026915408670902252, -0.057257648557424545, -0.028083961457014084, 0.014925692230463028]

MODEL SUMMARY
Device: mps
Total parameters: 31,053,827
Trainable parameters: 31,053,827
Memory footprint: ~118.46 MB

Architecture:
  DrugEncoder: SMILES → 256-dim
  CellEncoder: 3000 genes → 256-dim
  Attention Integration: 256 + 256 → 512-dim
  IC50 Head: 512 → 256 → 128 → 1
  Classification Head: 512 → 256 → 1
  Reconstruction Head: 512 → 256 → 512


In [None]:
# Three-Phase Transfer Learning Pipeline

# Setup directories
output_dir = Path("/Users/tjalling/Desktop/Dev./Capstone/Model_Notebooks")
models_dir = output_dir / "models"
results_dir = output_dir / "results"
prebatched_dir = output_dir / "prebatched_data"
models_dir.mkdir(exist_ok=True)
results_dir.mkdir(exist_ok=True)
prebatched_dir.mkdir(exist_ok=True)


class DrugResponseDataset(Dataset):
    """Dataset for drug-cell response prediction."""
    
    def __init__(self, dataframe, drug_graphs_dict, drug_col='DRUG_NAME'):
        """
        Initialize dataset.
        
        Args:
            dataframe: DataFrame with drug-cell pairs
            drug_graphs_dict: Dictionary mapping drug names to graph objects
            drug_col: Column name for drug identifier
        """
        self.data = dataframe.reset_index(drop=True)
        self.drug_graphs = drug_graphs_dict
        self.drug_col = drug_col
        
        # Extract gene expression columns
        self.gene_cols = [c for c in self.data.columns if '(' in c and ')' in c]
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Get drug graph from dictionary
        drug_name = row[self.drug_col]
        drug_entry = self.drug_graphs[drug_name]
        
        # Extract actual graph Data object
        if isinstance(drug_entry, dict) and 'graph_data' in drug_entry:
            drug_graph = drug_entry['graph_data']
        elif isinstance(drug_entry, Data):
            drug_graph = drug_entry
        else:
            raise ValueError(f"Unknown drug graph format for {drug_name}")
        
        # Get cell expression
        cell_expr = torch.tensor(
            row[self.gene_cols].values.astype(np.float32),
            dtype=torch.float32
        )
        
        # Get target
        ic50 = torch.tensor([row['LN_IC50']], dtype=torch.float32)
        
        return {
            'drug_graph': drug_graph,
            'cell_expr': cell_expr,
            'ic50': ic50,
            'drug_name': drug_name,
            'cell_id': row['ModelID']
        }


class PreBatchedDataset(Dataset):
    """Dataset that loads pre-computed batches directly."""
    
    def __init__(self, batch_file):
        """
        Initialize pre-batched dataset.
        
        Args:
            batch_file: Path to file containing pre-batched data
        """
        self.batches = torch.load(batch_file, weights_only=False)
        
    def __len__(self):
        return len(self.batches)
    
    def __getitem__(self, idx):
        return self.batches[idx]


def collate_fn(batch):
    """Custom collate function for drug graphs."""
    # Extract components from batch
    drug_graphs = [item['drug_graph'] for item in batch]
    cell_exprs = torch.stack([item['cell_expr'] for item in batch])
    ic50s = torch.stack([item['ic50'] for item in batch])
    
    # Batch drug graphs properly
    drug_batch = Batch.from_data_list(drug_graphs)
    
    return {
        'drug_batch': drug_batch,
        'cell_batch': cell_exprs,
        'ic50': ic50s
    }


def prebatched_collate_fn(batch):
    """Collate function for pre-batched data - just returns the batch as-is."""
    # batch is a list with a single pre-computed batch dictionary
    return batch[0]


def create_prebatched_data(dataset, batch_size, split_name, phase_name, shuffle=False):
    """
    Pre-compute and save batched data to disk.
    
    Args:
        dataset: DrugResponseDataset
        batch_size: Batch size
        split_name: Name of split (train/val/test)
        phase_name: Name of training phase (phase1/phase2/phase3)
        shuffle: Whether to shuffle before batching
        
    Returns:
        Path to saved batch file
    """
    indices = np.arange(len(dataset))
    if shuffle:
        np.random.seed(42)
        np.random.shuffle(indices)
    
    batches = []
    
    # Create batches
    for i in tqdm(range(0, len(indices), batch_size), desc=f"Pre-batching {split_name}"):
        batch_indices = indices[i:i+batch_size]
        
        # Get items for this batch
        batch_items = [dataset[idx] for idx in batch_indices]
        
        # Extract and batch components
        drug_graphs = [item['drug_graph'] for item in batch_items]
        cell_exprs = torch.stack([item['cell_expr'] for item in batch_items])
        ic50s = torch.stack([item['ic50'] for item in batch_items])
        
        # Batch drug graphs (the slow operation we're pre-computing)
        drug_batch = Batch.from_data_list(drug_graphs)
        
        # Store pre-computed batch
        batches.append({
            'drug_batch': drug_batch,
            'cell_batch': cell_exprs,
            'ic50': ic50s
        })
    
    # Save to disk
    batch_file = prebatched_dir / f"{phase_name}_{split_name}_batches.pt"
    torch.save(batches, batch_file)
    print(f"Saved {len(batches)} batches to {batch_file}")
    
    return batch_file


def create_dataloaders(dataset, batch_size, stratify_col=None, use_prebatched=False, phase_name=None):
    """
    Create train/val/test splits and dataloaders.
    
    Args:
        dataset: DrugResponseDataset
        batch_size: Batch size for training
        stratify_col: Column name for stratification
        use_prebatched: Whether to use pre-batched data
        phase_name: Name of phase for pre-batched data (phase1/phase2/phase3)
        
    Returns:
        Dictionary with train/val/test loaders (and datasets if not using prebatched)
    """
    if use_prebatched:
        # Load pre-batched data directly
        train_file = prebatched_dir / f"{phase_name}_train_batches.pt"
        val_file = prebatched_dir / f"{phase_name}_val_batches.pt"
        test_file = prebatched_dir / f"{phase_name}_test_batches.pt"
        
        train_dataset = PreBatchedDataset(train_file)
        val_dataset = PreBatchedDataset(val_file)
        test_dataset = PreBatchedDataset(test_file)
        
        # DataLoader with prebatched data (batch_size=1 since each item is already a batch)
        # num_workers=0 because data is already in memory and workers can't pickle notebook classes
        train_loader = DataLoader(
            train_dataset,
            batch_size=1,
            shuffle=True,
            collate_fn=prebatched_collate_fn,
            num_workers=0,
            pin_memory=False
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=prebatched_collate_fn,
            num_workers=0,
            pin_memory=False
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=prebatched_collate_fn,
            num_workers=0,
            pin_memory=False
        )
        
        print(f"Loaded pre-batched data: Train: {len(train_dataset)} batches | Val: {len(val_dataset)} batches | Test: {len(test_dataset)} batches")
        
        return {
            'train': train_loader,
            'val': val_loader,
            'test': test_loader
        }
    
    # Original approach: create datasets on-the-fly
    data = dataset.data
    
    # Try stratification, fall back to random if it fails
    stratify = None
    if stratify_col and stratify_col in data.columns:
        # Check if stratification is feasible (need at least 2 samples per class)
        value_counts = data[stratify_col].value_counts()
        if value_counts.min() >= 2:
            stratify = data[stratify_col]
        else:
            print(f"Warning: Cannot stratify by {stratify_col}, falling back to random split")
    
    # Split: 70% train, 15% val, 15% test
    try:
        train_idx, temp_idx = train_test_split(
            np.arange(len(data)),
            test_size=0.3,
            stratify=stratify,
            random_state=42
        )
        
        val_idx, test_idx = train_test_split(
            temp_idx,
            test_size=0.5,
            stratify=stratify.iloc[temp_idx] if stratify is not None else None,
            random_state=42
        )
    except ValueError:
        # Stratification failed, use random split
        print("Stratification failed, using random split")
        train_idx, temp_idx = train_test_split(
            np.arange(len(data)),
            test_size=0.3,
            random_state=42
        )
        val_idx, test_idx = train_test_split(
            temp_idx,
            test_size=0.5,
            random_state=42
        )
    
    # Create subset datasets
    train_data = data.iloc[train_idx]
    val_data = data.iloc[val_idx]
    test_data = data.iloc[test_idx]
    
    train_dataset = DrugResponseDataset(train_data, dataset.drug_graphs, dataset.drug_col)
    val_dataset = DrugResponseDataset(val_data, dataset.drug_graphs, dataset.drug_col)
    test_dataset = DrugResponseDataset(test_data, dataset.drug_graphs, dataset.drug_col)
    
    # If phase_name is provided, create pre-batched data
    if phase_name:
        print(f"\nCreating pre-batched data for {phase_name}...")
        create_prebatched_data(train_dataset, batch_size, 'train', phase_name, shuffle=False)
        create_prebatched_data(val_dataset, batch_size, 'val', phase_name, shuffle=False)
        create_prebatched_data(test_dataset, batch_size, 'test', phase_name, shuffle=False)
        print(f"Pre-batching complete! Rerun with use_prebatched=True to use cached data.\n")
    
    # Create loaders - use standard DataLoader with custom collate
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=8,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=8,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=8,
        pin_memory=True
    )
    
    print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")
    
    return {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader,
        'datasets': {
            'train': train_dataset,
            'val': val_dataset,
            'test': test_dataset
        }
    }


def evaluate_model(model, dataloader, device, median_ic50=None):
    """
    Evaluate model on a dataset.
    
    Args:
        model: DrugResponseGNN model
        dataloader: DataLoader
        device: torch device
        median_ic50: Median IC50 for classification threshold
        
    Returns:
        Dictionary of metrics and predictions
    """
    model.eval()
    
    all_preds = []
    all_targets = []
    total_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            # Move to device
            drug_batch = batch['drug_batch'].to(device)
            cell_batch = batch['cell_batch'].to(device)
            ic50_target = batch['ic50'].to(device)
            
            # Forward pass
            outputs = model(drug_batch, cell_batch)
            
            # Compute validation loss
            if median_ic50 is not None:
                outputs_with_emb = model(drug_batch, cell_batch, return_embeddings=True)
                targets = {
                    'ic50': ic50_target,
                    'embeddings': outputs_with_emb['embeddings']['combined'].detach()
                }
                loss, _ = compute_loss(outputs_with_emb, targets, median_ic50)
                total_loss += loss.item()
            
            # Collect predictions
            all_preds.append(outputs['ic50'].cpu())
            all_targets.append(ic50_target.cpu())
    
    all_preds = torch.cat(all_preds).numpy().flatten()
    all_targets = torch.cat(all_targets).numpy().flatten()
    
    # Compute metrics
    metrics = compute_metrics(all_targets, all_preds)
    metrics['loss'] = total_loss / len(dataloader)
    
    return metrics, all_preds, all_targets


def train_epoch(model, train_loader, optimizer, device, median_ic50, max_grad_norm=1.0):
    """
    Train for one epoch with multi-task learning.
    """
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    for batch in pbar:
        # Move to device
        drug_batch = batch['drug_batch'].to(device)
        cell_batch = batch['cell_batch'].to(device)
        ic50_target = batch['ic50'].to(device)
        
        # Forward pass with embeddings (compute once, not twice)
        outputs = model(drug_batch, cell_batch, return_embeddings=True)
        
        # Get embeddings for reconstruction target
        combined_emb = outputs['embeddings']['combined']
        
        targets = {
            'ic50': ic50_target,
            'embeddings': combined_emb.detach()
        }
        
        loss, loss_dict = compute_loss(outputs, targets, median_ic50)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(train_loader)


def train_phase(
    model,
    train_loader,
    val_loader,
    device,
    phase_name,
    num_epochs,
    lr,
    weight_decay,
    patience,
    scheduler_patience,
    checkpoint_path
):
    """
    Train one phase of the pipeline.
    
    Args:
        model: DrugResponseGNN
        train_loader: Training DataLoader
        val_loader: Validation DataLoader
        device: torch device
        phase_name: Name of phase for logging
        num_epochs: Number of epochs to train
        lr: Learning rate
        weight_decay: Weight decay
        patience: Early stopping patience
        scheduler_patience: LR scheduler patience
        checkpoint_path: Path to save best model
        
    Returns:
        Trained model and training history
    """
    # Setup optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999))
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.5,
        patience=scheduler_patience,
        min_lr=1e-6
    )
    
    # Compute median IC50 from training data
    all_ic50 = []
    for batch in train_loader:
        all_ic50.append(batch['ic50'])
    median_ic50 = torch.cat(all_ic50).median().item()
    
    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_r2': [],
        'val_pearson': [],
        'val_spearman': []
    }
    
    best_val_r2 = -np.inf
    epochs_without_improvement = 0
    
    print(f"\nStarting {phase_name}")
    print(f"Learning rate: {lr}, Weight decay: {weight_decay}")
    print(f"Median IC50: {median_ic50:.3f}\n")
    
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss = train_epoch(model, train_loader, optimizer, device, median_ic50)
        
        # Validate
        val_metrics, _, _ = evaluate_model(model, val_loader, device, median_ic50)
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_metrics['loss'])
        history['val_r2'].append(val_metrics['r2'])
        history['val_pearson'].append(val_metrics['pearson'])
        history['val_spearman'].append(val_metrics['spearman'])
        
        # Print metrics
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_metrics['loss']:.4f}")
        print(f"Val R²: {val_metrics['r2']:.4f} | Pearson: {val_metrics['pearson']:.4f} | "
              f"Spearman: {val_metrics['spearman']:.4f} | RMSE: {val_metrics['rmse']:.4f}")
        
        # Learning rate scheduling
        scheduler.step(val_metrics['r2'])
        
        # Check for improvement
        if val_metrics['r2'] > best_val_r2:
            best_val_r2 = val_metrics['r2']
            epochs_without_improvement = 0
            
            # Save checkpoint
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_r2': best_val_r2,
                'val_metrics': val_metrics,
                'history': history
            }, checkpoint_path)
            
            print(f"✓ Saved best model (R² = {best_val_r2:.4f})")
        else:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epoch(s)")
        
        # Early stopping
        if epochs_without_improvement >= patience:
            print(f"\nEarly stopping after {epoch+1} epochs")
            break
        
        print()
    
    # Load best model
    checkpoint = torch.load(checkpoint_path, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    print(f"Completed {phase_name}")
    print(f"Best validation R²: {checkpoint['val_r2']:.4f}\n")
    
    return model, history


def plot_training_history(history, phase_name, save_path):
    """Plot training history."""
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Val')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title(f'{phase_name} - Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # R²
    axes[1].plot(history['val_r2'])
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('R²')
    axes[1].set_title(f'{phase_name} - Validation R²')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


def save_predictions(preds, targets, phase_name, save_path):
    """Save predictions and create scatter plot."""
    # Save CSV
    df = pd.DataFrame({
        'predicted': preds,
        'actual': targets
    })
    csv_path = str(save_path).replace('.png', '.csv')
    df.to_csv(csv_path, index=False)
    
    # Create scatter plot
    plt.figure(figsize=(8, 8))
    plt.scatter(targets, preds, alpha=0.3, s=10)
    
    # Perfect prediction line
    min_val = min(targets.min(), preds.min())
    max_val = max(targets.max(), preds.max())
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Perfect prediction')
    
    # Calculate metrics
    from scipy.stats import pearsonr
    r, _ = pearsonr(targets, preds)
    r2 = r2_score(targets, preds)
    
    plt.xlabel('Actual LN_IC50')
    plt.ylabel('Predicted LN_IC50')
    plt.title(f'{phase_name} - Test Set\nR² = {r2:.3f}, Pearson r = {r:.3f}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()

print("Training pipeline functions defined")

Training pipeline functions defined


In [None]:
# Keep Mac Awake During Training

def keep_awake_wrapper(func):
    """
    Wrapper to keep Mac awake during long-running operations.
    Uses caffeinate to prevent system sleep.
    """
    def wrapper(*args, **kwargs):
        print("Starting caffeinate to prevent system sleep...")
        print("Your Mac will stay awake until training completes.\n")
        
        # Start caffeinate process
        # -i: Prevent idle sleep
        # -d: Prevent display sleep
        # -m: Prevent disk sleep
        caffeinate_process = subprocess.Popen(
            ['caffeinate', '-dims'],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL
        )
        
        try:
            # Run the actual training
            result = func(*args, **kwargs)
            return result
        finally:
            # Stop caffeinate when done
            caffeinate_process.terminate()
            caffeinate_process.wait()
            print("\n✓ Caffeinate stopped. System can now sleep normally.")
    
    return wrapper

print("Caffeinate wrapper ready - your Mac will stay awake during training!")


# =============================================================================
# PRE-BATCHING SETUP
# =============================================================================

def create_all_prebatched_data():
    """
    Create pre-batched data for all phases.
    Run this once, then use use_prebatched=True in training.
    """
    print("="*70)
    print("CREATING PRE-BATCHED DATA FOR ALL PHASES")
    print("="*70)
    print("This will take a few minutes but only needs to be done once.\n")
    
    drug_col = drug_name_col
    
    # Phase 1: Pan-cancer
    print("\n" + "="*70)
    print("PHASE 1: PAN-CANCER")
    print("="*70)
    pan_dataset = DrugResponseDataset(pan_cancer_data, drug_graphs, drug_col=drug_col)
    phase1_loaders = create_dataloaders(
        pan_dataset, 
        batch_size=256, 
        stratify_col='OncotreePrimaryDisease' if 'OncotreePrimaryDisease' in pan_cancer_data.columns else None,
        use_prebatched=False,
        phase_name='phase1'
    )
    
    # Phase 2: Breast/TNBC
    print("\n" + "="*70)
    print("PHASE 2: BREAST/TNBC")
    print("="*70)
    breast_dataset = DrugResponseDataset(breast_cancer_data, drug_graphs, drug_col=drug_col)
    phase2_loaders = create_dataloaders(
        breast_dataset,
        batch_size=128,
        stratify_col='ModelID',
        use_prebatched=False,
        phase_name='phase2'
    )
    
    # Phase 3: BL1 TNBC
    print("\n" + "="*70)
    print("PHASE 3: BL1 TNBC")
    print("="*70)
    bl1_dataset = DrugResponseDataset(bl1_tnbc_data, drug_graphs, drug_col=drug_col)
    phase3_loaders = create_dataloaders(
        bl1_dataset,
        batch_size=128,
        stratify_col='ModelID',
        use_prebatched=False,
        phase_name='phase3'
    )
    
    print("\n" + "="*70)
    print("✓ PRE-BATCHING COMPLETE!")
    print("="*70)
    print("Pre-batched data saved to:", prebatched_dir)
    print("\nNow run the training pipeline with use_prebatched=True for much faster training!")

Caffeinate wrapper ready - your Mac will stay awake during training!


In [None]:
# Execute Three-Phase Training Pipeline (with caffeinate to prevent sleep)

@keep_awake_wrapper
def run_full_training_pipeline(use_prebatched=False):
    """
    Run complete 3-phase training pipeline.
    Mac will stay awake during entire process.
    
    Args:
        use_prebatched: If True, use pre-batched data (much faster). 
                       If False, batch data on-the-fly (slower).
    """
    print(f"Loading datasets... (use_prebatched={use_prebatched})")
    
    # Get the drug column name from data
    drug_col = drug_name_col  # Already defined in earlier cells
    
    print(f"Pan-cancer: {len(pan_cancer_data)} pairs")
    print(f"Breast/TNBC: {len(breast_cancer_data)} pairs")
    print(f"BL1 TNBC: {len(bl1_tnbc_data)} pairs")
    print(f"Drug graphs: {len(drug_graphs)} drugs\n")
    
    # Device
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
    print(f"Using device: {device}\n")

    # =============================================================================
    # PHASE 1: PAN-CANCER PRE-TRAINING
    # =============================================================================
    print("="*70)
    print("PHASE 1: PAN-CANCER PRE-TRAINING")
    print("="*70)
    
    if use_prebatched:
        # Use pre-batched data
        phase1_loaders = create_dataloaders(
            None, 
            batch_size=256, 
            use_prebatched=True, 
            phase_name='phase1'
        )
        # Get n_genes from first batch
        first_batch = next(iter(phase1_loaders['train']))
        n_genes = first_batch['cell_batch'].shape[1]
    else:
        # Create dataset and loaders on-the-fly
        pan_dataset = DrugResponseDataset(pan_cancer_data, drug_graphs, drug_col=drug_col)
        n_genes = len(pan_dataset.gene_cols)
        
        stratify_col = 'OncotreePrimaryDisease' if 'OncotreePrimaryDisease' in pan_cancer_data.columns else 'ModelID'
        phase1_loaders = create_dataloaders(
            pan_dataset, 
            batch_size=256, 
            stratify_col=stratify_col,
            use_prebatched=False
        )
    
    print(f"Gene features: {n_genes}")
    
    # Initialize model with actual gene count
    model = DrugResponseGNN(cell_input_dim=n_genes)
    model = model.to(device)
    
    print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Train Phase 1
    model, phase1_history = train_phase(
        model=model,
        train_loader=phase1_loaders['train'],
        val_loader=phase1_loaders['val'],
        device=device,
        phase_name="Phase 1: Pan-Cancer",
        num_epochs=50,
        lr=1e-3,
        weight_decay=1e-5,
        patience=20,
        scheduler_patience=10,
        checkpoint_path=models_dir / "phase1_best.pt"
    )
    
    # Evaluate on test set
    print("Evaluating Phase 1 on test set...")
    phase1_test_metrics, phase1_preds, phase1_targets = evaluate_model(
        model, phase1_loaders['test'], device
    )
    
    print(f"Phase 1 Test Metrics:")
    print(f"  R²: {phase1_test_metrics['r2']:.4f}")
    print(f"  Pearson: {phase1_test_metrics['pearson']:.4f}")
    print(f"  Spearman: {phase1_test_metrics['spearman']:.4f}")
    print(f"  RMSE: {phase1_test_metrics['rmse']:.4f}")
    print(f"  MAE: {phase1_test_metrics['mae']:.4f}\n")
    
    # Save results
    plot_training_history(phase1_history, "Phase 1", results_dir / "phase1_history.png")
    save_predictions(phase1_preds, phase1_targets, "Phase 1", results_dir / "phase1_predictions.png")

    # =============================================================================
    # PHASE 2: BREAST/TNBC FINE-TUNING
    # =============================================================================
    print("="*70)
    print("PHASE 2: BREAST/TNBC FINE-TUNING")
    print("="*70)
    
    # Load best Phase 1 model
    checkpoint = torch.load(models_dir / "phase1_best.pt", weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded Phase 1 best weights\n")
    
    # Create dataset and loaders
    if use_prebatched:
        phase2_loaders = create_dataloaders(
            None,
            batch_size=128,
            use_prebatched=True,
            phase_name='phase2'
        )
    else:
        breast_dataset = DrugResponseDataset(breast_cancer_data, drug_graphs, drug_col=drug_col)
        phase2_loaders = create_dataloaders(
            breast_dataset,
            batch_size=128,
            stratify_col='ModelID',
            use_prebatched=False
        )
    
    # Train Phase 2
    model, phase2_history = train_phase(
        model=model,
        train_loader=phase2_loaders['train'],
        val_loader=phase2_loaders['val'],
        device=device,
        phase_name="Phase 2: Breast/TNBC",
        num_epochs=20,
        lr=1e-4,  # 10x lower
        weight_decay=1e-5,
        patience=15,
        scheduler_patience=5,
        checkpoint_path=models_dir / "phase2_best.pt"
    )
    
    # Evaluate on test set
    print("Evaluating Phase 2 on test set...")
    phase2_test_metrics, phase2_preds, phase2_targets = evaluate_model(
        model, phase2_loaders['test'], device
    )
    
    print(f"Phase 2 Test Metrics:")
    print(f"  R²: {phase2_test_metrics['r2']:.4f}")
    print(f"  Pearson: {phase2_test_metrics['pearson']:.4f}")
    print(f"  Spearman: {phase2_test_metrics['spearman']:.4f}")
    print(f"  RMSE: {phase2_test_metrics['rmse']:.4f}")
    print(f"  MAE: {phase2_test_metrics['mae']:.4f}\n")
    
    # Save results
    plot_training_history(phase2_history, "Phase 2", results_dir / "phase2_history.png")
    save_predictions(phase2_preds, phase2_targets, "Phase 2", results_dir / "phase2_predictions.png")

    # =============================================================================
    # PHASE 3: BL1 FINE-TUNING
    # =============================================================================
    print("="*70)
    print("PHASE 3: BL1 TNBC FINE-TUNING")
    print("="*70)
    
    # Load best Phase 2 model
    checkpoint = torch.load(models_dir / "phase2_best.pt", weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Loaded Phase 2 best weights\n")
    
    # Create dataset and loaders
    if use_prebatched:
        phase3_loaders = create_dataloaders(
            None,
            batch_size=128,
            use_prebatched=True,
            phase_name='phase3'
        )
    else:
        bl1_dataset = DrugResponseDataset(bl1_tnbc_data, drug_graphs, drug_col=drug_col)
        phase3_loaders = create_dataloaders(
            bl1_dataset,
            batch_size=128,
            stratify_col='ModelID',
            use_prebatched=False
        )
    
    # Train Phase 3
    model, phase3_history = train_phase(
        model=model,
        train_loader=phase3_loaders['train'],
        val_loader=phase3_loaders['val'],
        device=device,
        phase_name="Phase 3: BL1",
        num_epochs=20,
        lr=5e-5,  # 20x lower than Phase 1
        weight_decay=1e-5,
        patience=15,
        scheduler_patience=5,
        checkpoint_path=models_dir / "phase3_bl1_final.pt"
    )
    
    # Evaluate on test set
    print("Evaluating Phase 3 on test set...")
    phase3_test_metrics, phase3_preds, phase3_targets = evaluate_model(
        model, phase3_loaders['test'], device
    )
    
    print(f"Phase 3 Test Metrics:")
    print(f"  R²: {phase3_test_metrics['r2']:.4f}")
    print(f"  Pearson: {phase3_test_metrics['pearson']:.4f}")
    print(f"  Spearman: {phase3_test_metrics['spearman']:.4f}")
    print(f"  RMSE: {phase3_test_metrics['rmse']:.4f}")
    print(f"  MAE: {phase3_test_metrics['mae']:.4f}\n")
    
    # Save results
    plot_training_history(phase3_history, "Phase 3", results_dir / "phase3_history.png")
    save_predictions(phase3_preds, phase3_targets, "Phase 3", results_dir / "phase3_predictions.png")

    # =============================================================================
    # FINAL RESULTS SUMMARY
    # =============================================================================
    print("="*70)
    print("FINAL RESULTS SUMMARY")
    print("="*70)
    
    results_summary = pd.DataFrame({
        'Phase': ['Phase 1: Pan-Cancer', 'Phase 2: Breast/TNBC', 'Phase 3: BL1'],
        'Test R²': [
            phase1_test_metrics['r2'],
            phase2_test_metrics['r2'],
            phase3_test_metrics['r2']
        ],
        'Test Pearson': [
            phase1_test_metrics['pearson'],
            phase2_test_metrics['pearson'],
            phase3_test_metrics['pearson']
        ],
        'Test Spearman': [
            phase1_test_metrics['spearman'],
            phase2_test_metrics['spearman'],
            phase3_test_metrics['spearman']
        ],
        'Test RMSE': [
            phase1_test_metrics['rmse'],
            phase2_test_metrics['rmse'],
            phase3_test_metrics['rmse']
        ],
        'Test MAE': [
            phase1_test_metrics['mae'],
            phase2_test_metrics['mae'],
            phase3_test_metrics['mae']
        ]
    })
    
    print(results_summary.to_string(index=False))
    print()
    
    # Save summary
    results_summary.to_csv(results_dir / "final_results_summary.csv", index=False)
    
    print(f"All results saved to {results_dir}")
    print(f"All models saved to {models_dir}")
    print("\nTraining pipeline completed!")


# =============================================================================
# AUTOMATIC TRAINING PIPELINE
# =============================================================================
# Automatically creates pre-batched data if it doesn't exist, then trains.
# Just hit "Run All" and let it work!
# =============================================================================

def run_automatic_pipeline():
    """
    Automatically handle pre-batching and training.
    If pre-batched data doesn't exist, create it first.
    """
    # Check if pre-batched data exists
    phase1_train = prebatched_dir / "phase1_train_batches.pt"
    phase2_train = prebatched_dir / "phase2_train_batches.pt"
    phase3_train = prebatched_dir / "phase3_train_batches.pt"
    
    all_exist = phase1_train.exists() and phase2_train.exists() and phase3_train.exists()
    
    if not all_exist:
        print("=" * 70)
        print("PRE-BATCHED DATA NOT FOUND")
        print("=" * 70)
        print("Creating pre-batched data first (one-time setup, ~5-10 minutes)...")
        print("This will make all subsequent training runs 5-10x faster!\n")
        create_all_prebatched_data()
        print("\n" + "=" * 70)
        print("✓ PRE-BATCHING COMPLETE! Starting training...")
        print("=" * 70 + "\n")
    else:
        print("=" * 70)
        print("✓ PRE-BATCHED DATA FOUND")
        print("=" * 70)
        print("Using existing pre-batched data for fast training!\n")
    
    # Run training with pre-batched data
    run_full_training_pipeline(use_prebatched=True)

# Run the automatic pipeline
run_automatic_pipeline()

Starting caffeinate to prevent system sleep...
Your Mac will stay awake until training completes.

Loading datasets...


Pan-cancer: 149093 pairs
Breast/TNBC: 9601 pairs
BL1 TNBC: 4303 pairs
Drug graphs: 228 drugs

Using device: mps

PHASE 1: PAN-CANCER PRE-TRAINING
Train: 104365 | Val: 22364 | Test: 22364

Model parameters: 31,053,827


In [85]:
# =============================================================================
# RESUME FROM PHASE 2 (Phase 1 already completed)
# =============================================================================
# Just run this cell - it loads your saved Phase 1 model and continues training

# Setup paths
output_dir = Path("/Users/tjalling/Desktop/Dev./Capstone/Model_Notebooks")
models_dir = output_dir / "models"
results_dir = output_dir / "results"
prebatched_dir = output_dir / "prebatched_data"

# Device
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}\n")

# Load pre-batched data to get n_genes
phase1_batches = torch.load(prebatched_dir / "phase1_train_batches.pt", weights_only=False)
n_genes = phase1_batches[0]['cell_batch'].shape[1]
print(f"Gene features: {n_genes}")
del phase1_batches  # Free memory

# Initialize model
model = DrugResponseGNN(cell_input_dim=n_genes)
model = model.to(device)

# Load Phase 1 checkpoint
checkpoint = torch.load(models_dir / "phase1_best.pt", weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"✓ Loaded Phase 1 model (Best Val R²: {checkpoint['val_r2']:.4f})")
print(f"✓ Phase 1 completed at epoch {checkpoint['epoch']}\n")

# =============================================================================
# PHASE 2: BREAST/TNBC FINE-TUNING
# =============================================================================
print("="*70)
print("PHASE 2: BREAST/TNBC FINE-TUNING")
print("="*70)

# Load Phase 2 pre-batched data
phase2_train = torch.load(prebatched_dir / "phase2_train_batches.pt", weights_only=False)
phase2_val = torch.load(prebatched_dir / "phase2_val_batches.pt", weights_only=False)
phase2_test = torch.load(prebatched_dir / "phase2_test_batches.pt", weights_only=False)

print(f"Phase 2 data: Train {len(phase2_train)} | Val {len(phase2_val)} | Test {len(phase2_test)} batches")

# Compute median IC50
all_ic50 = torch.cat([b['ic50'] for b in phase2_train])
median_ic50_p2 = all_ic50.median().item()
print(f"Median IC50: {median_ic50_p2:.3f}\n")

# Setup optimizer
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, min_lr=1e-6)

best_val_r2 = -np.inf
patience_counter = 0
phase2_history = {'train_loss': [], 'val_r2': []}

print("Starting Phase 2 training...\n")
for epoch in range(20):
    # Train
    model.train()
    epoch_loss = 0
    for batch in tqdm(phase2_train, desc=f"Epoch {epoch+1}/20", leave=False):
        drug_batch = batch['drug_batch'].to(device)
        cell_batch = batch['cell_batch'].to(device)
        ic50_target = batch['ic50'].to(device)
        
        outputs = model(drug_batch, cell_batch, return_embeddings=True)
        targets = {'ic50': ic50_target, 'embeddings': outputs['embeddings']['combined'].detach()}
        loss, _ = compute_loss(outputs, targets, median_ic50_p2)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(phase2_train)
    phase2_history['train_loss'].append(avg_loss)
    
    # Validate
    model.eval()
    all_preds, all_targets = [], []
    with torch.no_grad():
        for batch in phase2_val:
            drug_batch = batch['drug_batch'].to(device)
            cell_batch = batch['cell_batch'].to(device)
            outputs = model(drug_batch, cell_batch)
            all_preds.append(outputs['ic50'].cpu())
            all_targets.append(batch['ic50'])
    
    preds = torch.cat(all_preds).numpy().flatten()
    targets = torch.cat(all_targets).numpy().flatten()
    val_r2 = r2_score(targets, preds)
    phase2_history['val_r2'].append(val_r2)
    
    print(f"Epoch {epoch+1}: Loss={avg_loss:.4f} | Val R²={val_r2:.4f}")
    
    scheduler.step(val_r2)
    
    # Save best model
    if val_r2 > best_val_r2:
        best_val_r2 = val_r2
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_r2': val_r2
        }, models_dir / "phase2_best.pt")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= 15:
            print("Early stopping!")
            break

# Load best Phase 2 model
checkpoint = torch.load(models_dir / "phase2_best.pt", weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"\n✓ Phase 2 complete! Best Val R²: {checkpoint['val_r2']:.4f}")

# =============================================================================
# PHASE 3: BL1 TNBC FINE-TUNING
# =============================================================================
print("\n" + "="*70)
print("PHASE 3: BL1 TNBC FINE-TUNING")
print("="*70)

# Load Phase 3 pre-batched data
phase3_train = torch.load(prebatched_dir / "phase3_train_batches.pt", weights_only=False)
phase3_val = torch.load(prebatched_dir / "phase3_val_batches.pt", weights_only=False)
phase3_test = torch.load(prebatched_dir / "phase3_test_batches.pt", weights_only=False)

print(f"Phase 3 data: Train {len(phase3_train)} | Val {len(phase3_val)} | Test {len(phase3_test)} batches")

# Compute median IC50
all_ic50 = torch.cat([b['ic50'] for b in phase3_train])
median_ic50_p3 = all_ic50.median().item()
print(f"Median IC50: {median_ic50_p3:.3f}\n")

# Setup optimizer (lower LR)
optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, min_lr=1e-6)

best_val_r2 = -np.inf
patience_counter = 0
phase3_history = {'train_loss': [], 'val_r2': []}

print("Starting Phase 3 training...\n")
for epoch in range(20):
    # Train
    model.train()
    epoch_loss = 0
    for batch in tqdm(phase3_train, desc=f"Epoch {epoch+1}/20", leave=False):
        drug_batch = batch['drug_batch'].to(device)
        cell_batch = batch['cell_batch'].to(device)
        ic50_target = batch['ic50'].to(device)
        
        outputs = model(drug_batch, cell_batch, return_embeddings=True)
        targets = {'ic50': ic50_target, 'embeddings': outputs['embeddings']['combined'].detach()}
        loss, _ = compute_loss(outputs, targets, median_ic50_p3)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(phase3_train)
    phase3_history['train_loss'].append(avg_loss)
    
    # Validate
    model.eval()
    all_preds, all_targets = [], []
    with torch.no_grad():
        for batch in phase3_val:
            drug_batch = batch['drug_batch'].to(device)
            cell_batch = batch['cell_batch'].to(device)
            outputs = model(drug_batch, cell_batch)
            all_preds.append(outputs['ic50'].cpu())
            all_targets.append(batch['ic50'])
    
    preds = torch.cat(all_preds).numpy().flatten()
    targets = torch.cat(all_targets).numpy().flatten()
    val_r2 = r2_score(targets, preds)
    phase3_history['val_r2'].append(val_r2)
    
    print(f"Epoch {epoch+1}: Loss={avg_loss:.4f} | Val R²={val_r2:.4f}")
    
    scheduler.step(val_r2)
    
    # Save best model
    if val_r2 > best_val_r2:
        best_val_r2 = val_r2
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_r2': val_r2
        }, models_dir / "phase3_bl1_final.pt")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= 15:
            print("Early stopping!")
            break

# Load best Phase 3 model
checkpoint = torch.load(models_dir / "phase3_bl1_final.pt", weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"\n✓ Phase 3 complete! Best Val R²: {checkpoint['val_r2']:.4f}")

# =============================================================================
# FINAL EVALUATION
# =============================================================================
print("\n" + "="*70)
print("TRAINING COMPLETE!")
print("="*70)
print(f"\nModels saved to: {models_dir}")
print("  - phase1_best.pt (Pan-cancer)")
print("  - phase2_best.pt (Breast/TNBC)")
print("  - phase3_bl1_final.pt (BL1 TNBC - FINAL)")
print("\n✓ All done! Your model is ready.")

Using device: mps



: 