<a href="https://colab.research.google.com/github/yingzibu/MOL2ADMET/blob/main/examples/Graph/ADMET_GIN_MO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! pip install rdkit --quiet
! pip install PyTDC --quiet
! pip install mycolorpy --quiet

! pip install dgllife --quiet
! pip install molvs --quiet
! pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html --quiet
! pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html --quiet

In [5]:
! pip install DeepPurpose --quiet
! pip install git+https://github.com/bp-kelley/descriptastorus --quiet
! pip install pandas-flavor --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for descriptastorus (setup.py) ... [?25l[?25hdone


In [1]:
cd /content/drive/MyDrive/ADMET

/content/drive/MyDrive/ADMET


In [2]:
from os import walk
import os
files = next(walk('/content/drive/MyDrive/ADMET/'), (None, None, []))[2]
for file in files:
    if isinstance(file, str):
        file_type = file.split('.')[-1]
        # print(file_type)
        if file_type == 'bin' or file_type == 'pth':
            os.remove(file)

In [2]:
from dgllife.model import load_pretrained
from dgl.nn.pytorch.glob import AvgPooling
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from functools import partial
import torch
from dgllife.utils import smiles_to_bigraph, PretrainAtomFeaturizer, PretrainBondFeaturizer

MASK = -100

class GIN_dataset(Dataset):
    def __init__(self, df, names, mask=MASK):
        df = df.fillna(mask)
        self.names = names
        self.df = df
        self.len = len(df)
        self.props = self.df[names]
        self.node_featurizer = PretrainAtomFeaturizer()
        self.edge_featurizer = PretrainBondFeaturizer()
        self.fc = partial(smiles_to_bigraph, add_self_loop=True)
    def __len__(self): return self.len
    def __getitem__(self, idx):
        v_d = self.df.iloc[idx]['Drug']
        v_d = self.fc(smiles=v_d, node_featurizer = self.node_featurizer,
                      edge_featurizer = self.edge_featurizer)
        label = torch.tensor(self.props.iloc[idx], dtype=torch.float32)
        return v_d, label



In [3]:
import dgl
def get_GIN_dataloader(datasets, loader_params):
    def dgl_collate_func(data):
        x, labels = map(list, zip(*data))
        bg = dgl.batch(x)
        labels = torch.stack(labels, dim=0)
        bg.set_n_initializer(dgl.init.zero_initializer)
        bg.set_e_initializer(dgl.init.zero_initializer)
        return bg, labels
    loader_params['collate_fn'] = dgl_collate_func
    return DataLoader(datasets, **loader_params)

In [4]:
from dgllife.utils import EarlyStopping, Meter
from tqdm import tqdm

def train_epoch(epoch, model, loader, loss_func, device,
                optimizer=None, names=None, MASK=-100):
    if optimizer==None: model.eval(); train_type='Valid'
    else: model.train(); train_type='Train'
    losses = 0
    y_probs = {}
    y_label = {}
    for idx, batch in tqdm(enumerate(loader), total=len(loader), desc=f'Epoch {epoch}'):
        bg, labels = batch
        bg, labels = bg.to(device), labels.to(device)
        mask = labels == MASK
        pred = model(bg)
        loss = loss_func(pred[~mask], labels[~mask])
        # del mask
        losses += loss.item()
        if optimizer != None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if names != None:
            for j, name in enumerate(names):
                probs = F.sigmoid(pred[:, j][~mask[:, j]])
                label = labels[:, j][~mask[:, j]]
                probs = probs.cpu().detach().numpy().tolist()
                label = label.cpu().detach().numpy().tolist()
                if idx ==0: y_probs[name], y_label[name] = probs, label
                else:
                    y_probs[name] += probs
                    y_label[name] += label

        # if idx % 10 == 0: print(losses)
    total_loss = losses / len(loader.dataset)
    print(f'Epoch:{epoch}, [{train_type}] Loss: {total_loss:.3f}')
    if names == None or train_type == 'train': return total_loss
    else: return total_loss, y_probs, y_label


In [5]:
import torch.nn.functional as F
from scripts.eval_utils import *
from scripts.preprocess_mols import *
from sklearn.metrics import mean_squared_error as MSE
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score

device = 'cuda'

def eval_dict(y_probs:dict, y_label:dict, names:list, IS_R=False):
    if IS_R == False: # classification task
        for i, name in enumerate(names):
            print('*'*15, name, '*'*15)
            probs = y_probs[name]
            label = y_label[name]
            assert len(probs) == len(label)
            preds = get_preds(0.5, probs)
            # evaluate(label, preds, probs)
            print(f'AUROC: {roc_auc_score(label, probs):.4f}',
                  f'AUPRC: {average_precision_score(label, probs):.4f}',
                  f'F1: {f1_score(label, preds):.4f}')
            evaluate(label, preds, probs)
            print()


def eval_AP(model, IS_R, test_loader, names, device=device):
    # print('Evaluate on test sets')
    # model = model.cpu()
    model.eval()
    total_loss = 0
    y_probs = {}
    y_label = {}
    if IS_R: print('using MSELoss')
    else: print('using BCELOSSwithdigits')
    if IS_R: loss_fn = nn.MSELoss(reduction='sum') # if regression
    else: loss_fn = nn.BCEWithLogitsLoss(reduction='sum') # if classification
    for i, batch_data in tqdm(enumerate(test_loader), total=len(test_loader)):
        bg, labels = batch_data
        bg, labels = bg.to(device), labels.to(device)
        pred = model(bg)
        mask = labels == MASK
        loss = loss_fn(pred[~mask], labels[~mask])
        total_loss += loss.item()
        for j, name in enumerate(names):
            probs = F.sigmoid(pred[:, j][~mask[:, j]])
            label = labels[:, j][~mask[:, j]]
            probs = probs.cpu().detach().numpy().tolist()
            label = label.cpu().detach().numpy().tolist()
            if i ==0: y_probs[name], y_label[name] = probs, label
            else:
                y_probs[name] += probs
                y_label[name] += label

    total_loss /= len(test_loader.dataset)
    print(f'total_loss: {total_loss:.3f}')

    eval_dict(y_probs, y_label, names, IS_R)

    return y_probs, y_label




In [6]:

class GIN_MOD(nn.Module):
    """
    Reference: https://github.com/kexinhuang12345/DeepPurpose/blob/master/DeepPurpose/encoders.py#L392
    """
	## adapted from https://github.com/awslabs/dgl-lifesci/blob/2fbf5fd6aca92675b709b6f1c3bc3c6ad5434e96/examples/property_prediction/moleculenet/utils.py#L76
    def __init__(self, **config):
        super(GIN_MOD, self).__init__()
        self.gnn = load_pretrained('gin_supervised_contextpred')
        self.readout = AvgPooling()
        self.transform = nn.Linear(300, config['GIN_out_dim'])
        self.dropout = nn.Dropout(config['dropout'])
        self.hidden_dims = config['hid_dims']
        self.out_dim = config['out_dim']
        layer_size = len(self.hidden_dims)
        neurons = [config['GIN_out_dim'], *self.hidden_dims]
        linear_layers = [nn.Linear(neurons[i-1], neurons[i]) \
                         for i in range(1, len(neurons))]
        self.hidden = nn.ModuleList(linear_layers)
        self.final = nn.Linear(self.hidden_dims[-1], self.out_dim)

    def forward(self, bg):
        # bg = bg.to(device)
        node_feats = [
            bg.ndata.pop('atomic_number'),
            bg.ndata.pop('chirality_type')
        ]
        edge_feats = [
            bg.edata.pop('bond_type'),
            bg.edata.pop('bond_direction_type')
        ]

        node_feats = self.gnn(bg, node_feats, edge_feats)
        x = self.readout(bg, node_feats)
        x = self.transform(x)
        for layer in self.hidden: x = F.leaky_relu(layer(x))
        x = self.final(x)
        return self.dropout(x)


In [19]:


class PRED:
    def __init__(self, **config):
        cuda = torch.cuda.is_available()
        if cuda: self.device = 'cuda'
        else:    self.device = 'cpu'
        self.prop_names = config['prop_names']
        self.model = GIN_MOD(**config).to(self.device)
        self.config = config
        self.IS_R = config['IS_R']
        if self.IS_R: loss_fn = nn.MSELoss(reduction='sum') # if regression
        else: loss_fn = nn.BCEWithLogitsLoss(reduction='sum') # if classification
        self.loss_fn = loss_fn
        self.optimizer = torch.optim.AdamW(self.model.parameters(),
                    lr=config['lr'], weight_decay=config['wd'])
        self.stopper = EarlyStopping(mode='lower', patience=config['patience'])
        self.min_loss = 10000
        self.best_epoch = 0

    def load_model(self, path):
        con = self.config.copy()
        # con['dropout'] = 0
        self.model = GIN_MOD(**con).to(self.device)
        print('load pretrained model from ', path)
        self.model.load_state_dict(torch.load(path, map_location=self.device))

    def eval(self, loader, path=None):
        # self.load_model(path)
        if path != None: self.load_model(path)
        eval_AP(self.model, self.IS_R, loader, self.prop_names)

    def train(self, data_loader, val_loader, test_loader=None,
              train_epoch=train_epoch):
        if self.best_epoch != 0: self.load_model(self.config['model_path'])

        for epoch in range(500):
            score = train_epoch(epoch, self.model, data_loader, self.loss_fn,
                                self.device, self.optimizer)
            val_score, probs, labels = \
                    train_epoch(epoch, self.model, val_loader, self.loss_fn,
                                self.device, names=self.prop_names)

            early_stop = self.stopper.step(val_score, self.model)
            if val_score < self.min_loss:
                eval_dict(probs, labels, self.prop_names, IS_R=self.IS_R)
                if epoch > 3:
                    print(f'prev min loss {self.min_loss:.3f}, '
                        f'now loss {val_score:.3f} |',
                        f'save model at epoch: {epoch}')
                    torch.save(self.model.state_dict(), self.config['model_path'])
                    self.best_epoch = epoch
                self.min_loss = val_score
            if early_stop: print('early stop'); break

        print(f"best epoch: {self.best_epoch}, min loss: {self.min_loss:.4f}")
        print()
        if test_loader != None: self.eval(test_loader, self.config['model_path'])



In [None]:
from scripts.preprocess_mols import collect_data
met_cls = ['CYP2C19_Veith', 'CYP2D6_Veith', 'CYP3A4_Veith',
            'CYP1A2_Veith', 'CYP2C9_Veith']
IS_R = False # is regression task



trains, valids, tests = collect_data(met_cls, IS_R=False)
batch_size = 128
loader_params ={'batch_size': batch_size, 'shuffle': True}

train_loader = get_GIN_dataloader(GIN_dataset(trains, met_cls), loader_params)
valid_loader = get_GIN_dataloader(GIN_dataset(valids, met_cls), loader_params)
p = {'batch_size': batch_size, 'shuffle': False}
test_loader = get_GIN_dataloader(GIN_dataset(tests, met_cls), p)


In [21]:
IS_R = False
config = {'GIN_out_dim': 256,
          'hid_dims': [512],
          'out_dim': len(met_cls),
          'prop_names': met_cls,
          'dropout': 0.1,
          'IS_R': IS_R,
          'lr': 5e-5,
          'wd':1e-5,
          'patience': 30,
          'model_path': f'ckpt_GIN_MO.pt'}

models = PRED(**config)

Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded


In [None]:
models.train(train_loader, valid_loader)

In [23]:
# fine tuned model
config = {'GIN_out_dim': 256,
          'hid_dims': [512],
          'out_dim': len(met_cls),
          'prop_names': met_cls,
          'dropout': 0.1,
          'IS_R': IS_R,
          'lr': 5e-5,
          'wd':1e-5,
          'patience': 30,
          'model_path': f'ckpt_GIN_MO.pt'}

models = PRED(**config)
print(f"best epoch: {models.best_epoch}, min loss: {models.min_loss:.4f}")
print()
models.eval(test_loader, 'ckpt_GIN_MO.pt')

best epoch: 269, min loss: 0.3535

Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
load pretrained model from  ckpt_GIN_MO.pt
using BCELOSSwithdigits


100%|██████████| 73/73 [00:21<00:00,  3.44it/s]

total_loss: 0.420
*************** CYP2C19_Veith ***************
AUROC: 0.9303 AUPRC: 0.9159 F1: 0.8522
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.863  &  0.863  &          0.845  &     0.860  &0.866  &0.852 &0.930 &   0.725 &   0.916

*************** CYP2D6_Veith ***************
AUROC: 0.9078 AUPRC: 0.7711 F1: 0.6833
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.894  &  0.787  &          0.766  &     0.617  &0.957  &0.683 &0.908 &   0.626 &   0.771

*************** CYP3A4_Veith ***************
AUROC: 0.9211 AUPRC: 0.8868 F1: 0.8063
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.845  &  0.836  &          0.832  &     0.782  &0.890  &0.806 &0.921 &   0.678 &   0.887

*************** CYP1A2_Veith ***************
AUROC: 0.9516 AUPRC: 0.9477 F1: 0.8777
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.884  &  0.




In [18]:
# prev trained model
print(f"best epoch: {models.best_epoch}, min loss: {models.min_loss:.4f}")
print()
models.eval(test_loader, 'ckpt_.pt')

best epoch: 0, min loss: 0.4084

Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
load pretrained model from  ckpt_.pt
using BCELOSSwithdigits


100%|██████████| 73/73 [00:20<00:00,  3.50it/s]

total_loss: 0.476
*************** CYP2C19_Veith ***************
AUROC: 0.9263 AUPRC: 0.9109 F1: 0.8482
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.855  &  0.857  &          0.817  &     0.881  &0.833  &0.848 &0.926 &   0.712 &   0.911

*************** CYP2D6_Veith ***************
AUROC: 0.9005 AUPRC: 0.7450 F1: 0.6331
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.883  &  0.751  &          0.763  &     0.541  &0.962  &0.633 &0.900 &   0.578 &   0.745

*************** CYP3A4_Veith ***************
AUROC: 0.9191 AUPRC: 0.8819 F1: 0.8093
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.841  &  0.838  &          0.799  &     0.820  &0.855  &0.809 &0.919 &   0.673 &   0.882

*************** CYP1A2_Veith ***************
AUROC: 0.9482 AUPRC: 0.9444 F1: 0.8755
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.883  &  0.




Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded
load pretrained model from  ckpt_.pt
Evaluate on test sets
using BCELOSSwithdigits


100%|██████████| 21/21 [00:05<00:00,  3.61it/s]

total_loss: 0.302
*************** CYP2D6_Veith ***************
AUROC: 0.9052 AUPRC: 0.7518 F1: 0.6485
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.887  &  0.761  &          0.768  &     0.561  &0.961  &0.649 &0.905 &   0.593 &   0.752
ROC-AUC: 0.761
PR-AUC: 0.512






In [None]:
conf_e = {'GIN_out_dim': 256,
          'hid_dims': [512],
          'out_dim': len(met_cls),
          'dropout': 0.,
          'IS_R': IS_R,
          'lr': 1e-4,
          'wd':1e-5}
models = PRED(**conf_e)

models.model.load_state_dict(torch.load('ckpt_.pt', map_location='cuda'))
eval_AP(models.model, False, test_loader, met_cls)

Evaluate on test sets
using BCELOSSwithdigits


100%|██████████| 73/73 [00:20<00:00,  3.65it/s]

total_loss: 0.432
*************** CYP2C19_Veith ***************
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.858  &  0.860  &          0.826  &     0.875  &0.844  &0.850 &0.927 &   0.717 &   0.910
ROC-AUC: 0.860
PR-AUC: 0.781

*************** CYP2D6_Veith ***************
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.883  &  0.760  &          0.747  &     0.564  &0.957  &0.643 &0.902 &   0.583 &   0.741
ROC-AUC: 0.760
PR-AUC: 0.502

*************** CYP3A4_Veith ***************
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.837  &  0.833  &          0.798  &     0.811  &0.856  &0.804 &0.919 &   0.665 &   0.883
ROC-AUC: 0.833
PR-AUC: 0.725

*************** CYP1A2_Veith ***************
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.885  &  0.884  &          0.887  &     0.866  &0.901  &0.876 &0.950 &   0.769 &




In [None]:
# def cal_prob(model, IS_R, test_loader, certain_name:str, names):
#     y_probs = {}
#     for j, name in enumerate(names):
#         if certain_name == name: break
#     for i, batch_data in enumerate(test_loader):
#         bg, labels = batch_data
#         bg, labels = bg.to(device), labels.to(device)
#         pred = model(bg)
#         mask = labels == MASK
#         pred = pred[:, j].reshape(len(pred), 1)
#         probs = F.sigmoid(pred[~mask])
#         probs = probs.cpu().detach().numpy().tolist()
#         if i ==0: y_probs[name] = probs
#         else: y_probs[name] += probs
#     return y_probs

# from tdc.benchmark_group import admet_group
# group = admet_group(path = 'data/')
# pred_list = []
# for seed in [0, 1, 2, 3, 4, 5]:
#     print(f'seed : {seed}')
#     pred_dict = {}
#     for i, name in tqdm(enumerate(met_cls), total=len(met_cls)):
#         benchmark = group.get(name)
#         name_spec = benchmark['name']
#         if name.lower() == name_spec:
#             test = benchmark['test']
#             test  = rename_cols(test[['Drug', 'Y']],  name)
#             test_loader = get_GIN_dataloader(GIN_dataset(test, [name]), p)

#             probs = cal_prob(models.model, False, test_loader, name, met_cls)
#             preds = get_preds(0.5, probs[name])
#             pred_dict[name_spec] = preds
#             pred_list.append(pred_dict)