In [1]:
from math import sqrt
import pandas as pd
import numpy as np
import os
import torch
import torch.nn.functional as F
from rdkit import Chem
import matplotlib.pyplot as plt

from torch_geometric.loader import DataLoader
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn import global_mean_pool, global_max_pool
%matplotlib inline
#A100 80GB

In [2]:
gpuid = 0
torch.cuda.set_device(gpuid)
print(torch.cuda.current_device())

0


In [3]:
import sys
sys.path.insert(0, '/home/shenwanxiang/Research/bidd-clsar/')

In [4]:
from clsar.dataset import LSSNS, HSSMS
from clsar.feature import Gen39AtomFeatures
from clsar.model.model import ACANet_PNA, get_deg, _fix_reproducibility # model
from clsar.model.loss import ACALoss, get_best_cliff, get_best_structure_batch
_fix_reproducibility(42)

In [5]:
def train(train_loader, model, optimizer, aca_loss):

    total_examples = 0
    total_loss =  0    
    total_tsm_loss = 0
    total_reg_loss = 0  
    
    n_label_triplets = []
    n_structure_triplets = []
    n_triplets = []
    n_hv_triplets = []


    model.train()
    for i, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        predictions, embeddings = model(data.x.float(), data.edge_index, 
                                        data.edge_attr, data.batch)

        
        loss_out = aca_loss(labels = data.y, 
                            predictions = predictions,
                            embeddings = embeddings,
                            fps_smiles = data.fp_smiles,
                            fps_scaffold = data.fp_scaffold,                           
                            smiles_list = data.smiles,                           
                           )
        
        loss, reg_loss, tsm_loss,  N_Y_ACTs, N_S_ACTs, N_ACTs, N_HV_ACTs = loss_out
        
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_tsm_loss += float(tsm_loss) * data.num_graphs        
        total_reg_loss += float(reg_loss) * data.num_graphs        
        total_examples += data.num_graphs

        n_label_triplets.append(int(N_Y_ACTs))
        n_structure_triplets.append(int(N_S_ACTs))
        n_triplets.append(int(N_ACTs))
        n_hv_triplets.append(int(N_HV_ACTs))
    
    train_loss = total_loss / total_examples
    total_tsm_loss = total_tsm_loss / total_examples
    total_reg_loss = total_reg_loss / total_examples

    n_label_triplets = int(sum(n_label_triplets) / (i+1))
    n_structure_triplets = int(sum(n_structure_triplets) / (i+1))
    n_triplets = int(sum(n_triplets) / (i+1))
    n_hv_triplets = int(sum(n_hv_triplets) / (i+1))
    
    return train_loss, total_tsm_loss, total_reg_loss, n_label_triplets, n_structure_triplets, n_triplets, n_hv_triplets

@torch.no_grad()
def test(test_loader, model, aca_loss):
    model.eval()
    total_examples = 0
    total_loss = 0
    total_tsm_loss = 0
    total_reg_loss = 0

    n_label_triplets = []
    n_structure_triplets = []
    n_triplets = []
    n_hv_triplets = []
    
    mse = []
    for i, data in enumerate(test_loader):
        data = data.to(device)
        predictions, embeddings = model(data.x.float(), data.edge_index,
                                        data.edge_attr, data.batch)
        
        loss_out = aca_loss(labels = data.y, 
                            predictions = predictions,
                            embeddings = embeddings,
                            fps_smiles = data.fp_smiles,
                            fps_scaffold = data.fp_scaffold,                           
                            smiles_list = data.smiles,                           
                           )

        
        loss, reg_loss, tsm_loss,  N_Y_ACTs, N_S_ACTs, N_ACTs, N_HV_ACTs = loss_out

        total_loss += float(loss) * data.num_graphs
        total_tsm_loss += float(tsm_loss) * data.num_graphs
        total_reg_loss += float(reg_loss) * data.num_graphs
        total_examples += data.num_graphs

        n_label_triplets.append(int(N_Y_ACTs))
        n_structure_triplets.append(int(N_S_ACTs))
        n_triplets.append(int(N_ACTs))
        n_hv_triplets.append(int(N_HV_ACTs))

        mse.append(F.mse_loss(predictions, data.y, reduction='none').cpu())

    test_loss = total_loss / total_examples
    total_tsm_loss = total_tsm_loss / total_examples
    total_reg_loss = total_reg_loss / total_examples

    n_label_triplets = int(sum(n_label_triplets) / (i+1))
    n_structure_triplets = int(sum(n_structure_triplets) / (i+1))
    n_triplets = int(sum(n_triplets) / (i+1))
    n_hv_triplets = int(sum(n_hv_triplets) / (i+1))
    
    test_rmse = float(torch.cat(mse, dim=0).mean().sqrt())
    
    return test_loss, total_tsm_loss, total_reg_loss, n_label_triplets, n_structure_triplets, n_triplets, n_hv_triplets, test_rmse



def Test_performance(alpha=1.0, similarity_gate = True, gate_type = 'OR'):
    _fix_reproducibility(42)
    model = ACANet_PNA(**pub_args, deg=deg).to(device)  
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=10**-5)
    aca_loss = ACALoss(alpha=alpha, 
                        cliff_lower = 1., 
                        cliff_upper = 1.,
                        squared = False,
                        similarity_gate = similarity_gate,
                        similarity_neg = 0.9, #0.
                        similarity_pos = 1, #1
                        gate_type = gate_type,
                        dev_mode = True,)
    
    history = []
    #ls_his = []
    for epoch in range(1, epochs):
        train_loss, tsm_loss, reg_loss, n_label_triplets, n_structure_triplets, n_triplets, n_hv_triplets = train(train_loader, model, optimizer, aca_loss)

        _, _, _, _, _, _, train_n_hv_triplets, train_rmse = test(train_loader, model, aca_loss)
        _, _, _, _, _, _, val_n_hv_triplets, val_rmse = test(val_loader, model, aca_loss)
        _, _, _, _, _, _, test_n_hv_triplets, test_rmse = test(test_loader, model, aca_loss)

        
        print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f} tsm_loss: {tsm_loss:.4f} reg_loss: {reg_loss:.4f} '
              f'N_Y: {n_label_triplets:03d} N_S: {n_structure_triplets:03d} N: {n_triplets:03d} N_HV: {n_hv_triplets:03d} '
              f'Val: {val_rmse:.4f} Test: {test_rmse:.4f}')
    
        history.append({'Epoch':epoch, 'train_loss':train_loss, 'train_triplet_loss':tsm_loss,
                        'train_reg_loss':reg_loss, 'val_rmse':val_rmse, 
                        'test_rmse':test_rmse, 'train_rmse':train_rmse,
                        
                        'n_label_triplets': n_label_triplets, 
                        'n_structure_triplets':n_structure_triplets,
                        'n_triplets':n_triplets,
                        'n_hv_triplets':n_hv_triplets,
                        

                        'train_n_hv_triplets':train_n_hv_triplets,
                        'val_n_hv_triplets':val_n_hv_triplets,
                        'test_n_hv_triplets':test_n_hv_triplets,
                        'alpha':alpha, 'similarity_gate':similarity_gate,
                        'gate_type':gate_type,
                       
                       })
        #ls_his.append({'Epoch':epoch, 'mae_loss':float(mae_loss), 'triplet_loss':float(triplet_loss)})
    dfh = pd.DataFrame(history)
    return dfh

In [6]:
dataset_name = 'CHEMBL3979_EC50'
Dataset =  HSSMS #LSSNS 
epochs = 800
batch_size = 128
lr = 1e-4

pre_transform = Gen39AtomFeatures()
in_channels = pre_transform.in_channels
path = './data/'

## model HPs
pub_args = {'in_channels':pre_transform.in_channels, 
            'edge_dim':pre_transform.edge_dim,
            'convs_layers': [64, 128, 256, 512],   
            'dense_layers': [256, 128, 32], 
            'out_channels':1, 
            'aggregators': ['mean', 'min', 'max', 'sum','std'],
            'scalers':['identity', 'amplification', 'attenuation'] ,
            'dropout_p': 0}

In [7]:
len(Dataset(path, name=dataset_name, pre_transform=pre_transform).shuffle(42))

1125

In [8]:
# train, valid, test splitting
res1 = []
res2 = []
res3 = []
res4 = []

for seed in [8, 16, 24, 42, 64, 128, 256, 512, 1024, 2048]: #,  
    dataset = Dataset(path, name=dataset_name, pre_transform=pre_transform).shuffle(42)
    N = len(dataset) // 5
    val_dataset = dataset[:N]
    test_dataset = dataset[N:2 * N]
    train_dataset = dataset[2 * N:]
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    deg = get_deg(train_dataset)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Without AC-Awareness ($\alpha = 0$)
    df1 = Test_performance(alpha=0.0, similarity_gate = False)
    df1['seed'] = seed

    # With AC-Awareness ($\alpha = 1$)
    df2 = Test_performance(alpha=1.0, similarity_gate = False)
    df2['seed'] = seed
    
    # With AC-Awareness and structure gate
    df3 = Test_performance(alpha=1.0, similarity_gate = True, gate_type = 'OR')
    df3['seed'] = seed

    # With AC-Awareness and structure gate
    df4 = Test_performance(alpha=1.0, similarity_gate = True, gate_type = 'AND')
    df4['seed'] = seed    


    res1.append(df1)
    res2.append(df2)
    res3.append(df3)
    res4.append(df4)



Epoch: 001, Loss: 6.5474 tsm_loss: 11.0853 reg_loss: 6.5474 N_Y: 407843 N_S: 1754772 N: 407843 N_HV: 217279 Val: 6.9647 Test: 6.9852
Epoch: 002, Loss: 5.5485 tsm_loss: 15.4066 reg_loss: 5.5485 N_Y: 403186 N_S: 1754772 N: 403186 N_HV: 214984 Val: 6.9662 Test: 6.9867
Epoch: 003, Loss: 4.3193 tsm_loss: 16.2928 reg_loss: 4.3193 N_Y: 405580 N_S: 1754772 N: 405580 N_HV: 213671 Val: 6.9648 Test: 6.9853
Epoch: 004, Loss: 2.5980 tsm_loss: 15.5147 reg_loss: 2.5980 N_Y: 406000 N_S: 1754772 N: 406000 N_HV: 213953 Val: 6.9586 Test: 6.9791
Epoch: 005, Loss: 0.9682 tsm_loss: 13.5867 reg_loss: 0.9682 N_Y: 406169 N_S: 1754772 N: 406169 N_HV: 201959 Val: 6.9493 Test: 6.9699
Epoch: 006, Loss: 0.8057 tsm_loss: 11.5977 reg_loss: 0.8057 N_Y: 404414 N_S: 1754772 N: 404414 N_HV: 193080 Val: 6.1465 Test: 6.1602
Epoch: 007, Loss: 0.6967 tsm_loss: 9.1759 reg_loss: 0.6967 N_Y: 405844 N_S: 1754772 N: 405844 N_HV: 190650 Val: 4.6541 Test: 4.6503
Epoch: 008, Loss: 0.6485 tsm_loss: 6.2630 reg_loss: 0.6485 N_Y: 406738



Epoch: 001, Loss: 10.7865 tsm_loss: 4.0381 reg_loss: 6.7484 N_Y: 407843 N_S: 1754772 N: 407843 N_HV: 214331 Val: 6.9623 Test: 6.9829
Epoch: 002, Loss: 8.0703 tsm_loss: 1.8523 reg_loss: 6.2180 N_Y: 403186 N_S: 1754772 N: 403186 N_HV: 201470 Val: 6.9560 Test: 6.9767
Epoch: 003, Loss: 7.0487 tsm_loss: 1.5025 reg_loss: 5.5462 N_Y: 405580 N_S: 1754772 N: 405580 N_HV: 208808 Val: 6.9394 Test: 6.9604
Epoch: 004, Loss: 5.9260 tsm_loss: 1.3178 reg_loss: 4.6082 N_Y: 406000 N_S: 1754772 N: 406000 N_HV: 185513 Val: 6.8928 Test: 6.9144
Epoch: 005, Loss: 4.4620 tsm_loss: 1.2371 reg_loss: 3.2250 N_Y: 406169 N_S: 1754772 N: 406169 N_HV: 181447 Val: 6.7496 Test: 6.7720
Epoch: 006, Loss: 2.6427 tsm_loss: 1.2889 reg_loss: 1.3538 N_Y: 404414 N_S: 1754772 N: 404414 N_HV: 176963 Val: 6.2749 Test: 6.2974
Epoch: 007, Loss: 2.0209 tsm_loss: 1.1914 reg_loss: 0.8295 N_Y: 405844 N_S: 1754772 N: 405844 N_HV: 158905 Val: 5.2881 Test: 5.3105
Epoch: 008, Loss: 1.8235 tsm_loss: 1.1288 reg_loss: 0.6947 N_Y: 406738 N_S:



NameError: name 'fp_mask' is not defined

In [None]:
df1 = pd.concat(res1)
df2 = pd.concat(res2)
df3 = pd.concat(res3)
df4 = pd.concat(res4)

In [None]:
df1.to_csv('./results/Baseline (no ACA).csv')
df2.to_csv('./results/ACA (label-only).csv')
df3.to_csv('./results/ACA (label OR structure).csv')
df4.to_csv('./results/ACA (label AND structure).csv')

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))
colors = ['#FFE699','#00B0F0','#0a16f5', 'green']

y = 'val_rmse'

n1 = r'Baseline (no ACA)' # ($\mathcal{L}_{mae}$)
n2 = r'ACA (label-only)'
n3 = r'ACA (label ∪ structure)'
n4 = r'ACA (label ∩  structure)'

res = []
res_std = []
for df, n, color in zip([df1, df2,df3,df4], [n1, n2, n3, n4], colors):
    dfp = df.groupby('Epoch')[y].mean().to_frame(name = n).rolling(1).mean()
    dfp_std = df.groupby('Epoch')[y].std().to_frame(name = n).rolling(1).mean()
    res.append(dfp)
    res_std.append(dfp_std)

dfp = pd.concat(res, axis=1)
dfp_std = pd.concat(res_std, axis=1)


dfp.plot(lw = 2, ax=ax,color = colors, alpha =1)

for n, color in zip([n1, n2, n3, n4], colors):
    ax.fill_between(dfp.index, (dfp - dfp_std)[n], (dfp + dfp_std)[n], color=color, alpha=0.2)

ax.set_ylabel('Validation RMSE')
ax.set_xlabel('epochs')
ax.spines[['right', 'top']].set_visible(False)

#ax.set_xlim(1,500)
ax.set_ylim(0.50, 1.0)

ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./results/Validation_RMSE.svg', bbox_inches='tight', dpi=400) 
fig.savefig('./results/Validation_RMSE.pdf', bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))

y = 'test_rmse'


res = []
res_std = []
for df, n, color in zip([df1, df2,df3,df4], [n1, n2, n3, n4], colors):
    dfp = df.groupby('Epoch')[y].mean().to_frame(name = n).rolling(2).mean()
    dfp_std = df.groupby('Epoch')[y].std().to_frame(name = n).rolling(2).mean()
    res.append(dfp)
    res_std.append(dfp_std)

dfp = pd.concat(res, axis=1)
dfp_std = pd.concat(res_std, axis=1)


dfp.plot(lw = 2, ax=ax,color = colors, alpha =1)

for n, color in zip([n1, n2, n3, n4], colors):
    ax.fill_between(dfp.index, (dfp - dfp_std)[n], (dfp + dfp_std)[n], color=color, alpha=0.2)

ax.set_ylabel('Test RMSE')
ax.set_xlabel('epochs')
ax.spines[['right', 'top']].set_visible(False)

ax.set_xlim(1,800)
ax.set_ylim(0.50, 1.0)

ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./results/Test_RMSE.svg', bbox_inches='tight', dpi=400) 
fig.savefig('./results/Test_RMSE.pdf', bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))

y = 'n_hv_triplets'


res = []
res_std = []
for df, n, color in zip([df1, df2,df3,df4], [n1, n2, n3, n4], colors):
    dfp = df.groupby('Epoch')[y].mean().to_frame(name = n).rolling(1).mean()
    dfp_std = df.groupby('Epoch')[y].std().to_frame(name = n).rolling(1).mean()
    res.append(dfp)
    res_std.append(dfp_std)

dfp = pd.concat(res, axis=1)
dfp_std = pd.concat(res_std, axis=1)


dfp.plot(lw = 2, ax=ax,color = colors, alpha =1)

for n, color in zip([n1, n2, n3, n4], colors):
    ax.fill_between(dfp.index, (dfp - dfp_std)[n], (dfp + dfp_std)[n], color=color, alpha=0.2)


ax.legend(loc='center', bbox_to_anchor=(0.55, 0.5))

ax.spines[['right', 'top']].set_visible(False)
plt.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
ax.set_ylabel("No. of HV-ACTs ($M^'$)")
ax.set_xlabel('epochs')
ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
#ax.set_xlim(-5,500)


fig.savefig('./results/Number_of_mined_ACTs_during_training.svg' , bbox_inches='tight', dpi=400) 
fig.savefig('./results/Number_of_mined_ACTs_during_training.pdf' , bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))
y = 'train_triplet_loss'


res = []
res_std = []
for df, n, color in zip([df1, df2,df3,df4], [n1, n2, n3, n4], colors):
    dfp = df.groupby('Epoch')[y].mean().to_frame(name = n).rolling(1).mean()
    dfp_std = df.groupby('Epoch')[y].std().to_frame(name = n).rolling(1).mean()
    res.append(dfp)
    res_std.append(dfp_std)

dfp = pd.concat(res, axis=1)
dfp_std = pd.concat(res_std, axis=1)


dfp.plot(lw = 2, ax=ax,color = colors, alpha =1)

for n, color in zip([n1, n2, n3, n4], colors):
    ax.fill_between(dfp.index, (dfp - dfp_std)[n], (dfp + dfp_std)[n], color=color, alpha=0.2)


ax.spines[['right', 'top']].set_visible(False)
# ax.set_xlim(-5,800)
ax.set_ylim(-1,10)

ax.set_ylabel('Training TSM Loss')
ax.set_xlabel('epochs')
ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./results/Triplet_loss_during_training.svg', bbox_inches='tight', dpi=400) 
fig.savefig('./results/Triplet_loss_during_training.pdf', bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))

y = 'train_reg_loss'

res = []
res_std = []
for df, n, color in zip([df1, df2,df3,df4], [n1, n2, n3, n4], colors):
    dfp = df.groupby('Epoch')[y].mean().to_frame(name = n).rolling(1).mean()
    dfp_std = df.groupby('Epoch')[y].std().to_frame(name = n).rolling(1).mean()
    res.append(dfp)
    res_std.append(dfp_std)

dfp = pd.concat(res, axis=1)
dfp_std = pd.concat(res_std, axis=1)


dfp.plot(lw = 2, ax=ax,color = colors, alpha =1)

for n, color in zip([n1, n2, n3, n4], colors):
    ax.fill_between(dfp.index, (dfp - dfp_std)[n], (dfp + dfp_std)[n], color=color, alpha=0.2)

ax.set_ylim(0.0, 0.8)
ax.spines[['right', 'top']].set_visible(False)

ax.set_ylabel('Training MAE loss')
ax.set_xlabel('epochs')
ax.legend(loc='center', bbox_to_anchor=(0.55, 0.5))

#ax.set_xlim(1,800)

ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./results/Train_mae_los.svg', bbox_inches='tight', dpi=400) 
fig.savefig('./results/Train_mae_los.pdf', bbox_inches='tight', dpi=400) 