In [1]:
from math import sqrt
import pandas as pd
import numpy as np
import os
import torch
import torch.nn.functional as F
from rdkit import Chem
import matplotlib.pyplot as plt

from torch_geometric.loader import DataLoader
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn import global_mean_pool, global_max_pool
%matplotlib inline

In [2]:
gpuid = 0
torch.cuda.set_device(gpuid)
print(torch.cuda.current_device())

0


In [3]:
import sys
sys.path.insert(0, '/home/was966/Research/bidd-clsar/')

In [4]:
from clsar.dataset import LSSNS, HSSMS
from clsar.feature import Gen39AtomFeatures
from clsar.model.model import ACANet_PNA, get_deg, _fix_reproducibility,ACANet_GAT, ACANet_GCN, ACANet_GIN
from clsar.model.loss import ACALoss, get_best_cliff
_fix_reproducibility(42)

In [5]:
ACANET_MODEL = ACANet_GCN

In [6]:
def train(train_loader, model, optimizer, aca_loss):

    total_examples = 0
    total_loss =  0    
    total_tsm_loss = 0
    total_reg_loss = 0   
    n_triplets = []
    n_pos_triplets = []
    model.train()
    for i, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        predictions, embeddings = model(data.x.float(), data.edge_index, 
                                        data.edge_attr, data.batch)
        
        loss_out = aca_loss(labels = data.y, 
                            predictions = predictions,
                            embeddings = embeddings)
        loss, reg_loss, tsm_loss, n, n_pos = loss_out
        
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_tsm_loss += float(tsm_loss) * data.num_graphs        
        total_reg_loss += float(reg_loss) * data.num_graphs        
        total_examples += data.num_graphs
        
        n_triplets.append(int(n))
        n_pos_triplets.append(int(n_pos))
        
    train_loss = total_loss / total_examples
    total_tsm_loss = total_tsm_loss / total_examples
    total_reg_loss = total_reg_loss / total_examples
    n_triplets = int(sum(n_triplets) / (i+1))
    n_pos_triplets = int(sum(n_pos_triplets) / (i+1))
    
    return train_loss, total_tsm_loss, total_reg_loss, n_triplets, n_pos_triplets

@torch.no_grad()
def test(test_loader, model, aca_loss):
    model.eval()
    total_examples = 0
    total_loss = 0
    total_tsm_loss = 0
    total_reg_loss = 0
    n_triplets = []
    n_pos_triplets = []
    mse = []
    for i, data in enumerate(test_loader):
        data = data.to(device)
        predictions, embeddings = model(data.x.float(), data.edge_index,
                                        data.edge_attr, data.batch)
        loss_out = aca_loss(labels=data.y,
                            predictions=predictions,
                            embeddings=embeddings)

        loss, reg_loss, tsm_loss, n, n_pos = loss_out

        total_loss += float(loss) * data.num_graphs
        total_tsm_loss += float(tsm_loss) * data.num_graphs
        total_reg_loss += float(reg_loss) * data.num_graphs
        total_examples += data.num_graphs

        n_triplets.append(int(n))
        n_pos_triplets.append(int(n_pos))

        mse.append(F.mse_loss(predictions, data.y, reduction='none').cpu())

    test_loss = total_loss / total_examples
    total_tsm_loss = total_tsm_loss / total_examples
    total_reg_loss = total_reg_loss / total_examples
    n_triplets = int(sum(n_triplets) / (i+1))
    n_pos_triplets = int(sum(n_pos_triplets) / (i+1))
    
    test_rmse = float(torch.cat(mse, dim=0).mean().sqrt())
    
    return test_loss, total_tsm_loss, total_reg_loss, n_triplets, n_pos_triplets, test_rmse



def Test_performance(alpha=1.0):
    
    model = ACANET_MODEL(**pub_args, aggr = 'mean').to(device)  
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=10**-5)
    aca_loss = ACALoss(alpha=alpha, cliff_lower = 1., cliff_upper = 1.)
    
    history = []
    #ls_his = []
    for epoch in range(1, epochs):
        train_loss, tsm_loss, reg_loss, n_triplets, n_pos_triplets = train(train_loader, model, optimizer, aca_loss)

        _, _, _, _, train_n_pos_triplets, train_rmse = test(train_loader, model, aca_loss)
        _, _, _, _, val_n_pos_triplets, val_rmse = test(val_loader, model, aca_loss)
        _, _, _, _, test_n_pos_triplets, test_rmse = test(test_loader, model, aca_loss)

        
        print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f} tsm_loss: {tsm_loss:.4f} reg_loss: {reg_loss:.4f} '
              f'n_pos_triplets: {n_pos_triplets:03d};  Val: {val_rmse:.4f} Test: {test_rmse:.4f}')
    
        history.append({'Epoch':epoch, 'train_loss':train_loss, 'train_triplet_loss':tsm_loss,'train_mae_loss':reg_loss,
                        'val_rmse':val_rmse, 'test_rmse':test_rmse, 'train_rmse':train_rmse,'n_triplets': n_triplets, 
                        'n_pos_triplets':n_pos_triplets, 
                        'train_n_pos_triplets':train_n_pos_triplets,
                        'val_n_pos_triplets':val_n_pos_triplets,
                        'test_n_pos_triplets':test_n_pos_triplets,
                       
                       })
        #ls_his.append({'Epoch':epoch, 'mae_loss':float(mae_loss), 'triplet_loss':float(triplet_loss)})
    dfh = pd.DataFrame(history)
    return dfh

In [7]:
dataset_name = 'CHEMBL3979_EC50'
Dataset =  HSSMS #LSSNS 
epochs = 800
batch_size = 128
lr = 1e-4

pre_transform = Gen39AtomFeatures()
in_channels = pre_transform.in_channels
path = '../data/'

## model HPs
pub_args = {'in_channels':pre_transform.in_channels, 
            'edge_dim':pre_transform.edge_dim,
            'convs_layers': [64, 128, 256, 512],   
            'dense_layers': [256, 128, 32], 
            'out_channels':1, 
            # 'aggregators': ['mean', 'min', 'max', 'sum','std'],
            # 'scalers':['identity', 'amplification', 'attenuation'] ,
            'dropout_p': 0}

In [8]:
len(Dataset(path, name=dataset_name, pre_transform=pre_transform).shuffle(42))

1125

In [None]:
# train, valid, test splitting
res1 = []
res2 = []
for seed in [8, 16, 24, 42, 64, 128, 256, 512, 1024, 2048]:
    dataset = Dataset(path, name=dataset_name, pre_transform=pre_transform).shuffle(42)
    N = len(dataset) // 5
    val_dataset = dataset[:N]
    test_dataset = dataset[N:2 * N]
    train_dataset = dataset[2 * N:]
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    deg = get_deg(train_dataset)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # With AC-Awareness ($\alpha = 1$)
    df1 = Test_performance(alpha=1.0)
    df1['seed'] = seed
    # Without AC-Awareness ($\alpha = 0$)
    df2 = Test_performance(alpha=0.0)
    df2['seed'] = seed
    res1.append(df1)
    res2.append(df2)



Epoch: 001, Loss: 11.2357 tsm_loss: 4.6409 reg_loss: 6.5948 n_pos_triplets: 226402;  Val: 6.0852 Test: 6.1176
Epoch: 002, Loss: 9.8413 tsm_loss: 3.9363 reg_loss: 5.9050 n_pos_triplets: 228764;  Val: 5.7904 Test: 5.8164
Epoch: 003, Loss: 8.5868 tsm_loss: 3.5585 reg_loss: 5.0283 n_pos_triplets: 228152;  Val: 5.1441 Test: 5.1747
Epoch: 004, Loss: 7.0034 tsm_loss: 3.1436 reg_loss: 3.8598 n_pos_triplets: 225069;  Val: 4.0747 Test: 4.1084
Epoch: 005, Loss: 5.1450 tsm_loss: 2.7784 reg_loss: 2.3667 n_pos_triplets: 225498;  Val: 2.4978 Test: 2.5298
Epoch: 006, Loss: 3.7488 tsm_loss: 2.6598 reg_loss: 1.0891 n_pos_triplets: 226579;  Val: 1.1534 Test: 1.1106
Epoch: 007, Loss: 3.7224 tsm_loss: 2.4534 reg_loss: 1.2690 n_pos_triplets: 221564;  Val: 1.0950 Test: 1.0300
Epoch: 008, Loss: 3.2741 tsm_loss: 2.2764 reg_loss: 0.9977 n_pos_triplets: 219174;  Val: 1.1472 Test: 1.1154
Epoch: 009, Loss: 3.0250 tsm_loss: 2.1026 reg_loss: 0.9224 n_pos_triplets: 217027;  Val: 1.2253 Test: 1.2039
Epoch: 010, Loss: 



Epoch: 001, Loss: 6.4871 tsm_loss: 5.0697 reg_loss: 6.4871 n_pos_triplets: 224117;  Val: 6.2231 Test: 6.2452
Epoch: 002, Loss: 5.8789 tsm_loss: 5.0111 reg_loss: 5.8789 n_pos_triplets: 223522;  Val: 5.9750 Test: 5.9953
Epoch: 003, Loss: 5.1107 tsm_loss: 4.8539 reg_loss: 5.1107 n_pos_triplets: 220625;  Val: 5.4398 Test: 5.4583
Epoch: 004, Loss: 4.0634 tsm_loss: 4.8849 reg_loss: 4.0634 n_pos_triplets: 222333;  Val: 4.4835 Test: 4.4960
Epoch: 005, Loss: 2.6064 tsm_loss: 4.9423 reg_loss: 2.6064 n_pos_triplets: 220328;  Val: 2.9674 Test: 2.9635
Epoch: 006, Loss: 1.1903 tsm_loss: 4.9643 reg_loss: 1.1903 n_pos_triplets: 217732;  Val: 1.2948 Test: 1.2200
Epoch: 007, Loss: 1.2456 tsm_loss: 4.7900 reg_loss: 1.2456 n_pos_triplets: 215482;  Val: 1.1043 Test: 1.0045
Epoch: 008, Loss: 1.0260 tsm_loss: 4.6517 reg_loss: 1.0260 n_pos_triplets: 214234;  Val: 1.2009 Test: 1.1330
Epoch: 009, Loss: 0.8835 tsm_loss: 4.4741 reg_loss: 0.8835 n_pos_triplets: 214189;  Val: 1.3029 Test: 1.2520
Epoch: 010, Loss: 0



Epoch: 001, Loss: 13.0537 tsm_loss: 6.4247 reg_loss: 6.6290 n_pos_triplets: 226735;  Val: 6.4495 Test: 6.2258
Epoch: 002, Loss: 11.5953 tsm_loss: 5.5160 reg_loss: 6.0793 n_pos_triplets: 227454;  Val: 6.1616 Test: 5.9410
Epoch: 003, Loss: 10.0580 tsm_loss: 4.7493 reg_loss: 5.3087 n_pos_triplets: 226071;  Val: 5.5964 Test: 5.3789
Epoch: 004, Loss: 8.3763 tsm_loss: 4.1311 reg_loss: 4.2452 n_pos_triplets: 225239;  Val: 4.6386 Test: 4.4243
Epoch: 005, Loss: 6.4182 tsm_loss: 3.6393 reg_loss: 2.7789 n_pos_triplets: 223663;  Val: 3.1563 Test: 2.9541
Epoch: 006, Loss: 4.4983 tsm_loss: 3.2457 reg_loss: 1.2526 n_pos_triplets: 226112;  Val: 1.4482 Test: 1.3381
Epoch: 007, Loss: 4.1029 tsm_loss: 2.9244 reg_loss: 1.1785 n_pos_triplets: 224962;  Val: 1.0867 Test: 1.1147
Epoch: 008, Loss: 3.7383 tsm_loss: 2.6629 reg_loss: 1.0754 n_pos_triplets: 226733;  Val: 1.2325 Test: 1.1768
Epoch: 009, Loss: 3.4158 tsm_loss: 2.4884 reg_loss: 0.9274 n_pos_triplets: 225636;  Val: 1.4445 Test: 1.3502
Epoch: 010, Loss



Epoch: 001, Loss: 6.7299 tsm_loss: 6.1022 reg_loss: 6.7299 n_pos_triplets: 225257;  Val: 6.6848 Test: 6.4631
Epoch: 002, Loss: 6.3736 tsm_loss: 5.9957 reg_loss: 6.3736 n_pos_triplets: 224334;  Val: 6.5269 Test: 6.3037
Epoch: 003, Loss: 5.8845 tsm_loss: 6.0797 reg_loss: 5.8845 n_pos_triplets: 223315;  Val: 6.1746 Test: 5.9526
Epoch: 004, Loss: 5.2090 tsm_loss: 5.9052 reg_loss: 5.2090 n_pos_triplets: 224959;  Val: 5.5483 Test: 5.3267
Epoch: 005, Loss: 4.2157 tsm_loss: 5.6136 reg_loss: 4.2157 n_pos_triplets: 221178;  Val: 4.4943 Test: 4.2768
Epoch: 006, Loss: 2.8400 tsm_loss: 5.4073 reg_loss: 2.8400 n_pos_triplets: 222270;  Val: 2.9157 Test: 2.7135
Epoch: 007, Loss: 1.2951 tsm_loss: 5.2485 reg_loss: 1.2951 n_pos_triplets: 223573;  Val: 1.2495 Test: 1.1853
Epoch: 008, Loss: 1.2438 tsm_loss: 4.5936 reg_loss: 1.2438 n_pos_triplets: 219007;  Val: 1.2715 Test: 1.3407
Epoch: 009, Loss: 1.1121 tsm_loss: 4.3056 reg_loss: 1.1121 n_pos_triplets: 219666;  Val: 1.1352 Test: 1.1193
Epoch: 010, Loss: 0



Epoch: 001, Loss: 11.3449 tsm_loss: 4.7679 reg_loss: 6.5770 n_pos_triplets: 227744;  Val: 6.4497 Test: 6.5925
Epoch: 002, Loss: 10.1266 tsm_loss: 4.0378 reg_loss: 6.0889 n_pos_triplets: 227033;  Val: 6.1274 Test: 6.2667
Epoch: 003, Loss: 9.0382 tsm_loss: 3.5628 reg_loss: 5.4754 n_pos_triplets: 227476;  Val: 5.6304 Test: 5.7616
Epoch: 004, Loss: 7.7850 tsm_loss: 3.1197 reg_loss: 4.6653 n_pos_triplets: 225302;  Val: 4.8470 Test: 4.9704
Epoch: 005, Loss: 6.2985 tsm_loss: 2.7657 reg_loss: 3.5329 n_pos_triplets: 224416;  Val: 3.6793 Test: 3.7898
Epoch: 006, Loss: 4.5402 tsm_loss: 2.5405 reg_loss: 1.9997 n_pos_triplets: 223807;  Val: 2.0455 Test: 2.1320
Epoch: 007, Loss: 3.3075 tsm_loss: 2.3429 reg_loss: 0.9646 n_pos_triplets: 223777;  Val: 1.0920 Test: 1.1115
Epoch: 008, Loss: 3.3424 tsm_loss: 2.1560 reg_loss: 1.1864 n_pos_triplets: 218354;  Val: 1.0829 Test: 1.0909
Epoch: 009, Loss: 2.9415 tsm_loss: 2.0030 reg_loss: 0.9385 n_pos_triplets: 217510;  Val: 1.1229 Test: 1.1635
Epoch: 010, Loss:



Epoch: 001, Loss: 6.5211 tsm_loss: 6.2180 reg_loss: 6.5211 n_pos_triplets: 222571;  Val: 6.3957 Test: 6.5351
Epoch: 002, Loss: 6.1172 tsm_loss: 5.9583 reg_loss: 6.1172 n_pos_triplets: 222383;  Val: 6.2258 Test: 6.3621
Epoch: 003, Loss: 5.6434 tsm_loss: 6.1889 reg_loss: 5.6434 n_pos_triplets: 223019;  Val: 5.9335 Test: 6.0680
Epoch: 004, Loss: 5.0012 tsm_loss: 6.1862 reg_loss: 5.0012 n_pos_triplets: 221027;  Val: 5.4470 Test: 5.5755
Epoch: 005, Loss: 4.1224 tsm_loss: 6.1148 reg_loss: 4.1224 n_pos_triplets: 218425;  Val: 4.7001 Test: 4.8177
Epoch: 006, Loss: 2.9274 tsm_loss: 6.5571 reg_loss: 2.9274 n_pos_triplets: 218777;  Val: 3.5767 Test: 3.6727
Epoch: 007, Loss: 1.5389 tsm_loss: 6.5784 reg_loss: 1.5389 n_pos_triplets: 216998;  Val: 2.0566 Test: 2.1107
Epoch: 008, Loss: 1.1628 tsm_loss: 6.4886 reg_loss: 1.1628 n_pos_triplets: 215301;  Val: 1.2883 Test: 1.3135
Epoch: 009, Loss: 1.1554 tsm_loss: 5.6820 reg_loss: 1.1554 n_pos_triplets: 213938;  Val: 1.2784 Test: 1.2995
Epoch: 010, Loss: 0



Epoch: 001, Loss: 10.8066 tsm_loss: 4.2709 reg_loss: 6.5357 n_pos_triplets: 223803;  Val: 6.4019 Test: 6.4210
Epoch: 002, Loss: 9.7426 tsm_loss: 3.6298 reg_loss: 6.1128 n_pos_triplets: 225877;  Val: 6.2101 Test: 6.2217
Epoch: 003, Loss: 8.7215 tsm_loss: 3.1876 reg_loss: 5.5339 n_pos_triplets: 223298;  Val: 5.7907 Test: 5.7968
Epoch: 004, Loss: 7.6398 tsm_loss: 2.8343 reg_loss: 4.8055 n_pos_triplets: 222247;  Val: 5.1536 Test: 5.1508
Epoch: 005, Loss: 6.4926 tsm_loss: 2.6229 reg_loss: 3.8697 n_pos_triplets: 220560;  Val: 4.2222 Test: 4.2074
Epoch: 006, Loss: 5.0135 tsm_loss: 2.3745 reg_loss: 2.6390 n_pos_triplets: 222937;  Val: 2.9373 Test: 2.9106
Epoch: 007, Loss: 3.4269 tsm_loss: 2.1765 reg_loss: 1.2503 n_pos_triplets: 221807;  Val: 1.4976 Test: 1.4961
Epoch: 008, Loss: 3.0426 tsm_loss: 2.0487 reg_loss: 0.9939 n_pos_triplets: 217091;  Val: 1.0921 Test: 1.1721
Epoch: 009, Loss: 3.0066 tsm_loss: 1.9101 reg_loss: 1.0965 n_pos_triplets: 211275;  Val: 1.0606 Test: 1.1325
Epoch: 010, Loss: 



Epoch: 001, Loss: 6.5065 tsm_loss: 5.4840 reg_loss: 6.5065 n_pos_triplets: 222161;  Val: 6.2951 Test: 6.3097
Epoch: 002, Loss: 6.0028 tsm_loss: 5.4014 reg_loss: 6.0028 n_pos_triplets: 221834;  Val: 6.0981 Test: 6.1078
Epoch: 003, Loss: 5.3294 tsm_loss: 5.1607 reg_loss: 5.3294 n_pos_triplets: 220031;  Val: 5.5939 Test: 5.5994
Epoch: 004, Loss: 4.3565 tsm_loss: 4.9750 reg_loss: 4.3565 n_pos_triplets: 221590;  Val: 4.6650 Test: 4.6637
Epoch: 005, Loss: 2.9182 tsm_loss: 4.7674 reg_loss: 2.9182 n_pos_triplets: 215888;  Val: 3.0871 Test: 3.0890
Epoch: 006, Loss: 1.3260 tsm_loss: 4.5429 reg_loss: 1.3260 n_pos_triplets: 216959;  Val: 1.2890 Test: 1.3152
Epoch: 007, Loss: 1.2231 tsm_loss: 4.3013 reg_loss: 1.2231 n_pos_triplets: 218480;  Val: 1.1941 Test: 1.2060
Epoch: 008, Loss: 1.0440 tsm_loss: 4.1478 reg_loss: 1.0440 n_pos_triplets: 214550;  Val: 1.1830 Test: 1.2045
Epoch: 009, Loss: 0.8620 tsm_loss: 4.0070 reg_loss: 0.8620 n_pos_triplets: 213377;  Val: 1.2511 Test: 1.2747
Epoch: 010, Loss: 0



Epoch: 001, Loss: 11.6701 tsm_loss: 4.9172 reg_loss: 6.7529 n_pos_triplets: 218558;  Val: 6.4721 Test: 6.5652
Epoch: 002, Loss: 10.4880 tsm_loss: 4.1552 reg_loss: 6.3328 n_pos_triplets: 218653;  Val: 6.2686 Test: 6.3629
Epoch: 003, Loss: 9.3240 tsm_loss: 3.5783 reg_loss: 5.7458 n_pos_triplets: 221583;  Val: 5.8879 Test: 5.9841
Epoch: 004, Loss: 8.2046 tsm_loss: 3.1776 reg_loss: 5.0270 n_pos_triplets: 221398;  Val: 5.2672 Test: 5.3630
Epoch: 005, Loss: 6.8748 tsm_loss: 2.7900 reg_loss: 4.0848 n_pos_triplets: 221500;  Val: 4.2892 Test: 4.3842
Epoch: 006, Loss: 5.5478 tsm_loss: 2.7402 reg_loss: 2.8076 n_pos_triplets: 220008;  Val: 2.8726 Test: 2.9633
Epoch: 007, Loss: 3.8227 tsm_loss: 2.4555 reg_loss: 1.3673 n_pos_triplets: 219725;  Val: 1.3744 Test: 1.4094
Epoch: 008, Loss: 3.3270 tsm_loss: 2.2683 reg_loss: 1.0587 n_pos_triplets: 216696;  Val: 1.2710 Test: 1.1920
Epoch: 009, Loss: 3.3507 tsm_loss: 2.1537 reg_loss: 1.1969 n_pos_triplets: 217157;  Val: 1.1759 Test: 1.1067
Epoch: 010, Loss:



Epoch: 001, Loss: 6.6815 tsm_loss: 5.5220 reg_loss: 6.6815 n_pos_triplets: 219925;  Val: 6.4554 Test: 6.5543
Epoch: 002, Loss: 6.2158 tsm_loss: 5.8698 reg_loss: 6.2158 n_pos_triplets: 220136;  Val: 6.2454 Test: 6.3429
Epoch: 003, Loss: 5.5631 tsm_loss: 5.5205 reg_loss: 5.5631 n_pos_triplets: 220496;  Val: 5.7735 Test: 5.8746
Epoch: 004, Loss: 4.6123 tsm_loss: 5.4887 reg_loss: 4.6123 n_pos_triplets: 222951;  Val: 4.9380 Test: 5.0477
Epoch: 005, Loss: 3.2695 tsm_loss: 5.2852 reg_loss: 3.2695 n_pos_triplets: 220869;  Val: 3.5320 Test: 3.6541
Epoch: 006, Loss: 1.6152 tsm_loss: 5.0541 reg_loss: 1.6152 n_pos_triplets: 221174;  Val: 1.6403 Test: 1.7510
Epoch: 007, Loss: 1.2148 tsm_loss: 4.7606 reg_loss: 1.2148 n_pos_triplets: 224152;  Val: 1.2281 Test: 1.2035
Epoch: 008, Loss: 1.2566 tsm_loss: 4.6186 reg_loss: 1.2566 n_pos_triplets: 222155;  Val: 1.1494 Test: 1.1706
Epoch: 009, Loss: 0.9367 tsm_loss: 4.2474 reg_loss: 0.9367 n_pos_triplets: 219674;  Val: 1.2892 Test: 1.3491
Epoch: 010, Loss: 0



Epoch: 001, Loss: 10.9324 tsm_loss: 4.1944 reg_loss: 6.7380 n_pos_triplets: 220082;  Val: 6.3478 Test: 6.4965
Epoch: 002, Loss: 9.7365 tsm_loss: 3.5660 reg_loss: 6.1706 n_pos_triplets: 221843;  Val: 6.1573 Test: 6.3124
Epoch: 003, Loss: 8.7366 tsm_loss: 3.1732 reg_loss: 5.5634 n_pos_triplets: 221184;  Val: 5.6866 Test: 5.8524
Epoch: 004, Loss: 7.5919 tsm_loss: 2.8245 reg_loss: 4.7674 n_pos_triplets: 222795;  Val: 4.9244 Test: 5.1021
Epoch: 005, Loss: 6.3043 tsm_loss: 2.5808 reg_loss: 3.7235 n_pos_triplets: 221756;  Val: 3.8018 Test: 3.9905
Epoch: 006, Loss: 4.7883 tsm_loss: 2.4267 reg_loss: 2.3616 n_pos_triplets: 220627;  Val: 2.2991 Test: 2.4812
Epoch: 007, Loss: 3.3013 tsm_loss: 2.2309 reg_loss: 1.0705 n_pos_triplets: 220393;  Val: 1.1257 Test: 1.1617
Epoch: 008, Loss: 3.1804 tsm_loss: 2.0979 reg_loss: 1.0825 n_pos_triplets: 218077;  Val: 1.2022 Test: 1.1127
Epoch: 009, Loss: 2.9505 tsm_loss: 1.9521 reg_loss: 0.9984 n_pos_triplets: 214679;  Val: 1.0685 Test: 1.0620
Epoch: 010, Loss: 



Epoch: 001, Loss: 6.3522 tsm_loss: 5.0815 reg_loss: 6.3522 n_pos_triplets: 225148;  Val: 6.0162 Test: 6.1584
Epoch: 002, Loss: 5.6259 tsm_loss: 5.0598 reg_loss: 5.6259 n_pos_triplets: 223923;  Val: 5.6640 Test: 5.8193
Epoch: 003, Loss: 4.7087 tsm_loss: 4.9692 reg_loss: 4.7087 n_pos_triplets: 223178;  Val: 4.9938 Test: 5.1579
Epoch: 004, Loss: 3.4668 tsm_loss: 5.0601 reg_loss: 3.4668 n_pos_triplets: 225770;  Val: 3.8728 Test: 4.0478
Epoch: 005, Loss: 1.8446 tsm_loss: 4.9102 reg_loss: 1.8446 n_pos_triplets: 221437;  Val: 2.2155 Test: 2.3977
Epoch: 006, Loss: 1.1099 tsm_loss: 4.7386 reg_loss: 1.1099 n_pos_triplets: 221953;  Val: 1.1158 Test: 1.2224
Epoch: 007, Loss: 1.2914 tsm_loss: 4.5762 reg_loss: 1.2914 n_pos_triplets: 219592;  Val: 1.0809 Test: 1.1963
Epoch: 008, Loss: 0.9516 tsm_loss: 4.3891 reg_loss: 0.9516 n_pos_triplets: 218204;  Val: 1.2812 Test: 1.4313
Epoch: 009, Loss: 0.8869 tsm_loss: 4.2615 reg_loss: 0.8869 n_pos_triplets: 216726;  Val: 1.1808 Test: 1.3186
Epoch: 010, Loss: 0



Epoch: 001, Loss: 12.0440 tsm_loss: 5.1990 reg_loss: 6.8451 n_pos_triplets: 224361;  Val: 6.9290 Test: 6.8613
Epoch: 002, Loss: 10.9140 tsm_loss: 4.3814 reg_loss: 6.5326 n_pos_triplets: 223768;  Val: 6.5912 Test: 6.5316
Epoch: 003, Loss: 9.9126 tsm_loss: 3.7303 reg_loss: 6.1823 n_pos_triplets: 226252;  Val: 6.3332 Test: 6.2776
Epoch: 004, Loss: 9.0127 tsm_loss: 3.2965 reg_loss: 5.7162 n_pos_triplets: 224650;  Val: 5.8932 Test: 5.8445
Epoch: 005, Loss: 8.0118 tsm_loss: 2.9745 reg_loss: 5.0373 n_pos_triplets: 225555;  Val: 5.1931 Test: 5.1533
Epoch: 006, Loss: 6.7681 tsm_loss: 2.6763 reg_loss: 4.0918 n_pos_triplets: 224865;  Val: 4.1796 Test: 4.1542
Epoch: 007, Loss: 5.3815 tsm_loss: 2.5466 reg_loss: 2.8349 n_pos_triplets: 226557;  Val: 2.8022 Test: 2.7998
Epoch: 008, Loss: 3.7630 tsm_loss: 2.3771 reg_loss: 1.3859 n_pos_triplets: 225628;  Val: 1.3109 Test: 1.3654
Epoch: 009, Loss: 3.4141 tsm_loss: 2.3038 reg_loss: 1.1104 n_pos_triplets: 225648;  Val: 1.1827 Test: 1.2610
Epoch: 010, Loss:



Epoch: 001, Loss: 6.6460 tsm_loss: 5.4120 reg_loss: 6.6460 n_pos_triplets: 223504;  Val: 6.4885 Test: 6.4320
Epoch: 002, Loss: 6.1995 tsm_loss: 5.3944 reg_loss: 6.1995 n_pos_triplets: 225327;  Val: 6.3359 Test: 6.2753
Epoch: 003, Loss: 5.6347 tsm_loss: 5.2997 reg_loss: 5.6347 n_pos_triplets: 223817;  Val: 5.9829 Test: 5.9208
Epoch: 004, Loss: 4.8292 tsm_loss: 5.2004 reg_loss: 4.8292 n_pos_triplets: 222233;  Val: 5.3098 Test: 5.2515
Epoch: 005, Loss: 3.6783 tsm_loss: 5.0349 reg_loss: 3.6783 n_pos_triplets: 223418;  Val: 4.1823 Test: 4.1353
Epoch: 006, Loss: 2.0896 tsm_loss: 4.9521 reg_loss: 2.0896 n_pos_triplets: 224608;  Val: 2.4669 Test: 2.4590
Epoch: 007, Loss: 1.1187 tsm_loss: 4.7121 reg_loss: 1.1187 n_pos_triplets: 222760;  Val: 1.1390 Test: 1.2426
Epoch: 008, Loss: 1.3256 tsm_loss: 4.4598 reg_loss: 1.3256 n_pos_triplets: 222821;  Val: 1.0695 Test: 1.1819
Epoch: 009, Loss: 0.9649 tsm_loss: 4.2240 reg_loss: 0.9649 n_pos_triplets: 222322;  Val: 1.3669 Test: 1.4242
Epoch: 010, Loss: 0



Epoch: 001, Loss: 11.7051 tsm_loss: 5.1490 reg_loss: 6.5561 n_pos_triplets: 225006;  Val: 6.5488 Test: 6.4676
Epoch: 002, Loss: 10.3671 tsm_loss: 4.2196 reg_loss: 6.1475 n_pos_triplets: 225756;  Val: 6.3210 Test: 6.2432
Epoch: 003, Loss: 9.3097 tsm_loss: 3.6887 reg_loss: 5.6209 n_pos_triplets: 226299;  Val: 5.9079 Test: 5.8295
Epoch: 004, Loss: 8.3424 tsm_loss: 3.4807 reg_loss: 4.8617 n_pos_triplets: 227642;  Val: 5.1737 Test: 5.0991
Epoch: 005, Loss: 6.7464 tsm_loss: 2.9899 reg_loss: 3.7566 n_pos_triplets: 226288;  Val: 4.0416 Test: 3.9677
Epoch: 006, Loss: 5.1546 tsm_loss: 2.8368 reg_loss: 2.3178 n_pos_triplets: 225422;  Val: 2.4837 Test: 2.4054
Epoch: 007, Loss: 3.7845 tsm_loss: 2.6702 reg_loss: 1.1144 n_pos_triplets: 225769;  Val: 1.2153 Test: 1.1502
Epoch: 008, Loss: 3.6699 tsm_loss: 2.4430 reg_loss: 1.2269 n_pos_triplets: 223065;  Val: 1.1555 Test: 1.1161
Epoch: 009, Loss: 3.3084 tsm_loss: 2.2486 reg_loss: 1.0598 n_pos_triplets: 223345;  Val: 1.1684 Test: 1.1147
Epoch: 010, Loss:



Epoch: 001, Loss: 6.5766 tsm_loss: 5.3869 reg_loss: 6.5766 n_pos_triplets: 222367;  Val: 6.5144 Test: 6.4366
Epoch: 002, Loss: 6.1319 tsm_loss: 5.2395 reg_loss: 6.1319 n_pos_triplets: 222609;  Val: 6.3875 Test: 6.3082
Epoch: 003, Loss: 5.5001 tsm_loss: 5.0614 reg_loss: 5.5001 n_pos_triplets: 219329;  Val: 5.9100 Test: 5.8301
Epoch: 004, Loss: 4.5696 tsm_loss: 4.9842 reg_loss: 4.5696 n_pos_triplets: 220294;  Val: 5.0461 Test: 4.9642
Epoch: 005, Loss: 3.2513 tsm_loss: 4.9388 reg_loss: 3.2513 n_pos_triplets: 217367;  Val: 3.6554 Test: 3.5741
Epoch: 006, Loss: 1.6029 tsm_loss: 4.8084 reg_loss: 1.6029 n_pos_triplets: 218035;  Val: 1.7798 Test: 1.6963
Epoch: 007, Loss: 1.1767 tsm_loss: 4.7411 reg_loss: 1.1767 n_pos_triplets: 214382;  Val: 1.1834 Test: 1.1436
Epoch: 008, Loss: 1.2421 tsm_loss: 4.6569 reg_loss: 1.2421 n_pos_triplets: 213998;  Val: 1.1389 Test: 1.1023
Epoch: 009, Loss: 0.9214 tsm_loss: 4.4734 reg_loss: 0.9214 n_pos_triplets: 212846;  Val: 1.3619 Test: 1.3289
Epoch: 010, Loss: 0



Epoch: 001, Loss: 12.8495 tsm_loss: 6.3775 reg_loss: 6.4721 n_pos_triplets: 224026;  Val: 6.2729 Test: 6.2118
Epoch: 002, Loss: 11.5675 tsm_loss: 5.5044 reg_loss: 6.0630 n_pos_triplets: 224220;  Val: 6.0586 Test: 6.0058
Epoch: 003, Loss: 10.1684 tsm_loss: 4.6325 reg_loss: 5.5359 n_pos_triplets: 221222;  Val: 5.6473 Test: 5.6058
Epoch: 004, Loss: 8.8065 tsm_loss: 3.9978 reg_loss: 4.8087 n_pos_triplets: 221443;  Val: 4.9658 Test: 4.9355
Epoch: 005, Loss: 7.4154 tsm_loss: 3.5926 reg_loss: 3.8229 n_pos_triplets: 218217;  Val: 3.9417 Test: 3.9289
Epoch: 006, Loss: 5.6833 tsm_loss: 3.1567 reg_loss: 2.5266 n_pos_triplets: 216883;  Val: 2.5626 Test: 2.5741
Epoch: 007, Loss: 4.2184 tsm_loss: 2.9771 reg_loss: 1.2413 n_pos_triplets: 215826;  Val: 1.2826 Test: 1.3291
Epoch: 008, Loss: 3.8739 tsm_loss: 2.7278 reg_loss: 1.1461 n_pos_triplets: 215707;  Val: 1.1870 Test: 1.2217
Epoch: 009, Loss: 3.6153 tsm_loss: 2.4863 reg_loss: 1.1290 n_pos_triplets: 212766;  Val: 1.1435 Test: 1.1659
Epoch: 010, Loss

In [None]:
df1 = pd.concat(res1)
df2 = pd.concat(res2)

In [None]:
df1.to_csv('./with_aca.csv')
df2.to_csv('./without_aca.csv')

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))
colors = ['#FFE699','#00B0F0']


y = 'val_rmse'

n1 = r'With AC-awareness ($\mathcal{L}_{mae} + \mathcal{L}_{tsm}$)'
n2 = r'Without AC-awareness ($\mathcal{L}_{mae}$)'


dfp = df2.groupby('Epoch').mean()[y].to_frame(name = n2).join(df1.groupby('Epoch').mean()[y].to_frame(name = n1)).rolling(1).mean()
dfp_std = df2.groupby('Epoch').std()[y].to_frame(name = n2).join(df1.groupby('Epoch').std()[y].to_frame(name = n1)).rolling(1).mean()

dfp.plot(lw = 2, ax=ax,color = colors, alpha =1)
ax.fill_between(dfp.index, (dfp - dfp_std)[n1], (dfp + dfp_std)[n1], color=colors[1], alpha=0.2)
ax.fill_between(dfp.index, (dfp - dfp_std)[n2], (dfp + dfp_std)[n2], color=colors[0], alpha=0.2)

ax.set_ylim(0.60, 1.0)
ax.set_ylabel('Validation RMSE')
ax.set_xlabel('epochs')
ax.spines[['right', 'top']].set_visible(False)

ax.set_xlim(1,800)

ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./Validation_RMSE.svg', bbox_inches='tight', dpi=400) 
fig.savefig('./Validation_RMSE.pdf', bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))
colors = ['#FFE699','#00B0F0']


y = 'test_rmse'



dfp = df2.groupby('Epoch').mean()[y].to_frame(name = n2).join(df1.groupby('Epoch').mean()[y].to_frame(name = n1)).rolling(1).mean()
dfp_std = df2.groupby('Epoch').std()[y].to_frame(name = n2).join(df1.groupby('Epoch').std()[y].to_frame(name = n1)).rolling(1).mean()

dfp.plot(lw = 2, ax=ax,color = colors, alpha =1)
ax.fill_between(dfp.index, (dfp - dfp_std)[n1], (dfp + dfp_std)[n1], color=colors[1], alpha=0.2)
ax.fill_between(dfp.index, (dfp - dfp_std)[n2], (dfp + dfp_std)[n2], color=colors[0], alpha=0.2)

ax.set_ylim(0.60, 1.0)
ax.set_ylabel('Test RMSE')
ax.set_xlabel('epochs')
ax.spines[['right', 'top']].set_visible(False)

ax.set_xlim(1,800)

ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./Test_RMSE.svg' , bbox_inches='tight', dpi=400) 
fig.savefig('./Test_RMSE.pdf' , bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))


y = 'n_pos_triplets'

dfp = df2.groupby('Epoch').mean()[y].to_frame(name = n2).join(df1.groupby('Epoch').mean()[y].to_frame(name = n1)).rolling(1).mean()
dfp_std = df2.groupby('Epoch').std()[y].to_frame(name = n2).join(df1.groupby('Epoch').std()[y].to_frame(name = n1)).rolling(1).mean()

dfp.plot(lw = 3, ax=ax,color = colors, alpha =1)
ax.fill_between(dfp.index, (dfp - dfp_std)[n1], (dfp + dfp_std)[n1], color=colors[1], alpha=0.3)
ax.fill_between(dfp.index, (dfp - dfp_std)[n2], (dfp + dfp_std)[n2], color=colors[0], alpha=0.3)

ax.legend(loc='center', bbox_to_anchor=(0.55, 0.5))

ax.spines[['right', 'top']].set_visible(False)
plt.ticklabel_format(axis='y', style='sci', scilimits=(0,0))
ax.set_ylabel("No. of HV-ACTs ($M^'$)")
ax.set_xlabel('epochs')
ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
ax.set_xlim(-5,800)


fig.savefig('./Number_of_mined_ACTs_during_training.svg' , bbox_inches='tight', dpi=400) 
fig.savefig('./Number_of_mined_ACTs_during_training.pdf' , bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))
y = 'train_triplet_loss'
dfp = df2.groupby('Epoch').mean()[y].to_frame(name = n2).join(df1.groupby('Epoch').mean()[y].to_frame(name = n1)).rolling(1).mean()
dfp_std = df2.groupby('Epoch').std()[y].to_frame(name = n2).join(df1.groupby('Epoch').std()[y].to_frame(name = n1)).rolling(1).mean()

dfp.plot(lw = 3, ax=ax,color = colors, alpha =1)
ax.fill_between(dfp.index, (dfp - dfp_std)[n1], (dfp + dfp_std)[n1], color=colors[1], alpha=0.3)
ax.fill_between(dfp.index, (dfp - dfp_std)[n2], (dfp + dfp_std)[n2], color=colors[0], alpha=0.3)

ax.spines[['right', 'top']].set_visible(False)
ax.set_xlim(-5,800)
ax.set_ylim(-1,10)

ax.set_ylabel('Training TSM Loss')
ax.set_xlabel('epochs')
ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./Triplet_loss_during_training.svg', bbox_inches='tight', dpi=400) 
fig.savefig('./Triplet_loss_during_training.pdf', bbox_inches='tight', dpi=400) 

In [None]:
fig, ax = plt.subplots(figsize=(9, 6))

y = 'train_mae_loss'

dfp = df2.groupby('Epoch').mean()[y].to_frame(name = n2).join(df1.groupby('Epoch').mean()[y].to_frame(name = n1)).rolling(1).mean()
dfp_std = df2.groupby('Epoch').std()[y].to_frame(name = n2).join(df1.groupby('Epoch').std()[y].to_frame(name = n1)).rolling(1).mean()

dfp.plot(lw = 2.5, ax=ax,color = colors, alpha =1)
ax.fill_between(dfp.index, (dfp - dfp_std)[n1], (dfp + dfp_std)[n1], color=colors[1], alpha=0.3)
ax.fill_between(dfp.index, (dfp - dfp_std)[n2], (dfp + dfp_std)[n2], color=colors[0], alpha=0.3)



ax.set_ylim(0.0, 0.8)
ax.spines[['right', 'top']].set_visible(False)


ax.set_ylabel('Training MAE loss')
ax.set_xlabel('epochs')
ax.legend(loc='center', bbox_to_anchor=(0.55, 0.5))

ax.set_xlim(1,800)

ax.tick_params(left='off', labelleft='on', labelbottom='on', bottom = 'off',  pad=.5,)
fig.savefig('./Train_mae_los.svg', bbox_inches='tight', dpi=400) 
fig.savefig('./Train_mae_los.pdf', bbox_inches='tight', dpi=400) 