In [1]:
import numpy as np
import pandas as pd
import os, sys, json

In [None]:
# os.chdir('..')

#### hetionet graph for g2g prior knowlege

In [None]:
nodes  = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/hetionet_public/hetionet-v1.0-nodes.tsv', sep='\t')
edges = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/hetionet_public/hetionet-v1.0-edges.sif', sep='\t')

node_map = dict(zip(nodes['id'], nodes['name']))

edges = pd.merge(edges, nodes.rename(columns={
    'id':'source', 'name':'source_name', 'kind':'source_type'}), on='source', how='left')

edges = pd.merge(edges, nodes.rename(columns={
    'id':'target', 'name':'target_name', 'kind':'target_type'}), on='target', how='left')

hetionet_g2g = edges.loc[(edges['source_type']=='Gene') & (edges['target_type']=='Gene')]
hetionet_g2g[['source_name', 'target_name']].rename(columns={
    'source_name':'source', 'target_name':'target'}).reset_index(drop=True).drop_duplicates().to_csv('gnn-depmap-project/src/data/depmap 25q2 public/hetionet_public/hetionet_g2g_edgelist.csv')

### depmap data processing

In [3]:
model = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/Model.csv')
model = model[['ModelID', 'StrippedCellLineName', 'SangerModelID', 'COSMICID', 'OncotreeLineage', 'Age', 'AgeCategory',
        'Sex', 'PatientRace', 'PrimaryOrMetastasis', 'ModelType']] 
model = model.loc[(model['ModelType']== 'Cell Line') & (model['AgeCategory']== 'Adult')]


expression = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/OmicsExpressionProteinCodingGenesTPMLogp1.csv')
expression.rename(columns={'Unnamed: 0':'ModelID'}, inplace=True)
expression.columns = expression.columns.map(lambda x: x.split(' (')[0].strip())
expression = expression.loc[expression['ModelID'].isin(model['ModelID'])]

mutations = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/OmicsSomaticMutationsMatrixDamaging.csv')
mutations.rename(columns={'Unnamed: 0':'ModelID'}, inplace=True)
mutations.columns = mutations.columns.map(lambda x: x.split(' (')[0].strip())
mutations = mutations.loc[mutations['ModelID'].isin(model['ModelID'])]

gene_effect = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/CRISPRGeneEffect.csv')
gene_effect.rename(columns={'Unnamed: 0':'ModelID'}, inplace=True)
gene_effect.columns = gene_effect.columns.map(lambda x: x.split(' (')[0].strip())
gene_effect = gene_effect.loc[gene_effect['ModelID'].isin(model['ModelID'])]

# filter on common model ID list
model_ids = set(mutations.ModelID) & set(expression.ModelID) & set(model.ModelID) # & set(gene_effect.ModelID)

model = model.loc[model['ModelID'].isin(model_ids)] #.to_csv('gnn-depmap-project/src/data/depmap 25q2 public/depmap_clinical_data.csv', index=False)
expression = expression.loc[expression['ModelID'].isin(model_ids)] #.to_csv('gnn-depmap-project/src/data/depmap 25q2 public/depmap_expression_data.csv', index=False)
mutations = mutations.loc[mutations['ModelID'].isin(model_ids)] #.to_csv('gnn-depmap-project/src/data/depmap 25q2 public/depmap_mutation_data.csv', index=False)
gene_effect = gene_effect.loc[gene_effect['ModelID'].isin(model_ids)] #.to_csv('gnn-depmap-project/src/data/depmap 25q2 public/depmap_gene_effect_data.csv', index=False)

In [4]:
## GDSC drug response data

gdsc1 = pd.read_excel('gnn-depmap-project/src/data/depmap 25q2 public/GDSC1_fitted_dose_response_27Oct23.xlsx')
gdsc2 = pd.read_excel('gnn-depmap-project/src/data/depmap 25q2 public/GDSC2_fitted_dose_response_27Oct23.xlsx')
model = model.loc[model['ModelID'].isin(model_ids)]

gdsc1 = gdsc1.loc[(gdsc1['COSMIC_ID'].isin(model['COSMICID'])) |
          (gdsc1['SANGER_MODEL_ID'].isin(model['SangerModelID']))]

gdsc2 = gdsc2.loc[(gdsc2['COSMIC_ID'].isin(model['COSMICID'])) |
          (gdsc2['SANGER_MODEL_ID'].isin(model['SangerModelID']))]

gdsc12 = pd.concat([gdsc1, gdsc2])[['COSMIC_ID', 'CELL_LINE_NAME', 'SANGER_MODEL_ID', 'TCGA_DESC', 'DRUG_NAME', 'PUTATIVE_TARGET', 'PATHWAY_NAME', 'LN_IC50']]

gdsc12 = gdsc12.loc[gdsc12['PUTATIVE_TARGET'].isin(expression.columns)]
gdsc12['CELL_LINE_NAME'] = gdsc12['CELL_LINE_NAME'].replace(',', '', regex=True).replace(' ', '', regex=True).replace(':', '', regex=True).replace('-', '', regex=True).str.upper()

gdsc12 = pd.merge(gdsc12, model[['ModelID', 'SangerModelID']].rename(columns={'SangerModelID':'SANGER_MODEL_ID'}), 
         on='SANGER_MODEL_ID', how='inner')

# gdsc12.to_csv('gnn-depmap-project/src/data/depmap 25q2 public/depmap_gdsc_drug_response.csv', index=False)

In [5]:
gdsc12 = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/depmap_gdsc_drug_response.csv')
model = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/depmap_clinical_data.csv')
expression = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/depmap_expression_data.csv', index_col=0)
expression = expression.loc[:, expression.sum() > 0]

mutations = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/depmap_mutation_data.csv', index_col=0)
gene_effect = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/depmap_gene_effect_data.csv', index_col=0)

hetionet_g2g = pd.read_csv('gnn-depmap-project/src/data/depmap 25q2 public/hetionet_public/hetionet_g2g_edgelist.csv', index_col=0)

In [41]:
# from sklearn.preprocessing import MinMaxScaler
# scaler = MinMaxScaler()
# expression.loc[:, expression.columns != 'ModelID'] = scaler.fit_transform(expression.loc[:, expression.columns != 'ModelID'])

In [21]:
# # cell line features
# model_features = pd.get_dummies(model[['OncotreeLineage', 'Age', 'AgeCategory', 'Sex', 'PatientRace',
#        'PrimaryOrMetastasis']], dtype='int')
# model_features['ModelID'] = model['ModelID']

# # break into train and test split
# train_model_ids = list(set(model_features['ModelID']).intersection(set(gdsc12['ModelID'])))
# test_model_ids = list(set(model_features['ModelID']) - set(train_model_ids))

# model_features = model_features.set_index('ModelID')
# train_model_features = model_features.loc[model_features.index.isin(train_model_ids)].reset_index()
# test_model_features = model_features.loc[model_features.index.isin(test_model_ids)].reset_index()

# # gene features
# gene_features = pd.DataFrame(expression.columns[1:].to_list(), columns=['gene']) #.set_index('gene') # is empty

# # drug features
# drug_features = pd.DataFrame(gdsc12['DRUG_NAME'].drop_duplicates().to_list(), columns=['drugs']) #.set_index('drugs') # is empty


In [7]:
expression_hivar = expression.loc[:, expression.var() > 2].columns.to_list()
gene_effect_hivar = gene_effect.loc[:, gene_effect.var() > 0.1].columns.to_list()
freq_muts = mutations.loc[:, mutations.sum() > 20].columns.to_list()
geneset_refined = list(set(expression_hivar + gene_effect_hivar + freq_muts + gdsc12['PUTATIVE_TARGET'].drop_duplicates().to_list()))

In [8]:
len(geneset_refined)

3919

In [None]:
# cell line features
model_features = pd.get_dummies(model[['OncotreeLineage', 'Age', 'AgeCategory', 'Sex', 'PatientRace',
       'PrimaryOrMetastasis']], dtype='int')
model_features = model_features['Age'].fillna(model_features['Age'].median())
model_features['ModelID'] = model['ModelID']

# break into train and test split
from sklearn.model_selection import train_test_split
train_model_ids, val_model_ids = train_test_split(list(set(model_features['ModelID']).intersection(set(gdsc12['ModelID']))), test_size=0.2)
# train_model_ids = list(set(model_features['ModelID']).intersection(set(gdsc12['ModelID'])))
test_model_ids = list(set(model_features['ModelID']) - set(train_model_ids))

model_features = model_features.set_index('ModelID')
train_model_features = model_features.loc[model_features.index.isin(train_model_ids)].reset_index()
val_model_features = model_features.loc[model_features.index.isin(val_model_ids)].reset_index()
test_model_features = model_features.loc[model_features.index.isin(test_model_ids)].reset_index()

# gene features
gene_features = pd.DataFrame(expression.columns[1:].to_list(), columns=['gene']) #.set_index('gene') # is empty

# drug features
drug_features = pd.DataFrame(gdsc12['DRUG_NAME'].drop_duplicates().to_list(), columns=['drugs']) #.set_index('drugs') # is empty

# drug edges
train_drug_to_gene = gdsc12[['DRUG_NAME', 'PUTATIVE_TARGET']].drop_duplicates()
val_drug_to_gene = gdsc12[['DRUG_NAME', 'PUTATIVE_TARGET']].drop_duplicates()
test_drug_to_gene = gdsc12[['DRUG_NAME', 'PUTATIVE_TARGET']].drop_duplicates() # LN_IC50 - response prediction

# gene edges
train_gene_to_gene = hetionet_g2g[['source', 'target']]
val_gene_to_gene = hetionet_g2g[['source', 'target']]
test_gene_to_gene = hetionet_g2g[['source', 'target']]


# cell line to drug: ln50 to link predict
train_cell_line_to_drug = gdsc12.loc[gdsc12['ModelID'].isin(train_model_ids), ['ModelID', 'DRUG_NAME',  'LN_IC50']] # LN_IC50 - response prediction
val_cell_line_to_drug = gdsc12.loc[gdsc12['ModelID'].isin(val_model_ids), ['ModelID', 'DRUG_NAME',  'LN_IC50']] # LN_IC50 - response prediction

from itertools import product
test_cell_line_to_drug = pd.DataFrame(columns=['ModelID', 'DRUG_NAME'], 
            data=product(list(test_model_ids), gdsc12['DRUG_NAME'].drop_duplicates().to_list()))
test_cell_line_to_drug['LN_IC50'] = np.nan # LN_IC50 - response prediction


# cell to gene: expression
train_cell_line_to_expression = pd.melt(expression.loc[train_model_ids].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='expression')
val_cell_line_to_expression = pd.melt(expression.loc[val_model_ids].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='expression')
test_cell_line_to_expression = pd.melt(expression.loc[test_model_ids].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='expression')

# cell to gene: mutation
train_cell_line_to_mutation = pd.melt(mutations.loc[train_model_ids].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='mutation')
train_cell_line_to_mutation = train_cell_line_to_mutation[train_cell_line_to_mutation['mutation'] == 1]

val_cell_line_to_mutation = pd.melt(mutations.loc[val_model_ids].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='mutation')
val_cell_line_to_mutation = val_cell_line_to_mutation[val_cell_line_to_mutation['mutation'] == 1]

test_cell_line_to_mutation = pd.melt(mutations.loc[test_model_ids].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='mutation')
test_cell_line_to_mutation = test_cell_line_to_mutation[test_cell_line_to_mutation['mutation'] == 1]

# cell to gene: effect
train_cell_line_to_gene_effect = pd.melt(gene_effect.loc[gene_effect.index.isin(train_model_ids)].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='gene_effect')
val_cell_line_to_gene_effect = pd.melt(gene_effect.loc[gene_effect.index.isin(val_model_ids)].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='gene_effect')
test_cell_line_to_gene_effect = pd.melt(gene_effect.loc[gene_effect.index.isin(test_model_ids)].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='gene_effect')

In [26]:
##### refined smaller version

# cell line features
model_features = pd.get_dummies(model[['OncotreeLineage', 'Age', 'AgeCategory', 'Sex', 'PatientRace',
       'PrimaryOrMetastasis']], dtype='int')
model_features['Age'] = model_features['Age'].fillna(model_features['Age'].median())
model_features['ModelID'] = model['ModelID']

# break into train and test split
from sklearn.model_selection import train_test_split

train_model_ids, val_model_ids = train_test_split(list(set(model_features['ModelID']).intersection(set(gdsc12['ModelID']))), test_size=0.2, random_state=42)
# train_model_ids = list(set(model_features['ModelID']).intersection(set(gdsc12['ModelID'])))
test_model_ids = list(set(model_features['ModelID']) - (set(model_features['ModelID']).intersection(set(gdsc12['ModelID']))))

model_features = model_features.set_index('ModelID')
train_model_features = model_features.loc[model_features.index.isin(train_model_ids)].reset_index()
val_model_features = model_features.loc[model_features.index.isin(val_model_ids)].reset_index()
test_model_features = model_features.loc[model_features.index.isin(test_model_ids)].reset_index()

# gene features
gene_features = pd.DataFrame(geneset_refined, columns=['gene']) #.set_index('gene') # is empty

# drug features
drug_features = pd.DataFrame(gdsc12['DRUG_NAME'].drop_duplicates().to_list(), columns=['drugs']) #.set_index('drugs') # is empty

# drug edges
train_drug_to_gene = gdsc12[['DRUG_NAME', 'PUTATIVE_TARGET']].drop_duplicates()
val_drug_to_gene = gdsc12[['DRUG_NAME', 'PUTATIVE_TARGET']].drop_duplicates()
test_drug_to_gene = gdsc12[['DRUG_NAME', 'PUTATIVE_TARGET']].drop_duplicates() # LN_IC50 - response prediction

# gene edges
train_gene_to_gene = hetionet_g2g[['source', 'target']]
val_gene_to_gene = hetionet_g2g[['source', 'target']]
test_gene_to_gene = hetionet_g2g[['source', 'target']]


# cell line to drug: ln50 to link predict
train_cell_line_to_drug = gdsc12.loc[gdsc12['ModelID'].isin(train_model_ids), ['ModelID', 'DRUG_NAME',  'LN_IC50']] # LN_IC50 - response prediction
val_cell_line_to_drug = gdsc12.loc[gdsc12['ModelID'].isin(val_model_ids), ['ModelID', 'DRUG_NAME',  'LN_IC50']] # LN_IC50 - response prediction

from itertools import product
test_cell_line_to_drug = pd.DataFrame(columns=['ModelID', 'DRUG_NAME'], 
            data=product(list(test_model_ids), gdsc12['DRUG_NAME'].drop_duplicates().to_list()))
test_cell_line_to_drug['LN_IC50'] = np.nan # LN_IC50 - response prediction


# cell to gene: expression
train_cell_line_to_expression = pd.melt(expression.loc[train_model_ids, list(expression.columns.intersection(set(geneset_refined)))].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='expression')
val_cell_line_to_expression = pd.melt(expression.loc[val_model_ids, list(expression.columns.intersection(set(geneset_refined)))].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='expression')
test_cell_line_to_expression = pd.melt(expression.loc[test_model_ids, list(expression.columns.intersection(set(geneset_refined)))].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='expression')

# cell to gene: mutation
train_cell_line_to_mutation = pd.melt(mutations.loc[train_model_ids, list(mutations.columns.intersection(set(geneset_refined)))].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='mutation')
train_cell_line_to_mutation = train_cell_line_to_mutation[train_cell_line_to_mutation['mutation'] == 1]

val_cell_line_to_mutation = pd.melt(mutations.loc[val_model_ids, list(mutations.columns.intersection(set(geneset_refined)))].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='mutation')
val_cell_line_to_mutation = val_cell_line_to_mutation[val_cell_line_to_mutation['mutation'] == 1]

test_cell_line_to_mutation = pd.melt(mutations.loc[test_model_ids, list(mutations.columns.intersection(set(geneset_refined)))].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='mutation')
test_cell_line_to_mutation = test_cell_line_to_mutation[test_cell_line_to_mutation['mutation'] == 1]

# cell to gene: effect
train_cell_line_to_gene_effect = pd.melt(gene_effect.loc[gene_effect.index.isin(train_model_ids), list(gene_effect.columns.intersection(set(geneset_refined)))].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='gene_effect')
val_cell_line_to_gene_effect = pd.melt(gene_effect.loc[gene_effect.index.isin(val_model_ids),  list(gene_effect.columns.intersection(set(geneset_refined)))].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='gene_effect')
test_cell_line_to_gene_effect = pd.melt(gene_effect.loc[gene_effect.index.isin(test_model_ids),  list(gene_effect.columns.intersection(set(geneset_refined)))].reset_index(), id_vars=['ModelID'], var_name='gene', value_name='gene_effect')

## Build PyTorch Geometric Heterogeneous Graph

This section constructs a heterogeneous graph with:
- **Node types**: Cell lines, Genes, Drugs
- **Edge types**: 
  - Cell line -> Gene (expression, mutation, gene effect)
  - Cell line -> Drug (drug response)
  - Drug -> Gene (drug target)
  - Gene -> Gene (gene-gene interactions)

In [None]:
# %pip install --upgrade torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cpu.html
import torch
from torch_geometric.data import HeteroData
import numpy as np
from src.data import build_depmap_heterogeneous_graph

    model_features,
    gene_features,
    drug_features,
    cell_line_to_expression,
    cell_line_to_mutation,
    cell_line_to_gene_effect,
    cell_line_to_drug,
    drug_to_gene,
    gene_to_gene
):
    """
    Build a heterogeneous graph from DepMap data (optimized version).
    
    Parameters
    --------
    model_features : pd.DataFrame
        Cell line features with 'ModelID' column
    gene_features : pd.DataFrame
        Gene features with 'gene' column
    drug_features : pd.DataFrame
        Drug features with 'drugs' column
    cell_line_to_expression : pd.DataFrame
        Edges: ['ModelID', 'gene', 'expression']
    cell_line_to_mutation : pd.DataFrame
        Edges: ['ModelID', 'gene', 'mutation']
    cell_line_to_gene_effect : pd.DataFrame
        Edges: ['ModelID', 'gene', 'gene_effect']
    cell_line_to_drug : pd.DataFrame
        Edges: ['ModelID', 'DRUG_NAME', 'LN_IC50']
    drug_to_gene : pd.DataFrame
        Edges: ['DRUG_NAME', 'PUTATIVE_TARGET']
    gene_to_gene : pd.DataFrame
        Edges: ['source_name', 'target_name']
        
    Returns
    ------------
    HeteroData
        PyTorch Geometric HeteroData object
    """
    print("Building heterogeneous graph (optimized)...")
    data = HeteroData()
    
    # ==================== Node Features ====================
    
    # Cell line nodes
    cell_line_ids = model_features['ModelID'].values
    cell_line_to_idx = {cell_id: idx for idx, cell_id in enumerate(cell_line_ids)}
    
    # Extract numeric features for cell lines
    numeric_features = model_features.drop('ModelID', axis=1).values
    data['cell_line'].x = torch.FloatTensor(numeric_features)
    data['cell_line'].num_nodes = len(cell_line_ids)
    data['cell_line'].node_ids = cell_line_ids.tolist()
    
    print(f"Cell line nodes: {data['cell_line'].num_nodes}")
    print(f"Cell line features: {data['cell_line'].x.shape}")
    
    # Gene nodes
    gene_ids = gene_features['gene'].values
    gene_to_idx = {gene: idx for idx, gene in enumerate(gene_ids)}
    
    # Create simple identity features for genes (can be replaced with embeddings)
    data['gene'].x = torch.eye(len(gene_ids))
    data['gene'].num_nodes = len(gene_ids)
    data['gene'].node_ids = gene_ids.tolist()
    
    print(f"Gene nodes: {data['gene'].num_nodes}")
    
    # Drug nodes
    drug_ids = drug_features['drugs'].values
    drug_to_idx = {drug: idx for idx, drug in enumerate(drug_ids)}
    
    # Create simple identity features for drugs (can be replaced with embeddings)
    data['drug'].x = torch.eye(len(drug_ids))
    data['drug'].num_nodes = len(drug_ids)
    data['drug'].node_ids = drug_ids.tolist()
    
    print(f"Drug nodes: {data['drug'].num_nodes}")
    
    # ==================== Cell Line -> Gene Edges (VECTORIZED) ====================
    
    # Expression edges
    print("\nBuilding cell line -> gene (expression) edges...")
    expression_edges = cell_line_to_expression.dropna(subset=['expression'])
    
    # Vectorized mapping using pandas map
    expression_edges['cell_idx'] = expression_edges['ModelID'].map(cell_line_to_idx)
    expression_edges['gene_idx'] = expression_edges['gene'].map(gene_to_idx)
    
    # Filter valid edges
    valid_mask = expression_edges['cell_idx'].notna() & expression_edges['gene_idx'].notna()
    expression_edges = expression_edges[valid_mask]
    
    if len(expression_edges) > 0:
        cell_indices = expression_edges['cell_idx'].values.astype(np.int64)
        gene_indices = expression_edges['gene_idx'].values.astype(np.int64)
        expression_values = expression_edges['expression'].values.astype(np.float32)
        
        data['cell_line', 'expresses', 'gene'].edge_index = torch.LongTensor(np.stack([cell_indices, gene_indices]))
        data['cell_line', 'expresses', 'gene'].edge_attr = torch.FloatTensor(expression_values).unsqueeze(1)
        print(f"Expression edges: {len(cell_indices)}")
    
    # Mutation edges
    print("Building cell line -> gene (mutation) edges...")
    mutation_edges = cell_line_to_mutation[cell_line_to_mutation['mutation'] == 1].copy()
    
    mutation_edges['cell_idx'] = mutation_edges['ModelID'].map(cell_line_to_idx)
    mutation_edges['gene_idx'] = mutation_edges['gene'].map(gene_to_idx)
    
    valid_mask = mutation_edges['cell_idx'].notna() & mutation_edges['gene_idx'].notna()
    mutation_edges = mutation_edges[valid_mask]
    
    if len(mutation_edges) > 0:
        cell_indices = mutation_edges['cell_idx'].values.astype(np.int64)
        gene_indices = mutation_edges['gene_idx'].values.astype(np.int64)
        
        data['cell_line', 'has_mutation', 'gene'].edge_index = torch.LongTensor(np.stack([cell_indices, gene_indices]))
        print(f"Mutation edges: {len(cell_indices)}")
    
    # Gene effect edges
    print("Building cell line -> gene (gene effect) edges...")
    gene_effect_edges = cell_line_to_gene_effect.dropna(subset=['gene_effect']).copy()
    
    gene_effect_edges['cell_idx'] = gene_effect_edges['ModelID'].map(cell_line_to_idx)
    gene_effect_edges['gene_idx'] = gene_effect_edges['gene'].map(gene_to_idx)
    
    valid_mask = gene_effect_edges['cell_idx'].notna() & gene_effect_edges['gene_idx'].notna()
    gene_effect_edges = gene_effect_edges[valid_mask]
    
    if len(gene_effect_edges) > 0:
        cell_indices = gene_effect_edges['cell_idx'].values.astype(np.int64)
        gene_indices = gene_effect_edges['gene_idx'].values.astype(np.int64)
        gene_effect_values = gene_effect_edges['gene_effect'].values.astype(np.float32)
        
        data['cell_line', 'has_gene_effect', 'gene'].edge_index = torch.LongTensor(np.stack([cell_indices, gene_indices]))
        data['cell_line', 'has_gene_effect', 'gene'].edge_attr = torch.FloatTensor(gene_effect_values).unsqueeze(1)
        print(f"Gene effect edges: {len(cell_indices)}")
    
    # ==================== Cell Line -> Drug Edges ====================
    
    print("\nBuilding cell line -> drug edges...")
    cell_drug_edges = cell_line_to_drug.copy()
    
    cell_drug_edges['cell_idx'] = cell_drug_edges['ModelID'].map(cell_line_to_idx)
    cell_drug_edges['drug_idx'] = cell_drug_edges['DRUG_NAME'].map(drug_to_idx)
    

    valid_mask = cell_drug_edges['cell_idx'].notna() & cell_drug_edges['drug_idx'].notna()
    cell_drug_edges = cell_drug_edges[valid_mask]
    
    if len(cell_drug_edges) > 0:
        cell_indices = cell_drug_edges['cell_idx'].values.astype(np.int64)
        drug_indices = cell_drug_edges['drug_idx'].values.astype(np.int64)
        ic50_values = cell_drug_edges['LN_IC50'].fillna(0.0).values.astype(np.float32)
        
        data['cell_line', 'treated_with', 'drug'].edge_index = torch.LongTensor(np.stack([cell_indices, drug_indices]))
        data['cell_line', 'treated_with', 'drug'].edge_attr = torch.FloatTensor(ic50_values).unsqueeze(1)

        print(f"Cell line -> drug edges: {len(cell_indices)}")
    
    # ==================== Drug -> Gene Edges ====================
    
    print("\nBuilding drug -> gene (target) edges...")
    drug_gene_edges = drug_to_gene.copy()
    
    drug_gene_edges['drug_idx'] = drug_gene_edges['DRUG_NAME'].map(drug_to_idx)
    drug_gene_edges['gene_idx'] = drug_gene_edges['PUTATIVE_TARGET'].map(gene_to_idx)
    
    valid_mask = drug_gene_edges['drug_idx'].notna() & drug_gene_edges['gene_idx'].notna()
    drug_gene_edges = drug_gene_edges[valid_mask]
    
    if len(drug_gene_edges) > 0:
        drug_indices = drug_gene_edges['drug_idx'].values.astype(np.int64)
        gene_indices = drug_gene_edges['gene_idx'].values.astype(np.int64)
        # ic50_values = drug_gene_edges['LN_IC50'].fillna(0.0).values.astype(np.float32)
        
        data['drug', 'targets', 'gene'].edge_index = torch.LongTensor(np.stack([drug_indices, gene_indices]))
        # data['drug', 'targets', 'gene'].edge_attr = torch.FloatTensor(ic50_values).unsqueeze(1)
        print(f"Drug -> gene edges: {len(drug_indices)}")
    
    # ==================== Gene -> Gene Edges ====================
    
    print("\nBuilding gene -> gene (interaction) edges...")
    gene_gene_edges = gene_to_gene.copy()
    
    gene_gene_edges['source_idx'] = gene_gene_edges['source'].map(gene_to_idx)
    gene_gene_edges['target_idx'] = gene_gene_edges['target'].map(gene_to_idx)
    
    valid_mask = gene_gene_edges['source_idx'].notna() & gene_gene_edges['target_idx'].notna()
    gene_gene_edges = gene_gene_edges[valid_mask]
    
    if len(gene_gene_edges) > 0:
        source_indices = gene_gene_edges['source_idx'].values.astype(np.int64)
        target_indices = gene_gene_edges['target_idx'].values.astype(np.int64)
        
        data['gene', 'interacts', 'gene'].edge_index = torch.LongTensor(np.stack([source_indices, target_indices]))
        print(f"Gene -> gene edges: {len(source_indices)}")
    
    # ==================== Add Reverse Edges ====================
    
    print("\nAdding reverse edges...")
    
    # Gene -> Cell line (reverse of expression)
    if hasattr(data['cell_line', 'expresses', 'gene'], 'edge_index'):
        edge_index = data['cell_line', 'expresses', 'gene'].edge_index
        data['gene', 'expressed_by', 'cell_line'].edge_index = torch.stack([edge_index[1], edge_index[0]])
        if hasattr(data['cell_line', 'expresses', 'gene'], 'edge_attr'):
            data['gene', 'expressed_by', 'cell_line'].edge_attr = data['cell_line', 'expresses', 'gene'].edge_attr
    
    # Gene -> Cell line (reverse of mutation)
    if hasattr(data['cell_line', 'has_mutation', 'gene'], 'edge_index'):
        edge_index = data['cell_line', 'has_mutation', 'gene'].edge_index
        data['gene', 'mutated_in', 'cell_line'].edge_index = torch.stack([edge_index[1], edge_index[0]])

    # Gene -> Cell line (reverse of gene effect)
    if hasattr(data['cell_line', 'has_gene_effect', 'gene'], 'edge_index'):
        edge_index = data['cell_line', 'has_gene_effect', 'gene'].edge_index
        data['gene', 'has_gene_effect_in', 'cell_line'].edge_index = torch.stack([edge_index[1], edge_index[0]])
    
    # Drug -> Cell line (reverse of treated_with)
    if hasattr(data['cell_line', 'treated_with', 'drug'], 'edge_index'):
        edge_index = data['cell_line', 'treated_with', 'drug'].edge_index
        data['drug', 'treats', 'cell_line'].edge_index = torch.stack([edge_index[1], edge_index[0]])
    
    # Gene -> Drug (reverse of targets)
    if hasattr(data['drug', 'targets', 'gene'], 'edge_index'):
        edge_index = data['drug', 'targets', 'gene'].edge_index
        data['gene', 'targeted_by', 'drug'].edge_index = torch.stack([edge_index[1], edge_index[0]])
        if hasattr(data['drug', 'targets', 'gene'], 'edge_attr'):
            data['gene', 'targeted_by', 'drug'].edge_attr = data['drug', 'targets', 'gene'].edge_attr
    
    print("\n" + "="*60)
    print("Graph construction complete!")
    print(f"Node types: {list(data.node_types)}")
    print(f"Edge types: {list(data.edge_types)}")
    print("="*60)
    
    return data

In [28]:
# Build the heterogeneous graph
train_hetero_graph = build_depmap_heterogeneous_graph(
    model_features=train_model_features,
    gene_features=gene_features,
    drug_features=drug_features,
    cell_line_to_expression=train_cell_line_to_expression,
    cell_line_to_mutation=train_cell_line_to_mutation,
    cell_line_to_gene_effect=train_cell_line_to_gene_effect,
    cell_line_to_drug=train_cell_line_to_drug,
    drug_to_gene=train_drug_to_gene,
    gene_to_gene=train_gene_to_gene
)

test_hetero_graph = build_depmap_heterogeneous_graph(
    model_features=test_model_features,
    gene_features=gene_features,
    drug_features=drug_features,
    cell_line_to_expression=test_cell_line_to_expression,
    cell_line_to_mutation=test_cell_line_to_mutation,
    cell_line_to_gene_effect=test_cell_line_to_gene_effect,
    cell_line_to_drug=test_cell_line_to_drug,
    drug_to_gene=test_drug_to_gene,
    gene_to_gene=test_gene_to_gene
)

val_hetero_graph = build_depmap_heterogeneous_graph(
    model_features=val_model_features,
    gene_features=gene_features,
    drug_features=drug_features,
    cell_line_to_expression=val_cell_line_to_expression,
    cell_line_to_mutation=val_cell_line_to_mutation,
    cell_line_to_gene_effect=val_cell_line_to_gene_effect,
    cell_line_to_drug=val_cell_line_to_drug,
    drug_to_gene=val_drug_to_gene,
    gene_to_gene=val_gene_to_gene
)

print("\n" + "="*60)
print("GRAPH SUMMARY")
print("="*60)
print("train graph")
print(train_hetero_graph)
print("\n" + "="*60)
print("val graph")
print(val_hetero_graph)
print("\n" + "="*60)
print("test graph")
print(test_hetero_graph)

Building heterogeneous graph (optimized)...
Cell line nodes: 463
Cell line features: torch.Size([463, 42])
Gene nodes: 3919
Drug nodes: 148

Building cell line -> gene (expression) edges...
Expression edges: 1812182
Building cell line -> gene (mutation) edges...
Mutation edges: 10308
Building cell line -> gene (gene effect) edges...
Gene effect edges: 1320047

Building cell line -> drug edges...
Cell line -> drug edges: 73076

Building drug -> gene (target) edges...
Drug -> gene edges: 148

Building gene -> gene (interaction) edges...
Gene -> gene edges: 53080

Adding reverse edges...

Graph construction complete!
Node types: ['cell_line', 'gene', 'drug']
Edge types: [('cell_line', 'expresses', 'gene'), ('cell_line', 'has_mutation', 'gene'), ('cell_line', 'has_gene_effect', 'gene'), ('cell_line', 'treated_with', 'drug'), ('drug', 'targets', 'gene'), ('gene', 'interacts', 'gene'), ('gene', 'expressed_by', 'cell_line'), ('gene', 'mutated_in', 'cell_line'), ('gene', 'has_gene_effect_in', 

In [29]:
# Save the graph
output_path = 'gnn-depmap-project/data/processed/depmap_hetero_graph_train_smaller.pt'
torch.save(train_hetero_graph, output_path)
output_path = 'gnn-depmap-project/data/processed/depmap_hetero_graph_val_smaller.pt'
torch.save(val_hetero_graph, output_path)
output_path = 'gnn-depmap-project/data/processed/depmap_hetero_graph_test_smaller.pt'
torch.save(test_hetero_graph, output_path)

print(f"\nGraph saved")


Graph saved


In [30]:
model_list = pd.DataFrame(model['ModelID'])
model_list['train'] = model_list['ModelID'].isin(train_model_ids).astype(int)
model_list['val'] = model_list['ModelID'].isin(val_model_ids).astype(int)
model_list['test'] = model_list['ModelID'].isin(test_model_ids).astype(int)
model_list.to_csv('gnn-depmap-project/src/data/depmap 25q2 public/model_ids_train_test_split_refined.csv', index=False)