# 최초 설정ㅡ

In [59]:
import os, time, random, shutil
import argparse
from argparse import ArgumentParser, Namespace
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch_geometric.data import DataLoader
from loader import MoleculeDataset, MoleculeDataset_other

from sklearn.metrics import roc_auc_score

from model import GNN, GNN_graphpred
from util import calcul_loss, save_cp, confusion_mat, makedirs, create_logger
from splitters import scaffold_split, random_split

from rdkit import RDLogger
import logging
from logging import Logger

# i don't want see warning of torch dataset
import warnings
warnings.filterwarnings(action='ignore')

In [161]:
#arguments
parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
parser.add_argument('--device', type=int, default=0,
                    help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=32,
                    help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100,
                    help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.0001,
                    help='learning rate (default: 0.001)')
parser.add_argument('--lr_scale', type=float, default=1,
                    help='relative learning rate for the feature extraction layer (default: 1)')
parser.add_argument('--decay', type=float, default=1e-7,
                    help='weight decay (default: 0)')
parser.add_argument('--num_layer', type=int, default=5,
                    help='number of GNN message passing layers (default: 5).')
parser.add_argument('--emb_dim', type=int, default=300,
                    help='embedding dimensions (default: 300)')
parser.add_argument('--dropout_ratio', type=float, default=0.2,
                    help='dropout ratio (default: 0.5)')
parser.add_argument('--graph_pooling', type=str, default="mean",
                    help='graph level pooling (sum, mean, max, set2set, attention)')
parser.add_argument('--JK', type=str, default="last",
                    help='how the node features across layers are combined. last, sum, max or concat')
parser.add_argument('--gnn_type', type=str, default="gin")
parser.add_argument('--dataset', type=str, default = 'sider', help='root directory of dataset. For now, only classification.')
parser.add_argument('--input_model_file', type=str, default = 'pretrained.pth', help='filename to read the model (if there is any)')
parser.add_argument('--output_path', type=str, default = 'output', help='output filename')
parser.add_argument('--seed', type=int, default=3, help = "Seed for splitting the dataset.")
parser.add_argument('--runseed', type=int, default=3, help = "Seed for minibatch selection, random initialization.")
parser.add_argument('--split', type = str, default="scaffold", help = "random or scaffold or random_scaffold")
parser.add_argument('--eval_train', type=int, default = 1, help='evaluating training or not')
parser.add_argument('--num_workers', type=int, default = 4, help='number of workers for dataset loading')
# For search
parser.add_argument('--randomsearch', action='store_true', default=False, help='randomsearch mode')
parser.add_argument('--gridsearch', action='store_true', default=False, help='gridsearch mode')
parser.add_argument('--n_iters', type=int, default=1,
                    help='Number of search')
args = parser.parse_args(['--dataset','tox21', '--epochs', '2', '--output_path', 'output/tox21'])

In [162]:
args.dataset = 'toxcast'

In [163]:
dataset = MoleculeDataset_other("dataset/" + args.dataset, dataset=args.dataset)

## train

In [3]:
criterion = nn.BCEWithLogitsLoss(reduction = "none")

def train(args, model, device, loader, optimizer):
    model.train()
    
    loss_sum = 0
    iter_count = 0
    
    for step, batch in enumerate(loader):
        batch = batch.to(device)
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        y = batch.y.view(pred.shape).to(torch.float64)

        #loss matrix after removing null target
        loss = calcul_loss(pred, y, criterion)
            
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        
        loss_sum += loss
        iter_count += 1
        
    torch.cuda.empty_cache()
    return loss_sum / iter_count

## valid

In [4]:
def valid(args, model, device, loader):
    model.eval()
    y_true = []
    y_scores = []
    cum_loss = 0

    for step, batch in enumerate(loader):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            y = batch.y.view(pred.shape)
            loss = calcul_loss(pred, y, criterion)
        
        cum_loss += loss
        y_true.append(y)
        y_scores.append(pred)

    y_true = torch.cat(y_true, dim = 0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

    roc_list = []
    for i in range(y_true.shape[1]):
        #AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == -1) > 0:
            is_valid = y_true[:,i]**2 > 0
            roc_list.append(roc_auc_score((y_true[is_valid,i] + 1)/2, y_scores[is_valid,i]))

    if len(roc_list) < y_true.shape[1]:
        print("Some target is missing!")
        print("Missing ratio: %f" %(1 - float(len(roc_list))/y_true.shape[1]))
        
    torch.cuda.empty_cache()
    return cum_loss, sum(roc_list)/len(roc_list) #y_true.shape[1]

## test

In [157]:
def test(args, model, device, loader):
    
    model.eval()
    y_true = []
    y_scores = []
    cum_loss = 0

    for step, batch in enumerate(loader):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            y = batch.y.view(pred.shape)
            loss = calcul_loss(pred, y, criterion)

        y_true.append(y)
        y_scores.append(pred)
        cum_loss += loss

    y_true = torch.cat(y_true, dim = 0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

    auc_list = []
    acc_list = []
    rec_list = []
    prec_list = []
    f1s_list = []
    BA_list = []
    tp_list = []
    fp_list = []
    tn_list = []
    fn_list = []
    for i in range(y_true.shape[1]):
        auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = confusion_mat(y_true[:,i], y_scores[:,i])
        auc_list.append(auc)
        acc_list.append(acc)
        rec_list.append(rec)
        prec_list.append(prec)
        f1s_list.append(f1s)
        BA_list.append(BA)
        tp_list.append(tp)
        fp_list.append(fp)
        tn_list.append(tn)
        fn_list.append(fn)

    torch.cuda.empty_cache()
    return cum_loss, auc_list, acc_list, rec_list, prec_list, f1s_list, BA_list, tp_list, fp_list, tn_list, fn_list

In [158]:
logger = create_logger(name='train', save_dir=args.output_path, quiet=False)

In [164]:
info = logger.info if logger is not None else print
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

#set up dataset
dataset = MoleculeDataset_other("dataset/" + args.dataset, dataset=args.dataset)
num_tasks = len(dataset[0]['y'])


if args.split == "scaffold":
    smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
    train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
    info(f'scaffold_balanced_split')
elif args.split == "random":
    train_dataset, valid_dataset, test_dataset = random_split(dataset, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
    info("random")
elif args.split == "random_scaffold":
    smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
    train_dataset, valid_dataset, test_dataset = random_scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
    info("random scaffold")
else:
    raise ValueError("Invalid split option.")

info(f'total_size:{len(dataset)} train_size:{len(train_dataset)} val_size:{len(valid_dataset)} test_size:{len(test_dataset)}')

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

scaffold_balanced_split
scaffold_balanced_split
total_size:8576 train_size:6860 val_size:858 test_size:858
total_size:8576 train_size:6860 val_size:858 test_size:858


In [165]:
model.eval()
y_true = []
y_scores = []
cum_loss = 0

for step, batch in enumerate(test_loader):
    batch = batch.to(device)

    with torch.no_grad():
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        y = batch.y.view(pred.shape)
        loss = calcul_loss(pred, y, criterion)

    y_true.append(y)
    y_scores.append(pred)
    cum_loss += loss

y_true = torch.cat(y_true, dim = 0).cpu().numpy()
y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()
for i in range(y_true.shape[1]):
    try : auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = confusion_mat(y_true[:,i], y_scores[:,i])
    except : print(i)

RuntimeError: shape '[32, 12]' is invalid for input of size 19744

In [138]:
sum(y_true[:,105]==1)

0

In [79]:
toxcast=pd.read_csv('dataset/toxcast/raw/toxcast.csv')

In [105]:
valid = y_true**2>0
targets = (y_true[valid]+1)/2
preds = y_scores[valid]
hard_preds = [1 if p > 0.5 else 0 for p in preds]

In [113]:
confusion_mat(y_true, y_scores)

(0.4942134768920143,
 0.8791672250001763,
 0.0,
 0.0,
 0.0,
 0.5,
 0,
 0,
 124658,
 17133)

In [109]:
roc_auc_score(targets, hard_preds)

0.5

In [83]:
toxcast.iloc[:,:]

Unnamed: 0,smiles,ACEA_T47D_80hr_Negative,ACEA_T47D_80hr_Positive,APR_HepG2_CellCycleArrest_24h_dn,APR_HepG2_CellCycleArrest_24h_up,APR_HepG2_CellCycleArrest_72h_dn,APR_HepG2_CellLoss_24h_dn,APR_HepG2_CellLoss_72h_dn,APR_HepG2_MicrotubuleCSK_24h_dn,APR_HepG2_MicrotubuleCSK_24h_up,...,Tanguay_ZF_120hpf_OTIC_up,Tanguay_ZF_120hpf_PE_up,Tanguay_ZF_120hpf_PFIN_up,Tanguay_ZF_120hpf_PIG_up,Tanguay_ZF_120hpf_SNOU_up,Tanguay_ZF_120hpf_SOMI_up,Tanguay_ZF_120hpf_SWIM_up,Tanguay_ZF_120hpf_TRUN_up,Tanguay_ZF_120hpf_TR_up,Tanguay_ZF_120hpf_YSE_up
0,[O-][N+](=O)C1=CC=C(Cl)C=C1,0.0,0.0,,,,,,,,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,C[SiH](C)O[Si](C)(C)O[Si](C)(C)O[SiH](C)C,,,,,,,,,,...,,,,,,,,,,
2,CN1CCN(CC1)C(=O)C1CCCCC1,,,,,,,,,,...,,,,,,,,,,
3,NC1=CC=C(C=C1)[N+]([O-])=O,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,OC1=CC=C(C=C1)[N+]([O-])=O,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8571,[O-]S(=O)(=O)C(F)(F)F.CCCC[N+]1=CC=CC=C1C,,,,,,,,,,...,,,,,,,,,,
8572,F[P-](F)(F)(F)(F)F.CCCC[N+]1=CC=CC=C1C,,,,,,,,,,...,,,,,,,,,,
8573,[O-]S(=O)(=O)C(F)(F)F.CCC[N+]1(C)CCCC1,,,,,,,,,,...,,,,,,,,,,
8574,CCCCCCCCCCCCC1=CC=CC=C1S([O-])(=O)=O.CCCCCCCCC...,,,,,,,,,,...,,,,,,,,,,


In [77]:
for i in range(858):
    if sum(y_true[i])==858:
        print(f'{i} is all 1')
    elif sum(y_true[i])==0:
        print(f'{i} is all 0')

686 is all 0


## runtrain

In [12]:
def run_training(args: Namespace, logger: Logger = None):
    info = logger.info if logger is not None else print
    
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        
    #set up dataset
    dataset = MoleculeDataset_other("dataset/" + args.dataset, dataset=args.dataset)
    num_tasks = len(dataset[0]['y'])

    
    if args.split == "scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
        info(f'scaffold_balanced_split')
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(dataset, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
        info("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
        info("random scaffold")
    else:
        raise ValueError("Invalid split option.")
        
    info(f'total_size:{len(dataset)} train_size:{len(train_dataset)} val_size:{len(valid_dataset)} test_size:{len(test_dataset)}')

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

    #set up model
    model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file)
    
    model.to(device)

    #set up optimizer
    #different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({"params": model.pool.parameters(), "lr":args.lr*args.lr_scale})
    model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr":args.lr*args.lr_scale})
    optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
    info(optimizer)

    best_val_loss = 9999
    best_model_path = os.path.join(args.output_path, str(args.seed))
    for epoch in range(1, args.epochs+1):
        info("====epoch " + str(epoch))
        tst = time.time()
        train_loss = train(args, model, device, train_loader, optimizer)
        tet = time.time() - tst
        info("====Evaluation")
        vst = time.time()
        val_loss, val_auc = valid(args, model, device, val_loader)
        vet = time.time() - vst
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_cp(args, model, path=best_model_path)
        info(f'train_loss:{train_loss:.4f} val_loss:{val_loss:.4f} val_auc:{val_auc:.4f} t_time:{tet} v_time:{vet}')
    
    best_state = torch.load(os.path.join(best_model_path,'model.pt'))
    model.load_state_dict(best_state['state_dict'])
    
    test_loss, auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = test(args, model, device, test_loader)
    avg_auc = sum(auc)/num_tasks
    avg_acc = sum(acc)/num_tasks
    avg_rec = sum(rec)/num_tasks
    avg_prec = sum(prec)/num_tasks
    avg_f1s = sum(f1s)/num_tasks
    avg_BA = sum(BA)/num_tasks
    avg_tp = sum(tp)/num_tasks
    avg_fp = sum(fp)/num_tasks
    avg_tn = sum(tn)/num_tasks
    avg_fn = sum(fn)/num_tasks
        
    info(f'seed:{args.seed} loss:{test_loss} auc:{avg_auc} acc:{avg_acc} rec:{avg_rec} prec:{avg_prec} f1:{avg_f1s} BA:{avg_BA}\ntp:{avg_tp} fp:{avg_fp} fn:{avg_fn} tn:{avg_tn}')
    #delete for memory
    del train_dataset, valid_dataset, test_dataset, train_loader, val_loader, test_loader

    return avg_auc, avg_acc, avg_rec, avg_prec, avg_f1s, avg_BA, avg_tp, avg_fp, avg_tn, avg_fn

# cross_validate

In [13]:
def cross_validate(args: Namespace, logger: Logger = None):
    info = logger.info if logger is not None else print
    
    if not os.path.exists(args.output_path):
        os.makedirs(args.output_path)
    auc_list = []
    acc_list = []
    rec_list = []
    prec_list = []
    f1s_list = []
    BA_list = []
    tp_list = []
    fp_list = []
    tn_list = []
    fn_list = []
    for k in range(3):
        auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = run_training(args)
        auc_list.append(auc)
        acc_list.append(acc)
        rec_list.append(rec)
        prec_list.append(prec)
        f1s_list.append(f1s)
        BA_list.append(BA)
        tp_list.append(tp)
        fp_list.append(fp)
        tn_list.append(tn)
        fn_list.append(fn)
        args.seed += 1
    info(f'all test end')
    info(f'overall test_auc : {np.nanmean(auc_list):.4f}\nstd={np.nanstd(auc_list):.4f}')
    info(f'overall test_accuracy : {np.nanmean(acc_list):.4f}\nstd={np.nanstd(acc_list):.4f}')
    info(f'overall test_recall : {np.nanmean(rec_list):.4f}\nstd={np.nanstd(rec_list):.4f}')
    info(f'overall test_precision : {np.nanmean(prec_list):.4f}\nstd={np.nanstd(prec_list):.4f}')
    info(f'overall test_f1score : {np.nanmean(f1s_list):.4f}\nstd={np.nanstd(f1s_list):.4f}')
    info(f'overall test_Balanced_Accuracy : {np.nanmean(BA_list):.4f}\nstd={np.nanstd(BA_list):.4f}')
    info(f'overall test_tp : {np.nanmean(tp_list):.2f}\nstd={np.nanstd(tp_list):.2f}')
    info(f'overall test_fp : {np.nanmean(fp_list):.2f}\nstd={np.nanstd(fp_list):.2f}')
    info(f'overall test_fn : {np.nanmean(fn_list):.2f}\nstd={np.nanstd(fn_list):.2f}')
    info(f'overall test_tn : {np.nanmean(tn_list):.2f}\nstd={np.nanstd(tn_list):.2f}')
    info(f'\n       (pred)pos    neg(pred)')
    info(f'pos(true)    {tp:.2f}  {fn:.2f}')
    info(f'neg(true)    {fp:.2f}  {tn:.2f}')
    
    return np.nanmean(auc_list)

# random_search

In [14]:
def random_search(args: Namespace, logger: Logger = None):
    info = logger.info if logger is not None else print
    
    init_seed = args.seed
    save_dir = args.output_path

    #randomize parameter list
    lr_list = [0.0005, 0.00075, 0.001, 0.00125, 0.0015, 0.00175, 0.002]
    dropout_list = [0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
    gpooling_list = ['mean', 'sum']
    lr_scale_list = [0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5]

    # Run training with different random seeds for each fold
    all_scores = []
    params = []
    for iter_num in range(0, args.n_iters):
        info(f'iter {iter_num}')

        #randomize parameter
        np.random.seed()
        random.seed()
        args.lr = np.random.choice(lr_list, 1)[0]
        args.dropout_ratio = np.random.choice(dropout_list, 1)[0]
        params.append(f'\n{iter_num}th search parameter : lr is {args.lr} \n dropout is {args.dropout_ratio} \n batch_size is {args.batch_size}')
        info(params[iter_num])

        args.seed = init_seed                        # if change this, result will be change
        iter_dir = os.path.join(save_dir, f'iter_{iter_num}')
        args.output_path = iter_dir
        makedirs(args.output_path)

        iter_score = cross_validate(args, logger)
        all_scores.append(iter_score)

        if max(all_scores)==iter_score : 
            best_iter = iter_num
            best_score = iter_score
            best_param = params[iter_num]
############iter end

    all_scores = np.array(all_scores)

    # Report scores for each iter
    info(f'\n---- {args.n_iters}-iter random search ----')

    for iter_num, scores in enumerate(all_scores):
        info(params[iter_num])
        info(f'Seed {init_seed} ==> test AUC = {np.nanmean(scores):.6f}\n')

    # Report best model
    info(f'\nbest_iter : {best_iter}\nbest_score is {np.nanmean(best_score)}\nbest_param : {best_param}')

In [15]:
gpooling_list = ['mean', 'sum']
JK_list = ['last', 'sum', 'concat']

In [17]:
np.random.choice(JK_list, 1)[0]

'last'

# main

In [18]:
from rdkit import RDLogger
import logging
from logging import Logger

In [20]:
from util import calcul_loss, save_cp, confusion_mat, makedirs, create_logger

In [28]:
import warnings
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
warnings.filterwarnings(action='ignore')
logger = create_logger(name='train', save_dir=args.output_path, quiet=False)
if args.randomsearch:
    best_metric = 0
    random_search(args, logger)
else : 
    cross_validate(args, logger)

scaffold_balanced_split
total_size:7831 train_size:6264 val_size:783 test_size:784
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)
====epoch 1


KeyboardInterrupt: 

# Dataset

In [3]:
import os
import torch
import pickle
import collections
import math
import pandas as pd
import numpy as np

from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
from torch.utils import data
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Batch

from itertools import repeat, product, chain

In [4]:
def _load_sider_dataset(input_path):
    """

    :param input_path:
    :return: list of smiles, list of rdkit mol obj, np.array containing the
    labels
    """
    input_df = pd.read_csv(input_path, sep=',')
    smiles_list = input_df['smiles']
    rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
    tasks = ['Hepatobiliary disorders',
       'Metabolism and nutrition disorders', 'Product issues', 'Eye disorders',
       'Investigations', 'Musculoskeletal and connective tissue disorders',
       'Gastrointestinal disorders', 'Social circumstances',
       'Immune system disorders', 'Reproductive system and breast disorders',
       'Neoplasms benign, malignant and unspecified (incl cysts and polyps)',
       'General disorders and administration site conditions',
       'Endocrine disorders', 'Surgical and medical procedures',
       'Vascular disorders', 'Blood and lymphatic system disorders',
       'Skin and subcutaneous tissue disorders',
       'Congenital, familial and genetic disorders',
       'Infections and infestations',
       'Respiratory, thoracic and mediastinal disorders',
       'Psychiatric disorders', 'Renal and urinary disorders',
       'Pregnancy, puerperium and perinatal conditions',
       'Ear and labyrinth disorders', 'Cardiac disorders',
       'Nervous system disorders',
       'Injury, poisoning and procedural complications']
    labels = input_df[tasks]
    # convert 0 to -1
    labels = labels.replace(0, -1)
    assert len(smiles_list) == len(rdkit_mol_objs_list)
    assert len(smiles_list) == len(labels)
    return smiles_list, rdkit_mol_objs_list, labels.values

## 2D Mol에 대한 Dataset 생성코드
- atomfeature는 2개 밖에 없다. atom type, chirality tag
- bondfeature도 2개 밖에 없다. bond type, bond direction

In [4]:
allowable_features = {
    'possible_atomic_num_list' : list(range(1, 119)),
    'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
    'possible_chirality_list' : [
        Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
        Chem.rdchem.ChiralType.CHI_OTHER
    ],
    'possible_hybridization_list' : [
        Chem.rdchem.HybridizationType.S,
        Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED
    ],
    'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8],
    'possible_implicit_valence_list' : [0, 1, 2, 3, 4, 5, 6],
    'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    'possible_bonds' : [
        Chem.rdchem.BondType.SINGLE,
        Chem.rdchem.BondType.DOUBLE,
        Chem.rdchem.BondType.TRIPLE,
        Chem.rdchem.BondType.AROMATIC
    ],
    'possible_bond_dirs' : [ # only for double bond stereo information
        Chem.rdchem.BondDir.NONE,
        Chem.rdchem.BondDir.ENDUPRIGHT,
        Chem.rdchem.BondDir.ENDDOWNRIGHT
    ]
}

In [5]:
def mol_to_graph_data_obj_simple(mol):
    """
    Converts rdkit mol object to graph Data object required by the pytorch
    geometric package. NB: Uses simplified atom and bond features, and represent
    as indices
    :param mol: rdkit mol object
    :return: graph data object with the attributes: x, edge_index, edge_attr
    """
    # atoms
    num_atom_features = 2   # atom type,  chirality tag
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_feature = [allowable_features['possible_atomic_num_list'].index(
            atom.GetAtomicNum())] + [allowable_features[
            'possible_chirality_list'].index(atom.GetChiralTag())]
        atom_features_list.append(atom_feature)
    x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

    # bonds
    num_bond_features = 2   # bond type, bond direction
    if len(mol.GetBonds()) > 0: # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_feature = [allowable_features['possible_bonds'].index(
                bond.GetBondType())] + [allowable_features[
                                            'possible_bond_dirs'].index(
                bond.GetBondDir())]
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = torch.tensor(np.array(edge_features_list),
                                 dtype=torch.long)
    else:   # mol has no bonds
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    return data

### mol_to_graph 세부 시행

In [6]:
mol = AllChem.MolFromSmiles('N(CC1=CC=CC=C1)(CCC#N)C2=CC=C(N=NC3=CC=C(N(=O)=O)C=C3)C=C2')
num_atom_features = 2   # atom type,  chirality tag
atom_features_list = []
for atom in mol.GetAtoms():
    atom_feature = [allowable_features['possible_atomic_num_list'].index(
        atom.GetAtomicNum())] + [allowable_features[
        'possible_chirality_list'].index(atom.GetChiralTag())]
    atom_features_list.append(atom_feature)
x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
x

tensor([[6, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [6, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [6, 0],
        [6, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [6, 0],
        [7, 0],
        [7, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0]])

In [7]:
# bonds
num_bond_features = 2   # bond type, bond direction
bonds = []
if len(mol.GetBonds()) > 0: # mol has bonds
    edges_list = []
    edge_features_list = []
    for bond in mol.GetBonds():
        bonds.append(bond)
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_feature = [allowable_features['possible_bonds'].index(
            bond.GetBondType())] + [allowable_features[
                                        'possible_bond_dirs'].index(
            bond.GetBondDir())]
        edges_list.append((i, j))
        edge_features_list.append(edge_feature)
        edges_list.append((j, i))
        edge_features_list.append(edge_feature)

    # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
    edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

    # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
    edge_attr = torch.tensor(np.array(edge_features_list),
                             dtype=torch.long)
else:   # mol has no bonds
    edge_index = torch.empty((2, 0), dtype=torch.long)
    edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)

In [8]:
allowable_features['possible_bonds'].index(bonds[0].GetBondType())

0

In [9]:
[allowable_features['possible_bond_dirs'].index(bonds[0].GetBondDir())]

[0]

In [10]:
Chem.rdchem.Bond?

[0;31mDocstring:[0m     
The class to store Bonds.
Note: unlike Atoms, is it currently impossible to construct Bonds from
Python.
[0;31mInit docstring:[0m
Raises an exception
This class cannot be instantiated from Python
[0;31mFile:[0m           /opt/conda/lib/python3.7/site-packages/rdkit/Chem/rdchem.so
[0;31mType:[0m           class
[0;31mSubclasses:[0m     QueryBond


In [11]:
mol = AllChem.MolFromSmiles('N(CC1=CC=CC=C1)(CCC#N)C2=CC=C(N=NC3=CC=C(N(=O)=O)C=C3)C=C2')
graph = mol_to_graph_data_obj_simple(mol)
graph
# dataset 대비 id, y만 차이난다.
# id는 말그대로 dataset에서 몇번째인지, y는 label을 의미한다.
# edge_attr : Edge feature matrix with shape [num_edges, num_edge_features]
# edge_index: Graph connectivity in COO format with shape [2, num_edges]
# x : atomfeature으로 atom번호, chirality

Data(x=[29, 2], edge_index=[2, 62], edge_attr=[62, 2])

In [12]:
graph.edge_attr

tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [2, 0],
        [2, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [1, 0],
        [1, 0],
        [0, 0],
        [0, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [1, 0],
        [1, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0],
        [3, 0]])

In [13]:
graph.edge_index

tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  0,  8,  8,  9,
          9, 10, 10, 11,  0, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18,
         18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 22, 24, 21, 25, 25, 26, 15, 27,
         27, 28,  7,  2, 28, 12, 26, 18],
        [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  0,  9,  8,
         10,  9, 11, 10, 12,  0, 13, 12, 14, 13, 15, 14, 16, 15, 17, 16, 18, 17,
         19, 18, 20, 19, 21, 20, 22, 21, 23, 22, 24, 22, 25, 21, 26, 25, 27, 15,
         28, 27,  2,  7, 12, 28, 18, 26]])

In [14]:
model.to('cpu')
model.gnn.x_embedding1

NameError: name 'model' is not defined

In [15]:
emb1 = model.gnn.x_embedding1(graph.edge_attr[:,0])
emb1[0]

NameError: name 'model' is not defined

In [183]:
emb1.shape

torch.Size([62, 300])

In [153]:
graph.edge_index

tensor([[ 0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  0,  8,  8,  9,
          9, 10, 10, 11,  0, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18,
         18, 19, 19, 20, 20, 21, 21, 22, 22, 23, 22, 24, 21, 25, 25, 26, 15, 27,
         27, 28,  7,  2, 28, 12, 26, 18],
        [ 1,  0,  2,  1,  3,  2,  4,  3,  5,  4,  6,  5,  7,  6,  8,  0,  9,  8,
         10,  9, 11, 10, 12,  0, 13, 12, 14, 13, 15, 14, 16, 15, 17, 16, 18, 17,
         19, 18, 20, 19, 21, 20, 22, 21, 23, 22, 24, 22, 25, 21, 26, 25, 27, 15,
         28, 27,  2,  7, 12, 28, 18, 26]])

In [154]:
graph.x

tensor([[6, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [6, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [6, 0],
        [6, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [6, 0],
        [7, 0],
        [7, 0],
        [5, 0],
        [5, 0],
        [5, 0],
        [5, 0]])

## change load dataset

In [5]:
def _load_other_dataset(input_path):
    """
    this is for loading other datasets
    """
    df = pd.read_csv(input_path)
    smiles_list = df.iloc[:,0]
    rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list]
    tasks = list(df.columns[1:])
    labels = df[tasks]
    labels = labels.replace(0, -1)
    assert len(smiles_list) == len(rdkit_mol_objs_list)
    assert len(smiles_list) == len(labels)
    return smiles_list, rdkit_mol_objs_list, labels.values

In [6]:
from typing import List, Set, Tuple, Union, Dict
from argparse import Namespace
import logging
from logging import Logger

In [7]:
class MoleculeDataset_other(InMemoryDataset):
    def __init__(self,
                 root,
                 #data = None,
                 #slices = None,
                 transform=None,
                 pre_transform=None,
                 pre_filter=None,
                 dataset='zinc250k',
                 empty=False):
        """
        Adapted from qm9.py. Disabled the download functionality
        :param root: directory of the dataset, containing a raw and processed
        dir. The raw dir should contain the file containing the smiles, and the
        processed dir can either empty or a previously processed file
        :param dataset: name of the dataset. Currently only implemented for
        zinc250k, chembl_with_labels, tox21, hiv, bace, bbbp, clintox, esol,
        freesolv, lipophilicity, muv, pcba, sider, toxcast
        :param empty: if True, then will not load any data obj. For
        initializing empty dataset
        """
        self.dataset = dataset
        self.root = root
        self.smiles = []

        super(MoleculeDataset_other, self).__init__(root, transform, pre_transform,
                                                 pre_filter)
        self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter
        if not empty:
            self.data, self.slices = torch.load(self.processed_paths[0])

    def get(self, idx):
        data = Data()
        for key in self.data.keys:
            item, slices = self.data[key], self.slices[key]
            s = list(repeat(slice(None), item.dim()))
            s[data.__cat_dim__(key, item)] = slice(slices[idx],
                                                    slices[idx + 1])
            data[key] = item[s]
        return data

    @property
    def raw_file_names(self):
        file_name_list = os.listdir(self.raw_dir)
        # assert len(file_name_list) == 1     # currently assume we have a
        # # single raw file
        return file_name_list

    @property
    def processed_file_names(self):
        return 'geometric_data_processed.pt'

    def download(self):
        raise NotImplementedError('Must indicate valid location of raw data. '
                                  'No download allowed')
        
    def process(self):
        data_smiles_list = []
        data_list = []
        
        smiles_list, rdkit_mol_objs, labels = \
            _load_other_dataset(self.raw_paths[0])
        for i in range(len(smiles_list)):
            #print(i)
            rdkit_mol = rdkit_mol_objs[i]
            if rdkit_mol != None:
                data = mol_to_graph_data_obj_simple(rdkit_mol)
                # manually add mol id
                data.id = torch.tensor(
                    [i])  # id here is the index of the mol in
                # the dataset
                data.y = torch.tensor(labels[i, :])
                data_list.append(data)
                data_smiles_list.append(smiles_list[i])

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        # write data_smiles_list in processed paths
        data_smiles_series = pd.Series(data_smiles_list)
        data_smiles_series.to_csv(os.path.join(self.processed_dir,
                                               'smiles.csv'), index=False,
                                  header=False)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        
        self.smiles = smiles_list
        self.data_list = data_list

In [12]:
data = MoleculeDataset_other(root='dataset/tox21', dataset='tox21')

In [28]:
smiles_list2 = pd.read_csv('dataset/' + 'tox21' + '/processed/smiles.csv', header=None)[0].tolist()
num_tasks = len(data[0]['y'])
train_dataset, valid_dataset, test_dataset = scaffold_split(data, smiles_list2, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)



In [240]:
from splitters import scaffold_split, random_split

In [16]:
dataset = MoleculeDataset('dataset/tox21', dataset='tox21')

In [23]:
smiles_list = pd.read_csv('dataset/' + 'tox21' + '/processed/smiles.csv', header=None)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)



# Split

In [36]:
from rdkit.Chem.Scaffolds import MurckoScaffold
from itertools import compress

In [37]:
def generate_scaffold(smiles, include_chirality=False):
    """
    Obtain Bemis-Murcko scaffold from smiles
    :param smiles:
    :param include_chirality:
    :return: smiles of scaffold
    """
    scaffold = MurckoScaffold.MurckoScaffoldSmiles(
        smiles=smiles, includeChirality=include_chirality)
    return scaffold

In [38]:
def scaffold_split(dataset, smiles_list, task_idx=None, null_value=0,
                   frac_train=0.8, frac_valid=0.1, frac_test=0.1,
                   return_smiles=False):
    """
    Adapted from  https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py
    Split dataset by Bemis-Murcko scaffolds
    This function can also ignore examples containing null values for a
    selected task when splitting. Deterministic split
    :param dataset: pytorch geometric dataset obj
    :param smiles_list: list of smiles corresponding to the dataset obj
    :param task_idx: column idx of the data.y tensor. Will filter out
    examples with null value in specified task column of the data.y tensor
    prior to splitting. If None, then no filtering
    :param null_value: float that specifies null value in data.y to filter if
    task_idx is provided
    :param frac_train:
    :param frac_valid:
    :param frac_test:
    :param return_smiles:
    :return: train, valid, test slices of the input dataset obj. If
    return_smiles = True, also returns ([train_smiles_list],
    [valid_smiles_list], [test_smiles_list])
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)

    if task_idx != None:
        # filter based on null values in task_idx
        # get task array
        y_task = np.array([data.y[task_idx].item() for data in dataset])
        # boolean array that correspond to non null values
        non_null = y_task != null_value
        smiles_list = list(compress(enumerate(smiles_list), non_null))
    else:
        non_null = np.ones(len(dataset)) == 1
        smiles_list = list(compress(enumerate(smiles_list), non_null))

    # create dict of the form {scaffold_i: [idx1, idx....]}
    all_scaffolds = {}
    for i, smiles in smiles_list:
        scaffold = generate_scaffold(smiles, include_chirality=True)
        if scaffold not in all_scaffolds:
            all_scaffolds[scaffold] = [i]
        else:
            all_scaffolds[scaffold].append(i)

    # sort from largest to smallest sets
    all_scaffolds = {key: sorted(value) for key, value in all_scaffolds.items()}
    all_scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            all_scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]

    # get train, valid test indices
    train_cutoff = frac_train * len(smiles_list)
    valid_cutoff = (frac_train + frac_valid) * len(smiles_list)
    train_idx, valid_idx, test_idx = [], [], []
    for scaffold_set in all_scaffold_sets:
        if len(train_idx) + len(scaffold_set) > train_cutoff:
            if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff:
                test_idx.extend(scaffold_set)
            else:
                valid_idx.extend(scaffold_set)
        else:
            train_idx.extend(scaffold_set)

    assert len(set(train_idx).intersection(set(valid_idx))) == 0
    assert len(set(test_idx).intersection(set(valid_idx))) == 0

    train_dataset = dataset[torch.tensor(train_idx)]
    valid_dataset = dataset[torch.tensor(valid_idx)]
    test_dataset = dataset[torch.tensor(test_idx)]

    if not return_smiles:
        return train_dataset, valid_dataset, test_dataset
    else:
        train_smiles = [smiles_list[i][1] for i in train_idx]
        valid_smiles = [smiles_list[i][1] for i in valid_idx]
        test_smiles = [smiles_list[i][1] for i in test_idx]
        return train_dataset, valid_dataset, test_dataset, (train_smiles,
                                                            valid_smiles,
                                                            test_smiles)

In [39]:
#split
smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
print("scaffold")



scaffold


In [40]:
print(train_dataset[0])

Data(edge_attr=[24, 2], edge_index=[2, 24], id=[1], x=[13, 2], y=[27])


# Dataloader

In [64]:
from collections.abc import Mapping, Sequence

import torch.utils.data
from torch.utils.data.dataloader import default_collate

from torch_geometric.data import Data, Batch

In [67]:
class Collater(object):
    def __init__(self, follow_batch, exclude_keys):
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def collate(self, batch):
        elem = batch[0]
        if isinstance(elem, Data):
            return Batch.from_data_list(batch, self.follow_batch,
                                        self.exclude_keys)
        elif isinstance(elem, torch.Tensor):
            return default_collate(batch)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int):
            return torch.tensor(batch)
        elif isinstance(elem, Mapping):
            return {key: self.collate([d[key] for d in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            return type(elem)(*(self.collate(s) for s in zip(*batch)))
        elif isinstance(elem, Sequence) and not isinstance(elem, str):
            return [self.collate(s) for s in zip(*batch)]

        raise TypeError('DataLoader found invalid type: {}'.format(type(elem)))

    def __call__(self, batch):
        return self.collate(batch)

In [68]:
#from torch_geometric.data import DataLoader
class DataLoader(torch.utils.data.DataLoader):
    r"""Data loader which merges data objects from a
    :class:`torch_geometric.data.dataset` to a mini-batch.

    Args:
        dataset (Dataset): The dataset from which to load the data.
        batch_size (int, optional): How many samples per batch to load.
            (default: :obj:`1`)
        shuffle (bool, optional): If set to :obj:`True`, the data will be
            reshuffled at every epoch. (default: :obj:`False`)
        follow_batch (list or tuple, optional): Creates assignment batch
            vectors for each key in the list. (default: :obj:`[]`)
        exclude_keys (list or tuple, optional): Will exclude each key in the
            list. (default: :obj:`[]`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`.
    """
    def __init__(self, dataset, batch_size=1, shuffle=False, follow_batch=[],
                 exclude_keys=[], **kwargs):

        if "collate_fn" in kwargs:
            del kwargs["collate_fn"]

        # Save for PyTorch Lightning...
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        super(DataLoader,
              self).__init__(dataset, batch_size, shuffle,
                             collate_fn=Collater(follow_batch,
                                                 exclude_keys), **kwargs)

In [95]:
train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True, num_workers = args.num_workers)
val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

In [96]:
for step, batch in enumerate(tqdm(train_loader, desc="Iteration")):
    if step==1 : break
    else : print(batch)

Iteration:   1%|▊                                                                        | 1/96 [00:00<00:11,  7.98it/s]

Batch(batch=[202], edge_attr=[422, 2], edge_index=[2, 422], id=[12], ptr=[13], x=[202, 2], y=[324])





In [101]:
train_dataset[2] # 얘는 Train_dataset에서 완성

5

In [210]:
model(batch)

tensor([[ 0.1360,  0.6997, -2.6893,  0.5936,  1.3722,  0.9739,  2.0579, -1.3538,
          0.9538,  0.2211, -1.0589,  2.0721, -1.3082, -1.4309,  1.3957,  0.4801,
          2.3593, -1.4051,  0.6639,  1.1736,  1.1822,  0.5815, -1.8966,  0.1475,
          1.1561,  2.2287,  0.7384],
        [ 0.3356,  0.8435, -2.6948,  0.7362,  1.5192,  1.0961,  2.2009, -1.1727,
          1.0461,  0.4092, -0.9925,  2.1748, -1.1748, -1.3428,  1.4990,  0.5909,
          2.5626, -1.4138,  0.7727,  1.2638,  1.3637,  0.7326, -1.8489,  0.2545,
          1.2724,  2.4541,  0.8209],
        [-0.1841,  1.0998, -2.5338,  0.3811,  1.1476,  0.5031,  1.4879, -1.6094,
          1.1318, -0.2442, -0.9821,  1.9916, -0.7661, -1.6970,  1.0296,  0.0738,
          2.2346, -1.4210,  1.6177,  1.1372,  0.5199,  0.7007, -1.5090, -0.4171,
          0.6175,  1.8873,  0.5372],
        [-0.0212,  0.8391, -2.4188,  0.4983,  1.2013,  0.7632,  1.6879, -1.3786,
          0.9997,  0.0719, -0.8919,  1.9005, -0.8378, -1.4584,  1.1671,  0.2381

# Model

In [197]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, softmax
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_scatter import scatter_add
from torch_geometric.nn.inits import glorot, zeros


In [198]:
class GINConv(MessagePassing):
    """
    Extension of GIN aggregation to incorporate edge information by concatenation.

    Args:
        emb_dim (int): dimensionality of embeddings for nodes and edges.
        embed_input (bool): whether to embed input or not. 
        

    See https://arxiv.org/abs/1810.00826
    """
    def __init__(self, emb_dim, aggr = "add"):
        super(GINConv, self).__init__()
        #multi-layer perceptron
        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.edge_embedding1 = torch.nn.Embedding(num_bond_type, emb_dim)
        self.edge_embedding2 = torch.nn.Embedding(num_bond_direction, emb_dim)

        torch.nn.init.xavier_uniform_(self.edge_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.edge_embedding2.weight.data)
        self.aggr = aggr

    def forward(self, x, edge_index, edge_attr):
        #add self loops in the edge space
        edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

        #add features corresponding to self-loop edges.
        self_loop_attr = torch.zeros(x.size(0), 2)
        self_loop_attr[:,0] = 4 #bond type for self-loop edge
        self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
        edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

        edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])

        return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings)

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

    def update(self, aggr_out):
        return self.mlp(aggr_out)

In [200]:
class GNN(torch.nn.Module):
    """
    

    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        JK (str): last, concat, max or sum.
        max_pool_layer (int): the layer from which we use max pool rather than add pool for neighbor aggregation
        drop_ratio (float): dropout rate
        gnn_type: gin, gcn, graphsage, gat

    Output:
        node representations

    """
    def __init__(self, num_layer, emb_dim, JK = "last", drop_ratio = 0, gnn_type = "gin"):
        super(GNN, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.x_embedding1 = torch.nn.Embedding(num_atom_type, emb_dim)
        self.x_embedding2 = torch.nn.Embedding(num_chirality_tag, emb_dim)

        torch.nn.init.xavier_uniform_(self.x_embedding1.weight.data)
        torch.nn.init.xavier_uniform_(self.x_embedding2.weight.data)

        ###List of MLPs
        self.gnns = torch.nn.ModuleList()
        for layer in range(num_layer):
            if gnn_type == "gin":
                self.gnns.append(GINConv(emb_dim, aggr = "add"))
            elif gnn_type == "gcn":
                self.gnns.append(GCNConv(emb_dim))
            elif gnn_type == "gat":
                self.gnns.append(GATConv(emb_dim))
            elif gnn_type == "graphsage":
                self.gnns.append(GraphSAGEConv(emb_dim))

        ###List of batchnorms
        self.batch_norms = torch.nn.ModuleList()
        for layer in range(num_layer):
            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    #def forward(self, x, edge_index, edge_attr):
    def forward(self, *argv):
        if len(argv) == 3:
            x, edge_index, edge_attr = argv[0], argv[1], argv[2]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        else:
            raise ValueError("unmatched number of arguments.")

        x = self.x_embedding1(x[:,0]) + self.x_embedding2(x[:,1])

        h_list = [x]
        for layer in range(self.num_layer):
            h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            #h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            if layer == self.num_layer - 1:
                #remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training = self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim = 1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.max(torch.cat(h_list, dim = 0), dim = 0)[0]
        elif self.JK == "sum":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim = 0), dim = 0)[0]

        return node_representation

In [199]:
class GNN_graphpred(torch.nn.Module):
    """
    Extension of GIN to incorporate edge information by concatenation.

    Args:
        num_layer (int): the number of GNN layers
        emb_dim (int): dimensionality of embeddings
        num_tasks (int): number of tasks in multi-task learning scenario
        drop_ratio (float): dropout rate
        JK (str): last, concat, max or sum.
        graph_pooling (str): sum, mean, max, attention, set2set
        gnn_type: gin, gcn, graphsage, gat
        
    See https://arxiv.org/abs/1810.00826
    JK-net: https://arxiv.org/abs/1806.03536
    """
    def __init__(self, num_layer, emb_dim, num_tasks, JK = "last", drop_ratio = 0, graph_pooling = "mean", gnn_type = "gin"):
        super(GNN_graphpred, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.gnn = GNN(num_layer, emb_dim, JK, drop_ratio, gnn_type = gnn_type)

        #Different kind of graph pooling
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        elif graph_pooling == "attention":
            if self.JK == "concat":
                self.pool = GlobalAttention(gate_nn = torch.nn.Linear((self.num_layer + 1) * emb_dim, 1))
            else:
                self.pool = GlobalAttention(gate_nn = torch.nn.Linear(emb_dim, 1))
        elif graph_pooling[:-1] == "set2set":
            set2set_iter = int(graph_pooling[-1])
            if self.JK == "concat":
                self.pool = Set2Set((self.num_layer + 1) * emb_dim, set2set_iter)
            else:
                self.pool = Set2Set(emb_dim, set2set_iter)
        else:
            raise ValueError("Invalid graph pooling type.")

        #For graph-level binary classification
        if graph_pooling[:-1] == "set2set":
            self.mult = 2
        else:
            self.mult = 1
        
        if self.JK == "concat":
            self.graph_pred_linear = torch.nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.mult * self.emb_dim, self.num_tasks)

    def from_pretrained(self, model_file):
        #self.gnn = GNN(self.num_layer, self.emb_dim, JK = self.JK, drop_ratio = self.drop_ratio)
        self.gnn.load_state_dict(torch.load(model_file))

    def forward(self, *argv):
        if len(argv) == 4:
            x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        else:
            raise ValueError("unmatched number of arguments.")

        node_representation = self.gnn(x, edge_index, edge_attr)

        return self.graph_pred_linear(self.pool(node_representation, batch))

In [45]:
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type)
model.to(device)

GNN_graphpred(
  (gnn): GNN(
    (x_embedding1): Embedding(120, 300)
    (x_embedding2): Embedding(3, 300)
    (gnns): ModuleList(
      (0): GINConv(
        (mlp): Sequential(
          (0): Linear(in_features=300, out_features=600, bias=True)
          (1): ReLU()
          (2): Linear(in_features=600, out_features=300, bias=True)
        )
        (edge_embedding1): Embedding(6, 300)
        (edge_embedding2): Embedding(3, 300)
      )
      (1): GINConv(
        (mlp): Sequential(
          (0): Linear(in_features=300, out_features=600, bias=True)
          (1): ReLU()
          (2): Linear(in_features=600, out_features=300, bias=True)
        )
        (edge_embedding1): Embedding(6, 300)
        (edge_embedding2): Embedding(3, 300)
      )
      (2): GINConv(
        (mlp): Sequential(
          (0): Linear(in_features=300, out_features=600, bias=True)
          (1): ReLU()
          (2): Linear(in_features=600, out_features=300, bias=True)
        )
        (edge_embedding1): E

In [47]:
model_param_group = []
model_param_group.append({"params": model.gnn.parameters()})
if args.graph_pooling == "attention":
    model_param_group.append({"params": model.pool.parameters(), "lr":args.lr*args.lr_scale})
model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr":args.lr*args.lr_scale})
optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
print(optimizer)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)


# cross_validate

In [16]:
criterion = nn.BCEWithLogitsLoss(reduction = "none")

In [52]:
def train(args, model, device, loader, optimizer):
    model.train()
    
    loss_sum = 0
    iter_count = 0
    
    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
        y = batch.y.view(pred.shape).to(torch.float64)

        #loss matrix after removing null target
        loss = calcul_loss(pred, y, criterion)
            
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        
        loss_sum += loss
        iter_count += 1
        
    return loss_sum / iter_count

In [53]:
def valid(args, model, device, loader):
    model.eval()
    y_true = []
    y_scores = []
    cum_loss = 0

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            y = batch.y.view(pred.shape)
            loss = calcul_loss(pred, y, criterion)
        
        cum_loss += loss
        y_true.append(y)
        y_scores.append(pred)

    y_true = torch.cat(y_true, dim = 0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

    roc_list = []
    for i in range(y_true.shape[1]):
        #AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:,i] == 1) > 0 and np.sum(y_true[:,i] == -1) > 0:
            is_valid = y_true[:,i]**2 > 0
            roc_list.append(roc_auc_score((y_true[is_valid,i] + 1)/2, y_scores[is_valid,i]))

    if len(roc_list) < y_true.shape[1]:
        print("Some target is missing!")
        print("Missing ratio: %f" %(1 - float(len(roc_list))/y_true.shape[1]))

    return cum_loss, sum(roc_list)/len(roc_list) #y_true.shape[1]

In [54]:
def test(args, model, device, loader):
    model.eval()
    y_true = []
    y_scores = []
    cum_loss = 0

    for step, batch in enumerate(tqdm(test_loader, desc="Iteration")):
        batch = batch.to(device)

        with torch.no_grad():
            pred = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
            y = batch.y.view(pred.shape)
            loss = calcul_loss(pred, y, criterion)

        y_true.append(y)
        y_scores.append(pred)
        cum_loss += loss

    y_true = torch.cat(y_true, dim = 0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim = 0).cpu().numpy()

    auc_list = []
    acc_list = []
    rec_list = []
    prec_list = []
    f1s_list = []
    BA_list = []
    tp_list = []
    fp_list = []
    tn_list = []
    fn_list = []
    for i in range(y_true.shape[1]):
        auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = confusion_mat(y_true[:,i], y_scores[:,i])
        auc_list.append(auc)
        acc_list.append(acc)
        rec_list.append(rec)
        prec_list.append(prec)
        f1s_list.append(f1s)
        BA_list.append(BA)
        tp_list.append(tp)
        fp_list.append(fp)
        tn_list.append(tn)
        fn_list.append(fn)

    return cum_loss, auc_list, acc_list, rec_list, prec_list, f1s_list, BA_list, tp_list, fp_list, tn_list, fn_list

In [55]:
def cross_validate(args):
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        
    #set up dataset
    dataset = MoleculeDataset_other("dataset/" + args.dataset, dataset=args.dataset)
    num_tasks = len(dataset[0]['y'])

    
    if args.split == "scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
        print(f'scaffold_balanced_split')
    elif args.split == "random":
        train_dataset, valid_dataset, test_dataset = random_split(dataset, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
        print("random")
    elif args.split == "random_scaffold":
        smiles_list = pd.read_csv('dataset/' + args.dataset + '/processed/smiles.csv', header=None)[0].tolist()
        train_dataset, valid_dataset, test_dataset = random_scaffold_split(dataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1, seed = args.seed)
        print("random scaffold")
    else:
        raise ValueError("Invalid split option.")
        
    print(f'total_size:{len(dataset)} train_size:{len(train_dataset)} val_size:{len(valid_dataset)} test_size:{len(test_dataset)}')

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers)
    val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers)

    #set up model
    model = GNN_graphpred(args.num_layer, args.emb_dim, num_tasks, JK = args.JK, drop_ratio = args.dropout_ratio, graph_pooling = args.graph_pooling, gnn_type = args.gnn_type)
    if not args.input_model_file == "":
        model.from_pretrained(args.input_model_file)
    
    model.to(device)

    #set up optimizer
    #different learning rate for different part of GNN
    model_param_group = []
    model_param_group.append({"params": model.gnn.parameters()})
    if args.graph_pooling == "attention":
        model_param_group.append({"params": model.pool.parameters(), "lr":args.lr*args.lr_scale})
    model_param_group.append({"params": model.graph_pred_linear.parameters(), "lr":args.lr*args.lr_scale})
    optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay)
    print(optimizer)

    best_val_loss = 9999
    best_model_path = os.path.join(args.output_path, str(args.seed))
    for epoch in range(1, args.epochs+1):
        print("====epoch " + str(epoch))
        
        train_loss = train(args, model, device, train_loader, optimizer)

        print("====Evaluation")
        val_loss, val_auc = valid(args, model, device, val_loader)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_cp(args, model, path=best_model_path)
        print(f'train_loss:{train_loss:.4f} val_loss:{val_loss:.4f} val_auc:{val_auc:.4f}')
    
    best_state = torch.load(os.path.join(best_model_path,'model.pt'))
    model.load_state_dict(best_state['state_dict'])
    
    test_loss, auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = test(args, model, device, test_loader)
    avg_auc = sum(auc)/num_tasks
    avg_acc = sum(acc)/num_tasks
    avg_rec = sum(rec)/num_tasks
    avg_prec = sum(prec)/num_tasks
    avg_f1s = sum(f1s)/num_tasks
    avg_BA = sum(BA)/num_tasks
    avg_tp = sum(tp)/num_tasks
    avg_fp = sum(fp)/num_tasks
    avg_tn = sum(tn)/num_tasks
    avg_fn = sum(fn)/num_tasks
        
    print(f'seed:{args.seed} loss:{test_loss} auc:{avg_auc} acc:{avg_acc} rec:{avg_rec} prec:{avg_prec} f1:{avg_f1s} BA:{avg_BA}\ntp:{avg_tp} fp:{avg_fp} fn:{avg_fn} tn:{avg_tn}')
    #delete for memory
    del train_dataset, valid_dataset, test_dataset, train_loader, val_loader, test_loader

    return avg_auc, avg_acc, avg_rec, avg_prec, avg_f1s, avg_BA, avg_tp, avg_fp, avg_tn, avg_fn

In [56]:
args.epochs=2

In [64]:
if not os.path.exists(args.output_path):
    os.makedirs(args.output_path)
auc_list = []
acc_list = []
rec_list = []
prec_list = []
f1s_list = []
BA_list = []
tp_list = []
fp_list = []
tn_list = []
fn_list = []
for k in range(3):
    auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn = cross_validate(args)
    auc_list.append(auc)
    acc_list.append(acc)
    rec_list.append(rec)
    prec_list.append(prec)
    f1s_list.append(f1s)
    BA_list.append(BA)
    tp_list.append(tp)
    fp_list.append(fp)
    tn_list.append(tn)
    fn_list.append(fn)
    args.seed += 1
print(f'test end')



scaffold_balanced_split
total_size:7831 train_size:6264 val_size:783 test_size:784
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)
====epoch 1


Iteration: 100%|██████████| 196/196 [00:03<00:00, 52.39it/s]


====Evaluation


Iteration: 100%|██████████| 25/25 [00:00<00:00, 40.78it/s]


train_loss:0.5212 val_loss:13.0586 val_auc:0.5935
====epoch 2


Iteration: 100%|██████████| 196/196 [00:03<00:00, 52.63it/s]


====Evaluation


Iteration: 100%|██████████| 25/25 [00:00<00:00, 41.82it/s]


train_loss:0.3220 val_loss:9.6046 val_auc:0.6125


Iteration: 100%|██████████| 25/25 [00:00<00:00, 43.65it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


seed:0 loss:9.561942131759293 auc:0.5999736823705634 acc:0.9003052888258489 rec:0.007285814606741573 prec:0.1527777777777778 f1:0.013822434875066453 BA:0.5035645865013657
tp:0.6666666666666666 fp:0.08333333333333333 fn:55.583333333333336 tn:532.75




scaffold_balanced_split
total_size:7831 train_size:6264 val_size:783 test_size:784
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)
====epoch 1


Iteration: 100%|██████████| 196/196 [00:03<00:00, 52.41it/s]


====Evaluation


Iteration: 100%|██████████| 25/25 [00:00<00:00, 42.03it/s]


train_loss:0.5412 val_loss:13.3269 val_auc:0.5894
====epoch 2


Iteration: 100%|██████████| 196/196 [00:03<00:00, 50.97it/s]


====Evaluation


Iteration: 100%|██████████| 25/25 [00:00<00:00, 42.95it/s]


train_loss:0.3273 val_loss:10.0504 val_auc:0.6117


Iteration: 100%|██████████| 25/25 [00:00<00:00, 40.82it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


seed:1 loss:10.036050189019925 auc:0.6147799217739862 acc:0.8995790877408466 rec:0.002740714731585518 prec:0.13888888888888887 f1:0.005341401464216345 BA:0.5012920365637877
tp:0.25 fp:0.08333333333333333 fn:56.0 tn:532.75




scaffold_balanced_split
total_size:7831 train_size:6264 val_size:783 test_size:784
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.0001
    maximize: False
    weight_decay: 1e-07
)
====epoch 1


Iteration: 100%|██████████| 196/196 [00:03<00:00, 51.58it/s]


====Evaluation


Iteration: 100%|██████████| 25/25 [00:00<00:00, 39.22it/s]


train_loss:0.5129 val_loss:12.2664 val_auc:0.6082
====epoch 2


Iteration: 100%|██████████| 196/196 [00:03<00:00, 53.41it/s]


====Evaluation


Iteration: 100%|██████████| 25/25 [00:00<00:00, 40.89it/s]


train_loss:0.3173 val_loss:9.3715 val_auc:0.6130


Iteration: 100%|██████████| 25/25 [00:00<00:00, 37.87it/s]

seed:2 loss:9.36351836373875 auc:0.5988983900266698 acc:0.8997281198325644 rec:0.006281210986267165 prec:0.11805555555555554 f1:0.011912291798439806 BA:0.5027834240458667
tp:0.5833333333333334 fp:0.3333333333333333 fn:55.666666666666664 tn:532.5
test end



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [74]:
print(f'overall test_auc : {np.nanmean(auc_list):.4f}\nstd={np.nanstd(auc_list):.4f}')
print(f'overall test_accuracy : {np.nanmean(acc_list):.4f}\nstd={np.nanstd(acc_list):.4f}')
print(f'overall test_recall : {np.nanmean(rec_list):.4f}\nstd={np.nanstd(rec_list):.4f}')
print(f'overall test_precision : {np.nanmean(prec_list):.4f}\nstd={np.nanstd(prec_list):.4f}')
print(f'overall test_f1score : {np.nanmean(f1s_list):.4f}\nstd={np.nanstd(f1s_list):.4f}')
print(f'overall test_Balanced_Accuracy : {np.nanmean(BA_list):.4f}\nstd={np.nanstd(BA_list):.4f}')
print(f'overall test_tp : {np.nanmean(tp_list):.2f}\nstd={np.nanstd(tp_list):.2f}')
print(f'overall test_fp : {np.nanmean(fp_list):.2f}\nstd={np.nanstd(fp_list):.2f}')
print(f'overall test_fn : {np.nanmean(fn_list):.2f}\nstd={np.nanstd(fn_list):.2f}')
print(f'overall test_tn : {np.nanmean(tn_list):.2f}\nstd={np.nanstd(tn_list):.2f}')
print(f'\n       (pred)pos    neg(pred)')
print(f'pos(true)    {tp:.2f}  {fn:.2f}')
print(f'neg(true)    {fp:.2f}  {tn:.2f}')

overall test_auc : 0.6046
std=0.0072
overall test_accuracy : 0.8999
std=0.0003
overall test_recall : 0.0054
std=0.0019
overall test_precision : 0.1366
std=0.0143
overall test_f1score : 0.0104
std=0.0036
overall test_Balanced_Accuracy : 0.5025
std=0.0009
overall test_tp : 0.50
std=0.18
overall test_fp : 0.17
std=0.12
overall test_fn : 55.75
std=0.18
overall test_tn : 532.67
std=0.12

       (pred)pos    neg(pred)
pos(true)    0.58  55.67
neg(true)    0.33  532.50


In [66]:
print(f'test_auc : {np.nanmean(auc_list)}\nstd={np.nanstd(auc_list)}')

test_auc : 0.6045506647237398
std=0.007246485964761184


In [67]:
np.nanstd([0.71, 0.60, 0.80])

0.08178562764256868

In [59]:
def save_cp(args, model, path=None):
    model_state = {
        'args': args,
        'state_dict': model.state_dict(),
    }
    if path is not None : 
        if not os.path.exists(path):
            os.makedirs(path)
        torch.save(model_state, os.path.join(path, 'model.pt'))
    else : 
        if not os.path.exists(args.output_path):
            os.makedirs(args.output_path)
        torch.save(model_state, os.path.join(args.output_path,'model.pt'))

In [62]:
from typing import List, Callable, Union
from sklearn.metrics import accuracy_score, mean_squared_error, roc_auc_score, mean_absolute_error, r2_score, \
    precision_recall_curve, auc, recall_score, confusion_matrix, f1_score, precision_score, classification_report

In [63]:
def confusion_mat(targets: List[int], preds: List[float], threshold: float = 0.5) -> float:
    """
    Computes the specificity of a binary prediction task using a given threshold for generating hard predictions.

    :param targets: A list of binary targets.
    :param preds: A list of prediction probabilities.
    :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0
    :return: The computed specificity.
    """
    valid = targets**2>0
    targets = (targets[valid]+1)/2
    preds = preds[valid]
    hard_preds = [1 if p > threshold else 0 for p in preds]
    auc = roc_auc_score(targets, preds)
    tn, fp, fn, tp = confusion_matrix(targets, hard_preds).ravel()
    acc = accuracy_score(targets, hard_preds)
    rec = recall_score(targets, hard_preds)
    prec = precision_score(targets, hard_preds)
    spe = tn / float(tn + fp)
    f1s = f1_score(targets, hard_preds)
    BA = (rec+spe)/2
    return auc, acc, rec, prec, f1s, BA, tp, fp, tn, fn