# ALBERT for aptamer-pair classification

## Set dependancies

In [None]:
!pip install transformers==4.9.1
#  In case prediction doesnt work Check:
#  https://github.com/huggingface/transformers/issues/8879 and BERT architecture

In [None]:
import torch
import torch.nn as nn
import torch.onnx
import os
import matplotlib.pyplot as plt
import copy
import torch.optim as optim
import random
import numpy as np
import pandas as pd
import ruamel.yaml
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup

os.environ["TOKENIZERS_PARALLELISM"] = "false"
config_name = 'config.yaml'

## Load your dataset

In [None]:
with open(config_name, 'r') as stream:
    try:
        yaml = ruamel.yaml.YAML()
        config = yaml.load(stream)
    except yaml.YAMLError as exc:
        print(exc)

In [None]:
path_to_train_data = config['Datasets']['train']
path_to_val_data = config['Datasets']['val']
path_to_test_data = config['Datasets']['test']

df_train = pd.read_csv(path_to_train_data)
df_val = pd.read_csv(path_to_val_data)
df_test = pd.read_csv(path_to_test_data)

In [None]:
print(df_train.shape, df_val.shape, df_test.shape)

(399600, 3)
(49950, 3)
(49950, 3)


In [None]:
df_train.head()

Unnamed: 0,Sequence1,Sequence2,Label
0,ATCTGGCTAATTAAT,GCTAGCTACTCTCAG,0
1,AAGGAGATGGTTGAA,GAAGGGTCGTCTCCA,0
2,AGATGCAAACGCTGC,TCGAATCTATAGACA,0
3,GCTGTCACGAACTTG,GTTTGGAGGCTGATC,1
4,CGAGTCTGTCTCGCC,ATAAAGCAGAAAAAT,1


## Classes and functions

In [None]:
class CustomDataset(Dataset):

    def __init__(self, data, maxlen, with_labels=True, bert_model='albert-base-v2'):
        self.data = data  # pandas dataframe
        self.tokenizer = AutoTokenizer.from_pretrained(bert_model, return_dict=False)  
        self.maxlen = maxlen
        self.with_labels = with_labels 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sent1 = str(self.data.loc[index, 'Sequence1'])
        sent2 = str(self.data.loc[index, 'Sequence2'])

        # Tokenize the pair of sentences to get token ids, attention masks and token type ids
        encoded_pair = self.tokenizer(sent1, sent2, 
                                      padding='max_length',  # Pad to max_length
                                      truncation=True,  # Truncate to max_length
                                      max_length=self.maxlen,  
                                      return_tensors='pt')  # Return torch.Tensor objects
        
        token_ids = encoded_pair['input_ids'].squeeze(0)  # tensor of token ids
        attn_masks = encoded_pair['attention_mask'].squeeze(0)  # binary tensor with "0" for padded values and "1" for the other values
        token_type_ids = encoded_pair['token_type_ids'].squeeze(0)  # binary tensor with "0" for the 1st sentence tokens & "1" for the 2nd sentence tokens

        if self.with_labels:  # True if the dataset has labels
            label = self.data.loc[index, 'Label']
            return token_ids, attn_masks, token_type_ids, label  
        else:
            return token_ids, attn_masks, token_type_ids

In [None]:
class Model(nn.Module):

    def __init__(self, bert_model="albert-base-v2", freeze_bert=False):
        super(Model, self).__init__()
        self.bert_layer = AutoModel.from_pretrained(bert_model)

        #  Fix the hidden-state size of the encoder outputs
        #  More information can be found https://huggingface.co/transformers/pretrained_models.html
        hidden_size = config['Model']['hidden_size']

        # Freeze bert layers and only train the classification layer weights
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False

        # Classification layer
        self.cls_layer = nn.Linear(hidden_size, config['Model']['number_labels'])
        self.dropout = nn.Dropout(p=config['Model'['dropout_rate']])

    @autocast()  # run in mixed precision
    def forward(self, input_ids, attn_masks, token_type_ids):
        '''
        Inputs:
            -input_ids : Tensor  containing token ids
            -attn_masks : Tensor containing attention masks to be used to focus on non-padded values
            -token_type_ids : Tensor containing token type ids to be used to identify sentence1 and sentence2
        '''

        # Feeding the inputs to the BERT-based model to obtain contextualized representations
        cont_reps, pooler_output = self.bert_layer(input_ids, attn_masks, token_type_ids)

        # Feeding to the classifier layer the last layer hidden-state of the [CLS] token further processed by a
        # Linear Layer and a Tanh activation. The Linear layer weights were trained from the sentence order prediction (ALBERT) or next sentence prediction (BERT)
        # objective during pre-training.
        logits = self.cls_layer(self.dropout(pooler_output))

        return logits

In [None]:
#seed to make results reproducible
def set_seed(seed):
    """ Set all seeds to make results reproducible """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    

def evaluate_loss(net, device, criterion, dataloader):
    net.eval()

    mean_loss = 0
    count = 0

    with torch.no_grad():
        for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(dataloader)):
            seq, attn_masks, token_type_ids, labels = \
                seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device)
            logits = net(seq, attn_masks, token_type_ids)
            mean_loss += criterion(logits.squeeze(-1), labels.float()).item()
            count += 1

    return mean_loss / count

In [None]:
def train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate):

    best_loss = np.Inf
    nb_iterations = len(train_loader)
    print_every = nb_iterations // config['Random']['print_every']  # print the training loss 5 times per epoch
    scaler = GradScaler()
    
    iters = []
    train_losses = []
    val_losses = []

    for ep in range(epochs):

        net.train()
        running_loss = 0.0
        for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(train_loader)):

            # Converting to cuda tensors
            seq, attn_masks, token_type_ids, labels = \
                seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device)
    
            # Enables autocasting for the forward pass (model + loss)
            with autocast():
                # Obtaining the logits from the model
                logits = net(seq, attn_masks, token_type_ids)

                # Computing loss
                loss = criterion(logits.squeeze(-1), labels.float())
                loss = loss / iters_to_accumulate  # Normalize the loss because it is averaged

            # Backpropagating the gradients
            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            scaler.scale(loss).backward()

            if (it + 1) % iters_to_accumulate == 0:
                # Optimization step
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, opti.step() is then called,
                # otherwise, opti.step() is skipped.
                scaler.step(opti)
                # Updates the scale for next iteration.
                scaler.update()
                # Adjust the learning rate based on the number of iterations.
                lr_scheduler.step()
                # Clear gradients
                opti.zero_grad()


            running_loss += loss.item()

            if (it + 1) % print_every == 0:  # Print training loss information
                print()
                print("Iteration {}/{} of epoch {} complete. Loss : {} "
                      .format(it+1, nb_iterations, ep+1, running_loss / print_every))

                running_loss = 0.0


        val_loss = evaluate_loss(net, device, criterion, val_loader)  # Compute validation loss
        print()
        print("Epoch {} complete! Validation Loss : {}".format(ep+1, val_loss))

        if val_loss < best_loss:
            print("Best validation loss improved from {} to {}".format(best_loss, val_loss))
            print()
            net_copy = copy.deepcopy(net)  # save a copy of the model
            best_loss = val_loss

    # Saving the model
    path_to_model='model/{}_{}_{}.pt'.format(bert_model, lr, round(best_loss, 4))
    path_to_model_evaluation = './datasets/model_validation/{}_{}_{}.pt'.format(bert_model, lr, round(best_loss, 4))
    
    config['Random']['path_to_model'] = path_to_model
    config['Random']['path_to_model_evaluation'] = path_to_model_evaluation

    torch.save(net_copy.state_dict(), path_to_model)
    print("The model has been saved in {}".format(path_to_model))

    del loss
    torch.cuda.empty_cache()

## Hyperparameters for training

In [None]:
bert_model = config['Model']['bert_model']  # any of previously defined BERT alternatives :'albert-base-v2', 'albert-large-v2', 'albert-xlarge-v2', 'albert-xxlarge-v2' and others
freeze_bert = config['Model']['freeze_bert']  # if True, freeze the encoder weights and only update the classification layer weights
maxlen = config['Model']['max_len']        # maximum length of the tokenized input sentence pair : if greater than "maxlen", the input is truncated and else if smaller, the input is padded
bs = config['Model']['batch_size']                # batch size
iters_to_accumulate = config['Model']['itters_to_accumulate']  # the gradient accumulation adds gradients over an effective batch of size : bs * iters_to_accumulate. If set to "1", you get the usual batch size
lr = config['Model']['learning_rate']                # learning rate
epochs = config['Model']['epochs']             # number of training epochs

## Training

In [None]:
#  Set all seeds to make reproducible results
set_seed(2)
 
# Creating instances of training and validation set
print("Reading training data...")
train_set = CustomDataset(df_train, maxlen, bert_model)
print("Reading validation data...")
val_set = CustomDataset(df_val, maxlen, bert_model)
# Creating instances of training and validation dataloaders
train_loader = DataLoader(train_set, batch_size=bs, num_workers=2)
val_loader = DataLoader(val_set, batch_size=bs, num_workers=2)
 
 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = Model(bert_model, freeze_bert=freeze_bert)

if torch.cuda.device_count() > 1:  # if multiple GPUs
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)

net.to(device)
   
criterion = nn.BCEWithLogitsLoss()
opti = AdamW(net.parameters(), lr=lr, weight_decay=1e-2)
num_warmup_steps = config['Model']['num_warmup_steps'] # The number of steps for the warmup phase.
num_training_steps = epochs * len(train_loader)  # The total number of training steps
t_total = (len(train_loader) // iters_to_accumulate) * epochs  # Necessary to take into account Gradient accumulation
lr_scheduler = get_linear_schedule_with_warmup(optimizer=opti, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)
                         
train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate)

with open('config.yaml', 'w') as conf:
    yaml.dump(config, conf)

Reading training data...
Reading validation data...


 10%|▉         | 624/6244 [02:16<20:26,  4.58it/s]


Iteration 624/6244 of epoch 1 complete. Loss : 0.35080221448189175 


 20%|█▉        | 1248/6244 [04:32<18:08,  4.59it/s]


Iteration 1248/6244 of epoch 1 complete. Loss : 0.32330357120969355 


 30%|██▉       | 1872/6244 [06:48<15:53,  4.59it/s]


Iteration 1872/6244 of epoch 1 complete. Loss : 0.28906675062787074 


 40%|███▉      | 2496/6244 [09:04<13:38,  4.58it/s]


Iteration 2496/6244 of epoch 1 complete. Loss : 0.2688979998899576 


 50%|████▉     | 3120/6244 [11:20<11:22,  4.58it/s]


Iteration 3120/6244 of epoch 1 complete. Loss : 0.26109381220661676 


 60%|█████▉    | 3744/6244 [13:35<09:05,  4.58it/s]


Iteration 3744/6244 of epoch 1 complete. Loss : 0.25548320373472494 


 70%|██████▉   | 4368/6244 [15:51<06:48,  4.59it/s]


Iteration 4368/6244 of epoch 1 complete. Loss : 0.250910395087722 


 80%|███████▉  | 4992/6244 [18:07<04:33,  4.59it/s]


Iteration 4992/6244 of epoch 1 complete. Loss : 0.24684393826203468 


 90%|████████▉ | 5616/6244 [20:23<02:16,  4.59it/s]


Iteration 5616/6244 of epoch 1 complete. Loss : 0.24763305958073872 


100%|█████████▉| 6240/6244 [22:39<00:00,  4.59it/s]


Iteration 6240/6244 of epoch 1 complete. Loss : 0.24403590520318502 


100%|██████████| 6244/6244 [22:40<00:00,  4.59it/s]
100%|██████████| 781/781 [01:03<00:00, 12.34it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 1 complete! Validation Loss : 0.47292126644588767
Best validation loss improved from inf to 0.47292126644588767



 10%|▉         | 624/6244 [02:16<20:27,  4.58it/s]


Iteration 624/6244 of epoch 2 complete. Loss : 0.24042210947626677 


 20%|█▉        | 1248/6244 [04:32<18:13,  4.57it/s]


Iteration 1248/6244 of epoch 2 complete. Loss : 0.24256092632332674 


 30%|██▉       | 1872/6244 [06:48<15:53,  4.59it/s]


Iteration 1872/6244 of epoch 2 complete. Loss : 0.23870933163337982 


 40%|███▉      | 2496/6244 [09:04<13:37,  4.58it/s]


Iteration 2496/6244 of epoch 2 complete. Loss : 0.24040618844521353 


 50%|████▉     | 3120/6244 [11:20<11:22,  4.58it/s]


Iteration 3120/6244 of epoch 2 complete. Loss : 0.24056508067326668 


 60%|█████▉    | 3744/6244 [13:36<09:05,  4.59it/s]


Iteration 3744/6244 of epoch 2 complete. Loss : 0.23654958049360758 


 70%|██████▉   | 4368/6244 [15:52<06:50,  4.57it/s]


Iteration 4368/6244 of epoch 2 complete. Loss : 0.23522920636698985 


 80%|███████▉  | 4992/6244 [18:08<04:32,  4.59it/s]


Iteration 4992/6244 of epoch 2 complete. Loss : 0.24035428841717732 


 90%|████████▉ | 5616/6244 [20:24<02:17,  4.58it/s]


Iteration 5616/6244 of epoch 2 complete. Loss : 0.23527657073468733 


100%|█████████▉| 6240/6244 [22:40<00:00,  4.59it/s]


Iteration 6240/6244 of epoch 2 complete. Loss : 0.23125852630115473 


100%|██████████| 6244/6244 [22:41<00:00,  4.59it/s]
100%|██████████| 781/781 [01:03<00:00, 12.34it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 2 complete! Validation Loss : 0.46500902368225605
Best validation loss improved from 0.47292126644588767 to 0.46500902368225605



 10%|▉         | 624/6244 [02:16<20:26,  4.58it/s]


Iteration 624/6244 of epoch 3 complete. Loss : 0.22991691135729736 


 20%|█▉        | 1248/6244 [04:32<18:11,  4.58it/s]


Iteration 1248/6244 of epoch 3 complete. Loss : 0.2380026673229459 


 30%|██▉       | 1872/6244 [06:48<15:54,  4.58it/s]


Iteration 1872/6244 of epoch 3 complete. Loss : 0.23347007151311025 


 40%|███▉      | 2496/6244 [09:04<13:37,  4.59it/s]


Iteration 2496/6244 of epoch 3 complete. Loss : 0.2293971095663997 


 50%|████▉     | 3120/6244 [11:20<11:22,  4.57it/s]


Iteration 3120/6244 of epoch 3 complete. Loss : 0.23374312731604546 


 60%|█████▉    | 3744/6244 [13:36<09:05,  4.58it/s]


Iteration 3744/6244 of epoch 3 complete. Loss : 0.23310349673892444 


 70%|██████▉   | 4368/6244 [15:52<06:49,  4.58it/s]


Iteration 4368/6244 of epoch 3 complete. Loss : 0.23234564495774415 


 80%|███████▉  | 4992/6244 [18:08<04:33,  4.58it/s]


Iteration 4992/6244 of epoch 3 complete. Loss : 0.22553521662186354 


 90%|████████▉ | 5616/6244 [20:24<02:17,  4.58it/s]


Iteration 5616/6244 of epoch 3 complete. Loss : 0.22601943334134725 


100%|█████████▉| 6240/6244 [22:40<00:00,  4.58it/s]


Iteration 6240/6244 of epoch 3 complete. Loss : 0.23230585529922676 


100%|██████████| 6244/6244 [22:41<00:00,  4.59it/s]
100%|██████████| 781/781 [01:03<00:00, 12.34it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 3 complete! Validation Loss : 0.46212897941031317
Best validation loss improved from 0.46500902368225605 to 0.46212897941031317



 10%|▉         | 624/6244 [02:16<20:25,  4.59it/s]


Iteration 624/6244 of epoch 4 complete. Loss : 0.22797506579603904 


 20%|█▉        | 1248/6244 [04:32<18:09,  4.58it/s]


Iteration 1248/6244 of epoch 4 complete. Loss : 0.23357876712599626 


 30%|██▉       | 1872/6244 [06:48<15:53,  4.59it/s]


Iteration 1872/6244 of epoch 4 complete. Loss : 0.22679377554987484 


 40%|███▉      | 2496/6244 [09:04<13:38,  4.58it/s]


Iteration 2496/6244 of epoch 4 complete. Loss : 0.2273792577191041 


 50%|████▉     | 3120/6244 [11:20<11:21,  4.58it/s]


Iteration 3120/6244 of epoch 4 complete. Loss : 0.2233300370235856 


 60%|█████▉    | 3744/6244 [13:36<09:04,  4.59it/s]


Iteration 3744/6244 of epoch 4 complete. Loss : 0.22554876528775844 


 70%|██████▉   | 4368/6244 [15:52<06:50,  4.57it/s]


Iteration 4368/6244 of epoch 4 complete. Loss : 0.2233959078693237 


 80%|███████▉  | 4992/6244 [18:08<04:33,  4.58it/s]


Iteration 4992/6244 of epoch 4 complete. Loss : 0.22122853230207395 


 90%|████████▉ | 5616/6244 [20:24<02:16,  4.58it/s]


Iteration 5616/6244 of epoch 4 complete. Loss : 0.22467362204900918 


100%|█████████▉| 6240/6244 [22:40<00:00,  4.59it/s]


Iteration 6240/6244 of epoch 4 complete. Loss : 0.2274191616437374 


100%|██████████| 6244/6244 [22:41<00:00,  4.59it/s]
100%|██████████| 781/781 [01:03<00:00, 12.33it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 4 complete! Validation Loss : 0.4399142893038394
Best validation loss improved from 0.46212897941031317 to 0.4399142893038394



 10%|▉         | 624/6244 [02:16<20:30,  4.57it/s]


Iteration 624/6244 of epoch 5 complete. Loss : 0.22073736863258558 


 20%|█▉        | 1248/6244 [04:32<18:12,  4.57it/s]


Iteration 1248/6244 of epoch 5 complete. Loss : 0.22457358988527304 


 30%|██▉       | 1872/6244 [06:48<15:54,  4.58it/s]


Iteration 1872/6244 of epoch 5 complete. Loss : 0.2193910336981599 


 40%|███▉      | 2496/6244 [09:04<13:37,  4.58it/s]


Iteration 2496/6244 of epoch 5 complete. Loss : 0.2161115451166645 


 50%|████▉     | 3120/6244 [11:20<11:20,  4.59it/s]


Iteration 3120/6244 of epoch 5 complete. Loss : 0.22080779987841082 


 60%|█████▉    | 3744/6244 [13:36<09:05,  4.58it/s]


Iteration 3744/6244 of epoch 5 complete. Loss : 0.2210307632549069 


 70%|██████▉   | 4368/6244 [15:52<06:48,  4.59it/s]


Iteration 4368/6244 of epoch 5 complete. Loss : 0.2198781622812534 


 80%|███████▉  | 4992/6244 [18:08<04:33,  4.58it/s]


Iteration 4992/6244 of epoch 5 complete. Loss : 0.21634769167464513 


 90%|████████▉ | 5616/6244 [20:24<02:17,  4.58it/s]


Iteration 5616/6244 of epoch 5 complete. Loss : 0.21894480175792408 


100%|█████████▉| 6240/6244 [22:40<00:00,  4.59it/s]


Iteration 6240/6244 of epoch 5 complete. Loss : 0.2174558268668942 


100%|██████████| 6244/6244 [22:41<00:00,  4.59it/s]
100%|██████████| 781/781 [01:03<00:00, 12.33it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 5 complete! Validation Loss : 0.4267978705990482
Best validation loss improved from 0.4399142893038394 to 0.4267978705990482



 10%|▉         | 624/6244 [02:16<20:24,  4.59it/s]


Iteration 624/6244 of epoch 6 complete. Loss : 0.2115418163772959 


 20%|█▉        | 1248/6244 [04:32<18:15,  4.56it/s]


Iteration 1248/6244 of epoch 6 complete. Loss : 0.21501150872940436 


 30%|██▉       | 1872/6244 [06:48<15:53,  4.58it/s]


Iteration 1872/6244 of epoch 6 complete. Loss : 0.2150394042285207 


 40%|███▉      | 2496/6244 [09:04<13:40,  4.57it/s]


Iteration 2496/6244 of epoch 6 complete. Loss : 0.21322198219310778 


 50%|████▉     | 3120/6244 [11:20<11:21,  4.58it/s]


Iteration 3120/6244 of epoch 6 complete. Loss : 0.2193191765497128 


 60%|█████▉    | 3744/6244 [13:36<09:06,  4.57it/s]


Iteration 3744/6244 of epoch 6 complete. Loss : 0.21652210632768962 


 70%|██████▉   | 4368/6244 [15:52<06:48,  4.59it/s]


Iteration 4368/6244 of epoch 6 complete. Loss : 0.2157361221810182 


 80%|███████▉  | 4992/6244 [18:08<04:32,  4.59it/s]


Iteration 4992/6244 of epoch 6 complete. Loss : 0.21177112399481046 


 90%|████████▉ | 5616/6244 [20:24<02:16,  4.59it/s]


Iteration 5616/6244 of epoch 6 complete. Loss : 0.21135077596856996 


100%|█████████▉| 6240/6244 [22:40<00:00,  4.58it/s]


Iteration 6240/6244 of epoch 6 complete. Loss : 0.21227412769953027 


100%|██████████| 6244/6244 [22:41<00:00,  4.59it/s]
100%|██████████| 781/781 [01:03<00:00, 12.34it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 6 complete! Validation Loss : 0.42536892651290503
Best validation loss improved from 0.4267978705990482 to 0.42536892651290503



 10%|▉         | 624/6244 [02:16<20:27,  4.58it/s]


Iteration 624/6244 of epoch 7 complete. Loss : 0.20914013113062352 


 20%|█▉        | 1248/6244 [04:32<18:08,  4.59it/s]


Iteration 1248/6244 of epoch 7 complete. Loss : 0.21191388299354377 


 30%|██▉       | 1872/6244 [06:48<15:57,  4.57it/s]


Iteration 1872/6244 of epoch 7 complete. Loss : 0.21575215642746443 


 40%|███▉      | 2496/6244 [09:04<13:37,  4.59it/s]


Iteration 2496/6244 of epoch 7 complete. Loss : 0.21019764319778636 


 50%|████▉     | 3120/6244 [11:20<11:22,  4.58it/s]


Iteration 3120/6244 of epoch 7 complete. Loss : 0.21014796414723 


 60%|█████▉    | 3744/6244 [13:36<09:05,  4.59it/s]


Iteration 3744/6244 of epoch 7 complete. Loss : 0.21538501039433938 


 70%|██████▉   | 4368/6244 [15:52<06:49,  4.58it/s]


Iteration 4368/6244 of epoch 7 complete. Loss : 0.2110947793086943 


 80%|███████▉  | 4992/6244 [18:08<04:33,  4.57it/s]


Iteration 4992/6244 of epoch 7 complete. Loss : 0.20815757607133725 


 90%|████████▉ | 5616/6244 [20:24<02:17,  4.58it/s]


Iteration 5616/6244 of epoch 7 complete. Loss : 0.2121788443854222 


100%|█████████▉| 6240/6244 [22:40<00:00,  4.59it/s]


Iteration 6240/6244 of epoch 7 complete. Loss : 0.20889299971839556 


100%|██████████| 6244/6244 [22:41<00:00,  4.59it/s]
100%|██████████| 781/781 [01:03<00:00, 12.32it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 7 complete! Validation Loss : 0.43000888717617175


 10%|▉         | 624/6244 [02:16<20:24,  4.59it/s]


Iteration 624/6244 of epoch 8 complete. Loss : 0.20677819561499816 


 20%|█▉        | 1248/6244 [04:32<18:11,  4.58it/s]


Iteration 1248/6244 of epoch 8 complete. Loss : 0.20900335418394742 


 30%|██▉       | 1872/6244 [06:48<15:57,  4.57it/s]


Iteration 1872/6244 of epoch 8 complete. Loss : 0.20878455027317008 


 40%|███▉      | 2496/6244 [09:04<13:43,  4.55it/s]


Iteration 2496/6244 of epoch 8 complete. Loss : 0.2084426576605974 


 50%|████▉     | 3120/6244 [11:20<11:21,  4.58it/s]


Iteration 3120/6244 of epoch 8 complete. Loss : 0.21166242437007335 


 60%|█████▉    | 3744/6244 [13:36<09:05,  4.58it/s]


Iteration 3744/6244 of epoch 8 complete. Loss : 0.21127641041022846 


 70%|██████▉   | 4368/6244 [15:52<06:49,  4.58it/s]


Iteration 4368/6244 of epoch 8 complete. Loss : 0.20854688770113847 


 80%|███████▉  | 4992/6244 [18:08<04:33,  4.58it/s]


Iteration 4992/6244 of epoch 8 complete. Loss : 0.2068958341335066 


 90%|████████▉ | 5616/6244 [20:24<02:17,  4.57it/s]


Iteration 5616/6244 of epoch 8 complete. Loss : 0.20877530036541897 


100%|█████████▉| 6240/6244 [22:41<00:00,  4.58it/s]


Iteration 6240/6244 of epoch 8 complete. Loss : 0.20757556040413105 


100%|██████████| 6244/6244 [22:41<00:00,  4.58it/s]
100%|██████████| 781/781 [01:03<00:00, 12.31it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 8 complete! Validation Loss : 0.4133752287769745
Best validation loss improved from 0.42536892651290503 to 0.4133752287769745



 10%|▉         | 624/6244 [02:16<20:30,  4.57it/s]


Iteration 624/6244 of epoch 9 complete. Loss : 0.20787133154674217 


 20%|█▉        | 1248/6244 [04:32<18:10,  4.58it/s]


Iteration 1248/6244 of epoch 9 complete. Loss : 0.20893842063080043 


 30%|██▉       | 1872/6244 [06:48<15:55,  4.58it/s]


Iteration 1872/6244 of epoch 9 complete. Loss : 0.20634127950343567 


 40%|███▉      | 2496/6244 [09:04<13:37,  4.58it/s]


Iteration 2496/6244 of epoch 9 complete. Loss : 0.20462286221579865 


 50%|████▉     | 3120/6244 [11:20<11:24,  4.57it/s]


Iteration 3120/6244 of epoch 9 complete. Loss : 0.20676962107133406 


 60%|█████▉    | 3744/6244 [13:36<09:04,  4.59it/s]


Iteration 3744/6244 of epoch 9 complete. Loss : 0.20682924994243643 


 70%|██████▉   | 4368/6244 [15:53<06:50,  4.57it/s]


Iteration 4368/6244 of epoch 9 complete. Loss : 0.20741019141263303 


 80%|███████▉  | 4992/6244 [18:09<04:33,  4.57it/s]


Iteration 4992/6244 of epoch 9 complete. Loss : 0.20588213697266886 


 90%|████████▉ | 5616/6244 [20:25<02:17,  4.58it/s]


Iteration 5616/6244 of epoch 9 complete. Loss : 0.20738985664091814 


100%|█████████▉| 6240/6244 [22:42<00:00,  4.58it/s]


Iteration 6240/6244 of epoch 9 complete. Loss : 0.20593790012674454 


100%|██████████| 6244/6244 [22:42<00:00,  4.58it/s]
100%|██████████| 781/781 [01:03<00:00, 12.30it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 9 complete! Validation Loss : 0.4110010450803669
Best validation loss improved from 0.4133752287769745 to 0.4110010450803669



 10%|▉         | 624/6244 [02:16<20:29,  4.57it/s]


Iteration 624/6244 of epoch 10 complete. Loss : 0.20492011414936337 


 20%|█▉        | 1248/6244 [04:32<18:14,  4.56it/s]


Iteration 1248/6244 of epoch 10 complete. Loss : 0.20831168743853384 


 30%|██▉       | 1872/6244 [06:48<15:58,  4.56it/s]


Iteration 1872/6244 of epoch 10 complete. Loss : 0.20810194874707705 


 40%|███▉      | 2496/6244 [09:05<13:39,  4.58it/s]


Iteration 2496/6244 of epoch 10 complete. Loss : 0.20498674742590922 


 50%|████▉     | 3120/6244 [11:21<11:23,  4.57it/s]


Iteration 3120/6244 of epoch 10 complete. Loss : 0.205266795001733 


 60%|█████▉    | 3744/6244 [13:37<09:05,  4.58it/s]


Iteration 3744/6244 of epoch 10 complete. Loss : 0.20510078970199594 


 70%|██████▉   | 4368/6244 [15:53<06:50,  4.57it/s]


Iteration 4368/6244 of epoch 10 complete. Loss : 0.2054863644596667 


 80%|███████▉  | 4992/6244 [18:10<04:34,  4.57it/s]


Iteration 4992/6244 of epoch 10 complete. Loss : 0.20480434268187636 


 90%|████████▉ | 5616/6244 [20:26<02:17,  4.57it/s]


Iteration 5616/6244 of epoch 10 complete. Loss : 0.2054674023619065 


100%|█████████▉| 6240/6244 [22:42<00:00,  4.57it/s]


Iteration 6240/6244 of epoch 10 complete. Loss : 0.2062393752261041 


100%|██████████| 6244/6244 [22:43<00:00,  4.58it/s]
100%|██████████| 781/781 [01:03<00:00, 12.30it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 10 complete! Validation Loss : 0.4084064777888043
Best validation loss improved from 0.4110010450803669 to 0.4084064777888043



 10%|▉         | 624/6244 [02:16<20:26,  4.58it/s]


Iteration 624/6244 of epoch 11 complete. Loss : 0.20235474051859897 


 20%|█▉        | 1248/6244 [04:32<18:09,  4.59it/s]


Iteration 1248/6244 of epoch 11 complete. Loss : 0.20513682693051985 


 30%|██▉       | 1872/6244 [06:48<15:56,  4.57it/s]


Iteration 1872/6244 of epoch 11 complete. Loss : 0.2040458157634697 


 40%|███▉      | 2496/6244 [09:05<13:39,  4.57it/s]


Iteration 2496/6244 of epoch 11 complete. Loss : 0.20270538601116875 


 50%|████▉     | 3120/6244 [11:21<11:20,  4.59it/s]


Iteration 3120/6244 of epoch 11 complete. Loss : 0.20380448885500813 


 60%|█████▉    | 3744/6244 [13:37<09:06,  4.58it/s]


Iteration 3744/6244 of epoch 11 complete. Loss : 0.20409179414407566 


 70%|██████▉   | 4368/6244 [15:53<06:49,  4.59it/s]


Iteration 4368/6244 of epoch 11 complete. Loss : 0.20509476938213295 


 80%|███████▉  | 4992/6244 [18:09<04:32,  4.59it/s]


Iteration 4992/6244 of epoch 11 complete. Loss : 0.20419504206914169 


 90%|████████▉ | 5616/6244 [20:25<02:16,  4.59it/s]


Iteration 5616/6244 of epoch 11 complete. Loss : 0.20495869818692788 


100%|█████████▉| 6240/6244 [22:41<00:00,  4.56it/s]


Iteration 6240/6244 of epoch 11 complete. Loss : 0.20300846084809074 


100%|██████████| 6244/6244 [22:42<00:00,  4.58it/s]
100%|██████████| 781/781 [01:03<00:00, 12.33it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 11 complete! Validation Loss : 0.40756769613786176
Best validation loss improved from 0.4084064777888043 to 0.40756769613786176



 10%|▉         | 624/6244 [02:16<20:27,  4.58it/s]


Iteration 624/6244 of epoch 12 complete. Loss : 0.20040419067327792 


 20%|█▉        | 1248/6244 [04:32<18:07,  4.60it/s]


Iteration 1248/6244 of epoch 12 complete. Loss : 0.20357023128188956 


 30%|██▉       | 1872/6244 [06:48<15:57,  4.56it/s]


Iteration 1872/6244 of epoch 12 complete. Loss : 0.2032933095947672 


 40%|███▉      | 2496/6244 [09:04<13:37,  4.59it/s]


Iteration 2496/6244 of epoch 12 complete. Loss : 0.20198259526529375 


 50%|████▉     | 3120/6244 [11:20<11:21,  4.59it/s]


Iteration 3120/6244 of epoch 12 complete. Loss : 0.20274898130446672 


 60%|█████▉    | 3744/6244 [13:36<09:05,  4.58it/s]


Iteration 3744/6244 of epoch 12 complete. Loss : 0.202870585084057 


 70%|██████▉   | 4368/6244 [15:52<06:48,  4.59it/s]


Iteration 4368/6244 of epoch 12 complete. Loss : 0.20219401482683727 


 80%|███████▉  | 4992/6244 [18:08<04:34,  4.56it/s]


Iteration 4992/6244 of epoch 12 complete. Loss : 0.2019946091593458 


 90%|████████▉ | 5616/6244 [20:24<02:16,  4.59it/s]


Iteration 5616/6244 of epoch 12 complete. Loss : 0.20445706124584645 


100%|█████████▉| 6240/6244 [22:40<00:00,  4.59it/s]


Iteration 6240/6244 of epoch 12 complete. Loss : 0.20178170836506745 


100%|██████████| 6244/6244 [22:41<00:00,  4.59it/s]
100%|██████████| 781/781 [01:03<00:00, 12.33it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 12 complete! Validation Loss : 0.4033249343975558
Best validation loss improved from 0.40756769613786176 to 0.4033249343975558



 10%|▉         | 624/6244 [02:16<20:29,  4.57it/s]


Iteration 624/6244 of epoch 13 complete. Loss : 0.20003974817406672 


 20%|█▉        | 1248/6244 [04:32<18:08,  4.59it/s]


Iteration 1248/6244 of epoch 13 complete. Loss : 0.2036584410100029 


 30%|██▉       | 1872/6244 [06:48<15:56,  4.57it/s]


Iteration 1872/6244 of epoch 13 complete. Loss : 0.20283097594689864 


 40%|███▉      | 2496/6244 [09:04<13:38,  4.58it/s]


Iteration 2496/6244 of epoch 13 complete. Loss : 0.20044313002234468 


 50%|████▉     | 3120/6244 [11:20<11:23,  4.57it/s]


Iteration 3120/6244 of epoch 13 complete. Loss : 0.20153650684425464 


 60%|█████▉    | 3744/6244 [13:36<09:06,  4.58it/s]


Iteration 3744/6244 of epoch 13 complete. Loss : 0.2012887715887374 


 70%|██████▉   | 4368/6244 [15:52<06:48,  4.59it/s]


Iteration 4368/6244 of epoch 13 complete. Loss : 0.20120782793189088 


 80%|███████▉  | 4992/6244 [18:08<04:33,  4.58it/s]


Iteration 4992/6244 of epoch 13 complete. Loss : 0.20138541193535694 


 90%|████████▉ | 5616/6244 [20:24<02:17,  4.58it/s]


Iteration 5616/6244 of epoch 13 complete. Loss : 0.20175063888279673 


100%|█████████▉| 6240/6244 [22:40<00:00,  4.58it/s]


Iteration 6240/6244 of epoch 13 complete. Loss : 0.20012603470912346 


100%|██████████| 6244/6244 [22:41<00:00,  4.59it/s]
100%|██████████| 781/781 [01:03<00:00, 12.34it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 13 complete! Validation Loss : 0.40183799039386453
Best validation loss improved from 0.4033249343975558 to 0.40183799039386453



 10%|▉         | 624/6244 [02:16<20:25,  4.58it/s]


Iteration 624/6244 of epoch 14 complete. Loss : 0.1978513623993748 


 20%|█▉        | 1248/6244 [04:32<18:13,  4.57it/s]


Iteration 1248/6244 of epoch 14 complete. Loss : 0.2013009557237801 


 30%|██▉       | 1872/6244 [06:48<15:53,  4.58it/s]


Iteration 1872/6244 of epoch 14 complete. Loss : 0.2009463014964683 


 40%|███▉      | 2496/6244 [09:04<13:37,  4.58it/s]


Iteration 2496/6244 of epoch 14 complete. Loss : 0.19948910846589848 


 50%|████▉     | 3120/6244 [11:20<11:23,  4.57it/s]


Iteration 3120/6244 of epoch 14 complete. Loss : 0.20098292450301158 


 60%|█████▉    | 3744/6244 [13:36<09:04,  4.59it/s]


Iteration 3744/6244 of epoch 14 complete. Loss : 0.20133199122471687 


 70%|██████▉   | 4368/6244 [15:52<06:48,  4.59it/s]


Iteration 4368/6244 of epoch 14 complete. Loss : 0.19999141616221422 


 80%|███████▉  | 4992/6244 [18:08<04:32,  4.59it/s]


Iteration 4992/6244 of epoch 14 complete. Loss : 0.19980892362311864 


 90%|████████▉ | 5616/6244 [20:24<02:17,  4.58it/s]


Iteration 5616/6244 of epoch 14 complete. Loss : 0.20161620950183043 


100%|█████████▉| 6240/6244 [22:40<00:00,  4.58it/s]


Iteration 6240/6244 of epoch 14 complete. Loss : 0.19984222636916316 


100%|██████████| 6244/6244 [22:41<00:00,  4.59it/s]
100%|██████████| 781/781 [01:03<00:00, 12.33it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 14 complete! Validation Loss : 0.40061460334268634
Best validation loss improved from 0.40183799039386453 to 0.40061460334268634



 10%|▉         | 624/6244 [02:16<20:29,  4.57it/s]


Iteration 624/6244 of epoch 15 complete. Loss : 0.1963229088160472 


 20%|█▉        | 1248/6244 [04:32<18:09,  4.59it/s]


Iteration 1248/6244 of epoch 15 complete. Loss : 0.20019281895544666 


 30%|██▉       | 1872/6244 [06:48<15:55,  4.57it/s]


Iteration 1872/6244 of epoch 15 complete. Loss : 0.1992189368018164 


 40%|███▉      | 2496/6244 [09:04<13:36,  4.59it/s]


Iteration 2496/6244 of epoch 15 complete. Loss : 0.19798499115336782 


 50%|████▉     | 3120/6244 [11:20<11:20,  4.59it/s]


Iteration 3120/6244 of epoch 15 complete. Loss : 0.1998343568844482 


 60%|█████▉    | 3744/6244 [13:36<09:05,  4.58it/s]


Iteration 3744/6244 of epoch 15 complete. Loss : 0.20022664516447827 


 70%|██████▉   | 4368/6244 [15:52<06:49,  4.58it/s]


Iteration 4368/6244 of epoch 15 complete. Loss : 0.19840425597981382 


 80%|███████▉  | 4992/6244 [18:08<04:32,  4.59it/s]


Iteration 4992/6244 of epoch 15 complete. Loss : 0.1987865312765233 


 90%|████████▉ | 5616/6244 [20:25<02:17,  4.57it/s]


Iteration 5616/6244 of epoch 15 complete. Loss : 0.20050483908599767 


100%|█████████▉| 6240/6244 [22:41<00:00,  4.55it/s]


Iteration 6240/6244 of epoch 15 complete. Loss : 0.19827416112933022 


100%|██████████| 6244/6244 [22:42<00:00,  4.58it/s]
100%|██████████| 781/781 [01:03<00:00, 12.24it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 15 complete! Validation Loss : 0.4008210504406088


 10%|▉         | 624/6244 [02:16<20:30,  4.57it/s]


Iteration 624/6244 of epoch 16 complete. Loss : 0.19505183088282743 


 20%|█▉        | 1248/6244 [04:33<18:15,  4.56it/s]


Iteration 1248/6244 of epoch 16 complete. Loss : 0.19913204924131817 


 30%|██▉       | 1872/6244 [06:50<16:00,  4.55it/s]


Iteration 1872/6244 of epoch 16 complete. Loss : 0.1981896567874803 


 40%|███▉      | 2496/6244 [09:06<13:43,  4.55it/s]


Iteration 2496/6244 of epoch 16 complete. Loss : 0.19687525119680244 


 50%|████▉     | 3120/6244 [11:23<11:24,  4.57it/s]


Iteration 3120/6244 of epoch 16 complete. Loss : 0.19847361652705914 


 60%|█████▉    | 3744/6244 [13:40<09:08,  4.56it/s]


Iteration 3744/6244 of epoch 16 complete. Loss : 0.19843192781823185 


 70%|██████▉   | 4368/6244 [15:56<06:51,  4.56it/s]


Iteration 4368/6244 of epoch 16 complete. Loss : 0.1973344995520818 


 80%|███████▉  | 4992/6244 [18:13<04:34,  4.56it/s]


Iteration 4992/6244 of epoch 16 complete. Loss : 0.19786497850257617 


 90%|████████▉ | 5616/6244 [20:30<02:17,  4.56it/s]


Iteration 5616/6244 of epoch 16 complete. Loss : 0.19874966077697584 


100%|█████████▉| 6240/6244 [22:47<00:00,  4.54it/s]


Iteration 6240/6244 of epoch 16 complete. Loss : 0.19725151425705123 


100%|██████████| 6244/6244 [22:47<00:00,  4.56it/s]
100%|██████████| 781/781 [01:03<00:00, 12.25it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 16 complete! Validation Loss : 0.4013128274343383


 10%|▉         | 624/6244 [02:16<20:33,  4.55it/s]


Iteration 624/6244 of epoch 17 complete. Loss : 0.19455033701916152 


 20%|█▉        | 1248/6244 [04:33<18:18,  4.55it/s]


Iteration 1248/6244 of epoch 17 complete. Loss : 0.19855695134267592 


 30%|██▉       | 1872/6244 [06:50<15:56,  4.57it/s]


Iteration 1872/6244 of epoch 17 complete. Loss : 0.19695059122899786 


 40%|███▉      | 2496/6244 [09:07<13:41,  4.56it/s]


Iteration 2496/6244 of epoch 17 complete. Loss : 0.19675299262580198 


 50%|████▉     | 3120/6244 [11:23<11:27,  4.54it/s]


Iteration 3120/6244 of epoch 17 complete. Loss : 0.1977698747188044 


 60%|█████▉    | 3744/6244 [13:40<09:07,  4.56it/s]


Iteration 3744/6244 of epoch 17 complete. Loss : 0.19772311928085026 


 70%|██████▉   | 4368/6244 [15:57<06:51,  4.55it/s]


Iteration 4368/6244 of epoch 17 complete. Loss : 0.19638714521454695 


 80%|███████▉  | 4992/6244 [18:13<04:34,  4.56it/s]


Iteration 4992/6244 of epoch 17 complete. Loss : 0.196856649664159 


 90%|████████▉ | 5616/6244 [20:30<02:17,  4.57it/s]


Iteration 5616/6244 of epoch 17 complete. Loss : 0.19694695841425505 


100%|█████████▉| 6240/6244 [22:47<00:00,  4.56it/s]


Iteration 6240/6244 of epoch 17 complete. Loss : 0.19547124536564717 


100%|██████████| 6244/6244 [22:48<00:00,  4.56it/s]
100%|██████████| 781/781 [01:03<00:00, 12.25it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 17 complete! Validation Loss : 0.40417169154682475


 10%|▉         | 624/6244 [02:16<20:32,  4.56it/s]


Iteration 624/6244 of epoch 18 complete. Loss : 0.1925560164695176 


 20%|█▉        | 1248/6244 [04:33<18:14,  4.57it/s]


Iteration 1248/6244 of epoch 18 complete. Loss : 0.19652252232369322 


 30%|██▉       | 1872/6244 [06:50<16:02,  4.54it/s]


Iteration 1872/6244 of epoch 18 complete. Loss : 0.19616784406109497 


 40%|███▉      | 2496/6244 [09:06<13:42,  4.56it/s]


Iteration 2496/6244 of epoch 18 complete. Loss : 0.19476521466500485 


 50%|████▉     | 3120/6244 [11:23<11:26,  4.55it/s]


Iteration 3120/6244 of epoch 18 complete. Loss : 0.19634318535622114 


 60%|█████▉    | 3744/6244 [13:40<09:08,  4.56it/s]


Iteration 3744/6244 of epoch 18 complete. Loss : 0.19625875147250602 


 70%|██████▉   | 4368/6244 [15:56<06:50,  4.57it/s]


Iteration 4368/6244 of epoch 18 complete. Loss : 0.19443009646895987 


 80%|███████▉  | 4992/6244 [18:13<04:34,  4.55it/s]


Iteration 4992/6244 of epoch 18 complete. Loss : 0.1952331873277823 


 90%|████████▉ | 5616/6244 [20:30<02:17,  4.56it/s]


Iteration 5616/6244 of epoch 18 complete. Loss : 0.1956312846010312 


100%|█████████▉| 6240/6244 [22:46<00:00,  4.55it/s]


Iteration 6240/6244 of epoch 18 complete. Loss : 0.19423014597776225 


100%|██████████| 6244/6244 [22:47<00:00,  4.56it/s]
100%|██████████| 781/781 [01:03<00:00, 12.23it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 18 complete! Validation Loss : 0.4096767829490227


 10%|▉         | 624/6244 [02:16<20:36,  4.54it/s]


Iteration 624/6244 of epoch 19 complete. Loss : 0.19158181570804653 


 20%|█▉        | 1248/6244 [04:33<18:16,  4.55it/s]


Iteration 1248/6244 of epoch 19 complete. Loss : 0.19518156469059297 


 30%|██▉       | 1872/6244 [06:50<15:59,  4.56it/s]


Iteration 1872/6244 of epoch 19 complete. Loss : 0.19442229118580237 


 40%|███▉      | 2496/6244 [09:06<13:44,  4.54it/s]


Iteration 2496/6244 of epoch 19 complete. Loss : 0.19316489496626532 


 50%|████▉     | 3120/6244 [11:23<11:24,  4.56it/s]


Iteration 3120/6244 of epoch 19 complete. Loss : 0.1949801552348221 


 60%|█████▉    | 3744/6244 [13:40<09:07,  4.57it/s]


Iteration 3744/6244 of epoch 19 complete. Loss : 0.19508305145427585 


 70%|██████▉   | 4368/6244 [15:57<06:52,  4.54it/s]


Iteration 4368/6244 of epoch 19 complete. Loss : 0.1929661033149713 


 80%|███████▉  | 4992/6244 [18:13<04:35,  4.54it/s]


Iteration 4992/6244 of epoch 19 complete. Loss : 0.19326173105778602 


 90%|████████▉ | 5616/6244 [20:30<02:18,  4.54it/s]


Iteration 5616/6244 of epoch 19 complete. Loss : 0.1935531156233106 


100%|█████████▉| 6240/6244 [22:47<00:00,  4.55it/s]


Iteration 6240/6244 of epoch 19 complete. Loss : 0.1914455391203937 


100%|██████████| 6244/6244 [22:48<00:00,  4.56it/s]
100%|██████████| 781/781 [01:03<00:00, 12.25it/s]
  0%|          | 0/6244 [00:00<?, ?it/s]


Epoch 19 complete! Validation Loss : 0.4149071317018223


 10%|▉         | 624/6244 [02:16<20:31,  4.56it/s]


Iteration 624/6244 of epoch 20 complete. Loss : 0.1894017465484257 


 20%|█▉        | 1248/6244 [04:33<18:15,  4.56it/s]


Iteration 1248/6244 of epoch 20 complete. Loss : 0.19288152209124887 


 30%|██▉       | 1872/6244 [06:50<15:58,  4.56it/s]


Iteration 1872/6244 of epoch 20 complete. Loss : 0.1929334621064556 


 40%|███▉      | 2496/6244 [09:06<13:42,  4.56it/s]


Iteration 2496/6244 of epoch 20 complete. Loss : 0.1909847928521534 


 50%|████▉     | 3120/6244 [11:23<11:25,  4.56it/s]


Iteration 3120/6244 of epoch 20 complete. Loss : 0.19252657688533267 


 60%|█████▉    | 3744/6244 [13:40<09:07,  4.56it/s]


Iteration 3744/6244 of epoch 20 complete. Loss : 0.1929190503910948 


 70%|██████▉   | 4368/6244 [15:57<06:51,  4.56it/s]


Iteration 4368/6244 of epoch 20 complete. Loss : 0.19081927890865466 


 80%|███████▉  | 4992/6244 [18:13<04:34,  4.56it/s]


Iteration 4992/6244 of epoch 20 complete. Loss : 0.19087189014475697 


 90%|████████▉ | 5616/6244 [20:30<02:17,  4.56it/s]


Iteration 5616/6244 of epoch 20 complete. Loss : 0.1904088930489543 


100%|█████████▉| 6240/6244 [22:47<00:00,  4.55it/s]


Iteration 6240/6244 of epoch 20 complete. Loss : 0.18800144887362152 


100%|██████████| 6244/6244 [22:48<00:00,  4.56it/s]
100%|██████████| 781/781 [01:03<00:00, 12.25it/s]



Epoch 20 complete! Validation Loss : 0.4195267601408513
The model has been saved in models/albertLarge8.pt


In [None]:
#!kill -9 -1 #in case you have to kill the process

## Prediction

In [None]:
def get_probs_from_logits(logits):
    """
    Converts a tensor of logits into an array of probabilities by applying the sigmoid function
    """
    probs = torch.sigmoid(logits.unsqueeze(-1))
    return probs.detach().cpu().numpy()


def test_prediction(net, device, aptamerDataFrame, dataloader, with_labels, result_path=config['Random']['path_to_model_evaluation']):
    """
    Predict the probabilities on a dataset with or without labels and print the result in a file
    """
    net.eval()
    probs_all = []
    nb_iterations = len(dataloader)
    
    with torch.no_grad():
        if with_labels:
            for it, (seq, attn_masks, token_type_ids) in tqdm(enumerate(dataloader), total = nb_iterations):
                seq, attn_masks, token_type_ids = seq.to(device), attn_masks.to(device), token_type_ids.to(device)
                logits = net(seq, attn_masks, token_type_ids)
                probs = get_probs_from_logits(logits.squeeze(-1)).squeeze(-1)
                probs_all += probs.tolist()
                
        else:
            for it, (seq, attn_masks, token_type_ids) in tqdm(enumerate(dataloader), total=nb_iterations):
                seq, attn_masks, token_type_ids = seq.to(device), attn_masks.to(device), token_type_ids.to(device)
                logits = net(seq, attn_masks, token_type_ids)
                probs = get_probs_from_logits(logits.squeeze(-1)).squeeze(-1)
                probs_all += probs.tolist()
            
    y_hat = [round(x) for x in probs_all]     
    df2 = pd.DataFrame({'y_hat': y_hat, 'prob': probs_all})
    df = pd.concat([aptamerDataFrame, df2], axis=1)
    df.to_csv(result_path)


In [None]:
print("Reading test data...")
test_set = CustomDataset(df_test, maxlen, bert_model)
test_loader = DataLoader(test_set, batch_size=bs, num_workers=4)

model = Model(bert_model)
if torch.cuda.device_count() > 1:  # if multiple GPUs
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)

print()
print("Loading the weights of the model...")
model.load_state_dict(torch.load(path_to_model))
model.to(device)

print("Predicting on test data...")
test_prediction(net=model
                , device=device
                , dataloader=test_loader
                , aptamerDataFrame=df_train)

print()
print("Predictions are available in : {}".format(config['Random']['path_to_model_evaluation']))

Reading test data...


  cpuset_checked))
  0%|          | 0/781 [00:00<?, ?it/s]


Loading the weights of the model...
Predicting on test data...


100%|██████████| 781/781 [01:03<00:00, 12.25it/s]



Predictions are available in : results/output.txt


#  Model Metrics

In [None]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix

In [None]:
df = pd.read_csv(config['Random']['path_to_model_evaluation'])

accuracy = accuracy_score(df['Label'], df['y_hat'])
precision = precision_score(df['Label'], df['y_hat'])
recall = recall_score(df['Label'], df['y_hat'])
f1 = f1_score(df['Label'], df['y_hat'])
roc_auc = roc_auc_score(df['Label'], df['prob'])
matrix = confusion_matrix(df['Label'], df['y_hat'])

print('accuracy: ', accuracy)
print('precision: ', precision)
print('recall: ', recall)
print('f1: ', f1)
print('roc_auc: ', roc_auc)
print('matrix: \n', matrix)

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import auc, roc_curve

fpr,tpr,threshold = roc_curve(df['Label'], df['prob'])
roc_auc =auc(fpr, tpr)

plt.plot(fpr, tpr, label='ROC curve')
plt.plot([0,1],[0,1], 'k--')
plt.xlabel('False Positive Rate or (1 - Specifity)')
plt.ylabel('True Positive Rate or (Sensitivity)')
plt.legend(loc='lower right')

#  Convert model to ONNX


In [None]:
if config['Model']['convert_to_onnx']:
    sequences = {'AGCTAAGCTAAGCTA':'AAGACTGACAGCTAA'}
    sequences=pd.DataFrame(sequences.items(), columns=['Sequence1', 'Sequence2'])
    dataset = CustomDataset(
        data = sequences,
        maxlen = maxlen,
        with_labels = False #sulyginti su class cusotmedatasets
        )
    model = Model(bert_model)
    #model = nn.DataParallel(model)
    model.load_state_dict(torch.load(path_to_model))
    model.to(device)
    model.eval()

    #defining model inputs, which are ids, kmask and toke type ids (it comes from Model Class)
    input_ids= dataset[0]['token_ids'].unsqueeze(0).cuda()
    attn_masks = dataset[0]['attn_masks'].unsqueeze(0).cuda()
    token_type_ids = dataset[0]['token_type_ids'].unsqueeze(0).cuda()
        
    torch.onnx.export(
        model, #.module if paralized
        (input_ids, attn_masks, token_type_ids),
        "model.onnx",
        input_names=["input_ids", "attn_masks", "token_type_ids"], 
        output_names=["output"],
        #which inputs have dynamical axes
        verbose=True,
        dynamic_axes={
            "input_ids": {0: "batch_size"},
            "attn_masks": {0: "batch_size"},
            "token_type_ids": {0: "batch_size"},
            "output": {0: "batch_size"},
        },
        opset_version=10,
        )