In [1]:
import torch
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from matplotlib import pyplot as plt
import inspect, random, os
import numpy as np
import pandas as pd
from rdkit.Chem import AllChem
from rdkit import Chem
from multiprocessing import Pool
from rdkit.Chem import Draw, PandasTools
from torch_geometric.loader import DataLoader
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, precision_score, recall_score, roc_auc_score, average_precision_score, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import minmax_scale
from torch_geometric.nn.aggr import AttentionalAggregation
import torch
from modules.som_dataset import CustomDataset
# from models.som_models import GNNSOM
from modules.som_models import GNNSOM
from utils import validation
from tabulate import tabulate
from tqdm import tqdm
import numpy as np
from scipy import stats
from collections import defaultdict
from glob import glob
import json
import warnings
warnings.filterwarnings('ignore', '.*Sparse CSR tensor support is in beta state.*')

In [2]:
import json

def save_json(json_path, dict_file):
    with open(json_path, 'w') as file:
        json.dump(dict_file, file)

def load_json(json_path):
    with open(json_path, "r") as st_json:
        json_obj = json.load(st_json)
    return json_obj   

In [3]:
def confidence_interval(data, confidence=0.95):
    n = len(data)
    m = np.mean(data)
    se = stats.sem(data)
    h = se * stats.t.ppf((1 + confidence) / 2., n - 1)
            
    lower_bound = max(m - h, 0)
    mean = max(m, 0)
    upper_bound = max(m + h, 0)
    return lower_bound, mean, upper_bound

In [4]:
def to_csv(path, table_data, headers):
    score_logs = tabulate(table_data, headers, tablefmt="tsv")
    text_file=open(path,"w")
    text_file.write(score_logs)
    text_file.close()
    df = pd.read_csv(path, sep='\t')
    df.to_csv(path, index=None)

def get_logs(scores, cyp_list, args):        
    logs = ''
    table_data1 = []
    table_data2 = []
    table_data3 = []
    table_data4 = []
    for cyp in cyp_list:        
        headers1 = ['CYP', 
                   'auc_subs', 'apc_subs', 'f1s_subs', 'rec_subs', 'prc_subs', 'n_subs',                    
                   ]
        
        headers2 = ['CYP',
                    'jac_bond', 'f1s_bond', 'apc_bond', 'n_bond',
                    'jac_spn', 'f1s_spn', 'apc_spn', 'n_spn',                    
                    'jac_som', 'f1s_som', 'apc_som', 'n_som',
                    ]   
        headers3 = ['CYP', 
                    'jac_oxi', 'f1s_oxi', 'apc_oxi', 'n_oxi',
                    'jac_clv', 'f1s_clv', 'apc_clv',  'n_clv',
                    'jac_hdx', 'f1s_hdx', 'apc_hdx', 'n_hdx',
                    'jac_rdc', 'f1s_rdc', 'apc_rdc',  'n_rdc', 

                    ]        
        headers4 = ['CYP',
                    'n_subs', 'n_bond', 'n_spn', 'n_hdx', 'n_oxi', 'n_clv', 'n_rdc', 'n_som'
                    ]          
        row1, row2, row3, row4 = [cyp], [cyp], [cyp], [cyp]
        for header in headers1[1:]:
            if 'loss' in header or header[:2] == 'n_':
                row1.append(scores[cyp][header])
            else:
                row1.append(scores[cyp][args.th][header])
                
                
        for header in headers2[1:]:
            if 'loss' in header or header[:2] == 'n_':
                row2.append(scores[cyp][header])
            else:
                row2.append(scores[cyp][args.th][header])
        
        for header in headers3[1:]:
            if 'loss' in header or header[:2] == 'n_':
                row3.append(scores[cyp][header])
            else:
                row3.append(scores[cyp][args.th][header])

        for header in headers4[1:]:
            if 'loss' in header or header[:2] == 'n_':
                row4.append(scores[cyp][header])
            else:
                row4.append(scores[cyp][args.th][header])

        table_data1.append(row1)
        table_data2.append(row2)
        table_data3.append(row3)        
        table_data4.append(row4)        

    logs += (tabulate(table_data1, headers1, tablefmt="grid", floatfmt=".4f") + '\n')
    logs += (tabulate(table_data2, headers2, tablefmt="grid", floatfmt=".4f") + '\n')
    logs += (tabulate(table_data3, headers3, tablefmt="grid", floatfmt=".4f") + '\n')
    logs += tabulate(table_data4, headers4, tablefmt="grid", floatfmt=".4f")

    return logs

In [5]:
def CYP_REACTION(x):
    cyp_col = ['BOM_1A2', 'BOM_2A6', 'BOM_2B6', 'BOM_2C8', 'BOM_2C9', 'BOM_2C19', 'BOM_2D6', 'BOM_2E1', 'BOM_3A4']
    cyp_reactions = x[cyp_col].tolist()
    cyp_reactions = [i for i in cyp_reactions if i] 
    return '\n'.join( cyp_reactions )

In [6]:
metrics = [
    'auc_subs', 'apc_subs', 'f1s_subs', 'rec_subs', 'prc_subs',
    'jac_bond', 'f1s_bond', 'prc_bond', 'rec_bond', 'auc_bond', 'apc_bond',
    'jac_spn', 'f1s_spn', 'prc_spn', 'rec_spn', 'auc_spn', 'apc_spn',
    'jac_hdx', 'f1s_hdx', 'prc_hdx', 'rec_hdx', 'auc_hdx', 'apc_hdx',
    'jac_oxi', 'f1s_oxi', 'prc_oxi', 'rec_oxi', 'auc_oxi', 'apc_oxi',
    'jac_clv', 'f1s_clv', 'prc_clv', 'rec_clv', 'auc_clv', 'apc_clv',    
    'jac_rdc', 'f1s_rdc', 'prc_rdc', 'rec_rdc', 'auc_rdc', 'apc_rdc',
    'jac_som', 'f1s_som', 'prc_som', 'rec_som', 'auc_som', 'apc_som',]


In [7]:
cyp_list = [f'BOM_{i}' for i in  '1A2 2A6 2B6 2C8 2C9 2C19 2D6 2E1 3A4'.split()] + ['CYP_REACTION']

In [8]:
class CONFIG:
    substrate_loss_weight = 0.33    
    bond_loss_weight = 0.33
    atom_loss_weight = 0.33
    som_type_loss_weight = 0.33
    class_type = 2    
    th = 0.1    
    substrate_th = 0.5
    adjust_substrate = False
    average = 'binary'    
    equivalent_bonds_mean = True
    train_only_spn_H_atom = False    
    device = 'cuda:0'
    test_only_reaction_mol = False
    drop_node_p = 0.0
    mask_node_p = 0.0
    filt_som = 0
    equivalent_mean = False
    reduction = 'sum'
    n_classes = 5
args = CONFIG()



In [9]:
model = GNNSOM(
            num_layers=2,
            gnn_num_layers = 8,
            pooling='sum',
            dropout=0.1, 
            cyp_list=cyp_list, 
            use_face = True, 
            node_attn = True,
            face_attn = True,            
            n_classes=args.n_classes,
            use_som_v2=True
            ).to('cuda:0')

In [10]:
df = PandasTools.LoadSDF('data/train_nonreact_0611.sdf')
df['CYP_REACTION'] = df.apply(CYP_REACTION, axis=1)
df['POS_ID'] = 'TRAIN' + df.index.astype(str).str.zfill(4)

# train_dataset = CustomDataset(df=df, class_type=2, cyp_list=cyp_list, args=args, add_H=True)
# train_loader = DataLoader(train_dataset, num_workers=2, batch_size=16, shuffle=False)

In [11]:
test_df = PandasTools.LoadSDF('data/test_0611.sdf')
test_df['CYP_REACTION'] = test_df.apply(CYP_REACTION, axis=1)
test_df['POS_ID'] = 'TEST' + test_df.index.astype(str).str.zfill(4)

# test_df = test_df[test_df['InChIKey'] != ''].reset_index(drop=True)

test_dataset = CustomDataset(df=test_df, cyp_list=cyp_list, args=args)
test_loader = DataLoader(test_dataset, num_workers=2, batch_size=8, shuffle=False)

loss_fn_ce, loss_fn_bce = torch.nn.CrossEntropyLoss(), torch.nn.BCEWithLogitsLoss()

Processing...
Done!


In [12]:
loss_fn_bce = torch.nn.BCEWithLogitsLoss()
loss_fn_ce = torch.nn.CrossEntropyLoss()

### 신뢰구간별

In [13]:
def cal_scores(validator, args):
    scores = {}

    if args.adjust_substrate:
        validator.adjust_substrate(args.substrate_th)

    tasks = ['subs', 'bond', 'atom',  'spn', 'hdx', 'clv', 'oxi', 'rdc', 'som']
    metrics = ['jac','f1s','prc','rec','auc', 'apc']

    if args.equivalent_mean:
        validator.unbatch()
        validator.eq_mean()

    for cyp in model.cyp_list:
        scores[cyp] = {}        
        
        for task in tasks:
            y_true, y_prob = validator.get_probs(task, cyp)
            scores[cyp][f'n_{task}'] = f'{int(sum(y_true))} / {len(y_true)}'

            if task == 'subs':
                task_scores = validator.get_scores(task=task, cyp=cyp, average=args.average, th=args.substrate_th)
            else:
                task_scores = validator.get_scores(task=task, cyp=cyp, average=args.average, th=args.th)

            for mname, tscore in zip(metrics, task_scores):
                scores[cyp][f'{mname}_{task}'] = tscore

    return scores

In [14]:
# args.ckpt = f'ckpt/0_reduction_sum_v2_dout01.pt'

In [15]:
for seed in tqdm(range(30)):        
    args.ckpt = f'ckpt/{seed}.pt'    
    # args.ckpt = f'ckpt/29_all_bond_warmup.pt'
    
    model.load_state_dict(torch.load(args.ckpt, 'cpu'))
    
    args.add_H = True

    seed_df = []
    
    test_scores = validation(model, test_loader, loss_fn_ce, loss_fn_bce, args)
    validator = test_scores['validator'] 
    for th in [0.1, 0.15, 0.2, 0.3]:
        args.th = th
        scores = cal_scores(validator, args)
        for cyp in cyp_list:
            for metric in metrics:
                seed_df.append({'cyp' : cyp, 'seed' : seed, 'metric' : metric, 'score' : scores[cyp][metric], 'threshold' : th})

    seed_df = pd.DataFrame(seed_df)
    seed_df.to_csv(f'infer/seed/{seed}.csv', index=None)
    

100%|██████████| 30/30 [03:55<00:00,  7.84s/it]


In [16]:
seed_df_dict = {}
for seed in range(30):    
    seed_df = pd.read_csv(f'infer/seed/{seed}.csv', index_col=None)
    seed_df_dict[seed] = seed_df
    

In [17]:
metrics_type = [
    ( 'substrate',['auc_subs', 'apc_subs', 'f1s_subs', 'rec_subs', 'prc_subs', 'n_subs']),
    ( 'bond' ,['jac_bond', 'f1s_bond', 'prc_bond', 'rec_bond', 'auc_bond', 'apc_bond', 'n_bond']),    
    ( 'SPN-Oxidation' ,['jac_spn', 'f1s_spn', 'prc_spn', 'rec_spn', 'auc_spn', 'apc_spn', 'n_spn']),

    ( 'Hydroxylation' ,['jac_hdx', 'f1s_hdx', 'prc_hdx', 'rec_hdx', 'auc_hdx', 'apc_hdx', 'n_hdx']),
    ( 'Oxidation' ,['jac__oxi', 'f1s__oxi', 'prc__oxi', 'rec__oxi', 'auc__oxi', 'apc__oxi', 'n__oxi']),
    ( 'Cleavage' ,['jac_clv', 'f1s_clv', 'prc_clv', 'rec_clv', 'auc_clv', 'apc_clv', 'n_clv']),
    ( 'Reducion' ,['jac_nn_rdc', 'f1s_nn_rdc', 'prc_nn_rdc', 'rec_nn_rdc', 'auc_nn_rdc', 'apc_nn_rdc', 'n_nn_rdc']),    
    ( 'Site-of-metabolism' ,['jac_som', 'f1s_som', 'prc_som', 'rec_som', 'auc_som', 'apc_som', 'n_som'],   ) 
    ]


In [24]:
# 시드 30개의 신뢰구간 성능
th = 0.15
args.th = th

ci_scores = scores.copy()

for cyp in tqdm(cyp_list):
    ci_scores[cyp][th] = {}
    for metric in metrics:
        score_value = []
        for seed in seed_df_dict.keys():
            score_value.append(seed_df_dict[seed].query('cyp==@cyp').query('metric==@metric').query('threshold==@th').score.item())
        s, m, l = confidence_interval(score_value)
        ci_scores[cyp][th][metric] = m
        # ci_scores[cyp][th][metric] =  f'{m:.4f}({s:.2f}-{l:.2f})'
print(get_logs(ci_scores, cyp_list, args))

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:22<00:00,  2.22s/it]

+--------------+------------+------------+------------+------------+------------+-----------+
| CYP          |   auc_subs |   apc_subs |   f1s_subs |   rec_subs |   prc_subs | n_subs    |
| BOM_1A2      |     0.7396 |     0.5233 |     0.5188 |     0.5758 |     0.4805 | 33 / 143  |
+--------------+------------+------------+------------+------------+------------+-----------+
| BOM_2A6      |     0.7392 |     0.3222 |     0.0300 |     0.0176 |     0.1394 | 17 / 143  |
+--------------+------------+------------+------------+------------+------------+-----------+
| BOM_2B6      |     0.6336 |     0.2322 |     0.0153 |     0.0100 |     0.0345 | 20 / 143  |
+--------------+------------+------------+------------+------------+------------+-----------+
| BOM_2C8      |     0.5750 |     0.2483 |     0.0322 |     0.0218 |     0.1231 | 29 / 143  |
+--------------+------------+------------+------------+------------+------------+-----------+
| BOM_2C9      |     0.6114 |     0.3515 |     0.3146 |     




In [23]:
# seed 평균 저장 (시각화 용)

for th in [0.1, 0.15, 0.2, 0.3]:    
    score_df = []

    for cyp in cyp_list:    
        for metric in metrics:
            
            score_value = []
            for seed in seed_df_dict.keys():
                score_value.append(seed_df_dict[seed].query('cyp==@cyp').query('metric==@metric').query('threshold==@th').score.item())
            _,score,_ = confidence_interval(score_value)
            score_df.append({'cyp' : cyp, 'metric' : metric, 'score' : score})
    score_df = pd.DataFrame(score_df)
    score_df.to_csv( f'scores/GNN-CYPSOM_{th}.csv', index=None)

In [None]:
# +----------+------------+------------+------------+------------+-----------+-----------+-----------+----------+-----------+-----------+-----------+------------+
# | CYP      |   jac_bond |   f1s_bond |   apc_bond | n_bond     |   jac_spn |   f1s_spn |   apc_spn | n_spn    |   jac_som |   f1s_som |   apc_som | n_som      |
# +==========+============+============+============+============+===========+===========+===========+==========+===========+===========+===========+============+
# | BOM_1A2  |     0.1088 |     0.1962 |     0.0745 | 106 / 5772 |    0.3636 |    0.5333 |    0.3882 | 5 / 399  |    0.1200 |    0.2143 |    0.0871 | 111 / 6171 |
# +----------+------------+------------+------------+------------+-----------+-----------+-----------+----------+-----------+-----------+-----------+------------+
# | BOM_2A6  |     0.0897 |     0.1647 |     0.0397 | 40 / 5772  |    0.4286 |    0.6000 |    0.4429 | 3 / 399  |    0.1176 |    0.2105 |    0.0652 | 43 / 6171  |
# +----------+------------+------------+------------+------------+-----------+-----------+-----------+----------+-----------+-----------+-----------+------------+
# | BOM_2B6  |     0.0519 |     0.0988 |     0.0281 | 49 / 5772  |    0.2727 |    0.4286 |    0.2372 | 5 / 399  |    0.0667 |    0.1250 |    0.0434 | 54 / 6171  |
# +----------+------------+------------+------------+------------+-----------+-----------+-----------+----------+-----------+-----------+-----------+------------+
# | BOM_2C8  |     0.0783 |     0.1453 |     0.0458 | 91 / 5772  |    0.2000 |    0.3333 |    0.1644 | 3 / 399  |    0.0837 |    0.1545 |    0.0467 | 94 / 6171  |
# +----------+------------+------------+------------+------------+-----------+-----------+-----------+----------+-----------+-----------+-----------+------------+
# | BOM_2C9  |     0.0930 |     0.1702 |     0.0585 | 82 / 5772  |    0.5455 |    0.7059 |    0.4431 | 7 / 399  |    0.1150 |    0.2063 |    0.0776 | 89 / 6171  |
# +----------+------------+------------+------------+------------+-----------+-----------+-----------+----------+-----------+-----------+-----------+------------+
# | BOM_2C19 |     0.1045 |     0.1892 |     0.0619 | 79 / 5772  |    0.4167 |    0.5882 |    0.5252 | 7 / 399  |    0.1221 |    0.2176 |    0.0829 | 86 / 6171  |
# +----------+------------+------------+------------+------------+-----------+-----------+-----------+----------+-----------+-----------+-----------+------------+
# | BOM_2D6  |     0.1461 |     0.2549 |     0.1762 | 133 / 5772 |    0.6364 |    0.7778 |    0.7828 | 8 / 399  |    0.1655 |    0.2840 |    0.2078 | 141 / 6171 |
# +----------+------------+------------+------------+------------+-----------+-----------+-----------+----------+-----------+-----------+-----------+------------+
# | BOM_2E1  |     0.0702 |     0.1311 |     0.0793 | 47 / 5772  |    0.2000 |    0.3333 |    0.1925 | 4 / 399  |    0.0806 |    0.1493 |    0.0888 | 51 / 6171  |
# +----------+------------+------------+------------+------------+-----------+-----------+-----------+----------+-----------+-----------+-----------+------------+
# | BOM_3A4  |     0.1531 |     0.2655 |     0.1742 | 276 / 5772 |    0.3889 |    0.5600 |    0.4039 | 14 / 399 |    0.1631 |    0.2805 |    0.1814 | 290 / 6171 |
# +----------+------------+------------+------------+------------+-----------+-----------+-----------+----------+-----------+-----------+-----------+------------+

In [None]:
# +----------+-----------+-----------+-----------+------------+-----------+-----------+-----------+------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+
# | CYP      |   jac_hdx |   f1s_hdx |   apc_hdx | n_hdx      |   jac_oxi |   f1s_oxi |   apc_oxi | n_oxi      |   jac_clv |   f1s_clv |   apc_clv | n_clv     |   jac_rdc |   f1s_rdc |   apc_rdc | n_rdc    |
# +==========+===========+===========+===========+============+===========+===========+===========+============+===========+===========+===========+===========+===========+===========+===========+==========+
# | BOM_1A2  |    0.0959 |    0.1750 |    0.1147 | 56 / 1995  |    0.0893 |    0.1639 |    0.0500 | 43 / 5772  |    0.1208 |    0.2156 |    0.0841 | 37 / 3777 |    0.0000 |    0.0000 |    0.0005 | 2 / 3777 |
# +----------+-----------+-----------+-----------+------------+-----------+-----------+-----------+------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+
# | BOM_2A6  |    0.1429 |    0.2500 |    0.1487 | 17 / 1995  |    0.0833 |    0.1538 |    0.0643 | 19 / 5772  |    0.0833 |    0.1538 |    0.0398 | 21 / 3777 |   -1.0000 |    0.0000 |   -0.0000 | 0 / 3777 |
# +----------+-----------+-----------+-----------+------------+-----------+-----------+-----------+------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+
# | BOM_2B6  |    0.1111 |    0.2000 |    0.1038 | 24 / 1995  |    0.0588 |    0.1111 |    0.0267 | 22 / 5772  |    0.0309 |    0.0600 |    0.0096 | 17 / 3777 |   -1.0000 |    0.0000 |   -0.0000 | 0 / 3777 |
# +----------+-----------+-----------+-----------+------------+-----------+-----------+-----------+------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+
# | BOM_2C8  |    0.0676 |    0.1266 |    0.0730 | 59 / 1995  |    0.0714 |    0.1333 |    0.0329 | 31 / 5772  |    0.0902 |    0.1654 |    0.0642 | 26 / 3777 |   -1.0000 |    0.0000 |   -0.0000 | 0 / 3777 |
# +----------+-----------+-----------+-----------+------------+-----------+-----------+-----------+------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+
# | BOM_2C9  |    0.1311 |    0.2319 |    0.1181 | 46 / 1995  |    0.0698 |    0.1304 |    0.0295 | 31 / 5772  |    0.0833 |    0.1538 |    0.0623 | 23 / 3777 |    0.0000 |    0.0000 |    0.0008 | 3 / 3777 |
# +----------+-----------+-----------+-----------+------------+-----------+-----------+-----------+------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+
# | BOM_2C19 |    0.1200 |    0.2143 |    0.1058 | 43 / 1995  |    0.0882 |    0.1622 |    0.0747 | 31 / 5772  |    0.1181 |    0.2113 |    0.1049 | 27 / 3777 |    0.0000 |    0.0000 |    0.0005 | 2 / 3777 |
# +----------+-----------+-----------+-----------+------------+-----------+-----------+-----------+------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+
# | BOM_2D6  |    0.0851 |    0.1569 |    0.1006 | 77 / 1995  |    0.0690 |    0.1290 |    0.0687 | 48 / 5772  |    0.1597 |    0.2754 |    0.2597 | 45 / 3777 |    0.0000 |    0.0000 |    0.0008 | 3 / 3777 |
# +----------+-----------+-----------+-----------+------------+-----------+-----------+-----------+------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+
# | BOM_2E1  |    0.1034 |    0.1875 |    0.1498 | 21 / 1995  |    0.0286 |    0.0556 |    0.0378 | 24 / 5772  |    0.0702 |    0.1311 |    0.0722 | 13 / 3777 |    0.0000 |    0.0000 |    0.0005 | 2 / 3777 |
# +----------+-----------+-----------+-----------+------------+-----------+-----------+-----------+------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+
# | BOM_3A4  |    0.0606 |    0.1143 |    0.1436 | 146 / 1995 |    0.0492 |    0.0938 |    0.0565 | 112 / 5772 |    0.2000 |    0.3333 |    0.1647 | 97 / 3777 |    0.0000 |    0.0000 |    0.0577 | 9 / 3777 |
# +----------+-----------+-----------+-----------+------------+-----------+-----------+-----------+------------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+----------+