<a href="https://colab.research.google.com/github/yingzibu/MOL2ADMET/blob/main/examples/Graph/ADMET_GIN_SO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! pip install rdkit --quiet
! pip install PyTDC --quiet
! pip install mycolorpy --quiet

! pip install dgllife --quiet
! pip install molvs --quiet
! pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html --quiet
! pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html --quiet

! pip install DeepPurpose --quiet
! pip install git+https://github.com/bp-kelley/descriptastorus --quiet
! pip install pandas-flavor --quiet

In [None]:
cd /content/drive/MyDrive/ADMET

/content/drive/MyDrive/ADMET


In [None]:
from dgllife.model import load_pretrained
from dgl.nn.pytorch.glob import AvgPooling
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from functools import partial
import torch
from dgllife.utils import smiles_to_bigraph, PretrainAtomFeaturizer, PretrainBondFeaturizer

MASK = -100

class GIN_dataset(Dataset):
    def __init__(self, df, names, mask=MASK):
        df = df.fillna(mask)
        self.names = names
        self.df = df
        self.len = len(df)
        self.props = self.df[names]
        self.node_featurizer = PretrainAtomFeaturizer()
        self.edge_featurizer = PretrainBondFeaturizer()
        self.fc = partial(smiles_to_bigraph, add_self_loop=True)
    def __len__(self): return self.len
    def __getitem__(self, idx):
        v_d = self.df.iloc[idx]['Drug']
        v_d = self.fc(smiles=v_d, node_featurizer = self.node_featurizer,
                      edge_featurizer = self.edge_featurizer)
        label = torch.tensor(self.props.iloc[idx], dtype=torch.float32)
        return v_d, label



In [None]:
import dgl
def get_GIN_dataloader(datasets, loader_params):
    def dgl_collate_func(data):
        x, labels = map(list, zip(*data))
        bg = dgl.batch(x)
        labels = torch.stack(labels, dim=0)
        bg.set_n_initializer(dgl.init.zero_initializer)
        bg.set_e_initializer(dgl.init.zero_initializer)
        return bg, labels
    loader_params['collate_fn'] = dgl_collate_func
    return DataLoader(datasets, **loader_params)

In [None]:
from dgllife.utils import EarlyStopping, Meter
from tqdm import tqdm

def train_epoch(epoch, model, loader, loss_func, device,
                optimizer=None, names=None, MASK=-100):
    if optimizer==None: model.eval(); train_type='Valid'
    else: model.train(); train_type='Train'
    losses = 0
    y_probs = {}
    y_label = {}
    # for idx, batch in tqdm(enumerate(loader), total=len(loader), desc=f'Epoch {epoch}'):
    for idx, batch in enumerate(loader):
        bg, labels = batch
        bg, labels = bg.to(device), labels.to(device)
        mask = labels == MASK
        pred = model(bg)
        loss = loss_func(pred[~mask], labels[~mask])
        # del mask
        losses += loss.item()
        if optimizer != None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if names != None:
            for j, name in enumerate(names):
                probs = F.sigmoid(pred[:, j][~mask[:, j]])
                label = labels[:, j][~mask[:, j]]
                probs = probs.cpu().detach().numpy().tolist()
                label = label.cpu().detach().numpy().tolist()
                if idx ==0: y_probs[name], y_label[name] = probs, label
                else:
                    y_probs[name] += probs
                    y_label[name] += label

        # if idx % 10 == 0: print(losses)
    total_loss = losses / len(loader.dataset)
    print(f'Epoch:{epoch}, [{train_type}] Loss: {total_loss:.3f}')
    if names == None or train_type == 'train': return total_loss
    else: return total_loss, y_probs, y_label


In [None]:
import torch.nn.functional as F
from scripts.eval_utils import *
from scripts.preprocess_mols import *
from sklearn.metrics import mean_squared_error as MSE
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score

device = 'cuda'

def eval_dict(y_probs:dict, y_label:dict, names:list, IS_R=False):
    if IS_R == False: # classification task
        for i, name in enumerate(names):
            print('*'*15, name, '*'*15)
            probs = y_probs[name]
            label = y_label[name]
            assert len(probs) == len(label)
            preds = get_preds(0.5, probs)
            # evaluate(label, preds, probs)
            print(f'AUROC: {roc_auc_score(label, probs):.4f}',
                  f'AUPRC: {average_precision_score(label, probs):.4f}',
                  f'F1: {f1_score(label, preds):.4f}')
            evaluate(label, preds, probs)
            print()


def eval_AP(model, IS_R, test_loader, names, device=device):
    # print('Evaluate on test sets')
    # model = model.cpu()
    model.eval()
    total_loss = 0
    y_probs = {}
    y_label = {}
    if IS_R: print('using MSELoss')
    else: print('using BCELOSSwithdigits')
    if IS_R: loss_fn = nn.MSELoss(reduction='sum') # if regression
    else: loss_fn = nn.BCEWithLogitsLoss(reduction='sum') # if classification
    for i, batch_data in tqdm(enumerate(test_loader), total=len(test_loader)):
        bg, labels = batch_data
        bg, labels = bg.to(device), labels.to(device)
        pred = model(bg)
        mask = labels == MASK
        loss = loss_fn(pred[~mask], labels[~mask])
        total_loss += loss.item()
        for j, name in enumerate(names):
            probs = F.sigmoid(pred[:, j][~mask[:, j]])
            label = labels[:, j][~mask[:, j]]
            probs = probs.cpu().detach().numpy().tolist()
            label = label.cpu().detach().numpy().tolist()
            if i ==0: y_probs[name], y_label[name] = probs, label
            else:
                y_probs[name] += probs
                y_label[name] += label

    total_loss /= len(test_loader.dataset)
    print(f'total_loss: {total_loss:.3f}')

    eval_dict(y_probs, y_label, names, IS_R)

    return y_probs, y_label


In [None]:

class GIN_MOD(nn.Module):
    """
    Reference: https://github.com/kexinhuang12345/DeepPurpose/blob/master/DeepPurpose/encoders.py#L392
    """
	## adapted from https://github.com/awslabs/dgl-lifesci/blob/2fbf5fd6aca92675b709b6f1c3bc3c6ad5434e96/examples/property_prediction/moleculenet/utils.py#L76
    def __init__(self, **config):
        super(GIN_MOD, self).__init__()
        self.gnn = load_pretrained('gin_supervised_contextpred')
        self.readout = AvgPooling()
        self.transform = nn.Linear(300, config['GIN_out_dim'])
        self.dropout = nn.Dropout(config['dropout'])
        self.hidden_dims = config['hid_dims']
        self.out_dim = config['out_dim']
        layer_size = len(self.hidden_dims)
        neurons = [config['GIN_out_dim'], *self.hidden_dims]
        linear_layers = [nn.Linear(neurons[i-1], neurons[i]) \
                         for i in range(1, len(neurons))]
        self.hidden = nn.ModuleList(linear_layers)
        self.final = nn.Linear(self.hidden_dims[-1], self.out_dim)

    def forward(self, bg):
        # bg = bg.to(device)
        node_feats = [
            bg.ndata.pop('atomic_number'),
            bg.ndata.pop('chirality_type')
        ]
        edge_feats = [
            bg.edata.pop('bond_type'),
            bg.edata.pop('bond_direction_type')
        ]

        node_feats = self.gnn(bg, node_feats, edge_feats)
        x = self.readout(bg, node_feats)
        x = self.transform(x)
        for layer in self.hidden: x = F.leaky_relu(layer(x))
        x = self.final(x)
        return self.dropout(x)


In [None]:


class PRED:
    def __init__(self, **config):
        cuda = torch.cuda.is_available()
        if cuda: self.device = 'cuda'
        else:    self.device = 'cpu'
        self.prop_names = config['prop_names']
        self.model = GIN_MOD(**config).to(self.device)
        self.config = config
        self.IS_R = config['IS_R']
        if self.IS_R: loss_fn = nn.MSELoss(reduction='sum') # if regression
        else: loss_fn = nn.BCEWithLogitsLoss(reduction='sum') # if classification
        self.loss_fn = loss_fn
        self.optimizer = torch.optim.AdamW(self.model.parameters(),
                    lr=config['lr'], weight_decay=config['wd'])
        self.stopper = EarlyStopping(mode='lower', patience=config['patience'])
        self.min_loss = 10000
        self.best_epoch = 0

    def load_model(self, path):
        con = self.config.copy()
        # con['dropout'] = 0
        self.model = GIN_MOD(**con).to(self.device)
        print('load pretrained model from ', path)
        self.model.load_state_dict(torch.load(path, map_location=self.device))

    def eval(self, loader, path=None):
        # self.load_model(path)
        if path != None: self.load_model(path)
        eval_AP(self.model, self.IS_R, loader, self.prop_names)

    def train(self, data_loader, val_loader, test_loader=None, train_epoch=train_epoch):
        if self.best_epoch != 0: self.load_model(self.config['model_path'])

        for epoch in range(500):
            score = train_epoch(epoch, self.model, data_loader, self.loss_fn,
                                self.device, self.optimizer)
            val_score, probs, labels = \
                    train_epoch(epoch, self.model, val_loader, self.loss_fn,
                                self.device, names=self.prop_names)

            early_stop = self.stopper.step(val_score, self.model)
            if val_score < self.min_loss:
                print(f'prev min loss {self.min_loss:.3f}, '
                      f'now loss {val_score:.3f} |',
                      f'save model at epoch: {epoch}')
                self.min_loss = val_score
                torch.save(self.model.state_dict(), self.config['model_path'])
                self.best_epoch = epoch
                eval_dict(probs, labels, self.prop_names, IS_R=self.IS_R)
                # self.eval(val_data_loader, 'ckpt_.pt')
            if early_stop: print('early stop'); break

        print(f"best epoch: {self.best_epoch}, min loss: {self.min_loss:.4f}")
        print()
        if test_loader != None: self.eval(test_loader, self.config['model_path'])


In [None]:
from scripts.preprocess_mols import collect_data
met_cls_all = ['CYP2C19_Veith', 'CYP2D6_Veith', 'CYP3A4_Veith',
            'CYP1A2_Veith', 'CYP2C9_Veith']
IS_R = False # is regression task

for met_cls in met_cls_all:

    met_cls = [met_cls]
    trains, valids, tests = collect_data(met_cls, IS_R=False)
    batch_size = 128
    loader_params ={'batch_size': batch_size, 'shuffle': True}

    train_loader = get_GIN_dataloader(GIN_dataset(trains, met_cls), loader_params)
    valid_loader = get_GIN_dataloader(GIN_dataset(valids, met_cls), loader_params)
    p = {'batch_size': batch_size, 'shuffle': False}
    test_loader = get_GIN_dataloader(GIN_dataset(tests, met_cls), p)

    IS_R = False
    config = {'GIN_out_dim': 256,
            'hid_dims': [512],
            'out_dim': len(met_cls),
            'prop_names': met_cls,
            'dropout': 0.1,
            'IS_R': IS_R,
            'lr': 1e-4,
            'wd':1e-5,
            'patience': 10,
            'model_path': f'GIN_SO_DR_0.1/ckpt_{met_cls[0]}.pt'}

    models = PRED(**config)
    print("*"*40, met_cls, "*"*40)
    print('---> start training...')
    models.train(train_loader, valid_loader, test_loader)


Epoch:51, [Train] Loss: 0.484
Epoch:51, [Valid] Loss: 0.359
prev min loss 0.359, now loss 0.359 | save model at epoch: 51
*************** CYP2C19_Veith ***************
AUROC: 0.9274 AUPRC: 0.9133 F1: 0.8536
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.859  &  0.860  &          0.833  &     0.875  &0.845  &0.854 &0.927 &   0.719 &   0.913

Epoch:52, [Train] Loss: 0.485
Epoch:52, [Valid] Loss: 0.359
EarlyStopping counter: 1 out of 10
Epoch:53, [Train] Loss: 0.492
Epoch:53, [Valid] Loss: 0.366
EarlyStopping counter: 2 out of 10
Epoch:54, [Train] Loss: 0.473
Epoch:54, [Valid] Loss: 0.356
prev min loss 0.359, now loss 0.356 | save model at epoch: 54
*************** CYP2C19_Veith ***************
AUROC: 0.9259 AUPRC: 0.9116 F1: 0.8569
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.862  &  0.863  &          0.832  &     0.884  &0.842  &0.857 &0.926 &   0.725 &   0.912

Epoch:55, [Train] Loss: 0.483
Epoc

100%|██████████| 20/20 [00:05<00:00,  3.67it/s]
Found local copy...
Loading...


total_loss: 0.351
*************** CYP2C19_Veith ***************
AUROC: 0.9281 AUPRC: 0.9143 F1: 0.8508
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.859  &  0.860  &          0.828  &     0.875  &0.845  &0.851 &0.928 &   0.718 &   0.914

*************** CYP2D6_Veith ***************


Done!
Cleaning mols: 100%|██████████| 9191/9191 [00:25<00:00, 360.00it/s]
Cleaning mols: 100%|██████████| 1313/1313 [00:03<00:00, 369.99it/s]
Cleaning mols: 100%|██████████| 2626/2626 [00:07<00:00, 358.04it/s]


Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
**************************************** ['CYP2D6_Veith'] ****************************************
---> start training...
Epoch:0, [Train] Loss: 0.600
Epoch:0, [Valid] Loss: 0.466
Epoch:1, [Train] Loss: 0.491
Epoch:1, [Valid] Loss: 0.387
Epoch:2, [Train] Loss: 0.465
Epoch:2, [Valid] Loss: 0.354
Epoch:3, [Train] Loss: 0.454
Epoch:3, [Valid] Loss: 0.330
Epoch:4, [Train] Loss: 0.449
Epoch:4, [Valid] Loss: 0.316
prev min loss 10000.000, now loss 0.316 | save model at epoch: 4
*************** CYP2D6_Veith ***************
AUROC: 0.9021 AUPRC: 0.7653 F1: 0.6567
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.878  &  0.769  &          0.743  &     0.588  &0.950  &0.657 &0.902 &   0.590 &   0.765

Epoch:5, [Train] Loss: 0.451
Epoch:5, [Valid] Loss: 0.321
EarlyStopping counter: 1 out of 10
Epoch:6

100%|██████████| 21/21 [00:05<00:00,  3.82it/s]
Found local copy...
Loading...


total_loss: 0.296
*************** CYP2D6_Veith ***************
AUROC: 0.9058 AUPRC: 0.7542 F1: 0.6735
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.891  &  0.780  &          0.764  &     0.602  &0.957  &0.674 &0.906 &   0.616 &   0.754

*************** CYP3A4_Veith ***************


Done!
Cleaning mols: 100%|██████████| 8629/8629 [00:23<00:00, 361.55it/s]
Cleaning mols: 100%|██████████| 1233/1233 [00:03<00:00, 357.89it/s]
Cleaning mols: 100%|██████████| 2466/2466 [00:06<00:00, 361.85it/s]


Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
**************************************** ['CYP3A4_Veith'] ****************************************
---> start training...
Epoch:0, [Train] Loss: 0.661
Epoch:0, [Valid] Loss: 0.585
Epoch:1, [Train] Loss: 0.580
Epoch:1, [Valid] Loss: 0.454
Epoch:2, [Train] Loss: 0.537
Epoch:2, [Valid] Loss: 0.411
Epoch:3, [Train] Loss: 0.525
Epoch:3, [Valid] Loss: 0.397
Epoch:4, [Train] Loss: 0.527
Epoch:4, [Valid] Loss: 0.392
prev min loss 10000.000, now loss 0.392 | save model at epoch: 4
*************** CYP3A4_Veith ***************
AUROC: 0.9235 AUPRC: 0.9009 F1: 0.8167
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.836  &  0.841  &          0.768  &     0.872  &0.810  &0.817 &0.924 &   0.674 &   0.901

Epoch:5, [Train] Loss: 0.524
Epoch:5, [Valid] Loss: 0.387
prev min loss 0.392, now loss 0.387 | save

100%|██████████| 20/20 [00:05<00:00,  3.91it/s]
Found local copy...
Loading...


total_loss: 0.363
*************** CYP3A4_Veith ***************
AUROC: 0.9247 AUPRC: 0.8918 F1: 0.8165
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.846  &  0.844  &          0.804  &     0.829  &0.859  &0.816 &0.925 &   0.685 &   0.892

*************** CYP1A2_Veith ***************


Done!
Cleaning mols: 100%|██████████| 8805/8805 [00:24<00:00, 365.82it/s]
Cleaning mols: 100%|██████████| 1258/1258 [00:03<00:00, 370.34it/s]
Cleaning mols: 100%|██████████| 2516/2516 [00:06<00:00, 365.49it/s]


Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
**************************************** ['CYP1A2_Veith'] ****************************************
---> start training...
Epoch:0, [Train] Loss: 0.663
Epoch:0, [Valid] Loss: 0.579
Epoch:1, [Train] Loss: 0.539
Epoch:1, [Valid] Loss: 0.379
Epoch:2, [Train] Loss: 0.489
Epoch:2, [Valid] Loss: 0.346
Epoch:3, [Train] Loss: 0.484
Epoch:3, [Valid] Loss: 0.338
Epoch:4, [Train] Loss: 0.484
Epoch:4, [Valid] Loss: 0.333
prev min loss 10000.000, now loss 0.333 | save model at epoch: 4
*************** CYP1A2_Veith ***************
AUROC: 0.9463 AUPRC: 0.9426 F1: 0.8696
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.878  &  0.878  &          0.861  &     0.878  &0.879  &0.870 &0.946 &   0.756 &   0.943

Epoch:5, [Train] Loss: 0.479
Epoch:5, [Valid] Loss: 0.323
prev min loss 0.333, now loss 0.323 | save

100%|██████████| 20/20 [00:05<00:00,  3.63it/s]
Found local copy...
Loading...


total_loss: 0.298
*************** CYP1A2_Veith ***************
AUROC: 0.9510 AUPRC: 0.9477 F1: 0.8756
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.883  &  0.882  &          0.877  &     0.874  &0.891  &0.876 &0.951 &   0.765 &   0.948

*************** CYP2C9_Veith ***************


Done!
Cleaning mols: 100%|██████████| 8465/8465 [00:23<00:00, 355.46it/s]
Cleaning mols: 100%|██████████| 1209/1209 [00:03<00:00, 354.11it/s]
Cleaning mols: 100%|██████████| 2418/2418 [00:06<00:00, 348.25it/s]


Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
**************************************** ['CYP2C9_Veith'] ****************************************
---> start training...
Epoch:0, [Train] Loss: 0.656
Epoch:0, [Valid] Loss: 0.580
Epoch:1, [Train] Loss: 0.569
Epoch:1, [Valid] Loss: 0.450
Epoch:2, [Train] Loss: 0.522
Epoch:2, [Valid] Loss: 0.384
Epoch:3, [Train] Loss: 0.505
Epoch:3, [Valid] Loss: 0.373
Epoch:4, [Train] Loss: 0.496
Epoch:4, [Valid] Loss: 0.360
prev min loss 10000.000, now loss 0.360 | save model at epoch: 4
*************** CYP2C9_Veith ***************
AUROC: 0.9273 AUPRC: 0.8594 F1: 0.7914
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.856  &  0.842  &          0.784  &     0.799  &0.886  &0.791 &0.927 &   0.682 &   0.859

Epoch:5, [Train] Loss: 0.502
Epoch:5, [Valid] Loss: 0.366
EarlyStopping counter: 1 out of 10
Epoch:6

100%|██████████| 19/19 [00:05<00:00,  3.66it/s]

total_loss: 0.315
*************** CYP2C9_Veith ***************
AUROC: 0.9387 AUPRC: 0.8746 F1: 0.8080
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.873  &  0.852  &          0.828  &     0.789  &0.916  &0.808 &0.939 &   0.714 &   0.875






In [None]:
# from os import walk
# import os
# files = next(walk('/content/drive/MyDrive/ADMET/'), (None, None, []))[2]
# for file in files:
#     if isinstance(file, str):
#         file_type = file.split('.')[-1]
#         # print(file_type)
#         if file_type == 'bin' or file_type == 'pth':
#             os.remove(file)

In [None]:
dropout = 0.3
/content/drive/MyDrive/ADMET/


IS_R = False
    config = {'GIN_out_dim': 256,
            'hid_dims': [512],
            'out_dim': len(met_cls),
            'prop_names': met_cls,
            'dropout': 0.3,
            'IS_R': IS_R,
            'lr': 1e-4,
            'wd':1e-5,
            'patience': 10,
            'model_path': f'ckpt_{met_cls[0]}.pt'}

*************** CYP2C19_Veith ***************
AUROC: 0.9281 AUPRC: 0.9143 F1: 0.8508
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.859  &  0.860  &          0.828  &     0.875  &0.845  &0.851 &0.928 &   0.718 &   0.914

*************** CYP2D6_Veith ***************
AUROC: 0.9058 AUPRC: 0.7542 F1: 0.6735
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.891  &  0.780  &          0.764  &     0.602  &0.957  &0.674 &0.906 &   0.616 &   0.754

*************** CYP3A4_Veith ***************
AUROC: 0.9247 AUPRC: 0.8918 F1: 0.8165
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.846  &  0.844  &          0.804  &     0.829  &0.859  &0.816 &0.925 &   0.685 &   0.892

*************** CYP1A2_Veith ***************
AUROC: 0.9510 AUPRC: 0.9477 F1: 0.8756
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.883  &  0.882  &          0.877  &     0.874  &0.891  &0.876 &0.951 &   0.765 &   0.948

*************** CYP2C9_Veith ***************
AUROC: 0.9387 AUPRC: 0.8746 F1: 0.8080
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.873  &  0.852  &          0.828  &     0.789  &0.916  &0.808 &0.939 &   0.714 &   0.875

In [None]:
from scripts.preprocess_mols import collect_data
met_cls_all = ['CYP2C19_Veith', 'CYP2D6_Veith', 'CYP3A4_Veith',
            'CYP1A2_Veith', 'CYP2C9_Veith']
IS_R = False # is regression task

for met_cls in met_cls_all:

    met_cls = [met_cls]
    trains, valids, tests = collect_data(met_cls, IS_R=False)
    batch_size = 128
    loader_params ={'batch_size': batch_size, 'shuffle': True}

    train_loader = get_GIN_dataloader(GIN_dataset(trains, met_cls), loader_params)
    valid_loader = get_GIN_dataloader(GIN_dataset(valids, met_cls), loader_params)
    p = {'batch_size': batch_size, 'shuffle': False}
    test_loader = get_GIN_dataloader(GIN_dataset(tests, met_cls), p)

    IS_R = False
    config = {'GIN_out_dim': 256,
            'hid_dims': [512],
            'out_dim': len(met_cls),
            'prop_names': met_cls,
            'dropout': 0.1,
            'IS_R': IS_R,
            'lr': 1e-4,
            'wd':1e-5,
            'patience': 10,
            'model_path': f'GIN_SO_DR_0.1/ckpt_{met_cls[0]}.pt'}

    models = PRED(**config)
    print("*"*40, met_cls, "*"*40)
    print('---> start training...')
    models.train(train_loader, valid_loader, test_loader)


Found local copy...
Loading...
Done!


*************** CYP2C19_Veith ***************


Cleaning mols: 100%|██████████| 8866/8866 [00:24<00:00, 356.81it/s]
Cleaning mols: 100%|██████████| 1266/1266 [00:03<00:00, 357.01it/s]
Cleaning mols: 100%|██████████| 2533/2533 [00:07<00:00, 342.30it/s]


Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
**************************************** ['CYP2C19_Veith'] ****************************************
---> start training...
Epoch:0, [Train] Loss: 0.656
Epoch:0, [Valid] Loss: 0.561
Epoch:1, [Train] Loss: 0.515
Epoch:1, [Valid] Loss: 0.396
Epoch:2, [Train] Loss: 0.468
Epoch:2, [Valid] Loss: 0.373
Epoch:3, [Train] Loss: 0.461
Epoch:3, [Valid] Loss: 0.359
Epoch:4, [Train] Loss: 0.459
Epoch:4, [Valid] Loss: 0.359
EarlyStopping counter: 1 out of 10
prev min loss 10000.000, now loss 0.359 | save model at epoch: 4
*************** CYP2C19_Veith ***************
AUROC: 0.9247 AUPRC: 0.9132 F1: 0.8506
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.856  &  0.857  &          0.829  &     0.874  &0.841  &0.851 &0.925 &   0.713 &   0.913

Epoch:5, [Train] Loss: 0.456
Epoch:5, [Valid] Loss: 0.354
prev 

100%|██████████| 20/20 [00:05<00:00,  3.65it/s]
Found local copy...
Loading...
Done!


total_loss: 0.340
*************** CYP2C19_Veith ***************
AUROC: 0.9299 AUPRC: 0.9149 F1: 0.8537
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.861  &  0.863  &          0.826  &     0.883  &0.842  &0.854 &0.930 &   0.723 &   0.915

*************** CYP2D6_Veith ***************


Cleaning mols: 100%|██████████| 9191/9191 [00:25<00:00, 359.11it/s]
Cleaning mols: 100%|██████████| 1313/1313 [00:03<00:00, 353.50it/s]
Cleaning mols: 100%|██████████| 2626/2626 [00:07<00:00, 363.75it/s]


Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
**************************************** ['CYP2D6_Veith'] ****************************************
---> start training...
Epoch:0, [Train] Loss: 0.594
Epoch:0, [Valid] Loss: 0.445
Epoch:1, [Train] Loss: 0.442
Epoch:1, [Valid] Loss: 0.367
Epoch:2, [Train] Loss: 0.402
Epoch:2, [Valid] Loss: 0.330
Epoch:3, [Train] Loss: 0.388
Epoch:3, [Valid] Loss: 0.310
Epoch:4, [Train] Loss: 0.376
Epoch:4, [Valid] Loss: 0.300
prev min loss 10000.000, now loss 0.300 | save model at epoch: 4
*************** CYP2D6_Veith ***************
AUROC: 0.9035 AUPRC: 0.7688 F1: 0.6625
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.877  &  0.776  &          0.728  &     0.608  &0.944  &0.662 &0.903 &   0.592 &   0.769

Epoch:5, [Train] Loss: 0.382
Epoch:5, [Valid] Loss: 0.296
prev min loss 0.300, now loss 0.296 | save

100%|██████████| 21/21 [00:05<00:00,  3.80it/s]
Found local copy...
Loading...
Done!


total_loss: 0.279
*************** CYP2D6_Veith ***************
AUROC: 0.9051 AUPRC: 0.7556 F1: 0.6757
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.890  &  0.784  &          0.750  &     0.615  &0.953  &0.676 &0.905 &   0.615 &   0.756

*************** CYP3A4_Veith ***************


Cleaning mols: 100%|██████████| 8629/8629 [00:24<00:00, 358.89it/s]
Cleaning mols: 100%|██████████| 1233/1233 [00:03<00:00, 360.55it/s]
Cleaning mols: 100%|██████████| 2466/2466 [00:06<00:00, 364.87it/s]


Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
**************************************** ['CYP3A4_Veith'] ****************************************
---> start training...
Epoch:0, [Train] Loss: 0.660
Epoch:0, [Valid] Loss: 0.579
Epoch:1, [Train] Loss: 0.555
Epoch:1, [Valid] Loss: 0.429
Epoch:2, [Train] Loss: 0.494
Epoch:2, [Valid] Loss: 0.381
Epoch:3, [Train] Loss: 0.483
Epoch:3, [Valid] Loss: 0.368
Epoch:4, [Train] Loss: 0.477
Epoch:4, [Valid] Loss: 0.361
prev min loss 10000.000, now loss 0.361 | save model at epoch: 4
*************** CYP3A4_Veith ***************
AUROC: 0.9239 AUPRC: 0.9020 F1: 0.8197
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.846  &  0.845  &          0.803  &     0.837  &0.852  &0.820 &0.924 &   0.686 &   0.902

Epoch:5, [Train] Loss: 0.472
Epoch:5, [Valid] Loss: 0.364
EarlyStopping counter: 1 out of 10
Epoch:6

100%|██████████| 20/20 [00:05<00:00,  3.72it/s]
Found local copy...
Loading...
Done!


total_loss: 0.354
*************** CYP3A4_Veith ***************
AUROC: 0.9232 AUPRC: 0.8895 F1: 0.8216
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.851  &  0.848  &          0.812  &     0.831  &0.865  &0.822 &0.923 &   0.694 &   0.890

*************** CYP1A2_Veith ***************


Cleaning mols: 100%|██████████| 8805/8805 [00:24<00:00, 360.18it/s]
Cleaning mols: 100%|██████████| 1258/1258 [00:03<00:00, 365.17it/s]
Cleaning mols: 100%|██████████| 2516/2516 [00:06<00:00, 359.52it/s]


Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
**************************************** ['CYP1A2_Veith'] ****************************************
---> start training...
Epoch:0, [Train] Loss: 0.659
Epoch:0, [Valid] Loss: 0.565
Epoch:1, [Train] Loss: 0.500
Epoch:1, [Valid] Loss: 0.355
Epoch:2, [Train] Loss: 0.432
Epoch:2, [Valid] Loss: 0.326
Epoch:3, [Train] Loss: 0.421
Epoch:3, [Valid] Loss: 0.313
Epoch:4, [Train] Loss: 0.420
Epoch:4, [Valid] Loss: 0.305
prev min loss 10000.000, now loss 0.305 | save model at epoch: 4
*************** CYP1A2_Veith ***************
AUROC: 0.9463 AUPRC: 0.9419 F1: 0.8661
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.878  &  0.876  &          0.875  &     0.857  &0.895  &0.866 &0.946 &   0.754 &   0.942

Epoch:5, [Train] Loss: 0.414
Epoch:5, [Valid] Loss: 0.301
prev min loss 0.305, now loss 0.301 | save

100%|██████████| 20/20 [00:05<00:00,  3.74it/s]
Found local copy...
Loading...
Done!


total_loss: 0.287
*************** CYP1A2_Veith ***************
AUROC: 0.9509 AUPRC: 0.9478 F1: 0.8793
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.886  &  0.885  &          0.876  &     0.883  &0.888  &0.879 &0.951 &   0.770 &   0.948

*************** CYP2C9_Veith ***************


Cleaning mols: 100%|██████████| 8465/8465 [00:23<00:00, 361.42it/s]
Cleaning mols: 100%|██████████| 1209/1209 [00:03<00:00, 363.11it/s]
Cleaning mols: 100%|██████████| 2418/2418 [00:06<00:00, 364.68it/s]


Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
**************************************** ['CYP2C9_Veith'] ****************************************
---> start training...
Epoch:0, [Train] Loss: 0.649
Epoch:0, [Valid] Loss: 0.569
Epoch:1, [Train] Loss: 0.534
Epoch:1, [Valid] Loss: 0.429
Epoch:2, [Train] Loss: 0.467
Epoch:2, [Valid] Loss: 0.365
Epoch:3, [Train] Loss: 0.448
Epoch:3, [Valid] Loss: 0.347
Epoch:4, [Train] Loss: 0.451
Epoch:4, [Valid] Loss: 0.342
prev min loss 10000.000, now loss 0.342 | save model at epoch: 4
*************** CYP2C9_Veith ***************
AUROC: 0.9273 AUPRC: 0.8575 F1: 0.8052
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.864  &  0.854  &          0.790  &     0.821  &0.887  &0.805 &0.927 &   0.702 &   0.858

Epoch:5, [Train] Loss: 0.445
Epoch:5, [Valid] Loss: 0.337
prev min loss 0.342, now loss 0.337 | save

100%|██████████| 19/19 [00:05<00:00,  3.50it/s]

total_loss: 0.303
*************** CYP2C9_Veith ***************
AUROC: 0.9383 AUPRC: 0.8730 F1: 0.8100
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.872  &  0.856  &          0.813  &     0.807  &0.905  &0.810 &0.938 &   0.713 &   0.873






In [None]:
*************** CYP2C19_Veith ***************
AUROC: 0.9299 AUPRC: 0.9149 F1: 0.8537
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.861  &  0.863  &          0.826  &     0.883  &0.842  &0.854 &0.930 &   0.723 &   0.915

*************** CYP2D6_Veith ***************
AUROC: 0.9051 AUPRC: 0.7556 F1: 0.6757
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.890  &  0.784  &          0.750  &     0.615  &0.953  &0.676 &0.905 &   0.615 &   0.756

*************** CYP3A4_Veith ***************
AUROC: 0.9232 AUPRC: 0.8895 F1: 0.8216
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.851  &  0.848  &          0.812  &     0.831  &0.865  &0.822 &0.923 &   0.694 &   0.890

*************** CYP1A2_Veith ***************
AUROC: 0.9509 AUPRC: 0.9478 F1: 0.8793
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.886  &  0.885  &          0.876  &     0.883  &0.888  &0.879 &0.951 &   0.770 &   0.948

*************** CYP2C9_Veith ***************
AUROC: 0.9383 AUPRC: 0.8730 F1: 0.8100
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.872  &  0.856  &          0.813  &     0.807  &0.905  &0.810 &0.938 &   0.713 &   0.873