In [1]:
%load_ext autoreload
%autoreload 2
# %load_ext blackcellmagic

In [2]:
# # Install graphein from master for bleeding-edge additions
# !pip install git+https://github.com/a-r-j/graphein

In [3]:
# Misc. tools
import os

# Hydra tools
import hydra

from hydra.compose import GlobalHydra
from hydra.core.hydra_config import HydraConfig

from proteinworkshop.constants import HYDRA_CONFIG_PATH
from proteinworkshop.utils.notebook import init_hydra_singleton

init_hydra_singleton(reload=True)

path = HYDRA_CONFIG_PATH
rel_path = os.path.relpath(path, start=".")

GlobalHydra.instance().clear()
hydra.initialize(rel_path)

cfg = hydra.compose(
    "train",
    overrides=[
        "dataset=afdb_swissprot_v4",
        "dataset.datamodule.batch_size=32",
        "dataset.datamodule.train_split=0.02", # here
        "dataset.datamodule.val_split=0.001", # here
        "features=fe_subgraph",

        "task=subgraph_distance_prediction", # here
        ],
    return_hydra_config=False,
)

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize(rel_path)


In [4]:
from proteinworkshop.configs import config

cfg = config.validate_config(cfg)

In [5]:
print(cfg.keys())
for key in cfg.keys():
    print(key)
    print(cfg[key])

dict_keys(['env', 'dataset', 'features', 'encoder', 'decoder', 'transforms', 'callbacks', 'optimiser', 'scheduler', 'trainer', 'extras', 'metrics', 'task', 'logger', 'name', 'seed', 'num_workers', 'task_name', 'test'])
env
{'paths': {'root_dir': '${oc.env:ROOT_DIR}', 'data': '${oc.env:DATA_PATH}', 'output_dir': '${hydra:runtime.output_dir}', 'work_dir': '${hydra:runtime.cwd}', 'log_dir': '${oc.env:RUNS_PATH}', 'runs': '${oc.env:RUNS_PATH}', 'run_dir': '${env.paths.runs}/${name}/${env.init_time}'}, 'python': {'version': '${python_version:micro}'}, 'init_time': '${now:%y-%m-%d_%H:%M:%S}'}
dataset
{'datamodule': {'_target_': 'graphein.ml.datasets.foldcomp_dataset.FoldCompLightningDataModule', 'data_dir': '${env.paths.data}/afdb_swissprot_v4/', 'database': 'afdb_swissprot_v4', 'batch_size': 32, 'num_workers': 32, 'train_split': 0.02, 'val_split': 0.001, 'test_split': 0.1, 'pin_memory': True, 'use_graphein': True, 'transform': '${transforms}'}, 'dataset_name': 'afdb_swissprot_v4', 'num_clas

### Load a dataset

Can switch out for another by replacing the dataset arg in overrides:

`cfg = hydra.compose("template", overrides=["dataset=afdb_swissprot_v4"], return_hydra_config=False)`

In [6]:
from omegaconf import OmegaConf

In [7]:
from proteinworkshop.configs import config

cfg = config.validate_config(cfg)
# print("Original config:\n", OmegaConf.to_yaml(cfg))
mutable_cfg = OmegaConf.to_container(cfg.dataset.datamodule, resolve=True)
mutable_cfg = OmegaConf.create(mutable_cfg)
# print("Cloned config:\n", OmegaConf.to_yaml(mutable_cfg))
# Instantiate the datamodule with the mutable configuration
datamodule = hydra.utils.instantiate(mutable_cfg)
datamodule.setup("fit")
dl = datamodule.train_dataloader()
dl = datamodule.val_dataloader()
for i in dl:
    print(i)
    break

100%|██████████| 542378/542378 [00:00<00:00, 4672582.55it/s]
Processing...
Done!
100%|██████████| 542378/542378 [00:00<00:00, 4710958.31it/s]
Processing...
Done!
100%|██████████| 542378/542378 [00:00<00:00, 4749642.38it/s]
Processing...
Done!
100%|██████████| 542378/542378 [00:00<00:00, 4747144.73it/s]
Processing...
Done!

This DataLoader will create 32 worker processes in total. Our suggested max number of worker in current system is 28, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.



ProteinBatch(fill_value=[32], atom_list=[32], coords=[11716, 37, 3], residues=[32], residue_id=[32], chains=[11716], residue_type=[11716], b_factor=[11716], id=[32], x=[11716], seq_pos=[11716, 1], batch=[11716], ptr=[33])


In [8]:
from torch import nn
featuriser: nn.Module = hydra.utils.instantiate(cfg.features)

for i in dl:
    batch = featuriser(i)
    print(batch)
    break

ProteinBatch(fill_value=[32], atom_list=[32], coords=[11716, 37, 3], residues=[32], residue_id=[32], chains=[11716], residue_type=[11716], b_factor=[11716], id=[32], x=[11716, 23], seq_pos=[11716, 1], batch=[11716], ptr=[33], pos=[11716, 3], edge_index=[2, 185106], subgraphs=[1157, 148], subgraph_distances=[1157], subgraph_lengths=[1157])


### Define a new encoder

In [9]:
from tqdm import tqdm
import numpy as np
import time
from torch_scatter import scatter_mean, scatter
import torch
def train(args, model, mlp_pred_dist, train_loader,  criterion, optimizer, device):
    model.train()
    mlp_pred_dist.train()
    loss_accum = 0
   
    # shuffle the train batches and all_subgraphs
    #random_idx = np.random.permutation(len(train_batches))
    #train_batches = [train_batches[i] for i in random_idx]
    
    
    #for step, batch in enumerate(tqdm(loader, disable=args.disable_tqdm)):
    for step, batch in enumerate(tqdm(train_loader, disable=args.disable_tqdm)):
        batch = featuriser(batch)
        #init_idx = random_idx[step]
        init_idx = step
        if args.mask:
            # random mask node aatype
            mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False))
            batch.x[:, 0][mask_indice] = 25
        if args.noise:
            # add gaussian noise to atom coords
            gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3)
            batch.coords_ca += gaussian_noise
            if args.level != 'aminoacid':
                batch.coords_n += gaussian_noise
                batch.coords_c += gaussian_noise
        if args.deform:
            # Anisotropic scale
            deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1)
            batch.coords_ca *= deform
            if args.level != 'aminoacid':
                batch.coords_n *= deform
                batch.coords_c *= deform
        batch = batch.to(device)
                      
        pred = model(batch) 
        
        subgraphs = batch.subgraphs
        dist = batch.subgraph_distances.to(device)

        #aggregate node representations of each subgraph

        pooled_subgraphs = []
        for i in range(len(subgraphs)):
            pooled_subgraphs.append(torch.sum(pred[subgraphs[i]], dim=0))

        pooled_subgraphs = torch.stack(pooled_subgraphs)
        graph_repr = scatter(pred, batch.batch, dim=0)
        # for lin in model.lins_out:
        #     pooled_subgraphs = model.relu(lin(pooled_subgraphs))
        #     pooled_subgraphs = model.dropout(pooled_subgraphs)     
        
        #protein representations
        #compute the center of the subgraphs based on the coordinates
       
        #compute the center of the protein based on the coordinates
        # compute the distance between the center of the subgraphs and the center of the proteins
        # we have to compute the distance only between the subgraph and the corresponding protein
        # repeat the center of the protein for each subgraph and the perform the distance computation
        #G_c = G_c.repeat(num_subgraphs_per_protein,1)

        
        #dist = torch.norm(G_c-S_c,dim=1)
        #concat the subgraph and the protein representations. find the graph_repr that corresponds to the subgraph of the protein
        fused_repr = []
        for i in range(len(pooled_subgraphs)):
            fused_repr.append(torch.cat((pooled_subgraphs[i],graph_repr[batch.batch[subgraphs[i][0]]])))
        fused_repr = torch.stack(fused_repr)
        #predict the distance
        pred_dist = mlp_pred_dist(fused_repr)
        pred_dist = pred_dist.squeeze()
        #normalize the distance (note maybe we should normalize in the whole dataset and not in each bach)
        y_dist = dist/torch.max(dist)
        #mse loss
        optimizer.zero_grad()

        loss = criterion(pred_dist, y_dist)
        loss.backward()
        optimizer.step()
        loss_accum += loss.item()
        if(step%300==0):
            print(loss_accum/(step + 1))
        #### end pretraining
        ######
    print('train loss epoch: ', loss_accum/(step + 1) )
    return loss_accum/(step + 1)

In [10]:
def evaluation(args, model, mlp_pred_dist, loader, criterion, device):    
    model.eval()
    
    loss_accum = 0
    for step, batch in enumerate(tqdm(loader, disable=args.disable_tqdm)):
        batch = featuriser(batch)
        if args.mask:
            # random mask node aatype
            mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False))
            batch.x[:, 0][mask_indice] = 25
        if args.noise:
            # add gaussian noise to atom coords
            gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3)
            batch.coords_ca += gaussian_noise
            if args.level != 'aminoacid':
                batch.coords_n += gaussian_noise
                batch.coords_c += gaussian_noise
        if args.deform:
            # Anisotropic scale
            deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1)
            batch.coords_ca *= deform
            if args.level != 'aminoacid':
                batch.coords_n *= deform
                batch.coords_c *= deform
        batch = batch.to(device)
                     
        
        pred = model(batch) 
        
        

        subgraphs = batch.subgraphs
        dist = batch.subgraph_distances.to(device)
        #### pretraining
        #aggregate node representations of each subgraph


        pooled_subgraphs = []
        for i in range(len(subgraphs)):
            pooled_subgraphs.append(torch.sum(pred[subgraphs[i]], dim=0))
      
        pooled_subgraphs = torch.stack(pooled_subgraphs)

        # for lin in model.lins_out:
        #     pooled_subgraphs = model.relu(lin(pooled_subgraphs))
        #     pooled_subgraphs = model.dropout(pooled_subgraphs)    
        #protein representations
        #compute the center of the subgraphs based on the coordinates

        #compute the center of the protein based on the coordinates
        # compute the distance between the center of the subgraphs and the center of the proteins
        # we have to compute the distance only between the subgraph and the corresponding protein
        # repeat the center of the protein for each subgraph and the perform the distance computation
        #G_c = G_c.repeat(num_subgraphs_per_protein,1)

        
        #dist = torch.norm(G_c-S_c,dim=1)
        #normalize the distance (note maybe we should normalize in the whole dataset and not in each bach)
        graph_repr = scatter(pred, batch.batch, dim=0)
        fused_repr = []
        for i in range(len(pooled_subgraphs)):
            fused_repr.append(torch.cat((pooled_subgraphs[i],graph_repr[batch.batch[subgraphs[i][0]]])))
        fused_repr = torch.stack(fused_repr)
        y_dist = dist/torch.max(dist)
        #predict the distance
        pred_dist = mlp_pred_dist(fused_repr)
        pred_dist = pred_dist.squeeze()
        #mse loss
        loss = criterion(pred_dist, y_dist)
        loss_accum += loss.item()
        if(step %100 == 0):
            print(loss_accum/(step + 1))
    print('eval loss epoch: ', loss_accum/(step + 1))
    return loss_accum/(step + 1) 

In [11]:
#     ### Args
import argparse
import sys
import torch
from pronet import ProNet
import torch.optim as optim
from datetime import datetime
# from torch.utils.tensorboard import SummaryWriter
import time

sys.argv = ['notebook']
parser = argparse.ArgumentParser()

parser.add_argument('--device', type=int, default=0, help='Device to use')
parser.add_argument('--num_workers', type=int, default=8, help='Number of workers in Dataloader')

### Data
# parser.add_argument('--dataset', type=str, default='alphafold', help='Func or fold or all')
# parser.add_argument('--dataset_path', type=str, default='/datalake/datastore2/alphafold_v4_pronet_processed', help='path to load and process the data')
# parser.add_argument('--annot_fn', type=str, default="/home/michail/datadisk/PretrainDas/data/GO_EC_labels_deepfri/nrPDB-GO_2019.06.18_annot.tsv")
# parser.add_argument('--ontology', type=str, default="ec")

# data augmentation tricks, see appendix E in the paper (https://openreview.net/pdf?id=9X-hgLDLYkQ)
parser.add_argument('--mask', action='store_true', help='Random mask some node type')
parser.add_argument('--noise', default=False, action='store_true', help='Add Gaussian noise to node coords')
parser.add_argument('--deform', default=False, action='store_true', help='Deform node coords')
parser.add_argument('--data_augment_eachlayer', default=True, action='store_true', help='Add Gaussian noise to features')
parser.add_argument('--euler_noise', default=False, action='store_true', help='Add Gaussian noise Euler angles')
parser.add_argument('--mask_aatype', type=float, default=0.1, help='Random mask aatype to 25(unknown:X) ratio')

### Model
parser.add_argument('--model_name', type=str, default='pronet', help='rgcn,pronet')
#for pronet
parser.add_argument('--level', type=str, default='aminoacid', help='Choose from \'aminoacid\', \'backbone\', and \'allatom\' levels')
parser.add_argument('--num_blocks', type=int, default=4, help='Model layers')
parser.add_argument('--hidden_channels', type=int, default=128, help='Hidden dimension')
parser.add_argument('--out_channels', type=int, default=384, help='Number of classes, 1195 for the fold data, 384 for the ECdata')
parser.add_argument('--fix_dist', action='store_true')  
parser.add_argument('--cutoff', type=float, default=10, help='Distance constraint for building the protein graph') 
parser.add_argument('--dropout', type=float, default=0.3, help='Dropout')
parser.add_argument('--precompute_subgraphs', type=int, default=0, help='Compute the subgraphs')

## Training hyperparameter
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train')
parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
parser.add_argument('--lr_decay_step_size', type=int, default=150, help='Learning rate step size')
parser.add_argument('--lr_decay_factor', type=float, default=0.5, help='Learning rate factor') 
parser.add_argument('--weight_decay', type=float, default=0, help='Weight Decay')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size during training')
parser.add_argument('--eval_batch_size', type=int, default=64, help='Batch size')



parser.add_argument('--continue_training', action='store_true')
parser.add_argument('--save_dir', type=str, default="./logs", help='Trained model path')

parser.add_argument('--disable_tqdm', default=False, action='store_true')
args = parser.parse_args()

device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
# if(args.device == 1):
#     torch.cuda.set_device(1)

##### load datasets
print('Loading Train & Val & Test Data...')




train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
#check if file exists


##### set up model
if(args.model_name == "pronet"):
    model = ProNet(num_blocks=args.num_blocks, hidden_channels=args.hidden_channels, out_channels=args.out_channels,
            cutoff=args.cutoff, dropout=args.dropout,
            data_augment_eachlayer=args.data_augment_eachlayer,
            euler_noise = args.euler_noise, level=args.level, pretraining=True)
else:
    model = RGCN(input_dim=input_dim, hidden_dim=args.hidden_channels, n_layers=6, emb_dim=args.out_channels, dropout=args.dropout, pretraining=True)
    
model.to(device)

mlp_pred_dist = torch.nn.Sequential(
    torch.nn.Linear(2*args.hidden_channels, args.hidden_channels),
    torch.nn.ReLU(),
    torch.nn.Linear(args.hidden_channels, 1)
).to(device)

#linear_pred_dist= torch.nn.Linear(2*args.hidden_channels, 1).to(device)


optimizer = optim.Adam(list(model.parameters())+list(mlp_pred_dist.parameters()), lr=args.lr, weight_decay=args.weight_decay) 
criterion = torch.nn.MSELoss()

#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step_size, gamma=args.lr_decay_factor)


if args.continue_training:
    save_dir = args.save_dir
    checkpoint = torch.load(save_dir + '/best_val.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    #scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch']
else:
    # save_dir = './pretrained_models_{dataset}/{level}/layer{num_blocks}_cutoff{cutoff}_hidden{hidden_channels}_batch{batch_size}_lr{lr}_{lr_decay_factor}_{lr_decay_step_size}_dropout{dropout}__{time}'.format(
    #     dataset=args.dataset, level=args.level, 
    #     num_blocks=args.num_blocks, cutoff=args.cutoff, hidden_channels=args.hidden_channels, batch_size=args.batch_size, 
    #     lr=args.lr, lr_decay_factor=args.lr_decay_factor, lr_decay_step_size=args.lr_decay_step_size, dropout=args.dropout, time=datetime.now())
    # print('saving to...', save_dir)
    start_epoch = 1
    
num_params = sum(p.numel() for p in model.parameters()) 
print('num_parameters:', num_params)


# writer = SummaryWriter(log_dir=save_dir)
#  best_val_loss = 1000
# test_at_best_val_loss = 1000

    
# print("loading edge_index")
# with open("edge_index_pronet_64.pkl","rb") as f:
#     edge_index = pickle.load(f)
# print("edge_index loaded")


    
print(len(train_loader))
# exit()

# print("Loading subgraphs")
# if(args.precompute_subgraphs==1):
#     print("Preprocessing - Compute Subgraphs")
#     train_subgraphs,train_dist = compute_subgraphs(train_loader, args=args, device=device)
#     with open(f'./subgraphs/alphafold_train_subgraphs_{args.batch_size}_490k.pkl', 'wb') as f:
#         pickle.dump(train_subgraphs, f)
#     with open(f'./subgraphs/alphafold_train_dist_{args.batch_size}_490k.pkl', 'wb') as f:
#         pickle.dump(train_dist, f)
# else:
#     with open(f'./subgraphs/alphafold_train_subgraphs_{args.batch_size}_490k.pkl', 'rb') as f:
#         train_subgraphs = pickle.load(f)
#     with open(f'./subgraphs/alphafold_train_dist_{args.batch_size}_490k.pkl', 'rb') as f:
#         train_dist = pickle.load(f)
    


print("Loaded subgraphs")


for epoch in range(start_epoch, args.epochs+1):
    print('==== Epoch {} ===='.format(epoch))
    t_start = time.perf_counter()
    
    train_loss = train(args, model, mlp_pred_dist, train_loader, criterion, optimizer, device)
    t_end_train = time.perf_counter()
    val_loss = evaluation(args, model, mlp_pred_dist, val_loader, criterion, device)
    # t_start_test = time.perf_counter()
    # test_loss = evaluation(args, model, linear_pred_dist, test_loader, criterion, device)
    
    
    # t_end_test = time.perf_counter() 

    # if not save_dir == "" and not os.path.exists(save_dir):
    #     os.makedirs(save_dir)

    t_end = time.perf_counter()
    print('Train: Loss:{:.6f}, time:{}, train_time:{}'.format(
        train_loss, t_end-t_start, t_end_train-t_start))
    
    # writer.add_scalar('train_loss', train_loss, epoch)

    # scheduler.step()   

    # writer.close()    
    # print("Train Loss", train_loss)
    # Save last model
    checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()} #'scheduler_state_dict': scheduler.state_dict()}
    # torch.save(checkpoint, save_dir + "/epoch{}.pt".format(epoch))

Loading Train & Val & Test Data...
num_parameters: 1433344
339
Loaded subgraphs
==== Epoch 1 ====


  0%|          | 1/339 [00:04<24:04,  4.27s/it]

0.19333094358444214


 89%|████████▉ | 301/339 [09:05<01:00,  1.60s/it]

0.06004828708462937


100%|██████████| 339/339 [10:20<00:00,  1.83s/it]


train loss epoch:  0.05867409093475799


  1%|          | 1/165 [00:04<11:51,  4.34s/it]

0.02635742351412773


 61%|██████    | 101/165 [02:52<01:57,  1.84s/it]

0.03942182510722392


100%|██████████| 165/165 [04:46<00:00,  1.74s/it]


eval loss epoch:  0.03992581272667105
Train: Loss:0.058674, time:906.9469184440095, train_time:620.4321674219973
==== Epoch 2 ====


  0%|          | 1/339 [00:03<22:29,  3.99s/it]

0.024970002472400665


 89%|████████▉ | 301/339 [09:15<01:14,  1.95s/it]

0.046417944832547166


100%|██████████| 339/339 [10:23<00:00,  1.84s/it]


train loss epoch:  0.04555778847155669


  1%|          | 1/165 [00:04<12:07,  4.44s/it]

0.10035014897584915


 61%|██████    | 101/165 [02:52<01:57,  1.84s/it]

0.06426998904657245


100%|██████████| 165/165 [04:46<00:00,  1.73s/it]


eval loss epoch:  0.06200156705171773
Train: Loss:0.045558, time:909.3110188319988, train_time:623.0785406450013
==== Epoch 3 ====


  0%|          | 1/339 [00:03<22:18,  3.96s/it]

0.054526835680007935


 89%|████████▉ | 301/339 [09:16<01:13,  1.92s/it]

0.03933691680493248


100%|██████████| 339/339 [10:22<00:00,  1.84s/it]


train loss epoch:  0.040142245005519515


  1%|          | 1/165 [00:04<11:59,  4.39s/it]

0.05959140881896019


 61%|██████    | 101/165 [02:53<01:58,  1.84s/it]

0.04970447140017358


100%|██████████| 165/165 [04:46<00:00,  1.74s/it]


eval loss epoch:  0.049263358308058794
Train: Loss:0.040142, time:909.3881620059983, train_time:622.5458605159947
==== Epoch 4 ====


  0%|          | 1/339 [00:05<29:04,  5.16s/it]

0.05154261738061905


 89%|████████▉ | 301/339 [09:17<01:05,  1.73s/it]

0.03805944661105491


100%|██████████| 339/339 [10:23<00:00,  1.84s/it]


train loss epoch:  0.03805751212018166


  1%|          | 1/165 [00:04<11:55,  4.36s/it]

0.024155309423804283


 61%|██████    | 101/165 [02:53<01:57,  1.84s/it]

0.03334032532085877


100%|██████████| 165/165 [04:46<00:00,  1.74s/it]


eval loss epoch:  0.034935878200287165
Train: Loss:0.038058, time:909.8540852959995, train_time:623.2159066159948
==== Epoch 5 ====


  0%|          | 1/339 [00:04<25:17,  4.49s/it]

0.021312864497303963


 89%|████████▉ | 301/339 [09:11<01:10,  1.87s/it]

0.03035295474346096


100%|██████████| 339/339 [10:22<00:00,  1.84s/it]


train loss epoch:  0.030151901722701602


  1%|          | 1/165 [00:04<12:23,  4.54s/it]

0.016864396631717682


 61%|██████    | 101/165 [02:53<01:58,  1.84s/it]

0.02622215389873427


100%|██████████| 165/165 [04:46<00:00,  1.74s/it]


eval loss epoch:  0.027736526767187047
Train: Loss:0.030152, time:909.4134876539902, train_time:622.6849301890034
==== Epoch 6 ====


  0%|          | 1/339 [00:04<26:36,  4.72s/it]

0.017745036631822586


 89%|████████▉ | 301/339 [09:12<01:23,  2.20s/it]

0.03124060809141972


100%|██████████| 339/339 [10:23<00:00,  1.84s/it]


train loss epoch:  0.03150184058576031


  1%|          | 1/165 [00:04<11:57,  4.38s/it]

0.017394671216607094


 61%|██████    | 101/165 [02:52<01:58,  1.85s/it]

0.028397551839156907


100%|██████████| 165/165 [04:46<00:00,  1.74s/it]


eval loss epoch:  0.028883139546396153
Train: Loss:0.031502, time:910.1711360140034, train_time:623.6630888109939
==== Epoch 7 ====


  0%|          | 1/339 [00:05<29:59,  5.32s/it]

0.03609895706176758


 89%|████████▉ | 301/339 [09:18<01:12,  1.90s/it]

0.0278279707299861


100%|██████████| 339/339 [10:25<00:00,  1.85s/it]


train loss epoch:  0.027316197406797284


  1%|          | 1/165 [00:04<12:07,  4.44s/it]

0.02230650745332241


 61%|██████    | 101/165 [02:52<01:57,  1.84s/it]

0.025128639616662323


100%|██████████| 165/165 [04:46<00:00,  1.74s/it]


eval loss epoch:  0.02581789732318033
Train: Loss:0.027316, time:912.3694836729992, train_time:625.8477841529966
==== Epoch 8 ====


  0%|          | 1/339 [00:04<23:17,  4.14s/it]

0.013320054858922958


 89%|████████▉ | 301/339 [09:18<01:08,  1.80s/it]

0.026852127009897534


100%|██████████| 339/339 [10:26<00:00,  1.85s/it]


train loss epoch:  0.026765850026457184


  1%|          | 1/165 [00:04<12:31,  4.58s/it]

0.016851380467414856


 61%|██████    | 101/165 [02:53<01:58,  1.85s/it]

0.024040494805736706


100%|██████████| 165/165 [04:47<00:00,  1.74s/it]


eval loss epoch:  0.024804918806661257
Train: Loss:0.026766, time:914.111687134995, train_time:626.8559465879953
==== Epoch 9 ====


  0%|          | 1/339 [00:04<27:50,  4.94s/it]

0.019249556586146355


 89%|████████▉ | 301/339 [09:12<01:05,  1.72s/it]

0.027206222276107417


100%|██████████| 339/339 [10:23<00:00,  1.84s/it]


train loss epoch:  0.027334844322660857


  1%|          | 1/165 [00:04<12:08,  4.44s/it]

0.013516737148165703


 61%|██████    | 101/165 [02:53<01:57,  1.84s/it]

0.0243061199267902


100%|██████████| 165/165 [04:46<00:00,  1.74s/it]


eval loss epoch:  0.025090810543659962
Train: Loss:0.027335, time:910.4465467270056, train_time:623.5201317340106
==== Epoch 10 ====


  0%|          | 1/339 [00:04<23:18,  4.14s/it]

0.01678326353430748


 89%|████████▉ | 301/339 [09:12<01:20,  2.11s/it]

0.025918189182481496


100%|██████████| 339/339 [10:21<00:00,  1.83s/it]


train loss epoch:  0.02560640700919702


  1%|          | 1/165 [00:04<12:13,  4.48s/it]

0.013540513813495636


 61%|██████    | 101/165 [02:52<01:57,  1.83s/it]

0.02471251428901854


100%|██████████| 165/165 [04:45<00:00,  1.73s/it]

eval loss epoch:  0.026172691792475455
Train: Loss:0.025606, time:907.723735299005, train_time:621.7440614070074





: 