In [1]:
from math import sqrt
import pandas as pd
import numpy as np
import os
import torch
import torch.nn.functional as F
from rdkit import Chem
import matplotlib.pyplot as plt

from torch_geometric.loader import DataLoader
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn import global_mean_pool, global_max_pool
%matplotlib inline
#A100 80GB

In [2]:
gpuid = 0
torch.cuda.set_device(gpuid)
print(torch.cuda.current_device())

0


In [3]:
import seaborn as sns


In [4]:
import sys
sys.path.insert(0, '/home/shenwanxiang/Research/bidd-clsar/')

In [5]:
from clsar.dataset import LSSNS, HSSMS
from clsar.feature import Gen39AtomFeatures
from clsar.model.model import ACANet_PNA, get_deg, _fix_reproducibility # model
from clsar.model.loss import ACALoss, get_best_cliff, get_best_structure_batch
_fix_reproducibility(42)

In [6]:
def train(train_loader, model, optimizer, aca_loss):

    total_examples = 0
    total_loss =  0    
    total_tsm_loss = 0
    total_reg_loss = 0  
    
    n_label_triplets = []
    n_structure_triplets = []
    n_triplets = []
    n_hv_triplets = []


    model.train()
    for i, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        predictions, embeddings = model(data.x.float(), data.edge_index, 
                                        data.edge_attr, data.batch)
        
        loss_out = aca_loss(labels = data.y, 
                            predictions = predictions,
                            embeddings = embeddings,
                            fps_smiles = data.fp_smiles,
                            fps_scaffold = data.fp_scaffold,                           
                            smiles_list = data.smiles,                           
                           )
        
        
        loss, reg_loss, tsm_loss,  N_Y_ACTs, N_S_ACTs, N_ACTs, N_HV_ACTs = loss_out
        
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_tsm_loss += float(tsm_loss) * data.num_graphs        
        total_reg_loss += float(reg_loss) * data.num_graphs        
        total_examples += data.num_graphs

        n_label_triplets.append(int(N_Y_ACTs))
        n_structure_triplets.append(int(N_S_ACTs))
        n_triplets.append(int(N_ACTs))
        n_hv_triplets.append(int(N_HV_ACTs))
    
    train_loss = total_loss / total_examples
    total_tsm_loss = total_tsm_loss / total_examples
    total_reg_loss = total_reg_loss / total_examples

    n_label_triplets = int(sum(n_label_triplets) / (i+1))
    n_structure_triplets = int(sum(n_structure_triplets) / (i+1))
    n_triplets = int(sum(n_triplets) / (i+1))
    n_hv_triplets = int(sum(n_hv_triplets) / (i+1))
    
    return train_loss, total_tsm_loss, total_reg_loss, n_label_triplets, n_structure_triplets, n_triplets, n_hv_triplets

@torch.no_grad()
def test(test_loader, model, aca_loss):
    model.eval()
    total_examples = 0
    total_loss = 0
    total_tsm_loss = 0
    total_reg_loss = 0

    n_label_triplets = []
    n_structure_triplets = []
    n_triplets = []
    n_hv_triplets = []
    
    mse = []
    for i, data in enumerate(test_loader):
        data = data.to(device)
        predictions, embeddings = model(data.x.float(), data.edge_index,
                                        data.edge_attr, data.batch)
        loss_out = aca_loss(labels = data.y, 
                            predictions = predictions,
                            embeddings = embeddings,
                            fps_smiles = data.fp_smiles,
                            fps_scaffold = data.fp_scaffold,                           
                            smiles_list = data.smiles,                           
                           )
        
        
        loss, reg_loss, tsm_loss,  N_Y_ACTs, N_S_ACTs, N_ACTs, N_HV_ACTs = loss_out

        total_loss += float(loss) * data.num_graphs
        total_tsm_loss += float(tsm_loss) * data.num_graphs
        total_reg_loss += float(reg_loss) * data.num_graphs
        total_examples += data.num_graphs

        n_label_triplets.append(int(N_Y_ACTs))
        n_structure_triplets.append(int(N_S_ACTs))
        n_triplets.append(int(N_ACTs))
        n_hv_triplets.append(int(N_HV_ACTs))

        mse.append(F.mse_loss(predictions, data.y, reduction='none').cpu())

    test_loss = total_loss / total_examples
    total_tsm_loss = total_tsm_loss / total_examples
    total_reg_loss = total_reg_loss / total_examples

    n_label_triplets = int(sum(n_label_triplets) / (i+1))
    n_structure_triplets = int(sum(n_structure_triplets) / (i+1))
    n_triplets = int(sum(n_triplets) / (i+1))
    n_hv_triplets = int(sum(n_hv_triplets) / (i+1))
    
    test_rmse = float(torch.cat(mse, dim=0).mean().sqrt())
    
    return test_loss, total_tsm_loss, total_reg_loss, n_label_triplets, n_structure_triplets, n_triplets, n_hv_triplets, test_rmse



def Test_performance(alpha=1.0, similarity_gate = True, gate_type = 'AND', similarity_neg = 0., similarity_pos = 1):
    _fix_reproducibility(42)
    model = ACANet_PNA(**pub_args, deg=deg).to(device)  
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=10**-5)
    aca_loss = ACALoss(alpha=alpha, 
                        cliff_lower = 1., 
                        cliff_upper = 1.,
                        squared = False,
                        similarity_gate = similarity_gate,
                        similarity_neg = similarity_neg, #0.
                        similarity_pos = similarity_pos, #1
                        gate_type = gate_type,
                        dev_mode = True,)
    
    history = []
    #ls_his = []
    for epoch in range(1, epochs):
        train_loss, tsm_loss, reg_loss, n_label_triplets, n_structure_triplets, n_triplets, n_hv_triplets = train(train_loader, 
                                                                                                                  model, 
                                                                                                                  optimizer,
                                                                                                                  aca_loss)

        _, _, _, _, _, _, train_n_hv_triplets, train_rmse = test(train_loader, model, aca_loss)
        _, _, _, _, _, _, val_n_hv_triplets, val_rmse = test(val_loader, model, aca_loss)
        _, _, _, _, _, _, test_n_hv_triplets, test_rmse = test(test_loader, model, aca_loss)

        
        print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f} tsm_loss: {tsm_loss:.4f} reg_loss: {reg_loss:.4f} '
              f'N_Y: {n_label_triplets:03d} N_S: {n_structure_triplets:03d} N: {n_triplets:03d} N_HV: {n_hv_triplets:03d} '
              f'Val: {val_rmse:.4f} Test: {test_rmse:.4f}')
    
        history.append({'Epoch':epoch, 'train_loss':train_loss, 'train_triplet_loss':tsm_loss,
                        'train_reg_loss':reg_loss, 'val_rmse':val_rmse, 
                        'test_rmse':test_rmse, 'train_rmse':train_rmse,
                        
                        'n_label_triplets': n_label_triplets, 
                        'n_structure_triplets':n_structure_triplets,
                        'n_triplets':n_triplets,
                        'n_hv_triplets':n_hv_triplets,
                        

                        'train_n_hv_triplets':train_n_hv_triplets,
                        'val_n_hv_triplets':val_n_hv_triplets,
                        'test_n_hv_triplets':test_n_hv_triplets,
                        'alpha':alpha, 'similarity_gate':similarity_gate,
                        'gate_type':gate_type, 'similarity_neg':similarity_neg,
                        'similarity_pos':similarity_pos
                       
                       })
        #ls_his.append({'Epoch':epoch, 'mae_loss':float(mae_loss), 'triplet_loss':float(triplet_loss)})
    dfh = pd.DataFrame(history)
    return dfh

In [7]:
dataset_name = 'BRAF'
Dataset =  LSSNS #LSSNS 
epochs = 300
batch_size = 128
lr = 1e-4

pre_transform = Gen39AtomFeatures()
in_channels = pre_transform.in_channels
path = './data/'

## model HPs
pub_args = {'in_channels':pre_transform.in_channels, 
            'edge_dim':pre_transform.edge_dim,
            'convs_layers': [64, 128, 256, 512],   
            'dense_layers': [256, 128, 32], 
            'out_channels':1, 
            'aggregators': ['mean', 'min', 'max', 'sum','std'],
            'scalers':['identity', 'amplification', 'attenuation'] ,
            'dropout_p': 0}

In [8]:
len(Dataset(path, name=dataset_name, pre_transform=pre_transform).shuffle(42))

Downloading https://bidd-group.github.io/MPCD/dataset/LSSNS/BRAF.csv
Processing...
Done!


128

In [None]:
# train, valid, test splitting
from torch_geometric.loader import DataLoader
from sklearn.model_selection import KFold


res = []
for seed in [8, 16, 24, 42, 64,]: #,   128, 256, 512, 1024, 2048
    dataset = Dataset(path, name=dataset_name, pre_transform=pre_transform).shuffle(42)

    for similarity_neg in [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]:
         for similarity_pos in [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]:
            kf = KFold(n_splits=5, shuffle=True, random_state=seed)
            for fold, (train_idx, test_idx) in enumerate(kf.split(dataset)):
                train_ds = dataset[train_idx.tolist()]
                test_ds  = dataset[test_idx.tolist()]
                
                train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
                test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
                val_loader = test_loader
            
                deg = get_deg(train_ds)
                device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
                df1 = Test_performance(alpha=1.0, similarity_gate = True, gate_type = 'AND',
                                       similarity_neg =similarity_neg, similarity_pos=similarity_pos)
                df1['seed'] = seed
                df1['fold'] = fold
                res.append(df1)



Epoch: 001, Loss: 6.7996 tsm_loss: 0.0000 reg_loss: 6.7996 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7409 Test: 6.7409
Epoch: 002, Loss: 6.5556 tsm_loss: 0.0000 reg_loss: 6.5556 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7400 Test: 6.7400
Epoch: 003, Loss: 6.3399 tsm_loss: 0.0000 reg_loss: 6.3399 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7393 Test: 6.7393
Epoch: 004, Loss: 6.1508 tsm_loss: 0.0000 reg_loss: 6.1508 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7389 Test: 6.7389
Epoch: 005, Loss: 5.9816 tsm_loss: 0.0000 reg_loss: 5.9816 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7388 Test: 6.7388
Epoch: 006, Loss: 5.8223 tsm_loss: 0.0000 reg_loss: 5.8223 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7387 Test: 6.7387
Epoch: 007, Loss: 5.6683 tsm_loss: 0.0000 reg_loss: 5.6683 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7387 Test: 6.7387
Epoch: 008, Loss: 5.5139 tsm_loss: 0.0000 reg_loss: 5.5139 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7387 Test: 6.7387
Epoch: 009, Loss: 5.3571



Epoch: 002, Loss: 6.5495 tsm_loss: 0.0000 reg_loss: 6.5495 N_Y: 123252 N_S: 000 N: 000 N_HV: 000 Val: 6.7441 Test: 6.7441
Epoch: 003, Loss: 6.3316 tsm_loss: 0.0000 reg_loss: 6.3316 N_Y: 123252 N_S: 000 N: 000 N_HV: 000 Val: 6.7435 Test: 6.7435
Epoch: 004, Loss: 6.1463 tsm_loss: 0.0000 reg_loss: 6.1463 N_Y: 123252 N_S: 000 N: 000 N_HV: 000 Val: 6.7433 Test: 6.7433
Epoch: 005, Loss: 5.9804 tsm_loss: 0.0000 reg_loss: 5.9804 N_Y: 123252 N_S: 000 N: 000 N_HV: 000 Val: 6.7433 Test: 6.7433
Epoch: 006, Loss: 5.8250 tsm_loss: 0.0000 reg_loss: 5.8250 N_Y: 123252 N_S: 000 N: 000 N_HV: 000 Val: 6.7433 Test: 6.7433
Epoch: 007, Loss: 5.6725 tsm_loss: 0.0000 reg_loss: 5.6725 N_Y: 123252 N_S: 000 N: 000 N_HV: 000 Val: 6.7433 Test: 6.7433
Epoch: 008, Loss: 5.5169 tsm_loss: 0.0000 reg_loss: 5.5169 N_Y: 123252 N_S: 000 N: 000 N_HV: 000 Val: 6.7432 Test: 6.7432
Epoch: 009, Loss: 5.3570 tsm_loss: 0.0000 reg_loss: 5.3570 N_Y: 123252 N_S: 000 N: 000 N_HV: 000 Val: 6.7433 Test: 6.7433
Epoch: 010, Loss: 5.1944



Epoch: 002, Loss: 6.5470 tsm_loss: 0.0000 reg_loss: 6.5470 N_Y: 109356 N_S: 000 N: 000 N_HV: 000 Val: 6.7831 Test: 6.7831
Epoch: 003, Loss: 6.3339 tsm_loss: 0.0000 reg_loss: 6.3339 N_Y: 109356 N_S: 000 N: 000 N_HV: 000 Val: 6.7822 Test: 6.7822
Epoch: 004, Loss: 6.1516 tsm_loss: 0.0000 reg_loss: 6.1516 N_Y: 109356 N_S: 000 N: 000 N_HV: 000 Val: 6.7817 Test: 6.7817
Epoch: 005, Loss: 5.9879 tsm_loss: 0.0000 reg_loss: 5.9879 N_Y: 109356 N_S: 000 N: 000 N_HV: 000 Val: 6.7814 Test: 6.7814
Epoch: 006, Loss: 5.8330 tsm_loss: 0.0000 reg_loss: 5.8330 N_Y: 109356 N_S: 000 N: 000 N_HV: 000 Val: 6.7812 Test: 6.7812
Epoch: 007, Loss: 5.6845 tsm_loss: 0.0000 reg_loss: 5.6845 N_Y: 109356 N_S: 000 N: 000 N_HV: 000 Val: 6.7812 Test: 6.7812
Epoch: 008, Loss: 5.5361 tsm_loss: 0.0000 reg_loss: 5.5361 N_Y: 109356 N_S: 000 N: 000 N_HV: 000 Val: 6.7814 Test: 6.7814
Epoch: 009, Loss: 5.3849 tsm_loss: 0.0000 reg_loss: 5.3849 N_Y: 109356 N_S: 000 N: 000 N_HV: 000 Val: 6.7816 Test: 6.7816
Epoch: 010, Loss: 5.2317



Epoch: 002, Loss: 6.5364 tsm_loss: 0.0000 reg_loss: 6.5364 N_Y: 125162 N_S: 000 N: 000 N_HV: 000 Val: 6.8259 Test: 6.8259
Epoch: 003, Loss: 6.3145 tsm_loss: 0.0000 reg_loss: 6.3145 N_Y: 125162 N_S: 000 N: 000 N_HV: 000 Val: 6.8254 Test: 6.8254
Epoch: 004, Loss: 6.1314 tsm_loss: 0.0000 reg_loss: 6.1314 N_Y: 125162 N_S: 000 N: 000 N_HV: 000 Val: 6.8252 Test: 6.8252
Epoch: 005, Loss: 5.9581 tsm_loss: 0.0000 reg_loss: 5.9581 N_Y: 125162 N_S: 000 N: 000 N_HV: 000 Val: 6.8252 Test: 6.8252
Epoch: 006, Loss: 5.7984 tsm_loss: 0.0000 reg_loss: 5.7984 N_Y: 125162 N_S: 000 N: 000 N_HV: 000 Val: 6.8252 Test: 6.8252
Epoch: 007, Loss: 5.6437 tsm_loss: 0.0000 reg_loss: 5.6437 N_Y: 125162 N_S: 000 N: 000 N_HV: 000 Val: 6.8249 Test: 6.8249
Epoch: 008, Loss: 5.4913 tsm_loss: 0.0000 reg_loss: 5.4913 N_Y: 125162 N_S: 000 N: 000 N_HV: 000 Val: 6.8248 Test: 6.8248
Epoch: 009, Loss: 5.3407 tsm_loss: 0.0000 reg_loss: 5.3407 N_Y: 125162 N_S: 000 N: 000 N_HV: 000 Val: 6.8251 Test: 6.8251
Epoch: 010, Loss: 5.1855



Epoch: 001, Loss: 6.7588 tsm_loss: 0.0000 reg_loss: 6.7588 N_Y: 125430 N_S: 000 N: 000 N_HV: 000 Val: 6.9041 Test: 6.9041
Epoch: 002, Loss: 6.5140 tsm_loss: 0.0000 reg_loss: 6.5140 N_Y: 125430 N_S: 000 N: 000 N_HV: 000 Val: 6.9031 Test: 6.9031
Epoch: 003, Loss: 6.3043 tsm_loss: 0.0000 reg_loss: 6.3043 N_Y: 125430 N_S: 000 N: 000 N_HV: 000 Val: 6.9025 Test: 6.9025
Epoch: 004, Loss: 6.1232 tsm_loss: 0.0000 reg_loss: 6.1232 N_Y: 125430 N_S: 000 N: 000 N_HV: 000 Val: 6.9024 Test: 6.9024
Epoch: 005, Loss: 5.9599 tsm_loss: 0.0000 reg_loss: 5.9599 N_Y: 125430 N_S: 000 N: 000 N_HV: 000 Val: 6.9023 Test: 6.9023
Epoch: 006, Loss: 5.8028 tsm_loss: 0.0000 reg_loss: 5.8028 N_Y: 125430 N_S: 000 N: 000 N_HV: 000 Val: 6.9022 Test: 6.9022
Epoch: 007, Loss: 5.6459 tsm_loss: 0.0000 reg_loss: 5.6459 N_Y: 125430 N_S: 000 N: 000 N_HV: 000 Val: 6.9021 Test: 6.9021
Epoch: 008, Loss: 5.4908 tsm_loss: 0.0000 reg_loss: 5.4908 N_Y: 125430 N_S: 000 N: 000 N_HV: 000 Val: 6.9020 Test: 6.9020
Epoch: 009, Loss: 5.3301



Epoch: 002, Loss: 6.5556 tsm_loss: 0.0000 reg_loss: 6.5556 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7400 Test: 6.7400
Epoch: 003, Loss: 6.3399 tsm_loss: 0.0000 reg_loss: 6.3399 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7393 Test: 6.7393
Epoch: 004, Loss: 6.1508 tsm_loss: 0.0000 reg_loss: 6.1508 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7389 Test: 6.7389
Epoch: 005, Loss: 5.9820 tsm_loss: 0.0000 reg_loss: 5.9820 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7388 Test: 6.7388
Epoch: 006, Loss: 5.8207 tsm_loss: 0.0000 reg_loss: 5.8207 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7387 Test: 6.7387
Epoch: 007, Loss: 5.6692 tsm_loss: 0.0000 reg_loss: 5.6692 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7388 Test: 6.7388
Epoch: 008, Loss: 5.5140 tsm_loss: 0.0000 reg_loss: 5.5140 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7389 Test: 6.7389
Epoch: 009, Loss: 5.3588 tsm_loss: 0.0000 reg_loss: 5.3588 N_Y: 118792 N_S: 000 N: 000 N_HV: 000 Val: 6.7388 Test: 6.7388
Epoch: 010, Loss: 5.2018

In [None]:
df = pd.concat(res)
df.groupby(['similarity_neg', 'similarity_pos', 'seed', 'fold']).test_rmse.min()

In [None]:
df.to_csv('./results/similarity_grid.csv')