In [None]:
import os
import ast
import sys
import yaml
import pickle
import argparse
sys.path.append("./")
sys.path.append("../")

import numpy as np
import pandas as pd
import networkx as nx
import seaborn as sns
import matplotlib.pyplot as plt

from data.process import LoadDataset
from utils_old.functions import LoadConfig
from DriverGenerater import getDriver_df
from data.Augmentation import mutation_anchored_subgraphs
from utils_old.table_utils import make_bin_cols, scaling_and_fillnafeature
from utils_old.graph_utils import loadGraph, merge_graph_attributes, get_node_att_value, map_att_to_node

In [None]:
def NoteBookParse():
    parser = argparse.ArgumentParser(description='Run DGI model')
    parser.add_argument('--config_path', type=str, default=os.path.join(os.getcwd(), '../config/run.yaml'), help='Path to the run configuration file')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')
    parser.add_argument('--wandb_key', type=str, default=None, help='Wandb API key')
    parser.add_argument('--entity_name', type=str, default='shmoon', help='Wandb entity name')
    parser.add_argument('--project_name', type=str, default='DeepResidueCluster', help='Wandb project name')
    parser.add_argument('--wandb_run_name', type=str, default='DGI', help='Wandb run name')
    parser.add_argument('--wandb_run_id', type=str, default=None, help='Wandb run id')
    parser.add_argument('--load_pretrained', action='store_true', help='Load pretrained model')
    parser.add_argument('--nowandb', action='store_true', help='Do not use wandb')
    
    args, unknown = parser.parse_known_args() 
    return args

In [None]:
ConfigPATH = '../config/run.yaml'
with open(ConfigPATH, 'r') as f:
    model_config = yaml.safe_load(f)
config = argparse.Namespace(**model_config)

In [None]:
G = loadGraph(os.path.join('../',config.Graph_PATH))
print(G)
nodes_in_G = list(G.nodes())

In [None]:
node_info_df = pd.DataFrame.from_dict(dict(G.nodes(data=True)), orient='index')
node_info_df = node_info_df.reset_index().rename(columns={'index': 'node_id'})

In [None]:
am_df = pd.read_csv(os.path.join(config.Feature_PATH, 'node_features_with_location_nodeid_am_annotated_per_node_v02092026.csv'))
am_df[~am_df['ptms_mapped'].isna()]['ptms_mapped']

In [None]:
def parse_ptm(x):
    if pd.isna(x) or x == '[]':
        return []
        # "['ac', 'for']" -> ['ac', 'for']
    return ast.literal_eval(x)
    
am_df['ptms_parsed'] = am_df['ptms_mapped'].apply(parse_ptm)

all_ptms = set()
for l in am_df['ptms_parsed']:
    all_ptms.update(l)
all_ptms = sorted(list(all_ptms))
print("Detected PTMs:", all_ptms)

In [None]:
dssp_cols = [col for col in am_df.columns if 'dssp' in col]
dssp_df = am_df[dssp_cols]
dssp_df.head(3)

In [None]:
basic_node_df = pd.read_csv(os.path.join(config.Feature_PATH, 'node_features.csv'))
am_node_df = pd.read_csv(os.path.join(config.Feature_PATH, 'node_features_with_location_nodeid_am_annotated_per_node_v02092026.csv'))

bmr_df = pd.read_csv(os.path.join(config.Feature_PATH, 'node_mutation_with_BMR_v120525.csv'))
bmr_df.drop(columns=['total_mutations_count', 'unique_mutation_types_count', 'unique_patients_count', 'uniprot_id', 'position'], inplace=True)

feat_df = pd.merge(basic_node_df, am_node_df, on='node_id', how='left')
feat_df = pd.merge(feat_df, bmr_df, on='node_id', how='left')
feat_df['avg_am_pathogenicity'].fillna(0, inplace=True)

# feat_df = scaling_and_fillnafeature(feat_df, config.table_features)

In [None]:
feat_df.info()

# Splited Graph Analysis

In [None]:
trainG = loadGraph('../DeepResidueCluster_train.pkl')
valG = loadGraph('../DeepResidueCluster_val.pkl')
testG = loadGraph('../DeepResidueCluster_test.pkl')
AugG = loadGraph('../DeepResidueCluster_train_aug.pkl')

In [None]:
cnt = 0
for g in testG:
    tempNode = g.nodes(data=True)
    for n, val in tempNode:
        if val['is_mut'] != 0:
            cnt += 1
            break

print("Total Graph", len(testG))
print("Mutated Graph", cnt)

# Connection (Edge) Analysis

In [None]:
# original_edge_df = pd.read_csv(edgepath)

In [None]:
# all_semicol_node_df = original_edge_df[(original_edge_df['uniprot1'].str.contains(';')) | (original_edge_df['uniprot2'].str.contains(';'))][['uniprot1', 'uniprot2', 'pdb_code', 'chain_flag']]
# both_semicol_node_df = original_edge_df[(original_edge_df['uniprot1'].str.contains(';')) & (original_edge_df['uniprot2'].str.contains(';'))][['uniprot1', 'uniprot2', 'pdb_code', 'chain_flag']]

In [None]:
# all_semicol_node_df.head(2)

In [None]:
# both_semicol_node_df['edge_key'] = list(zip(
#     np.minimum(both_semicol_node_df['uniprot1'], both_semicol_node_df['uniprot2']),
#     np.maximum(both_semicol_node_df['uniprot1'], both_semicol_node_df['uniprot2'])
# ))

# all_semicol_node_df['edge_key'] = list(zip(
#     np.minimum(all_semicol_node_df['uniprot1'], all_semicol_node_df['uniprot2']),
#     np.maximum(all_semicol_node_df['uniprot1'], all_semicol_node_df['uniprot2'])
# ))

In [None]:
# all_semicol_node_df.reset_index(drop=False, inplace=True)
# both_semicol_node_df.reset_index(drop=False, inplace=True)


In [None]:
# unique_index = all_semicol_node_df.index.difference(both_semicol_node_df.index)
# filtered_df = all_semicol_node_df.loc[unique_index]
# len(filtered_df)

In [None]:
# both_semicol_node_df.pdb_code.value_counts()

In [None]:
# all_semicol_node_df.chain_flag.value_counts()

In [None]:
# both_semicol_node_df.chain_flag.value_counts()

<!-- ## Graph Connection -->

In [None]:
graph_connections = nx.to_pandas_edgelist(G)
graph_connections

In [None]:
graph_connections.source = graph_connections.source.apply(lambda x: x.split('_')[0])
graph_connections.target = graph_connections.target.apply(lambda x: x.split('_')[0])

g_node1 = np.minimum(graph_connections['source'], graph_connections['target'])
g_node2 = np.maximum(graph_connections['source'], graph_connections['target'])
graph_connections['edge_key'] = list(zip(g_node1, g_node2))
graph_connections.drop_duplicates(subset=['edge_key'], inplace=True)
only_semicol_in_graph = pd.concat([graph_connections[graph_connections.target.str.contains(';')], graph_connections[graph_connections.source.str.contains(';')]])

In [None]:
only_semicol_in_graph

In [None]:
merged_df = pd.merge(
    only_semicol_in_graph,
    all_semicol_node_df[['edge_key', 'pdb_code', 'chain_flag']],
    on='edge_key',
    how='left'
)
merged_df.drop_duplicates(subset=['edge_key'], inplace=True)

In [None]:
merged_df.at[merged_df.index[0], 'pdb_code'] = '5z23'
merged_df.at[merged_df.index[0], 'chain_flag'] = 'inter-chain'

In [None]:
merged_df

<!-- ## Remove Unrelated Nodes -->

In [None]:
# removed_nodes = ['q92522;p10412', 'p06899;p04908', 'q92522;p10412', 'p49450;p68431', 'p49450;p68431-1']

In [None]:
# newGraphPATH =
# nodes_to_delete = [
#     n for n in G.nodes() 
#     if any(pattern in n for pattern in removed_nodes)
# ]

# G.remove_nodes_from(nodes_to_delete)

# with open(newGraphPATH, 'wb') as f:
#     pickle.dump(G, f)

In [None]:
# print(G)

In [None]:
# 

In [None]:
# 

# Graph Analysis

In [None]:
CCs = [G.subgraph(c).copy() for c in nx.connected_components(G)]

In [None]:
single_cc = []
non_mut_cc = []
non_mut_onlycopy_cc = []
non_mut_noncopy_cc = []
has_mut_cc = []

for i in range(len(CCs)):
    cc = CCs[i]
    df_cc = nx.to_pandas_edgelist(cc)
    df_cc['uniprot1'] = df_cc['source'].apply(lambda x: x.split('_')[0])
    df_cc['uniprot2'] = df_cc['target'].apply(lambda x: x.split('_')[0])
    prot = set(df_cc['uniprot1'].unique()).union(df_cc['uniprot2'].unique())
    
    if len(prot) == 1:
        single_cc.extend(list(prot))
        
        attG = map_att_to_node(cc, feat_df, use_cols=['from_copy', 'is_mut'])
        cnt_mut = get_node_att_value(attG, 'is_mut')
        sum_mut = sum(cnt_mut)
        if sum_mut == 0:
            non_mut_cc.append(i)
            cnt_copy = get_node_att_value(attG, 'from_copy')
            sum_copy = sum(cnt_copy)
            if sum_copy == attG.number_of_nodes():
                non_mut_onlycopy_cc.append(i)
            elif sum_copy == 0 :
                non_mut_noncopy_cc.append(i)
        else:
            has_mut_cc.append(i)

In [None]:
print("Total Number of Chain", len(CCs))
print("-----------------")
print("Single Protein CCs", len(single_cc), "(", round((len(single_cc)/len(CCs))*100, 2),'%)')
print("-----------------")
print("Mut CCs among one-chain only", len(has_mut_cc), "(", round((len(has_mut_cc)/len(CCs))*100, 2),'%)')
print("Non-Mut CCs among one-chain only", len(non_mut_cc), "(", round((len(non_mut_cc)/len(CCs))*100, 2),'%)')
print("-----------------")
print("Non-Mut only NonCopy CCs", len(non_mut_noncopy_cc), "(", round((len(non_mut_noncopy_cc)/len(CCs))*100, 2),'%)')
print("Non-Mut Only Copy CCs", len(non_mut_onlycopy_cc), "(", round((len(non_mut_onlycopy_cc)/len(CCs))*100, 2),'%)')

In [None]:
get_node_attr_value(G, 'degree')

# Node Features from Table

In [None]:
# df1 = 
# df2 = 
# df3 = 
# df2 = df2[df2['node_id'].isin(df1['node_id'])] 
# df3 = df3[df3['node_id'].isin(df1['node_id'])]

In [None]:
df3.columns

In [None]:
df1.info()

In [None]:
df1['hugo_symbol']

In [None]:
use_feat_in_df1 = ['BHAR880101', 'CHOP780201', 'GRAR740102', 'JANJ780101', 'KLEP840101', 'KYTJ820101']

fig, axs = plt.subplots(2, 3, figsize=(15, 6))

for i, feat in enumerate(use_feat_in_df1):
    axs[i // 3, i % 3].hist(df1[feat])
    axs[i // 3, i % 3].set_title(f'{feat} (Unique: {df1[feat].unique().shape[0]})')
    axs[i // 3, i % 3].set_xlabel('Value')
    axs[i // 3, i % 3].set_ylabel('Frequency')

plt.tight_layout()
plt.show()

In [None]:
mut_use_feat_in_df1 = ['unique_patients_count', 'total_mutations_count', 'unique_mutation_types_count', 'DAYM780301_avg', 'HENS920102_avg']

fig, axs = plt.subplots(2, 3, figsize=(15, 6))

for i, feat in enumerate(mut_use_feat_in_df1):
    axs[i // 3, i % 3].hist(df1[feat],)
    axs[i // 3, i % 3].set_title(f"{feat} (Unique: {df1[feat].nunique()})")
    axs[i // 3, i % 3].set_xlabel('Value')
    axs[i // 3, i % 3].set_yscale('log')
    axs[i // 3, i % 3].set_ylabel('Frequency (log)')

plt.tight_layout()
plt.show()

In [None]:
df1.copyindex.fillna(0, inplace=True)
nodes_in_df = df1.node_id.values
target_set = set(nodes_in_df)

In [None]:
skip_nodes = [n for n in nodes_in_G if n not in target_set]
skip_nodes

In [None]:
df3[df3.node_id.str.contains(';')].uniprot_id.unique()

In [None]:
df3[df3['node_id'].isin(df1['node_id'])]

# Connected Components Analysis

In [None]:
num_node_in_originG = [nx.subgraph(G, g).number_of_nodes() for g in list(nx.connected_components(G))]
num_node_in_originG = np.sort(num_node_in_originG)[::-1]

In [None]:
cc_dict = {'train': None, 'train_aug': None, 'val': None, 'test': None}

for name in ['train', 'val', 'test', 'train_aug']:
    with open(f'../DeepResidueCluster_{name}.pkl', 'rb') as f:
        cc_dict[name] = pickle.load(f)

In [None]:
list(cc_dict['train'][0].nodes(data=True))[0]

In [None]:
num_dict = {'train': [], 'train_aug': [], 'val': [], 'test': []}
mut_dict = {'train': [], 'train_aug': [], 'val': [], 'test': []}
for name, val in cc_dict.items():
    cnt = 0

    for g in val:
        num_dict[name].append(g.number_of_nodes())
        all_mut_in_subG = get_node_att_value(g,'is_mut')
        mut_cnt = sum(all_mut_in_subG)
        if mut_cnt > 0:
            cnt += 1

        mut_dict[name].append(mut_cnt)
    print(f"Total number of {name} graph included mutated Node", cnt)

## Mutation

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(8, 5))
for i, (name, mut_list) in enumerate(mut_dict.items()):
    mut_data = mut_dict[name]
    mut_data.sort()
    if 'train' in name:
        axs[i//2, i%2].hist(mut_data[::-1][2:])
    else:
        axs[i//2, i%2].hist(mut_data)
    axs[i//2, i%2].set_title(f'{name} ({len(mut_data)})')
    axs[i//2, i%2].set_xlabel('Mutation Count')
    axs[i//2, i%2].set_ylabel('Number of Graphs')
    axs[i//2, i%2].set_yscale('log')

plt.tight_layout()
plt.show()


## Number of Nodes

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(8, 5))
for i, (name, mut_list) in enumerate(num_dict.items()):
    mut_data = num_dict[name]
    mut_data.sort()
    if 'train' in name:
        axs[i//2, i%2].hist(mut_data[::-1][2:])
    else:
        axs[i//2, i%2].hist(mut_data)
    axs[i//2, i%2].set_title(name)
    axs[i//2, i%2].set_xlabel('Node Count')
    axs[i//2, i%2].set_ylabel('Number of Graphs')
    axs[i//2, i%2].set_yscale('log')

plt.tight_layout()
plt.show()


# Cancer Driver

In [None]:
MutaGenePATH = './reference/MutaGene_Benchmark.csv'
COSMICPATH = './reference/CosmicMutantExport.tsv.gz'
CHEMPATH = './reference/CHASMplus.xlsx'

PATHDict = {'MutaGene': MutaGenePATH,
            'COSMIC': COSMICPATH,
            'ChemPlus': CHEMPATH}

In [None]:
mutagene = getDriver_df(PATHDict, feat_df, score_th=None, reference_data='MutaGene')
cosmic = getDriver_df(PATHDict, feat_df, score_th=0.8, reference_data='COSMIC')
chemplus = getDriver_df(PATHDict, feat_df, score_th=None, reference_data='ChemPlus')

In [None]:
non_dup_df = pd.concat([cosmic, chemplus, mutagene], axis=0)
non_dup_df.drop_duplicates(subset=['position', 'residueType', 'node_id', 'mutability', 'is_driver'], inplace=True)
non_dup_df.is_driver.value_counts()

In [None]:
trainable_driver_df = pd.concat([cosmic, chemplus], axis=0)
trainable_driver_df.drop_duplicates(subset=['position', 'residueType', 'node_id', 'mutability', 'is_driver'], inplace=True)
trainable_driver_df.is_driver.value_counts()