In [1]:
from math import sqrt
import pandas as pd
import numpy as np
import os
import torch
import torch.nn.functional as F
from rdkit import Chem
import matplotlib.pyplot as plt

from torch_geometric.loader import DataLoader
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn import global_mean_pool, global_max_pool
%matplotlib inline
#A100 80GB

In [2]:
gpuid = 0
torch.cuda.set_device(gpuid)
print(torch.cuda.current_device())

0


In [3]:
import sys
sys.path.insert(0, '/home/shenwanxiang/Research/bidd-clsar/')

In [4]:
from clsar.dataset import LSSNS, HSSMS
from clsar.feature import Gen39AtomFeatures
from clsar.model.model import ACANet_PNA, get_deg, _fix_reproducibility # model
from clsar.model.loss import ACALoss, get_best_cliff
_fix_reproducibility(42)

In [13]:
def train(train_loader, model, optimizer, aca_loss):

    total_examples = 0
    total_loss =  0    
    total_tsm_loss = 0
    total_reg_loss = 0  
    
    n_label_triplets = []
    n_structure_triplets = []
    n_triplets = []
    n_hv_triplets = []


    model.train()
    for i, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        predictions, embeddings = model(data.x.float(), data.edge_index, 
                                        data.edge_attr, data.batch)
        
        loss_out = aca_loss(labels = data.y, 
                            predictions = predictions,
                            embeddings = embeddings,
                            fingerprints = data.fp)
        
        loss, reg_loss, tsm_loss,  N_Y_ACTs, N_S_ACTs, N_ACTs, N_HV_ACTs = loss_out
        
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_tsm_loss += float(tsm_loss) * data.num_graphs        
        total_reg_loss += float(reg_loss) * data.num_graphs        
        total_examples += data.num_graphs

        n_label_triplets.append(int(N_Y_ACTs))
        n_structure_triplets.append(int(N_S_ACTs))
        n_triplets.append(int(N_ACTs))
        n_hv_triplets.append(int(N_HV_ACTs))
    
    train_loss = total_loss / total_examples
    total_tsm_loss = total_tsm_loss / total_examples
    total_reg_loss = total_reg_loss / total_examples

    n_label_triplets = int(sum(n_label_triplets) / (i+1))
    n_structure_triplets = int(sum(n_structure_triplets) / (i+1))
    n_triplets = int(sum(n_triplets) / (i+1))
    n_hv_triplets = int(sum(n_hv_triplets) / (i+1))
    
    return train_loss, total_tsm_loss, total_reg_loss, n_label_triplets, n_structure_triplets, n_triplets, n_hv_triplets

@torch.no_grad()
def test(test_loader, model, aca_loss):
    model.eval()
    total_examples = 0
    total_loss = 0
    total_tsm_loss = 0
    total_reg_loss = 0

    n_label_triplets = []
    n_structure_triplets = []
    n_triplets = []
    n_hv_triplets = []
    
    mse = []
    for i, data in enumerate(test_loader):
        data = data.to(device)
        predictions, embeddings = model(data.x.float(), data.edge_index,
                                        data.edge_attr, data.batch)
        loss_out = aca_loss(labels = data.y, 
                            predictions = predictions,
                            embeddings = embeddings,
                           fingerprints = data.fp)
        
        loss, reg_loss, tsm_loss,  N_Y_ACTs, N_S_ACTs, N_ACTs, N_HV_ACTs = loss_out

        total_loss += float(loss) * data.num_graphs
        total_tsm_loss += float(tsm_loss) * data.num_graphs
        total_reg_loss += float(reg_loss) * data.num_graphs
        total_examples += data.num_graphs

        n_label_triplets.append(int(N_Y_ACTs))
        n_structure_triplets.append(int(N_S_ACTs))
        n_triplets.append(int(N_ACTs))
        n_hv_triplets.append(int(N_HV_ACTs))

        mse.append(F.mse_loss(predictions, data.y, reduction='none').cpu())

    test_loss = total_loss / total_examples
    total_tsm_loss = total_tsm_loss / total_examples
    total_reg_loss = total_reg_loss / total_examples

    n_label_triplets = int(sum(n_label_triplets) / (i+1))
    n_structure_triplets = int(sum(n_structure_triplets) / (i+1))
    n_triplets = int(sum(n_triplets) / (i+1))
    n_hv_triplets = int(sum(n_hv_triplets) / (i+1))
    
    test_rmse = float(torch.cat(mse, dim=0).mean().sqrt())
    
    return test_loss, total_tsm_loss, total_reg_loss, n_label_triplets, n_structure_triplets, n_triplets, n_hv_triplets, test_rmse



def Test_performance(alpha=1.0):
    
    model = ACANet_PNA(**pub_args, deg=deg).to(device)  
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=10**-5)
    aca_loss = ACALoss(alpha=alpha, 
                        cliff_lower = 1., 
                        cliff_upper = 1.,
                        squared = False,
                        similarity_gate = True,
                        similarity_neg = 0.8,
                        similarity_pos = 0.2,
                        dev_mode = True,)
    
    history = []
    #ls_his = []
    for epoch in range(1, epochs):
        train_loss, tsm_loss, reg_loss, n_label_triplets, n_structure_triplets, n_triplets, n_hv_triplets = train(train_loader, model, optimizer, aca_loss)

        _, _, _, _, _, _, train_n_hv_triplets, train_rmse = test(train_loader, model, aca_loss)
        _, _, _, _, _, _, val_n_hv_triplets, val_rmse = test(val_loader, model, aca_loss)
        _, _, _, _, _, _, test_n_hv_triplets, test_rmse = test(test_loader, model, aca_loss)

        
        print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f} tsm_loss: {tsm_loss:.4f} reg_loss: {reg_loss:.4f} '
              f'N_Y: {n_label_triplets:03d} N_S: {n_structure_triplets:03d} N: {n_triplets:03d} N_HV: {n_hv_triplets:03d} '
              f'Val: {val_rmse:.4f} Test: {test_rmse:.4f}')
    
        history.append({'Epoch':epoch, 'train_loss':train_loss, 'train_triplet_loss':tsm_loss,
                        'train_reg_loss':reg_loss, 'val_rmse':val_rmse, 
                        'test_rmse':test_rmse, 'train_rmse':train_rmse,
                        
                        'n_label_triplets': n_label_triplets, 
                        'n_structure_triplets':n_structure_triplets,
                        'n_triplets':n_triplets,
                        'n_hv_triplets':n_hv_triplets,
                        

                        'train_n_hv_triplets':train_n_hv_triplets,
                        'val_n_hv_triplets':val_n_hv_triplets,
                        'test_n_hv_triplets':test_n_hv_triplets,
                       
                       })
        #ls_his.append({'Epoch':epoch, 'mae_loss':float(mae_loss), 'triplet_loss':float(triplet_loss)})
    dfh = pd.DataFrame(history)
    return dfh

In [14]:
dataset_name = 'CHEMBL3979_EC50'
Dataset =  HSSMS #LSSNS 
epochs = 800
batch_size = 128
lr = 1e-4

pre_transform = Gen39AtomFeatures()
in_channels = pre_transform.in_channels
path = './data/'

## model HPs
pub_args = {'in_channels':pre_transform.in_channels, 
            'edge_dim':pre_transform.edge_dim,
            'convs_layers': [64, 128, 256, 512],   
            'dense_layers': [256, 128, 32], 
            'out_channels':1, 
            'aggregators': ['mean', 'min', 'max', 'sum','std'],
            'scalers':['identity', 'amplification', 'attenuation'] ,
            'dropout_p': 0}

In [15]:
len(Dataset(path, name=dataset_name, pre_transform=pre_transform).shuffle(42))

1125

In [16]:
# train, valid, test splitting
res1 = []
res2 = []
for seed in [8]: #, 16, 24, 42, 64, 128, 256, 512, 1024, 2048
    dataset = Dataset(path, name=dataset_name, pre_transform=pre_transform).shuffle(42)
    N = len(dataset) // 5
    val_dataset = dataset[:N]
    test_dataset = dataset[N:2 * N]
    train_dataset = dataset[2 * N:]
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    deg = get_deg(train_dataset)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # With AC-Awareness ($\alpha = 1$)
    df1 = Test_performance(alpha=1.0)
    df1['seed'] = seed
    # Without AC-Awareness ($\alpha = 0$)
    df2 = Test_performance(alpha=0.0)
    df2['seed'] = seed
    res1.append(df1)
    res2.append(df2)

Epoch: 001, Loss: 37.6279 tsm_loss: 31.3185 reg_loss: 6.3093 N_Y: 405148 N_S: 1345 N: 723 N_HV: 2090 Val: 6.6624 Test: 6.6574
Epoch: 002, Loss: 5.7550 tsm_loss: 0.0000 reg_loss: 5.7550 N_Y: 403115 N_S: 1728 N: 871 N_HV: 000 Val: 6.6591 Test: 6.6541
Epoch: 003, Loss: 5.1914 tsm_loss: 0.0910 reg_loss: 5.1004 N_Y: 404914 N_S: 1451 N: 809 N_HV: 021 Val: 6.6528 Test: 6.6478
Epoch: 004, Loss: 4.2507 tsm_loss: 0.0859 reg_loss: 4.1649 N_Y: 405940 N_S: 1777 N: 1001 N_HV: 085 Val: 6.6404 Test: 6.6357
Epoch: 005, Loss: 2.9336 tsm_loss: 0.0863 reg_loss: 2.8472 N_Y: 404316 N_S: 1261 N: 689 N_HV: 005 Val: 6.5882 Test: 6.5852
Epoch: 006, Loss: 1.4796 tsm_loss: 0.0000 reg_loss: 1.4796 N_Y: 405238 N_S: 1233 N: 603 N_HV: 000 Val: 6.3250 Test: 6.3300
Epoch: 007, Loss: 1.2747 tsm_loss: 0.0000 reg_loss: 1.2747 N_Y: 405693 N_S: 1638 N: 871 N_HV: 000 Val: 5.6635 Test: 5.6828
Epoch: 008, Loss: 1.3072 tsm_loss: 0.0000 reg_loss: 1.3072 N_Y: 402855 N_S: 1793 N: 973 N_HV: 000 Val: 4.8576 Test: 4.8965
Epoch: 009, 

KeyboardInterrupt: 

In [None]:

 
Epoch: 001, Loss: 505.8423 tsm_loss: 499.5391 reg_loss: 6.3032 N_Y: 406098 N_S: 1754772 N: 406098 N_HV: 23162645 Val: 6.6238 Test: 6.5971
Epoch: 002, Loss: 164.1380 tsm_loss: 158.4051 reg_loss: 5.7329 N_Y: 407711 N_S: 1754772 N: 407711 N_HV: 20964135 Val: 6.6180 Test: 6.5913
Epoch: 003, Loss: 116.1526 tsm_loss: 111.1598 reg_loss: 4.9928 N_Y: 406124 N_S: 1754772 N: 406124 N_HV: 19399902 Val: 6.5892 Test: 6.5626
Epoch: 004, Loss: 85.8233 tsm_loss: 81.7326 reg_loss: 4.0907 N_Y: 404934 N_S: 1754772 N: 404934 N_HV: 18177328 Val: 6.4449 Test: 6.4185
Epoch: 005, Loss: 64.3978 tsm_loss: 61.3894 reg_loss: 3.0084 N_Y: 404929 N_S: 1754772 N: 404929 N_HV: 16892596 Val: 6.0339 Test: 6.0068
Epoch: 006, Loss: 51.1996 tsm_loss: 49.4932 reg_loss: 1.7064 N_Y: 405037 N_S: 1754772 N: 405037 N_HV: 16161038 Val: 5.1840 Test: 5.1546
Epoch: 007, Loss: 41.2007 tsm_loss: 40.2709 reg_loss: 0.9298 N_Y: 407263 N_S: 1754772 N: 407263 N_HV: 15222850 Val: 3.9104 Test: 3.8786
Epoch: 008, Loss: 35.2102 tsm_loss: 34.1526 reg_loss: 1.0576 N_Y: 405573 N_S: 1754772 N: 405573 N_HV: 14585591 Val: 2.7372 Test: 2.7022
Epoch: 009, Loss: 28.8879 tsm_loss: 27.8919 reg_loss: 0.9960 N_Y: 405859 N_S: 1754772 N: 405859 N_HV: 13614493 Val: 2.1299 Test: 2.0952
Epoch: 010, Loss: 25.8078 tsm_loss: 24.9041 reg_loss: 0.9037 N_Y: 407184 N_S: 1754772 N: 407184 N_HV: 13922313 Val: 1.7592 Test: 1.7214
Epoch: 011, Loss: 23.2485 tsm_loss: 22.3144 reg_loss: 0.9341 N_Y: 406054 N_S: 1754772 N: 406054 N_HV: 13378202 Val: 1.4483 Test: 1.4128
Epoch: 012, Loss: 21.3133 tsm_loss: 20.3923 reg_loss: 0.9210 N_Y: 404748 N_S: 1754772 N: 404748 N_HV: 13387117 Val: 1.2014 Test: 1.1677
Epoch: 013, Loss: 19.2430 tsm_loss: 18.3419 reg_loss: 0.9011 N_Y: 405713 N_S: 1754772 N: 405713 N_HV: 12809677 Val: 1.1517 Test: 1.1229
Epoch: 014, Loss: 17.0707 tsm_loss: 16.1642 reg_loss: 0.9066 N_Y: 406226 N_S: 1754772 N: 406226 N_HV: 12633630 Val: 1.1729 Test: 1.1523
Epoch: 015, Loss: 15.2888 tsm_loss: 14.3809 reg_loss: 0.9079 N_Y: 405731 N_S: 1754772 N: 405731 N_HV: 12061293 Val: 1.1521 Test: 1.1294
Epoch: 016, Loss: 14.0254 tsm_loss: 13.1231 reg_loss: 0.9022 N_Y: 404400 N_S: 1754772 N: 404400 N_HV: 11842593 Val: 1.1371 Test: 1.1158
Epoch: 017, Loss: 13.4109 tsm_loss: 12.5078 reg_loss: 0.9031 N_Y: 404543 N_S: 1754772 N: 404543 N_HV: 11810295 Val: 1.1364 Test: 1.1107
Epoch: 018, Loss: 12.6436 tsm_loss: 11.7351 reg_loss: 0.9085 N_Y: 406329 N_S: 1754772 N: 406329 N_HV: 11543175 Val: 1.1534 Test: 1.1327
Epoch: 019, Loss: 11.8808 tsm_loss: 10.9745 reg_loss: 0.9063 N_Y: 405813 N_S: 1754772 N: 405813 N_HV: 11353777 Val: 1.1434 Test: 1.1189
Epoch: 020, Loss: 11.0344 tsm_loss: 10.1323 reg_loss: 0.9021 N_Y: 406899 N_S: 1754772 N: 406899 N_HV: 10946095 Val: 1.1407 Test: 1.1188
Epoch: 021, Loss: 10.3187 tsm_loss: 9.4154 reg_loss: 0.9033 N_Y: 403563 N_S: 1754772 N: 403563 N_HV: 10568007 Val: 1.1482 Test: 1.1259
Epoch: 022, Loss: 9.7869 tsm_loss: 8.8831 reg_loss: 0.9038 N_Y: 406004 N_S: 1754772 N: 406004 N_HV: 10596467 Val: 1.1231 Test: 1.0989
Epoch: 023, Loss: 9.3746 tsm_loss: 8.4732 reg_loss: 0.9014 N_Y: 407431 N_S: 1754772 N: 407431 N_HV: 10632770 Val: 1.1309 Test: 1.1085
Epoch: 024, Loss: 9.0668 tsm_loss: 8.1601 reg_loss: 0.9068 N_Y: 406086 N_S: 1754772 N: 406086 N_HV: 10323246 Val: 1.1329 Test: 1.1063
Epoch: 025, Loss: 8.2675 tsm_loss: 7.3605 reg_loss: 0.9070 N_Y: 406918 N_S: 1754772 N: 406918 N_HV: 10002732 Val: 1.1291 Test: 1.1024
Epoch: 026, Loss: 7.9625 tsm_loss: 7.0596 reg_loss: 0.9029 N_Y: 405222 N_S: 1754772 N: 405222 N_HV: 9910866 Val: 1.1358 Test: 1.1156
Epoch: 027, Loss: 7.7207 tsm_loss: 6.8163 reg_loss: 0.9044 N_Y: 407275 N_S: 1754772 N: 407275 N_HV: 9627092 Val: 1.1381 Test: 1.1156
Epoch: 028, Loss: 7.3465 tsm_loss: 6.4424 reg_loss: 0.9041 N_Y: 406111 N_S: 1754772 N: 406111 N_HV: 9591652 Val: 1.1387 Test: 1.1190
Epoch: 029, Loss: 7.2410 tsm_loss: 6.3353 reg_loss: 0.9057 N_Y: 405556 N_S: 1754772 N: 405556 N_HV: 9888921 Val: 1.1268 Test: 1.1008
Epoch: 030, Loss: 7.1685 tsm_loss: 6.2637 reg_loss: 0.9048 N_Y: 407425 N_S: 1754772 N: 407425 N_HV: 9382837 Val: 1.1183 Test: 1.0961
Epoch: 031, Loss: 6.7557 tsm_loss: 5.8509 reg_loss: 0.9048 N_Y: 404700 N_S: 1754772 N: 404700 N_HV: 9148393 Val: 1.1357 Test: 1.1145
Epoch: 032, Loss: 6.5621 tsm_loss: 5.6617 reg_loss: 0.9004 N_Y: 405274 N_S: 1754772 N: 405274 N_HV: 8934917 Val: 1.1313 Test: 1.1093
Epoch: 033, Loss: 6.3590 tsm_loss: 5.4563 reg_loss: 0.9027 N_Y: 406289 N_S: 1754772 N: 406289 N_HV: 9076396 Val: 1.1398 Test: 1.1149
Epoch: 034, Loss: 6.0815 tsm_loss: 5.1754 reg_loss: 0.9062 N_Y: 405553 N_S: 1754772 N: 405553 N_HV: 8810466 Val: 1.1280 Test: 1.1037
Epoch: 035, Loss: 6.0966 tsm_loss: 5.1956 reg_loss: 0.9010 N_Y: 406233 N_S: 1754772 N: 406233 N_HV: 8982248 Val: 1.1454 Test: 1.1245
Epoch: 036, Loss: 5.8351 tsm_loss: 4.9338 reg_loss: 0.9014 N_Y: 404242 N_S: 1754772 N: 404242 N_HV: 8504289 Val: 1.1284 Test: 1.1081
Epoch: 037, Loss: 5.5254 tsm_loss: 4.6232 reg_loss: 0.9021 N_Y: 405728 N_S: 1754772 N: 405728 N_HV: 8503393 Val: 1.1233 Test: 1.1017
Epoch: 038, Loss: 5.4919 tsm_loss: 4.5901 reg_loss: 0.9017 N_Y: 404674 N_S: 1754772 N: 404674 N_HV: 8560067 Val: 1.1220 Test: 1.0956
Epoch: 039, Loss: 5.4316 tsm_loss: 4.5291 reg_loss: 0.9025 N_Y: 405837 N_S: 1754772 N: 405837 N_HV: 8387001 Val: 1.1150 Test: 1.0942
Epoch: 040, Loss: 5.7299 tsm_loss: 4.8302 reg_loss: 0.8997 N_Y: 405549 N_S: 1754772 N: 405549 N_HV: 8875401 Val: 1.1238 Test: 1.0985
Epoch: 041, Loss: 5.2737 tsm_loss: 4.3732 reg_loss: 0.9006 N_Y: 405118 N_S: 1754772 N: 405118 N_HV: 8189212 Val: 1.1351 Test: 1.1161
Epoch: 042, Loss: 5.2150 tsm_loss: 4.3148 reg_loss: 0.9001 N_Y: 405764 N_S: 1754772 N: 405764 N_HV: 8563219 Val: 1.1400 Test: 1.1209
Epoch: 043, Loss: 5.0776 tsm_loss: 4.1838 reg_loss: 0.8937 N_Y: 405348 N_S: 1754772 N: 405348 N_HV: 7983448 Val: 1.1168 Test: 1.0917
Epoch: 044, Loss: 4.7990 tsm_loss: 3.8994 reg_loss: 0.8995 N_Y: 407445 N_S: 1754772 N: 407445 N_HV: 7964485 Val: 1.1129 Test: 1.0883
Epoch: 045, Loss: 4.9075 tsm_loss: 4.0140 reg_loss: 0.8935 N_Y: 406010 N_S: 1754772 N: 406010 N_HV: 7723743 Val: 1.1200 Test: 1.0950
Epoch: 046, Loss: 5.2888 tsm_loss: 4.3939 reg_loss: 0.8949 N_Y: 405659 N_S: 1754772 N: 405659 N_HV: 7869013 Val: 1.1254 Test: 1.1053
Epoch: 047, Loss: 4.7228 tsm_loss: 3.8304 reg_loss: 0.8924 N_Y: 405203 N_S: 1754772 N: 405203 N_HV: 7922719 Val: 1.1160 Test: 1.0939
Epoch: 048, Loss: 5.0686 tsm_loss: 4.1778 reg_loss: 0.8908 N_Y: 405611 N_S: 1754772 N: 405611 N_HV: 7804627 Val: 1.1285 Test: 1.1051
Epoch: 049, Loss: 5.4764 tsm_loss: 4.5754 reg_loss: 0.9010 N_Y: 405033 N_S: 1754772 N: 405033 N_HV: 7758364 Val: 1.1090 Test: 1.0825
Epoch: 050, Loss: 4.8427 tsm_loss: 3.9496 reg_loss: 0.8931 N_Y: 406347 N_S: 1754772 N: 406347 N_HV: 7836730 Val: 1.1088 Test: 1.0912
Epoch: 051, Loss: 6.0015 tsm_loss: 5.1136 reg_loss: 0.8880 N_Y: 403858 N_S: 1754772 N: 403858 N_HV: 7827957 Val: 1.1098 Test: 1.0936
Epoch: 052, Loss: 4.5031 tsm_loss: 3.6166 reg_loss: 0.8864 N_Y: 404991 N_S: 1754772 N: 404991 N_HV: 7557500 Val: 1.1216 Test: 1.1036
Epoch: 053, Loss: 4.1690 tsm_loss: 3.2818 reg_loss: 0.8872 N_Y: 406386 N_S: 1754772 N: 406386 N_HV: 7361394 Val: 1.1151 Test: 1.0993
Epoch: 054, Loss: 4.2005 tsm_loss: 3.3136 reg_loss: 0.8870 N_Y: 407190 N_S: 1754772 N: 407190 N_HV: 7581680 Val: 1.1131 Test: 1.0913
Epoch: 055, Loss: 4.1414 tsm_loss: 3.2547 reg_loss: 0.8867 N_Y: 406508 N_S: 1754772 N: 406508 N_HV: 7392246 Val: 1.0949 Test: 1.0785
Epoch: 056, Loss: 3.9785 tsm_loss: 3.0964 reg_loss: 0.8821 N_Y: 405172 N_S: 1754772 N: 405172 N_HV: 7411442 Val: 1.1030 Test: 1.0829
Epoch: 057, Loss: 3.9716 tsm_loss: 3.0884 reg_loss: 0.8832 N_Y: 405448 N_S: 1754772 N: 405448 N_HV: 7276681 Val: 1.1121 Test: 1.0956
Epoch: 058, Loss: 3.8341 tsm_loss: 2.9495 reg_loss: 0.8847 N_Y: 404553 N_S: 1754772 N: 404553 N_HV: 7351858 Val: 1.1026 Test: 1.0798
Epoch: 059, Loss: 3.9141 tsm_loss: 3.0293 reg_loss: 0.8849 N_Y: 404818 N_S: 1754772 N: 404818 N_HV: 7196173 Val: 1.0992 Test: 1.0760
Epoch: 060, Loss: 3.7666 tsm_loss: 2.8797 reg_loss: 0.8868 N_Y: 406508 N_S: 1754772 N: 406508 N_HV: 7107792 Val: 1.0976 Test: 1.0789
Epoch: 061, Loss: 3.7562 tsm_loss: 2.8812 reg_loss: 0.8750 N_Y: 405609 N_S: 1754772 N: 405609 N_HV: 6948064 Val: 1.1016 Test: 1.0852
Epoch: 062, Loss: 3.6771 tsm_loss: 2.7960 reg_loss: 0.8811 N_Y: 406004 N_S: 1754772 N: 406004 N_HV: 6879330 Val: 1.1153 Test: 1.0937
Epoch: 063, Loss: 3.6846 tsm_loss: 2.8076 reg_loss: 0.8770 N_Y: 403566 N_S: 1754772 N: 403566 N_HV: 6745342 Val: 1.0949 Test: 1.0745
Epoch: 064, Loss: 3.6120 tsm_loss: 2.7382 reg_loss: 0.8737 N_Y: 405540 N_S: 1754772 N: 405540 N_HV: 6806963 Val: 1.0945 Test: 1.0732
Epoch: 065, Loss: 3.3043 tsm_loss: 2.4269 reg_loss: 0.8774 N_Y: 406625 N_S: 1754772 N: 406625 N_HV: 6541643 Val: 1.1026 Test: 1.0854
Epoch: 066, Loss: 3.2799 tsm_loss: 2.4056 reg_loss: 0.8743 N_Y: 406068 N_S: 1754772 N: 406068 N_HV: 6460041 Val: 1.0964 Test: 1.0770
Epoch: 067, Loss: 3.3141 tsm_loss: 2.4450 reg_loss: 0.8691 N_Y: 403444 N_S: 1754772 N: 403444 N_HV: 6132364 Val: 1.0963 Test: 1.0758
Epoch: 068, Loss: 3.4450 tsm_loss: 2.5703 reg_loss: 0.8747 N_Y: 406253 N_S: 1754772 N: 406253 N_HV: 6913482 Val: 1.0909 Test: 1.0710
Epoch: 069, Loss: 3.5357 tsm_loss: 2.6650 reg_loss: 0.8707 N_Y: 406930 N_S: 1754772 N: 406930 N_HV: 6820549 Val: 1.0964 Test: 1.0767
Epoch: 070, Loss: 3.3889 tsm_loss: 2.5162 reg_loss: 0.8727 N_Y: 405928 N_S: 1754772 N: 405928 N_HV: 6831304 Val: 1.0841 Test: 1.0685
Epoch: 071, Loss: 3.6323 tsm_loss: 2.7680 reg_loss: 0.8643 N_Y: 406425 N_S: 1754772 N: 406425 N_HV: 6619868 Val: 1.0882 Test: 1.0693
Epoch: 072, Loss: 3.4390 tsm_loss: 2.5734 reg_loss: 0.8656 N_Y: 406235 N_S: 1754772 N: 406235 N_HV: 6371829 Val: 1.0924 Test: 1.0738
Epoch: 073, Loss: 3.1731 tsm_loss: 2.3022 reg_loss: 0.8709 N_Y: 405429 N_S: 1754772 N: 405429 N_HV: 6284679 Val: 1.0823 Test: 1.0650
Epoch: 074, Loss: 2.9572 tsm_loss: 2.0953 reg_loss: 0.8619 N_Y: 406853 N_S: 1754772 N: 406853 N_HV: 5786677 Val: 1.0866 Test: 1.0679
Epoch: 075, Loss: 2.9429 tsm_loss: 2.0801 reg_loss: 0.8628 N_Y: 405148 N_S: 1754772 N: 405148 N_HV: 5852388 Val: 1.0804 Test: 1.0640
Epoch: 076, Loss: 2.7417 tsm_loss: 1.8791 reg_loss: 0.8626 N_Y: 406006 N_S: 1754772 N: 406006 N_HV: 5746844 Val: 1.0776 Test: 1.0601
Epoch: 077, Loss: 2.8950 tsm_loss: 2.0357 reg_loss: 0.8593 N_Y: 405756 N_S: 1754772 N: 405756 N_HV: 5978885 Val: 1.0871 Test: 1.0657
Epoch: 078, Loss: 3.3632 tsm_loss: 2.5034 reg_loss: 0.8597 N_Y: 406061 N_S: 1754772 N: 406061 N_HV: 5958748 Val: 1.0750 Test: 1.0541
Epoch: 079, Loss: 3.4763 tsm_loss: 2.6179 reg_loss: 0.8584 N_Y: 403735 N_S: 1754772 N: 403735 N_HV: 6411100 Val: 1.0778 Test: 1.0627
Epoch: 080, Loss: 3.2669 tsm_loss: 2.4189 reg_loss: 0.8480 N_Y: 404921 N_S: 1754772 N: 404921 N_HV: 6530007 Val: 1.0674 Test: 1.0454
Epoch: 081, Loss: 3.9032 tsm_loss: 3.0575 reg_loss: 0.8456 N_Y: 405339 N_S: 1754772 N: 405339 N_HV: 6429254 Val: 1.0797 Test: 1.0572
Epoch: 082, Loss: 3.2930 tsm_loss: 2.4467 reg_loss: 0.8462 N_Y: 405544 N_S: 1754772 N: 405544 N_HV: 6567809 Val: 1.0734 Test: 1.0567
Epoch: 083, Loss: 3.1284 tsm_loss: 2.2842 reg_loss: 0.8442 N_Y: 407783 N_S: 1754772 N: 407783 N_HV: 6277752 Val: 1.0614 Test: 1.0420
Epoch: 084, Loss: 3.0763 tsm_loss: 2.2318 reg_loss: 0.8445 N_Y: 406373 N_S: 1754772 N: 406373 N_HV: 6274150 Val: 1.0687 Test: 1.0499
Epoch: 085, Loss: 3.0601 tsm_loss: 2.2156 reg_loss: 0.8446 N_Y: 405215 N_S: 1754772 N: 405215 N_HV: 6167134 Val: 1.0606 Test: 1.0419
Epoch: 086, Loss: 3.0572 tsm_loss: 2.2227 reg_loss: 0.8345 N_Y: 404274 N_S: 1754772 N: 404274 N_HV: 6036077 Val: 1.0566 Test: 1.0397
Epoch: 087, Loss: 3.1218 tsm_loss: 2.2870 reg_loss: 0.8347 N_Y: 406593 N_S: 1754772 N: 406593 N_HV: 6155705 Val: 1.0503 Test: 1.0364
Epoch: 088, Loss: 3.0674 tsm_loss: 2.2314 reg_loss: 0.8360 N_Y: 405461 N_S: 1754772 N: 405461 N_HV: 6177015 Val: 1.0586 Test: 1.0400
Epoch: 089, Loss: 2.8340 tsm_loss: 2.0037 reg_loss: 0.8303 N_Y: 404562 N_S: 1754772 N: 404562 N_HV: 5878345 Val: 1.0536 Test: 1.0380
Epoch: 090, Loss: 2.9095 tsm_loss: 2.0793 reg_loss: 0.8301 N_Y: 406961 N_S: 1754772 N: 406961 N_HV: 5501692 Val: 1.0663 Test: 1.0498
Epoch: 091, Loss: 2.9963 tsm_loss: 2.1669 reg_loss: 0.8295 N_Y: 406019 N_S: 1754772 N: 406019 N_HV: 5634796 Val: 1.0420 Test: 1.0258
Epoch: 092, Loss: 2.9043 tsm_loss: 2.0853 reg_loss: 0.8190 N_Y: 404786 N_S: 1754772 N: 404786 N_HV: 5959839 Val: 1.0664 Test: 1.0513
Epoch: 093, Loss: 3.0049 tsm_loss: 2.1814 reg_loss: 0.8235 N_Y: 406252 N_S: 1754772 N: 406252 N_HV: 5756120 Val: 1.0485 Test: 1.0297
Epoch: 094, Loss: 3.1208 tsm_loss: 2.3036 reg_loss: 0.8171 N_Y: 402504 N_S: 1754772 N: 402504 N_HV: 6025506 Val: 1.0442 Test: 1.0274
Epoch: 095, Loss: 2.6919 tsm_loss: 1.8755 reg_loss: 0.8164 N_Y: 403756 N_S: 1754772 N: 403756 N_HV: 5698776 Val: 1.0387 Test: 


In [None]:
df1 = pd.concat(res1)
df2 = pd.concat(res2)

In [12]:
df1.to_csv('./with_aca.csv')
df2.to_csv('./without_aca.csv')

NameError: name 'df2' is not defined

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))
colors = ['#FFE699','#00B0F0']


y = 'val_rmse'

n1 = r'With AC-awareness ($\mathcal{L}_{mae} + \mathcal{L}_{tsm}$)'
n2 = r'Without AC-awareness ($\mathcal{L}_{mae}$)'


dfp = df2.groupby('Epoch').mean()[y].to_frame(name = n2).join(df1.groupby('Epoch').mean()[y].to_frame(name = n1)).rolling(1).mean()
dfp_std = df2.groupby('Epoch').std()[y].to_frame(name = n2).join(df1.groupby('Epoch').std()[y].to_frame(name = n1)).rolling(1).mean()

dfp.plot(lw = 2, ax=ax,color = colors, alpha =1)
ax.fill_between(dfp.index, (dfp - dfp_std)[n1], (dfp + dfp_std)[n1], color=colors[1], alpha=0.2)
ax.fill_between(dfp.index, (dfp - dfp_std)[n2], (dfp + dfp_std)[n2], color=colors[0], alpha=0.2)

ax.set_ylim(0.60, 1.0)
ax.set_ylabel('Validation RMSE')
ax.set_xlabel('epochs')
ax.spines[['right', 'top']].set_visible(False)

ax.set_xlim(1,800)

ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./Validation_RMSE.svg', bbox_inches='tight', dpi=400) 
fig.savefig('./Validation_RMSE.pdf', bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))
colors = ['#FFE699','#00B0F0']


y = 'test_rmse'



dfp = df2.groupby('Epoch').mean()[y].to_frame(name = n2).join(df1.groupby('Epoch').mean()[y].to_frame(name = n1)).rolling(1).mean()
dfp_std = df2.groupby('Epoch').std()[y].to_frame(name = n2).join(df1.groupby('Epoch').std()[y].to_frame(name = n1)).rolling(1).mean()

dfp.plot(lw = 2, ax=ax,color = colors, alpha =1)
ax.fill_between(dfp.index, (dfp - dfp_std)[n1], (dfp + dfp_std)[n1], color=colors[1], alpha=0.2)
ax.fill_between(dfp.index, (dfp - dfp_std)[n2], (dfp + dfp_std)[n2], color=colors[0], alpha=0.2)

ax.set_ylim(0.60, 1.0)
ax.set_ylabel('Test RMSE')
ax.set_xlabel('epochs')
ax.spines[['right', 'top']].set_visible(False)

ax.set_xlim(1,800)

ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./Test_RMSE.svg' , bbox_inches='tight', dpi=400) 
fig.savefig('./Test_RMSE.pdf' , bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))


y = 'n_pos_triplets'

dfp = df2.groupby('Epoch').mean()[y].to_frame(name = n2).join(df1.groupby('Epoch').mean()[y].to_frame(name = n1)).rolling(1).mean()
dfp_std = df2.groupby('Epoch').std()[y].to_frame(name = n2).join(df1.groupby('Epoch').std()[y].to_frame(name = n1)).rolling(1).mean()

dfp.plot(lw = 3, ax=ax,color = colors, alpha =1)
ax.fill_between(dfp.index, (dfp - dfp_std)[n1], (dfp + dfp_std)[n1], color=colors[1], alpha=0.3)
ax.fill_between(dfp.index, (dfp - dfp_std)[n2], (dfp + dfp_std)[n2], color=colors[0], alpha=0.3)

ax.legend(loc='center', bbox_to_anchor=(0.55, 0.5))

ax.spines[['right', 'top']].set_visible(False)
plt.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
ax.set_ylabel("No. of HV-ACTs ($M^'$)")
ax.set_xlabel('epochs')
ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
ax.set_xlim(-5,800)


fig.savefig('./Number_of_mined_ACTs_during_training.svg' , bbox_inches='tight', dpi=400) 
fig.savefig('./Number_of_mined_ACTs_during_training.pdf' , bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))
y = 'train_triplet_loss'
dfp = df2.groupby('Epoch').mean()[y].to_frame(name = n2).join(df1.groupby('Epoch').mean()[y].to_frame(name = n1)).rolling(1).mean()
dfp_std = df2.groupby('Epoch').std()[y].to_frame(name = n2).join(df1.groupby('Epoch').std()[y].to_frame(name = n1)).rolling(1).mean()

dfp.plot(lw = 3, ax=ax,color = colors, alpha =1)
ax.fill_between(dfp.index, (dfp - dfp_std)[n1], (dfp + dfp_std)[n1], color=colors[1], alpha=0.3)
ax.fill_between(dfp.index, (dfp - dfp_std)[n2], (dfp + dfp_std)[n2], color=colors[0], alpha=0.3)

ax.spines[['right', 'top']].set_visible(False)
ax.set_xlim(-5,800)
ax.set_ylim(-1,10)

ax.set_ylabel('Training TSM Loss')
ax.set_xlabel('epochs')
ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./Triplet_loss_during_training.svg', bbox_inches='tight', dpi=400) 
fig.savefig('./Triplet_loss_during_training.pdf', bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))

y = 'train_mae_loss'

dfp = df2.groupby('Epoch').mean()[y].to_frame(name = n2).join(df1.groupby('Epoch').mean()[y].to_frame(name = n1)).rolling(1).mean()
dfp_std = df2.groupby('Epoch').std()[y].to_frame(name = n2).join(df1.groupby('Epoch').std()[y].to_frame(name = n1)).rolling(1).mean()

dfp.plot(lw = 2.5, ax=ax,color = colors, alpha =1)
ax.fill_between(dfp.index, (dfp - dfp_std)[n1], (dfp + dfp_std)[n1], color=colors[1], alpha=0.3)
ax.fill_between(dfp.index, (dfp - dfp_std)[n2], (dfp + dfp_std)[n2], color=colors[0], alpha=0.3)

ax.set_ylim(0.0, 0.8)
ax.spines[['right', 'top']].set_visible(False)

ax.set_ylabel('Training MAE loss')
ax.set_xlabel('epochs')
ax.legend(loc='center', bbox_to_anchor=(0.55, 0.5))

ax.set_xlim(1,800)

ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./Train_mae_los.svg', bbox_inches='tight', dpi=400) 
fig.savefig('./Train_mae_los.pdf', bbox_inches='tight', dpi=400) 