In [1]:
import pandas as pd
from transformers import BertModel, AutoTokenizer, RobertaModel, RobertaTokenizerFast
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, RobustScaler, normalize
import torch, os, random, copy
import numpy as np
import gc
from torch.nn.utils import clip_grad_norm_
from ogb.graph_aug import mask_nodes, mask_edges, permute_edges, drop_nodes, subgraph
from torch import nn
from torch.nn import functional as F
# from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.metrics import mean_squared_error
from matplotlib import pyplot as plt
from torch_ema import ExponentialMovingAverage
from ogb.utils import smiles2graph
from models.dualgraph.mol import smiles2graphwithface, simles2graphwithface_with_mask
from models.dualgraph.gnn import one_hot_atoms, one_hot_bonds, GNN2
from rdkit.Chem import AllChem
from rdkit import Chem
from torch_geometric.data import Dataset, InMemoryDataset
from models.dualgraph.dataset import DGData


from torch_geometric.loader import DataLoader


In [None]:
def ring(data, node_mask, edge_mask):
    ring_index = data.ring_index[:, edge_mask]
    data.ring_index = ring_index

In [76]:
import os
from itertools import repeat
import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, to_networkx

class MoleculeDataset_graphcl(InMemoryDataset):
    def __init__(self,
                 root='dataset_path',
                 transform=None,
                 pre_transform=None, 
                 df=None):
        self.df = df
        self.aug_prob = None
        self.aug_mode = 'sample'
        self.aug_strength = 0.2
        self.augmentations = [self.node_drop, self.subgraph,
                              self.edge_pert, self.attr_mask, lambda x: x]
        super().__init__(root, transform, pre_transform, df)
        
    @property
    def raw_file_names(self):
        return [f'raw_{i+1}.pt' for i in range(self.df.shape[0])]

    @property
    def processed_file_names(self):
        return [f'data_{i+1}.pt' for i in range(self.df.shape[0])]
    
    def set_augMode(self, aug_mode):
        self.aug_mode = aug_mode

    def set_augStrength(self, aug_strength):
        self.aug_strength = aug_strength

    def set_augProb(self, aug_prob):
        self.aug_prob = aug_prob

    def node_drop(self, data):

        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        drop_num = int(node_num * self.aug_strength)

        idx_perm = np.random.permutation(node_num)
        idx_nodrop = idx_perm[drop_num:].tolist()
        idx_nodrop.sort()

        edge_idx, edge_attr, edge_mask = subgraph(subset=idx_nodrop,
                                       edge_index=data.edge_index,
                                       edge_attr=data.edge_attr,
                                       relabel_nodes=True,
                                       num_nodes=node_num,
                                       return_edge_mask=True)                        

        data.edge_index = edge_idx
        data.edge_attr = edge_attr
        data.x = data.x[idx_nodrop]
        data.__num_nodes__, _ = data.x.shape
        return data

    def edge_pert(self, data):        
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        pert_num = int(edge_num * self.aug_strength)

        # delete edges
        idx_drop = np.random.choice(edge_num, (edge_num - pert_num),
                                    replace=False)
        edge_index = data.edge_index[:, idx_drop]
        edge_attr = data.edge_attr[idx_drop]

        # add edges
        adj = torch.ones((node_num, node_num))
        adj[edge_index[0], edge_index[1]] = 0
        # edge_index_nonexist = adj.nonzero(as_tuple=False).t()
        edge_index_nonexist = torch.nonzero(adj, as_tuple=False).t()
        idx_add = np.random.choice(edge_index_nonexist.shape[1],
                                    pert_num, replace=False)
        edge_index_add = edge_index_nonexist[:, idx_add]
        # random 4-class & 3-class edge_attr for 1st & 2nd dimension
        edge_attr_add_1 = torch.tensor(np.random.randint(
            4, size=(edge_index_add.shape[1], 1)))
        edge_attr_add_2 = torch.tensor(np.random.randint(
            3, size=(edge_index_add.shape[1], 1)))
        edge_attr_add_3 = torch.tensor(np.random.randint(
            2, size=(edge_index_add.shape[1], 1)))
        edge_attr_add = torch.cat((edge_attr_add_1, edge_attr_add_2, edge_attr_add_3), dim=1)
        edge_index = torch.cat((edge_index, edge_index_add), dim=1)
        

        edge_attr = torch.cat((edge_attr, edge_attr_add), dim=0)

        data.edge_index = edge_index
        data.edge_attr = edge_attr
        return data

    def attr_mask(self, data):

        _x = data.x.clone()
        node_num, _ = data.x.size()
        mask_num = int(node_num * self.aug_strength)

        token = data.x.float().mean(dim=0).long()
        idx_mask = np.random.choice(
            node_num, mask_num, replace=False)

        _x[idx_mask] = token
        data.x = _x
        return data

    def subgraph(self, data):

        G = to_networkx(data)
        node_num, _ = data.x.size()
        _, edge_num = data.edge_index.size()
        sub_num = int(node_num * (1 - self.aug_strength))

        idx_sub = [np.random.randint(node_num, size=1)[0]]
        idx_neigh = set([n for n in G.neighbors(idx_sub[-1])])

        while len(idx_sub) <= sub_num:
            if len(idx_neigh) == 0:
                idx_unsub = list(set([n for n in range(node_num)]).difference(set(idx_sub)))
                idx_neigh = set([np.random.choice(idx_unsub)])
            sample_node = np.random.choice(list(idx_neigh))

            idx_sub.append(sample_node)
            idx_neigh = idx_neigh.union(
                set([n for n in G.neighbors(idx_sub[-1])])).difference(set(idx_sub))

        idx_nondrop = idx_sub
        idx_nondrop.sort()

        edge_idx, edge_attr = subgraph(subset=idx_nondrop,
                                       edge_index=data.edge_index,
                                       edge_attr=data.edge_attr,
                                       relabel_nodes=True,
                                       num_nodes=node_num)
        
        data.edge_index = edge_idx
        data.edge_attr = edge_attr
        data.x = data.x[idx_nondrop]
        data.__num_nodes__, _ = data.x.shape
        return data

    def get(self, idx):
        sid = self.sid_list[idx]
        dset = self.dset_list[idx]
        data = torch.load(f'/home/pjh/workspace/SOM/data/pretrain_graph/{dset}/{sid}.pt')
        data1 = copy.deepcopy(data)
        data2 = copy.deepcopy(data)

        if self.aug_mode == 'no_aug':
            n_aug1, n_aug2 = 4, 4
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        elif self.aug_mode == 'uniform':
            n_aug = np.random.choice(25, 1)[0]
            n_aug1, n_aug2 = n_aug // 5, n_aug % 5
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        elif self.aug_mode == 'sample':
            n_aug = np.random.choice(25, 1, p=self.aug_prob)[0]
            n_aug1, n_aug2 = n_aug // 5, n_aug % 5
            data1 = self.augmentations[n_aug1](data1)
            data2 = self.augmentations[n_aug2](data2)
        else:
            raise ValueError
        
        return data, data1, data2
    
    def process(self):
        self.sid_list = []
        self.dset_list = []
        
        for i in range(self.df.shape[0]):
            smile = self.df.loc[i, 'smiles']
            dset_name = self.df.loc[i, 'dataset']
            sid = self.df.loc[i, 'sid']
            self.dset_list.append(dset_name)
            self.sid_list.append(sid)

SyntaxError: invalid syntax (2705491510.py, line 55)

In [40]:
# def mol2graph(mol):
#     data = DGData()
#     graph = smiles2graphwithface(mol)

#     data.__num_nodes__ = int(graph["num_nodes"])
#     data.edge_index = torch.from_numpy(graph["edge_index"]).to(torch.int64)
#     data.edge_attr = torch.from_numpy(graph["edge_feat"]).to(torch.int64)
#     data.x = torch.from_numpy(graph["node_feat"]).to(torch.int64)    

#     data.ring_mask = torch.from_numpy(graph["ring_mask"]).to(torch.bool)
#     data.ring_index = torch.from_numpy(graph["ring_index"]).to(torch.int64)
#     data.nf_node = torch.from_numpy(graph["nf_node"]).to(torch.int64)
#     data.nf_ring = torch.from_numpy(graph["nf_ring"]).to(torch.int64)
#     data.num_rings = int(graph["num_rings"])
#     data.n_edges = int(graph["n_edges"])
#     data.n_nodes = int(graph["n_nodes"])
#     data.n_nfs = int(graph["n_nfs"])

#     return data

In [41]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)

    torch.cuda.manual_seed(seed)  # type: ignore
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = False  # type: ignore
seed_everything(2023)


In [56]:
data = pd.read_csv('/home/pjh/workspace/MED/sub_code/OGNN/pretrain/zinc_combined_apr_8_2019.csv', index_col=None)
data = data[['zinc_id', 'smiles']]
data.columns = ['sid', 'smiles']
data['dataset'] = 'zinc_combined_apr_8_2019'
for dset_name in tqdm(os.listdir( '/home/pjh/workspace/MED/sub_code/OGNN/pretrain')):
    if 'zinc_combined_apr_8_2019' in dset_name:
        continue
    df = pd.read_csv(f'/home/pjh/workspace/MED/sub_code/OGNN/pretrain/{dset_name}', index_col=None)            
    
    df['dataset'] = dset_name.replace('.csv', '')
    df['sid'] = dset_name.replace('.csv', '') + '_' +  df.reset_index()['index'].astype(str)
    if 'bace' in dset_name:
        df = df[['sid', 'dataset', 'mol']]
        df.columns = ['sid', 'dataset', 'smiles']
    else:
        df = df[['sid', 'dataset', 'smiles']]
    data = pd.concat([data, df]).reset_index(drop=True)

# data = data.sample(20000).reset_index(drop=True)

100%|██████████| 10/10 [00:04<00:00,  2.27it/s]


In [None]:
def drop_node_edge(batch_org, p, cyp_list):    
    batch = deepcopy(batch_org)
    prob = torch.rand(batch.num_nodes, device=batch.edge_index.device)
    node_mask = prob > p    
    node_mask[batch.y_atom['CYP_REACTION'].bool()] = True

    edge_index, _, edge_mask = subgraph(node_mask, batch.edge_index,
                                        num_nodes=batch.num_nodes,
                                        return_edge_mask=True)
    
    node_mask_idx = torch.zeros(node_mask.shape[0]).long()
    node_mask_idx[node_mask] = torch.arange(node_mask.sum().item())

    edge_mask_idx = torch.zeros(edge_mask.shape[0]).long()
    edge_mask_idx[edge_mask] = torch.arange(edge_mask.sum().item())
    
    batch.edge_index = node_mask_idx[edge_index]

    ring_index = batch.ring_index[:, edge_mask]
    # ring_index = edge_mask_idx[ring_index]

    batch.ring_index = ring_index
    batch.edge_attr = batch.edge_attr[edge_mask]
    
    batch.n_edges = edge_mask.sum().item()
    batch.n_nodes = node_mask.sum().item()

    nf_node_mask = ~torch.isin(batch.nf_node, torch.where(~node_mask)[0])[0]

    for ridx in range(batch.num_rings):        
        ring_node_mask = nf_node_mask[batch.nf_ring[0] == ridx+1]
        if not ring_node_mask.all():
            batch.ring_mask[ridx] = False

    batch.nf_node = batch.nf_node[:, nf_node_mask]
    batch.nf_node = node_mask_idx[batch.nf_node]
    batch.nf_ring = batch.nf_ring[:, nf_node_mask]

    batch.n_nfs = batch.nf_node.size(1)
    
    edge_mask = edge_mask.view(edge_mask.shape[0] // 2, 2)
    edge_mask = edge_mask[:, 0]

    batch.x = batch.x[node_mask]
    batch.spn_atom = batch.spn_atom[node_mask]
    batch.has_H_atom = batch.has_H_atom[node_mask]
    batch.not_has_H_bond = batch.not_has_H_bond[edge_mask]
        
    for cyp in cyp_list:
        batch.y_spn[cyp] = batch.y_spn[cyp][node_mask]
        batch.y_atom[cyp] = batch.y_atom[cyp][node_mask]
        batch.y_hydroxylation[cyp] = batch.y_hydroxylation[cyp][node_mask]
        batch.y_nh_oxidation[cyp] = batch.y_nh_oxidation[cyp][node_mask]
        
        batch.y[cyp] = batch.y[cyp][edge_mask]
        batch.y_cleavage[cyp] = batch.y_cleavage[cyp][edge_mask]
        batch.y_nn_oxidation[cyp] = batch.y_nn_oxidation[cyp][edge_mask]

    return batch

In [57]:
train_dataset = MoleculeDataset_graphcl(df = data)

Processing...
Done!


In [58]:
train_dataset.set_augMode('sample')
train_dataset.set_augProb(np.ones(25) / 25)
train_dataset.set_augStrength(0.2)

In [59]:
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers = 8)

In [60]:
for batch, batch1, batch2 in train_loader:
    break

In [73]:
batch1.ring_mask
batch1.nf_node
batch1.nf_ring

tensor([ True, False, False, False, False], device='cuda:0')

In [69]:
with torch.no_grad():
    out = model(batch1.to(device))

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 46 but got size 66 for tensor number 3 in the list.

In [70]:
out.shape

NameError: name 'out' is not defined

In [71]:
batch1.edge_attr.shape

torch.Size([46, 3])

In [29]:
# # for dset_name in data['dataset'].unique():
# #     os.mkdir(f'/home/pjh/workspace/SOM/data/pretrain_graph/{dset_name}')
# for smile, dset_name, sid in tqdm(data[['smiles', 'dataset', 'sid']].values):
#     mol = Chem.MolFromSmiles(smile)
#     # mol = AllChem.AddHs( mol, addCoords=True)
#     graph = mol2graph(mol)
#     torch.save(graph,f'/home/pjh/workspace/SOM/data/pretrain_graph/{dset_name}/{sid}.pt')
    

In [65]:
triplet_loss = nn.TripletMarginLoss(margin=0.0, p=2)
criterion = nn.CrossEntropyLoss()

In [66]:
class GraphContrastiveLearning(torch.nn.Module):

    def __init__(self):
        super(GraphContrastiveLearning, self).__init__()
        self.ddi = True
        self.gnn = GNN2(
                        mlp_hidden_size = 512,
                        mlp_layers = 2,
                        num_message_passing_steps=8,
                        latent_size = 128,
                        use_layer_norm = True,
                        use_bn=False,
                        use_face=True,
                        som_mode=False,
                        ddi=True,
                        dropedge_rate = 0.1,
                        dropnode_rate = 0.1,
                        dropout = 0.1,
                        dropnet = 0.1,
                        global_reducer = 'sum',
                        node_reducer = 'sum',
                        face_reducer = 'sum',
                        graph_pooling = 'sum',
                        use_mamba=False,
                        node_attn = True,
                        face_attn = True,
                        encoder_dropout=0.0,
                        use_pe=False
                        )
                        
        self.proj = nn.Sequential(
                        nn.Linear(128, 128),
                        nn.ReLU(),
                        nn.Linear(128, 128),
                        )
    def forward(self, batch):
        mol = self.gnn(batch).squeeze(1)
        return self.proj(mol)

In [67]:
device = 'cuda:0'
epochs = 100
lr = 1e-5
weight_decay = 5e-5
batch_size= 256

In [68]:
train_dataset = MoleculeDataset_graphcl(df = data)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers = 8)

model = GraphContrastiveLearning().to(device)

optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay = weight_decay)
ema = ExponentialMovingAverage(model.parameters(), decay=0.999)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=epochs, verbose=False)

In [12]:
def loss_cl(x1, x2):
    T = 0.1
    batch_size, _ = x1.size()
    x1_abs = x1.norm(dim=1)
    x2_abs = x2.norm(dim=1)

    sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs)
    sim_matrix = torch.exp(sim_matrix / T)
    pos_sim = sim_matrix[range(batch_size), range(batch_size)]
    loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
    loss = - torch.log(loss).mean()
    return loss

In [13]:
state_dict = torch.load(f'/home/pjh/workspace/SOM/graph_mamba/ckpt_pretrain/gnn_pretrain_epoch36.pt')

model.load_state_dict(state_dict['model_state_dict'])
scheduler.load_state_dict(state_dict['scheduler_state_dict'])
ema.load_state_dict(state_dict['ema_state_dict'] )
optim.load_state_dict(state_dict['optimizer_state_dict'])

In [14]:
best_val_loss = 1e6
start = 1
for epoch in range(37, epochs+1):

    model.train()
    train_loss = 0
    for batch in tqdm(train_loader):
        batch = [bat.to(device) for bat in batch]
        outputs = [model(bat) for  bat in batch]        
        origin_output = outputs[0]
        
        mask_cl_loss = loss_cl(outputs[1], outputs[2])                    
        mask_t_loss = triplet_loss(outputs[0], outputs[1], outputs[2])

        loss = mask_cl_loss + (mask_t_loss * 0.1)

        optim.zero_grad()
        loss.backward() 
        optim.step()
        ema.update()
        
        train_loss += loss.cpu().item()
        
        gc.collect()
        torch.cuda.empty_cache()

    if train_loss < best_val_loss:
        best_val_loss = train_loss
        torch.save(model.gnn.state_dict(), f'ckpt_pretrain/gnn_pretrain.pt')
        

    scheduler.step()
    torch.save(
            {
            'optimizer_state_dict': optim.state_dict(),
            'model_state_dict': model.state_dict(),
            'gnn_state_dict' : model.gnn.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'ema_state_dict' : ema.state_dict()
            },
            f'ckpt_pretrain/gnn_pretrain_epoch{epoch}.pt')

    print(f'EPOCH : {epoch} | train_loss : {train_loss/len(train_loader):.4f}')



100%|██████████| 10113/10113 [1:57:26<00:00,  1.44it/s] 


EPOCH : 30 | train_loss : -3.9199


100%|██████████| 10113/10113 [1:56:08<00:00,  1.45it/s] 


EPOCH : 31 | train_loss : -3.9221


100%|██████████| 10113/10113 [1:52:47<00:00,  1.49it/s] 


EPOCH : 32 | train_loss : -3.9248


100%|██████████| 10113/10113 [1:52:54<00:00,  1.49it/s] 


EPOCH : 33 | train_loss : -3.9270


100%|██████████| 10113/10113 [1:52:57<00:00,  1.49it/s]


EPOCH : 34 | train_loss : -3.9289


100%|██████████| 10113/10113 [1:52:45<00:00,  1.49it/s]


EPOCH : 35 | train_loss : -3.9310


100%|██████████| 10113/10113 [1:52:51<00:00,  1.49it/s] 


EPOCH : 36 | train_loss : -3.9328


100%|██████████| 10113/10113 [1:52:58<00:00,  1.49it/s] 


EPOCH : 37 | train_loss : -3.9343


  0%|          | 39/10113 [00:29<1:59:15,  1.41it/s]

In [None]:
# EPOCH : 1 | train_loss : -2.3468
# EPOCH : 2 | train_loss : -3.3075
# EPOCH : 3 | train_loss : -3.5078
# EPOCH : 4 | train_loss : -3.6116
# EPOCH : 5 | train_loss : -3.6767
# EPOCH : 6 | train_loss : -3.7220
# EPOCH : 7 | train_loss : -3.7547
# EPOCH : 8 | train_loss : -3.7785
# EPOCH : 9 | train_loss : -3.7976
# EPOCH : 10 | train_loss : -3.8130
# EPOCH : 11 | train_loss : -3.8248
# EPOCH : 12 | train_loss : -3.8358
# EPOCH : 13 | train_loss : -3.8443
# EPOCH : 14 | train_loss : -3.8521
# EPOCH : 15 | train_loss : -3.8593
# EPOCH : 16 | train_loss : -3.8657
# EPOCH : 17 | train_loss : -3.8720
# EPOCH : 18 | train_loss : -3.8774
# EPOCH : 19 | train_loss : -3.8831
# EPOCH : 20 | train_loss : -3.8879
# EPOCH : 21 | train_loss : -3.8916
# EPOCH : 22 | train_loss : -3.8959
# EPOCH : 23 | train_loss : -3.8989
# EPOCH : 24 | train_loss : -3.9027
# EPOCH : 25 | train_loss : -3.9056
# EPOCH : 26 | train_loss : -3.9086
# EPOCH : 27 | train_loss : -3.9118
# EPOCH : 28 | train_loss : -3.9144
# EPOCH : 29 | train_loss : -3.9176
# EPOCH : 30 | train_loss : -3.9199
# EPOCH : 31 | train_loss : -3.9221
# EPOCH : 32 | train_loss : -3.9248
# EPOCH : 33 | train_loss : -3.9270
# EPOCH : 34 | train_loss : -3.9289
# EPOCH : 35 | train_loss : -3.9310
# EPOCH : 36 | train_loss : -3.9328
# EPOCH : 37 | train_loss : -3.9343

In [None]:
# add H
# 100%|██████████| 10113/10113 [2:39:32<00:00,  1.06it/s] 
# EPOCH : 1 | train_loss : -2.2254
# 100%|██████████| 10113/10113 [2:40:06<00:00,  1.05it/s] 
# EPOCH : 2 | train_loss : -3.2680
# 100%|██████████| 10113/10113 [2:40:23<00:00,  1.05it/s] 
# EPOCH : 3 | train_loss : -3.4840
# 100%|██████████| 10113/10113 [2:40:28<00:00,  1.05it/s] 
# EPOCH : 4 | train_loss : -3.5963
# 100%|██████████| 10113/10113 [2:40:39<00:00,  1.05it/s] 
# EPOCH : 5 | train_loss : -3.6643
# 100%|██████████| 10113/10113 [2:41:22<00:00,  1.04it/s] 
# EPOCH : 5 | train_loss : -3.6644
# 100%|██████████| 10113/10113 [2:43:59<00:00,  1.03it/s] 
# EPOCH : 6 | train_loss : -3.7121
# 100%|██████████| 10113/10113 [2:43:19<00:00,  1.03it/s] 
# EPOCH : 7 | train_loss : -3.7456
# 100%|██████████| 10113/10113 [2:43:16<00:00,  1.03it/s] 
# EPOCH : 8 | train_loss : -3.7727
# 100%|██████████| 10113/10113 [2:42:57<00:00,  1.03it/s] 
# EPOCH : 9 | train_loss : -3.7932
# 100%|██████████| 10113/10113 [2:41:36<00:00,  1.04it/s] 
# EPOCH : 10 | train_loss : -3.8113
# 100%|██████████| 10113/10113 [2:42:50<00:00,  1.04it/s] 
# EPOCH : 11 | train_loss : -3.8263
# 100%|██████████| 10113/10113 [2:42:50<00:00,  1.04it/s] 
# EPOCH : 12 | train_loss : -3.8376
# 100%|██████████| 10113/10113 [2:43:01<00:00,  1.03it/s] 
# EPOCH : 13 | train_loss : -3.8482
# 100%|██████████| 10113/10113 [2:43:19<00:00,  1.03it/s] 
# EPOCH : 14 | train_loss : -3.8564
# 100%|██████████| 10113/10113 [2:42:09<00:00,  1.04it/s] 
# EPOCH : 15 | train_loss : -3.8636
# 100%|██████████| 10113/10113 [2:42:25<00:00,  1.04it/s] 
# EPOCH : 16 | train_loss : -3.8719
# 100%|██████████| 10113/10113 [2:42:51<00:00,  1.03it/s] 
# EPOCH : 17 | train_loss : -3.8782
# 100%|██████████| 10113/10113 [2:42:59<00:00,  1.03it/s] 
# EPOCH : 18 | train_loss : -3.8842
# 100%|██████████| 10113/10113 [2:43:17<00:00,  1.03it/s] 
# EPOCH : 19 | train_loss : -3.8888
# 100%|██████████| 10113/10113 [2:43:05<00:00,  1.03it/s] 
# EPOCH : 20 | train_loss : -3.8924