In [80]:
import torch
import os.path as osp
import GCL.losses as L
import GCL.augmentors as A
import torch.nn.functional as F

from torch import nn
from tqdm import tqdm
from torch.optim import Adam
from GCL.eval import get_split, SVMEvaluator
from GCL.models import DualBranchContrast
from torch_geometric.nn import GINConv, global_add_pool
from torch_geometric.data import DataLoader
from torch_geometric.datasets import TUDataset
import math
import numpy as np
from sklearn.metrics import roc_auc_score, precision_score, recall_score, precision_recall_curve, auc
from sklearn import metrics
from torch import optim
import pandas as pd

from torch.utils.data import DataLoader,random_split
from torch.utils.data import TensorDataset
from torch.optim import Adam

In [81]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Load the dataset

In [82]:
import torch
from torch_geometric.data import InMemoryDataset, download_url
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
import warnings
warnings.filterwarnings("ignore", category=Warning)
from torch_geometric.datasets import TUDataset

from rdkit import Chem
import deepchem as dc
import pandas as pd
class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, file = None, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        self.file = file
        # print(root) # MYdata
        # print(self.data) # Data(x=[3, 1], edge_index=[2, 4], y=[3])
        # print(self.slices) # defaultdict(<class 'dict'>, {'x': tensor([0, 3, 6]), 'edge_index': tensor([ 0,  4, 10]), 'y': tensor([0, 3, 6])})
        # print(self.processed_paths[0]) # MYdata\processed\datas.pt
    @property
    def raw_file_names(self):
        # pass 
        return []
    @property
    def processed_file_names(self):
        return ['datas.pt']
    def download(self):
        pass
    def process(self):
        df = pd.read_csv(file)
        smiles = df['SMILES'].tolist()
        featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)
        out = featurizer.featurize(smiles)

        data_list = []
        for i in range(len(out)):
            edge_index = torch.tensor(out[i].edge_index)
            x = torch.tensor(out[i].node_features, dtype=torch.float)
            y = torch.tensor(df['LABEL'][i])
            edge_node = torch.tensor(out[i].edge_features)
            data = Data(x=x, edge_index=edge_index, y = y, edge_node = edge_node)
            data_list.append(data)

        # data_list = data_list.append(data)
        if self.pre_filter is not None: 
            data_list = [data for data in data_list if self.pre_filter(data)]
        if self.pre_transform is not None: 
            data_list = [self.pre_transform(data) for data in data_list]
        data, slices = self.collate(data_list) 
        # print(data)
        torch.save((data, slices), self.processed_paths[0])

Load graph contrastive learning code

In [83]:
class GConv(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(GConv, self).__init__()
        self.layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for i in range(num_layers):
            if i == 0:
                self.layers.append(make_gin_conv(input_dim, hidden_dim))
            else:
                self.layers.append(make_gin_conv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        project_dim = hidden_dim * num_layers
        self.project = torch.nn.Sequential(
            nn.Linear(project_dim, project_dim),
            nn.ReLU(inplace=True),
            nn.Linear(project_dim, project_dim))

    def forward(self, x, edge_index, batch):
        z = x
        zs = []
        for conv, bn in zip(self.layers, self.batch_norms):
            z = conv(z, edge_index)
            z = F.relu(z)
            z = bn(z)
            zs.append(z)
        gs = [global_add_pool(z, batch) for z in zs]
        z, g = [torch.cat(x, dim=1) for x in [zs, gs]]
        return z, g

class Encoder(torch.nn.Module):
    def __init__(self, encoder, augmentor):
        super(Encoder, self).__init__()
        self.encoder = encoder
        self.augmentor = augmentor

    def forward(self, x, edge_index, batch):
        aug1, aug2 = self.augmentor
        x1, edge_index1, edge_weight1 = aug1(x, edge_index)
        x2, edge_index2, edge_weight2 = aug2(x, edge_index)
        z, g = self.encoder(x, edge_index, batch)
        z1, g1 = self.encoder(x1, edge_index1, batch)
        z2, g2 = self.encoder(x2, edge_index2, batch)
        return z, g, z1, z2, g1, g2

def gcl_train(encoder_model, contrast_model, dataloader, optimizer):
    encoder_model.train()
    epoch_loss = 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()

        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)

        _, _, _, _, g1, g2 = encoder_model(data.x, data.edge_index.to(torch.long), data.batch)
        g1, g2 = [encoder_model.encoder.project(g) for g in [g1, g2]]
        loss = contrast_model(g1=g1, g2=g2, batch=data.batch)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    return epoch_loss

def gcl_test(encoder_model, dataloader):
    encoder_model.eval()
    x = []
    y = []
    for data in dataloader:
        data = data.to(device)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)
        _, g, _, _, _, _ = encoder_model(data.x, data.edge_index.to(torch.long), data.batch)
        x.append(g)
        y.append(data.y)
    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)
    print(x.shape)
    print(y.shape)

    #split = get_split(num_samples=x.size()[0], train_ratio=0.8, test_ratio=0.1)
    #result = SVMEvaluator()(x, y, split)
    #return result
    return x,y

Load the BERT model

In [84]:
from transformers import BertTokenizer, AutoTokenizer,BertModel,RobertaTokenizer,RobertaModel
tokenizer_ = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MLM")
model_ = BertModel.from_pretrained("DeepChem/ChemBERTa-77M-MLM")

def smiles_to_vector(seq):
    inputs = tokenizer_(seq, return_tensors="pt")
    with torch.no_grad():
        outputs = model_(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze()

You are using a model of type roberta to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of BertModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MLM and are newly initialized: ['embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.inte

co-attention model

In [85]:
class Att(nn.Module):
    def __init__(self, hid_dim, dropout):
        super(Att, self).__init__()

        self.linear_v = nn.Linear(hid_dim, hid_dim)
        self.linear_k = nn.Linear(hid_dim, hid_dim)
        self.linear_q = nn.Linear(hid_dim, hid_dim)
        self.linear_merge = nn.Linear(hid_dim, hid_dim)
        self.hid_dim = hid_dim
        self.dropout = dropout

        self.dropout = nn.Dropout(dropout)

    def forward(self, v, k, q, mask):
        atted = self.att(v, k, q, mask).transpose(-1,-2)
        atted = self.linear_merge(atted)

        return atted

    def att(self, value, key, query, mask):
        d_k = query.size(-1)

        scores = torch.matmul(
            query, key.transpose(-2, -1)
        ) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask, -1e9)

        att_map = F.softmax(scores, dim=-1)
        att_map = self.dropout(att_map)

        return torch.matmul(att_map, value)

class MHAtt(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):   #128,4,0.1
        super(MHAtt, self).__init__()

        self.linear_v = nn.Linear(hid_dim, hid_dim)
        self.linear_k = nn.Linear(hid_dim, hid_dim)
        self.linear_q = nn.Linear(hid_dim, hid_dim)
        self.linear_merge = nn.Linear(hid_dim, hid_dim)
        self.hid_dim = hid_dim
        self.dropout = dropout
        self.nhead = n_heads

        self.dropout = nn.Dropout(dropout)
        self.hidden_size_head = int(self.hid_dim / self.nhead)
    def forward(self, v, k, q, mask):
        n_batches = q.size(0)

        v = self.linear_v(v).view(
            n_batches,
            -1,
            self.nhead,
            self.hidden_size_head
        ).transpose(1, 2)

        k = self.linear_k(k).view(
            n_batches,
            -1,
            self.nhead,
            self.hidden_size_head
        ).transpose(1, 2)

        q = self.linear_q(q).view(
            n_batches,
            -1,
            self.nhead,
            self.hidden_size_head
        ).transpose(1, 2)

        atted = self.att(v, k, q, mask)
        atted = atted.transpose(1, 2).contiguous().view(
            n_batches,
            -1,
            self.hid_dim
        )

        atted = self.linear_merge(atted)

        return atted

    def att(self, value, key, query, mask):
        d_k = query.size(-1)

        scores = torch.matmul(
            query, key.transpose(-2, -1)
        ) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask, -1e9)

        att_map = F.softmax(scores, dim=-1)
        att_map = self.dropout(att_map)

        return torch.matmul(att_map, value)

class SA(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):
        super(SA, self).__init__()

        self.mhatt1 = MHAtt(hid_dim, n_heads, dropout)
        # self.mhatt1 = MultiAttn(hid_dim, n_heads)

        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(hid_dim)

    def forward(self, x, mask=None):

        x = self.norm1(x + self.dropout1(
            self.mhatt1(x, x, x, mask)
        ))
        # x = self.norm1(x + self.dropout1(
        #     self.mhatt1(x, x, mask, mask)
        # ))

        return x

class GSA(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):
        super(GSA, self).__init__()
        self.mhatt1 = MHAtt(hid_dim, n_heads, dropout)
        # self.mhatt1 = MultiAttn(hid_dim, n_heads)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(hid_dim)


    def forward(self, x, y, y_mask=None):

        # x as V while y as Q and K
        # x = self.norm1(x + self.dropout1(
        #     self.mhatt1(x, x, y, y_mask)
        # ))
        x = self.norm1(x+self.dropout1(
            self.mhatt1(y, y, x, y_mask)
        ))
        # x = self.norm1(x + self.dropout1(
        #     self.mhatt1(x, y, y_mask, y_mask)
        # ))

        return x

class inter_cross_att(nn.Module):
    def __init__(self, dim, nhead, dropout):
        super(inter_cross_att, self).__init__()
        self.gsa = SA(dim, nhead, dropout)
        self.sga = SA(dim, nhead, dropout)
        self.coa_gs = GSA(dim, nhead, dropout)
        self.coa_sg = GSA(dim, nhead, dropout)

    def forward(self, graph_vector, sequence_vector):
        graph_vector = self.gsa(graph_vector, None)  # self-attention
        sequence_vector = self.sga(sequence_vector, None)  # self-attention
        graph_covector = self.coa_gs(graph_vector, sequence_vector, None)  # co-attention
        sequence_covector = self.coa_sg(sequence_vector, graph_vector, None)  # co-attention

        return graph_covector, sequence_covector

class GSC(nn.Module):
    def __init__(self, dim, nhead=2, dropout=0.1, layer_output =3, layer_coa=1):
        super(GSC, self).__init__()
        self.layer_output = layer_output
        self.layer_coa = layer_coa
        self.lin1 = nn.Linear(64, 128)
        self.lin2 = nn.Linear(384, 128)
        
        self.sca_1 = SA(dim, nhead, dropout)
        self.sca_2 = SA(dim, nhead, dropout)
        self.sca_3 = SA(dim, nhead, dropout)

        # self-protein-attention layers
        self.spa_1 = SA(dim, nhead, dropout)
        self.spa_2 = SA(dim, nhead, dropout)
        self.spa_3 = SA(dim, nhead, dropout)


        self.coa_gs_1 = GSA(dim, nhead, dropout)
        self.coa_gs_2 = GSA(dim, nhead, dropout)
        self.coa_sg_1 = GSA(dim, nhead, dropout)
        self.coa_sg_2 = GSA(dim, nhead, dropout)
        self.coa_gs_3 = GSA(dim, nhead, dropout)
        self.coa_sg_3 = GSA(dim, nhead, dropout)

        self.W_out = nn.ModuleList([nn.Linear(2 * dim, dim),nn.Linear(dim, 128),nn.Linear(128, 64)])
        self.inter_coa_layers = nn.ModuleList([inter_cross_att(dim, nhead, dropout) for _ in range(layer_coa)])
        self.W_interaction = nn.Linear(64, 2)

    def forward(self, X1, X2, y):

        sequence_vector = self.lin1(X1)
        graph_vector = self.lin2(X2)
        for i in range(self.layer_coa):
            sequence_vector, graph_vector = self.inter_coa_layers[i](sequence_vector, graph_vector)
        graph_vector = graph_vector.mean(dim=1)
        sequence_vector = sequence_vector.mean(dim=1)
    
        """Concatenate the above two vectors and output the interaction."""
        cat_vector = torch.cat((sequence_vector, graph_vector), 1)
        for j in range(self.layer_output):
            cat_vector = torch.relu(self.W_out[j](cat_vector))
        interaction = self.W_interaction(cat_vector)
        return interaction

    def __call__(self, X1, X2, y, train=True):

        correct_interaction = y
        predicted_interaction = self.forward(X1, X2, y)
        if train:
            criterion = torch.nn.CrossEntropyLoss()
            loss = criterion(predicted_interaction, correct_interaction)
            return loss, predicted_interaction
        else:
            correct_labels = correct_interaction.to(device).data.numpy()
            ys = F.softmax(predicted_interaction, 1).to(device).data.numpy()
            predicted_labels = list(map(lambda x: np.argmax(x), ys))
            predicted_scores = list(map(lambda x: x[1], ys))
            return correct_labels, predicted_labels, predicted_scores

In [87]:
def train(model, dataloader, optimizer):
    print('Training on {} samples...'.format(len(dataloader.dataset)))
    criterion = torch.nn.CrossEntropyLoss()
    for batch_idx, (x1,x2,y) in enumerate(dataloader):
        #self.model(data.to(device), proteins.to(device), train=False)
        optimizer.zero_grad()
        sigma = model.forward(x1, x2, y)
        loss = criterion(sigma, y)
        loss.backward()
        optimizer.step()
        if batch_idx % 30 == 0:
            print('Train epoch: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx * len(x1),
                                                                        len(dataloader.dataset),
                                                                        100. * batch_idx / len(dataloader),
                                                                        loss.item()))

In [98]:
def test(model, testloader):
    model.eval()
    T, Y, S = [], [], []
    print('Make prediction for {} samples...'.format(len(testloader.dataset)))
        
    with torch.no_grad():  # accelerating calculations
        for x1,x2,y in testloader:
            (correct_labels, predicted_labels, predicted_scores) = model(x1, x2, y, train=False)
            T.extend(correct_labels)
            Y.extend(predicted_labels)
            S.extend(predicted_scores)
    tpr, fpr, _ = precision_recall_curve(T, S)
    PRC = auc(fpr, tpr)
    train_accu = metrics.accuracy_score(T, Y)
    AUC = roc_auc_score(T, S)
    precision = precision_score(T, Y)
    recall = recall_score(T, Y)
    return AUC, precision, recall, train_accu

Take the BACE dataset as an example.

In [90]:
gcl_model = torch.load("gclModel")   #Load the pre-trained graph contrastive learning model

device = torch.device('cpu')
path = "BACE"
file = "bace.csv"
gcl_dataset = MyOwnDataset(path, file)
gcl_dataloader = DataLoader(gcl_dataset, batch_size=32)
X,Y = gcl_test(gcl_model, gcl_dataloader)   # Obtain the embeddings generated by the pre-trained contrastive learning model

torch.Size([1513, 64])
torch.Size([1513])


In [91]:
df = pd.read_csv(file)
bertdata = []
for i in range(len(df)):
    smiles_vector = smiles_to_vector(df.iloc[i,0]).tolist()
    bertdata.append(smiles_vector)

In [95]:
GCL_X = torch.tensor(X.detach().numpy()).to(torch.float32)
GCL_y = torch.tensor(Y.detach().numpy()).to(torch.long)
BERT_X = torch.tensor(bertdata).to(torch.float32)

data = TensorDataset(GCL_X, BERT_X, GCL_y)

In [96]:
# Split the dataset into training, validation, and test sets, with 80% for training, 10% for validation, and 10% for testing.
train_size = int(0.8 * len(data))
temp_size = len(data) - train_size
train_dataset, temp_dataset = random_split(data, [train_size, temp_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

valid_size = int(0.5 * len(temp_dataset))
test_size = len(temp_dataset) - valid_size
valid_dataset, test_dataset = random_split(temp_dataset, [valid_size, test_size])

valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
model = GSC(dim = 128, nhead = 8, dropout=0.1, layer_output = 3, layer_coa = 1)
lr = 0.0005
opt = torch.optim.Adam(model.parameters(), lr=lr)
criterion1 = torch.nn.CrossEntropyLoss()
validAUC = 0.0
bestAUC = 0.0
bestprecision = 0.0
bestrecall = 0.0
bestacc = 0.0
for epoch in range(200):
    train(model,train_loader, opt)
    AUC, _, _, _ = test(model,valid_loader)
    if AUC > validAUC:
        validAUC = AUC
        testAUC, testPRE, testRECALL, testACC = test(model,test_loader)
        print("AUC : ", testAUC, ", precision : ", testPRE, ", recall : ",  testRECALL, ", acc : ", testACC)
        if testAUC > bestAUC:
            bestAUC = testAUC
            bestprecision = testPRE
            bestrecall = testRECALL
            bestacc = testACC
    print(epoch, "best AUC : ", bestAUC, "best precision : ", bestprecision, "best recall : ", bestrecall, "best acc : ", bestacc)

Training on 1210 samples...
Make prediction for 151 samples...
Make prediction for 152 samples...
AUC :  0.6430769230769231 , precision :  0.0 , recall :  0.0 , acc :  0.6578947368421053
0 best AUC :  0.6430769230769231 best precision :  0.0 best recall :  0.0 best acc :  0.6578947368421053
Training on 1210 samples...
Make prediction for 151 samples...
Make prediction for 152 samples...
AUC :  0.7669230769230769 , precision :  0.4065040650406504 , recall :  0.9615384615384616 , acc :  0.506578947368421
1 best AUC :  0.7669230769230769 best precision :  0.4065040650406504 best recall :  0.9615384615384616 best acc :  0.506578947368421
Training on 1210 samples...
Make prediction for 151 samples...
Make prediction for 152 samples...
AUC :  0.7592307692307693 , precision :  0.5121951219512195 , recall :  0.8076923076923077 , acc :  0.6710526315789473
2 best AUC :  0.7669230769230769 best precision :  0.4065040650406504 best recall :  0.9615384615384616 best acc :  0.506578947368421
Trainin