<a href="https://colab.research.google.com/github/yingzibu/MOL2ADMET/blob/main/examples/chembert/chemberta_train_cuda.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pubchempy --quiet
!pip install rdkit --quiet
!pip uninstall transformers -y
!pip install transformers==4.30.2  --quiet

In [1]:
cd /content/drive/MyDrive/chemberta

/content/drive/MyDrive/chemberta


In [2]:
import torch
import os
from tqdm import tqdm
import math
import pandas as pd
import numpy as np
from torch.nn import init
import torch.nn.functional as F
import argparse
import random
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import classification_report, confusion_matrix, average_precision_score, roc_auc_score

MPNN   https://arxiv.org/pdf/1704.01212v2.pdf   

chemberta-2 https://github.com/deepchem/deepchem/blob/master/examples/tutorials/Transfer_Learning_With_ChemBERTa_Transformers.ipynb

vae code explained https://avandekleut.github.io/vae/




In [3]:
from help_function import *
from const import *
from chembertforclassification import *
from bert_encoder import *

class jak_dataset(Dataset):
    def __init__(self, dataframe):
        super(jak_dataset, self).__init__()

        self.len = len(dataframe)
        self.dataframe = dataframe

    def __getitem__(self, idx):
        try:
            sml = self.dataframe.Smiles[idx]
        except:
            sml = self.dataframe.SMILES[idx]

        y = 1 if self.dataframe.Activity[idx] == 1 else 0

        return sml, y

    def __len__(self):
        return self.len

def data_load(df, params={'batch_size': 16, 'shuffle': True,
                          'drop_last': False, 'num_workers': 0}):
    reset_df = df.reset_index(drop=True)
    data = jak_dataset(reset_df)
    loader = DataLoader(data, **params)
    return loader

If cuda available: True
cuda


In [4]:
data_path = 'data_0.25uM/'
if torch.cuda.is_available(): device='cuda'
else: device='cpu'

enzymes = ['JAK1', 'JAK2', 'JAK3', 'TYK2']
model_path = 'model_epoch_20/'
create_path(model_path) # if model_path exists, ignore, if not,create it

In [5]:
enzyme = 'JAK1'


In [7]:
train_df

Unnamed: 0,SMILES,TYK2,pIC50_TYK2,Activity
0,[2H]C([2H])([2H])N=C(O)c1cnc(N=C(O)C2CC2)cc1Nc...,13,-1.113943,1
1,[2H]C([2H])([2H])NC(=O)c1cnc(NC(=O)C2CC2)cc1Nc...,13,-1.113943,1
2,[2H]C([2H])([2H])NC(=O)c1n[nH]c(=NC(=O)C2CC2)c...,50000,-4.698970,0
3,[2H]C([2H])([2H])NC(=O)c1n[nH]c(=Nc2ccc(F)cn2)...,200,-2.301030,1
4,[2H]C([2H])([2H])NC(=O)c1nnc(NC(=O)C2CC2)cc1Nc...,1.3,-0.113943,1
...,...,...,...,...
1950,ON=C(O)CCCCCCOc1ccc2cc1COC/C=C/COCc1cccc(c1)-c...,10,-1.000000,1
1951,ON=C(O)CCCCCn1cc(Nc2ncc(Cl)c(Nc3ccc(Cl)cc3)n2)cn1,49,-1.690196,1
1952,S=c1[nH]c(C2CCCCC2)c2c3cc[nH]c3ncn12,36,-1.556303,1
1953,SC[C@H]1CC[C@H](c2nnn3cnc4[nH]ccc4c23)CC1,9.7,-0.986772,1


In [6]:
for enzyme in enzymes:
    print()
    ind = enzymes.index(enzyme)
    print(enzyme)
    train_df = pd.read_csv(data_path + enzyme + '_train.csv')
    valid_df = pd.read_csv(data_path + enzyme + '_valid.csv')
    # print(train_df['Activity'].value_counts())
    weight_dict = {1: torch.tensor([3.0, 1.0]), 2: torch.tensor([2.0, 1.0]),
                   3: torch.tensor([2.0, 1.0]), 4: torch.tensor([2.0, 1.0])}
    params = {'batch_size': 16, 'shuffle': True,
              'drop_last': False, 'num_workers': 0}

    train_loader = data_load(train_df, params)
    valid_loader = data_load(valid_df, params)
    epoches = 20
    model_name = f'chembert_{enzyme}.pt'
    file_exist = os.path.isfile(model_path + model_name)
    print(f"{model_path+model_name} existence: ", file_exist)

    if device == 'cuda':
        model = chembertforclassification().cuda()
    else: model = chembertforclassification()

    if file_exist: # load model from predefined model_path
        model.load_state_dict(torch.load(model_path+model_name,
                                         map_location=torch.device(device)))
    else: # trained model does not exist, need to train and save
        optimizer = optim.AdamW(params=model.parameters(),
                                lr=1e-5, weight_decay=1e-2)
        loss_function = nn.CrossEntropyLoss(weight=weight_dict[ind+1].cuda())
        model.train()
        for epoch in range(epoches):
            print("EPOCH -- {}".format(epoch))
            total_loss = 0
            for idx, (x, y_true) in tqdm(enumerate(train_loader),
                                         total=len(train_loader)):
                # print(y_true)
                optimizer.zero_grad()
                output = model(list(x))
                loss = loss_function(output, y_true.cuda())
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
        ###SAVE MODEL
        torch.save(model.state_dict(), model_path + model_name)
        print(f'model trained and saved at {model_path + model_name}')
    model.eval()

    accumulate_y_pred = []
    accumulate_y_true = []
    accumulate_y_prob = []
    accumulate_x = []
    for idx, (x, y_true) in tqdm(enumerate(valid_loader),
                                 total=len(valid_loader)):
        output = model(list(x))
        _, y_pred = torch.max(output, 1)
        accumulate_y_pred.extend(y_pred.tolist())
        accumulate_y_true.extend(y_true.tolist())
        accumulate_y_prob.extend(torch.softmax(output, 1)[:, 1].tolist())
        accumulate_x.extend(list(x))
    evaluate(accumulate_y_true, accumulate_y_pred, accumulate_y_prob)




JAK1
model_epoch_20/chembert_JAK1.pt existence:  True


100%|██████████| 55/55 [00:02<00:00, 22.65it/s]


Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.912  &  0.876  &          0.961  &     0.931  &0.821  &0.946 &0.941 &   0.712 &   0.978

JAK2
model_epoch_20/chembert_JAK2.pt existence:  True


100%|██████████| 74/74 [00:02<00:00, 34.78it/s]


Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.868  &  0.859  &          0.906  &     0.889  &0.828  &0.897 &0.929 &   0.712 &   0.956

JAK3
model_epoch_20/chembert_JAK3.pt existence:  True


100%|██████████| 33/33 [00:00<00:00, 34.72it/s]


Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.776  &  0.774  &          0.770  &     0.818  &0.731  &0.793 &0.869 &   0.551 &   0.889

TYK2
model_epoch_20/chembert_TYK2.pt existence:  True


100%|██████████| 17/17 [00:00<00:00, 33.37it/s]

Accuracy, weighted accuracy, precision, recall/SE, SP,     F1,     AUC,     MCC,     AP
& 0.812  &  0.808  &          0.822  &     0.849  &0.767  &0.835 &0.892 &   0.618 &   0.920





In [None]:
model_name

In [None]:
create_path('model/')

In [None]:
file_name = data_path + 'JAK1_valid.csv'
file_exist = os.path.isfile(file_name)

In [None]:
check_file

In [None]:
model.eval()