<a href="https://colab.research.google.com/github/yingzibu/MOL2ADMET/blob/main/examples/ADMET_GIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
! pip install rdkit --quiet
! pip install PyTDC --quiet
! pip install mycolorpy --quiet

! pip install dgllife --quiet
! pip install molvs --quiet
! pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html --quiet
! pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html --quiet

In [3]:
cd /content/drive/MyDrive/ADMET

/content/drive/MyDrive/ADMET


In [4]:
from os import walk
import os
files = next(walk('/content/drive/MyDrive/ADMET/'), (None, None, []))[2]
for file in files:
    if isinstance(file, str):
        file_type = file.split('.')[-1]
        # print(file_type)
        if file_type == 'bin' or file_type == 'pth':
            os.remove(file)

In [None]:
! pip install DeepPurpose --quiet
! pip install git+https://github.com/bp-kelley/descriptastorus --quiet
! pip install pandas-flavor --quiet

In [None]:
from scripts.preprocess_mols import collect_data
met_cls = ['CYP2C19_Veith', 'CYP2D6_Veith', 'CYP3A4_Veith',
            'CYP1A2_Veith', 'CYP2C9_Veith']
IS_R = False # is regression task
trains, valids, tests = collect_data(met_cls, IS_R=False)

In [50]:
from dgllife.model import load_pretrained
from dgl.nn.pytorch.glob import AvgPooling
import torch.nn as nn

device = 'cuda'
class GIN(nn.Module):
    """
    Reference: https://github.com/kexinhuang12345/DeepPurpose/blob/master/DeepPurpose/encoders.py#L392
    """
	## adapted from https://github.com/awslabs/dgl-lifesci/blob/2fbf5fd6aca92675b709b6f1c3bc3c6ad5434e96/examples/property_prediction/moleculenet/utils.py#L76
    def __init__(self, predictor_dim):
        super(GIN, self).__init__()
        self.gnn = load_pretrained('gin_supervised_contextpred')
        self.readout = AvgPooling()
        self.transform = nn.Linear(300, 200)
        self.final = nn.Linear(200, predictor_dim)
        self.out_dim = predictor_dim
    def forward(self, bg):
        # bg = bg.to(device)
        node_feats = [
            bg.ndata.pop('atomic_number'),
            bg.ndata.pop('chirality_type')
        ]
        edge_feats = [
            bg.edata.pop('bond_type'),
            bg.edata.pop('bond_direction_type')
        ]

        node_feats = self.gnn(bg, node_feats, edge_feats)
        graph_feats = self.readout(bg, node_feats)
        graph_feats = self.transform(graph_feats)
        return self.final(graph_feats)

In [12]:
from torch.utils.data import Dataset, DataLoader
from functools import partial
import torch

from dgllife.utils import smiles_to_bigraph, PretrainAtomFeaturizer, PretrainBondFeaturizer
MASK = -100

class GIN_dataset(Dataset):
    def __init__(self, df, names, mask=MASK):
        df = df.fillna(mask)
        self.names = names
        self.df = df
        self.len = len(df)
        self.props = self.df[names]
        self.node_featurizer = PretrainAtomFeaturizer()
        self.edge_featurizer = PretrainBondFeaturizer()
        self.fc = partial(smiles_to_bigraph, add_self_loop=True)
    def __len__(self): return self.len
    def __getitem__(self, idx):
        v_d = self.df.iloc[idx]['Drug']
        v_d = self.fc(smiles=v_d, node_featurizer = self.node_featurizer,
                      edge_featurizer = self.edge_featurizer)
        label = torch.tensor(self.props.iloc[idx], dtype=torch.float32)
        return v_d, label

In [13]:
import dgl
def get_GIN_dataloader(datasets, loader_params):
    def dgl_collate_func(data):
        x, labels = map(list, zip(*data))
        bg = dgl.batch(x)
        labels = torch.stack(labels, dim=0)
        bg.set_n_initializer(dgl.init.zero_initializer)
        bg.set_e_initializer(dgl.init.zero_initializer)
        return bg, labels
    loader_params['collate_fn'] = dgl_collate_func
    return DataLoader(datasets, **loader_params)

In [14]:
batch_size = 128
loader_params ={'batch_size': batch_size, 'shuffle': True}

train_loader = get_GIN_dataloader(GIN_dataset(trains, met_cls), loader_params)
valid_loader = get_GIN_dataloader(GIN_dataset(valids, met_cls), loader_params)
p = {'batch_size': batch_size, 'shuffle': False}
test_loader = get_GIN_dataloader(GIN_dataset(tests, met_cls), p)

In [16]:
from dgllife.utils import EarlyStopping, Meter
from tqdm import tqdm
def train_epoch(epoch, model, loader, loss_func, device,
                optimizer=None, MASK=-100):
    if optimizer==None: model.eval(); train_type='Valid'
    else: model.train(); train_type='Train'
    losses = 0
    # train_meter = Meter()
    for batch in tqdm(loader, total=len(loader), desc=f'Epoch {epoch}'):
        bg, labels = batch
        bg, labels = bg.to(device), labels.to(device)
        mask = labels == MASK
        pred = model(bg)
        loss = loss_func(pred[~mask], labels[~mask])
        del mask
        if optimizer != None:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        losses += loss.item()
    total_loss = losses / len(loader.dataset)
    print(f'Epoch:{epoch}, [{train_type}] Loss: {total_loss:.3f}')
    return total_loss

In [30]:
IS_R = False
config = {'GIN_out_dim': len(met_cls),
          'hid_dims': [32],
          'out_dim': len(met_cls),
          'dropout': 0.3,
          'IS_R': IS_R,
          'lr': 1e-4,
          'wd':1e-5}
class PRED:
    def __init__(self, **config):
        cuda = torch.cuda.is_available()
        if cuda: self.device = 'cuda'
        else:    self.device = 'cpu'
        # self.model_drug = GIN(config['GIN_out_dim']).to(self.device)
        # model = TL(self.model_drug, config['hid_dims'],
        #            config['out_dim'], config['dropout'])
        # self.model = model.to(self.device)

        self.model = GIN(config['GIN_out_dim']).to(self.device)

        self.IS_R = config['IS_R']
        if self.IS_R: loss_fn = nn.MSELoss(reduction='sum') # if regression
        else: loss_fn = nn.BCEWithLogitsLoss(reduction='sum') # if classification
        self.loss_fn = loss_fn
        self.optimizer = torch.optim.AdamW(self.model.parameters(),
                    lr=config['lr'], weight_decay=config['wd'])
        self.stopper = EarlyStopping(mode='lower', patience=30)
        self.min_loss = 10000
        self.best_epoch = 0

    def load_model(self, path):
        self.model.load_state_dict(torch.load(path, map_location=self.device))

    def train(self, data_loader, val_data_loader, train_epoch=train_epoch):
        if self.best_epoch != 0:
            self.load_model('ckpt_.pt')
        for epoch in range(500):
            score = train_epoch(epoch, self.model, data_loader, self.loss_fn,
                                self.device, self.optimizer)
            val_score = train_epoch(epoch, self.model, val_data_loader,
                                    self.loss_fn, self.device)
            early_stop = self.stopper.step(val_score, self.model)
            if val_score < self.min_loss and epoch > 3:
                print(f'prev min loss {self.min_loss:.3f}, '
                      f'now loss {val_score:.3f} |',
                      f'save model at epoch: {epoch}')
                self.min_loss = val_score
                torch.save(self.model.state_dict(), 'ckpt_.pt')
                self.best_epoch = epoch
            if early_stop: print('early stop'); break


In [51]:
models = PRED(**config)
# models.load_model('ckpt_tr.pt')

Downloading gin_supervised_contextpred_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gin_supervised_contextpred.pth...
Pretrained model loaded


In [None]:
models.train(train_loader, valid_loader)

In [37]:
print(f"best epoch: {models.best_epoch}, min loss: {models.min_loss:.4f}")

best epoch: 42, min loss: 0.3673


In [75]:
import torch.nn.functional as F
from scripts.eval_utils import *
from scripts.preprocess_mols import *



def eval_AP(model, IS_R, test_loader, names):
    print('Evaluate on test sets')
    # model = model.cpu()
    model.eval()
    total_loss = 0
    y_probs = {}
    y_label = {}
    if IS_R: print('using MSELoss')
    else: print('using BCELOSSwithdigits')
    if IS_R: loss_fn = nn.MSELoss(reduction='sum') # if regression
    else: loss_fn = nn.BCEWithLogitsLoss(reduction='sum') # if classification
    for i, batch_data in tqdm(enumerate(test_loader), total=len(test_loader)):
        bg, labels = batch_data
        # bg, labels = bg.to('cpu'), labels.to('cpu')
        bg, labels = bg.to(device), labels.to(device)
        # print(bg.device)
        pred = model(bg)
        # mask = masks < 1
        mask = labels == MASK
        loss = loss_fn(pred[~mask], labels[~mask])
        # test_meter.update(pred, labels, masks)
        total_loss += loss.item()
        for j, name in enumerate(names):
            probs = F.sigmoid(pred[:, j][~mask[:, j]])
            label = labels[:, j][~mask[:, j]]
            probs = probs.cpu().detach().numpy().tolist()
            label = label.cpu().detach().numpy().tolist()
            if i ==0: y_probs[name], y_label[name] = probs, label
            else:
                y_probs[name] += probs
                y_label[name] += label

    total_loss /= len(test_loader.dataset)
    print(f'total_loss: {total_loss:.3f}')


    if IS_R == False: # classification task
        for i, name in enumerate(names):
            print('*'*15, name, '*'*15)
            probs = y_probs[name]
            label = y_label[name]
            assert len(probs) == len(label)
            preds = get_preds(0.5, probs)
            evaluate(label, preds, probs)

            print()
    return y_probs

In [59]:
models.model.load_state_dict(torch.load('ckpt_.pt', map_location='cuda'))
eval_AP(models.model, False, test_loader, met_cls)

Evaluate on test sets
using BCELOSSwithdigits


100%|██████████| 73/73 [00:20<00:00,  3.65it/s]

total_loss: 0.432
*************** CYP2C19_Veith ***************
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.858  &  0.860  &          0.826  &     0.875  &0.844  &0.850 &0.927 &   0.717 &   0.910
ROC-AUC: 0.860
PR-AUC: 0.781

*************** CYP2D6_Veith ***************
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.883  &  0.760  &          0.747  &     0.564  &0.957  &0.643 &0.902 &   0.583 &   0.741
ROC-AUC: 0.760
PR-AUC: 0.502

*************** CYP3A4_Veith ***************
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.837  &  0.833  &          0.798  &     0.811  &0.856  &0.804 &0.919 &   0.665 &   0.883
ROC-AUC: 0.833
PR-AUC: 0.725

*************** CYP1A2_Veith ***************
Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.885  &  0.884  &          0.887  &     0.866  &0.901  &0.876 &0.950 &   0.769 &




In [90]:
def cal_prob(model, IS_R, test_loader, certain_name:str, names):
    y_probs = {}
    for j, name in enumerate(names):
        if certain_name == name: break
    for i, batch_data in enumerate(test_loader):
        bg, labels = batch_data
        bg, labels = bg.to(device), labels.to(device)
        pred = model(bg)
        mask = labels == MASK
        pred = pred[:, j].reshape(len(pred), 1)
        probs = F.sigmoid(pred[~mask])
        probs = probs.cpu().detach().numpy().tolist()
        if i ==0: y_probs[name] = probs
        else: y_probs[name] += probs
    return y_probs

from tdc.benchmark_group import admet_group
group = admet_group(path = 'data/')
pred_list = []
for seed in [0, 1, 2, 3, 4, 5]:
    print(f'seed : {seed}')
    pred_dict = {}
    for i, name in tqdm(enumerate(met_cls), total=len(met_cls)):
        benchmark = group.get(name)
        name_spec = benchmark['name']
        if name.lower() == name_spec:
            test = benchmark['test']
            test  = rename_cols(test[['Drug', 'Y']],  name)
            test_loader = get_GIN_dataloader(GIN_dataset(test, [name]), p)

            probs = cal_prob(models.model, False, test_loader, name, met_cls)
            preds = get_preds(0.5, probs[name])
            pred_dict[name_spec] = preds
            pred_list.append(pred_dict)

Found local copy...


seed : 0


100%|██████████| 5/5 [00:15<00:00,  3.12s/it]


seed : 1


100%|██████████| 5/5 [00:15<00:00,  3.10s/it]


seed : 2


100%|██████████| 5/5 [00:15<00:00,  3.12s/it]


seed : 3


100%|██████████| 5/5 [00:15<00:00,  3.10s/it]


seed : 4


100%|██████████| 5/5 [00:15<00:00,  3.07s/it]


seed : 5


100%|██████████| 5/5 [00:15<00:00,  3.16s/it]


In [91]:
results = group.evaluate_many(pred_list)
results

{'cyp2d6_veith': [0.551, 0.0],
 'cyp3a4_veith': [0.759, 0.0],
 'cyp2c9_veith': [0.692, 0.0]}