In [1]:
!pip install -q transformers



In [2]:
import torch
import time
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import copy
import datetime
import shutil
import torch.optim as optim
import random
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset, TensorDataset, RandomSampler, SequentialSampler
from torch.cuda.amp import autocast, GradScaler
from torch import nn
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup, BertForSequenceClassification, BertConfig
from datasets import load_dataset, load_metric
from sklearn.model_selection import train_test_split
from sklearn.metrics import matthews_corrcoef

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
# Check that we are using 100% of GPU memory footprint support libraries/code
# from https://github.com/patrickvonplaten/notebooks/blob/master/PyTorch_Reformer.ipynb
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip -q install gputil
!pip -q install psutil
!pip -q install humanize
import psutil
import humanize
import os
import GPUtil as GPU
GPUs = GPU.getGPUs()
# XXX: only one GPU on Colab and isn’t guaranteed
gpu = GPUs[0]
def printm():
 process = psutil.Process(os.getpid())
 print("Gen RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available ), " | Proc size: " + humanize.naturalsize( process.memory_info().rss))
 print("GPU RAM Free: {0:.0f}MB | Used: {1:.0f}MB | Util {2:3.0f}% | Total {3:.0f}MB".format(gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil*100, gpu.memoryTotal))
printm()

Gen RAM Free: 15.7 GB  | Proc size: 687.3 MB
GPU RAM Free: 16280MB | Used: 0MB | Util   0% | Total 16280MB


In [4]:
df = pd.read_csv('../input/full-in-house-dataset/Final_In_house_dataset.csv')
df.sample(10)

Unnamed: 0,Context,Acronym,Target,Definition,Label
58376,2 Terni 04/24/19 4^3 06:39/00:18 3.955 kg F Va...,PE,Pulmonary embolism,Pulmonary embolism is a blockage in one of the...,False
37943,General/bilateral: • MS: • MS: Right lower ext...,MS,Musculoskeletal system,Your musculoskeletal system includes your bone...,True
91817,150.1- 428.1 Left VF I,VF,Ventricular fibrillation,Ventricular fibrillation is a type of abnormal...,False
40558,Intractable PE (CMS-hcc) [G40.119],PE,Pulmonary emphysema,Pulmonary emphysema is a chronic lung conditio...,False
11755,Pt continued to have CP which required her to ...,CP,Chest pain,"Discomfort in the chest including a dull ache,...",True
81934,"Awaiting imaging, unclear etiology of symptoms...",SCT,Spinocerebellar tract,The spinocerebellar tracts are afferent neuron...,False
68227,PVD (PVD) (CMS/HCC) I73.9,PVD,Primary affective disorders,The main types of affective disorders are depr...,False
53730,• PE - Annual physical exam e Retinopathy due ...,PE,Pleural effusion,"Pleural effusion, sometimes referred to as “wa...",False
46565,"1C] overweight(CLAMI, breast cancerCL1c) s/p b...",PE,Physical examination,"The routine physical, also known as general me...",True
97372,He was discharged with Foley catheter in place...,VT,Ventricular tachycardia,A condition in which the lower chambers of the...,False


In [5]:
df_train, df_test = train_test_split(df, test_size = 0.3, stratify = df['Label'], random_state = 11)

In [6]:
df_train.reset_index(inplace = True)
df_test.reset_index(inplace = True)

In [7]:
class CustomDataset(Dataset):

    def __init__(self, data, maxlen, with_labels=True, bert_model='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'):

        self.data = data  # pandas dataframe
        self.tokenizer = AutoTokenizer.from_pretrained(bert_model, do_lower_case=True)  
        self.sent1 = self.data['Context'].values
        self.sent2 = self.data['Definition'].values
        self.maxlen = maxlen
        self.with_labels = with_labels 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        # Selecting sentence1 and sentence2 at the specified index in the data frame
        sent1 = str(self.sent1[index])
        sent2 = str(self.sent2[index])

        # Tokenize the pair of sentences to get token ids, attention masks and token type ids
        encoded_pair = self.tokenizer(sent1, sent2, 
                                      padding='max_length',  # Pad to max_length
                                      truncation=True,  # Truncate to max_length
                                      max_length=self.maxlen,  
                                      return_tensors='pt')  # Return torch.Tensor objects
        
        token_ids = encoded_pair['input_ids'].squeeze(0)  # tensor of token ids
        attn_masks = encoded_pair['attention_mask'].squeeze(0)  # binary tensor with "0" for padded values and "1" for the other values
        #token_type_ids = encoded_pair['token_type_ids'].squeeze(0)  # binary tensor with "0" for the 1st sentence tokens & "1" for the 2nd sentence tokens

        if self.with_labels:  # True if the dataset has labels
            label = self.data.loc[index, 'Label']
            return token_ids, attn_masks,  int(label)
        else:
            return token_ids, attn_masks
    

In [8]:
def get_data_loaders(batch_size, train_data, validation_data):
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size, num_workers=5)
    validation_sampler = SequentialSampler(validation_data)
    validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size, num_workers=5)
    return train_dataloader, validation_dataloader

In [9]:
def set_seed(seed):
    """ Set all seeds to make results reproducible """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    

def evaluate(net, device, criterion, dataloader):
    print("")
    print("Running Validation...")
    t0 = time.time()
    # Put the model in evaluation mode--the dropout layers behave differently
    # during evaluation.
    net.eval()
    # Tracking variables 
    eval_loss, eval_accuracy = 0, 0
    valid_loss = 0.0
    nb_eval_steps, nb_eval_examples = 0, 0
    # Evaluate data for one epoch
    for batch in tqdm(dataloader):
        
        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)
        
        # Unpack the inputs from our dataloader
        b_input_ids, b_input_mask, b_labels = batch
        
        # Telling the model not to compute or store gradients, saving memory and
        # speeding up validation
        with torch.no_grad():        
            outputs = net(b_input_ids, 
                            token_type_ids=None, 
                            attention_mask=b_input_mask)
        
        # Get the "logits" output by the model. The "logits" are the output
        # values prior to applying an activation function like the softmax.
        logits = outputs[0]
        loss = criterion(logits, b_labels) #Criterion is crossentropy loss
        valid_loss += loss.item()
        # Move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
      
        # Calculate the accuracy for this batch of test sentences.
        tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        
        # Accumulate the total accuracy.
        eval_accuracy += tmp_eval_accuracy
        # Track the number of batches
        nb_eval_steps += 1
    # Report the final accuracy for this validation run.
    avg_valid_loss = valid_loss / len(dataloader) 
    avg_eval_accuracy = eval_accuracy/nb_eval_steps
    return avg_valid_loss, avg_eval_accuracy

In [10]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)


def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds = elapsed_rounded))


def save_ckp(state, is_best, checkpoint_path, best_model_path):
    f_path = checkpoint_path
    torch.save(state, f_path)
    if is_best:
        best_fpath = best_model_path
        shutil.copyfile(f_path, best_fpath)


def train(model, epochs, train_dataloader, validation_dataloader, optimizer, scheduler, device, criterion, 
                                                            checkpoint_path, best_model_path):
    loss_values = []
    valid_loss_min = np.inf
    for epoch_i in range(0, epochs):
        
        # ========================================
        #               Training
        # ========================================
        
        # Perform one full pass over the training set.
        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
        print('Training...')
        # Measure how long the training epoch takes.
        t0 = time.time()
        # Reset the total loss for this epoch.
        total_loss = 0.0
        valid_loss = 0.0
        # Put the model into training mode.
        model.train()
        # For each batch of training data...
        for step, batch in enumerate(tqdm(train_dataloader)):
            # Progress update every 40 batches.
            if step % 40 == 0 and not step == 0:
                # Calculate elapsed time in minutes.
                elapsed = format_time(time.time() - t0)
                
                # Report progress.
                print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))
            # Unpack this training batch from our dataloader. 
            # `batch` contains three pytorch tensors:
            #   [0]: input ids 
            #   [1]: attention masks
            #   [2]: labels 
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_labels = batch[2].to(device)

            model.zero_grad()        
            # Perform a forward pass (evaluate the model on this training batch).
            # This will return the loss (rather than the model output) because we have provided the labels.
            outputs = model(b_input_ids, 
                        token_type_ids=None, 
                        attention_mask=b_input_mask, 
                        labels=b_labels)
            

            loss = outputs[0]
            total_loss += loss.item()
            # Perform a backward pass to calculate the gradients.
            loss.backward()
            # Clip the norm of the gradients to 1.0.
            # This is to help prevent the "exploding gradients" problem.
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # Update parameters and take a step using the computed gradient.
            optimizer.step()
            # Update the learning rate.
            scheduler.step()
        # Calculate the average loss over the training data.
        avg_train_loss = total_loss / len(train_dataloader)            
        
        # Store the loss value for plotting the learning curve.
        loss_values.append(avg_train_loss)
        print("")
        print("  Average training loss: {0:.2f}".format(avg_train_loss))
        print("  Training epcoh took: {:}".format(format_time(time.time() - t0)))
            
        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.

        avg_valid_loss, avg_eval_accuracy = evaluate(model, device, criterion, validation_dataloader)
        print("  Accuracy: {0:.2f}".format(avg_eval_accuracy))
        print("  Average validation loss: {0:.2f}".format(avg_valid_loss))
        print("  Validation took: {:}".format(format_time(time.time() - t0)))
        checkpoint = {
                'epoch': epoch_i + 1,
                'valid_loss_min': avg_valid_loss,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            
        save_ckp(checkpoint, False, checkpoint_path, best_model_path)
        if avg_valid_loss <= valid_loss_min:
                print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min, avg_valid_loss))
                save_ckp(checkpoint, True, checkpoint_path, best_model_path)
                valid_loss_min = avg_valid_loss
        
        
    print("")
    print("Training complete!")
    return model

In [11]:
model = BertForSequenceClassification.from_pretrained(
  'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext',
    num_labels = 2,  
    output_attentions = False, 
    output_hidden_states = False,
)

if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")
    
model.to(device)

Downloading:   0%|          | 0.00/385 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Ber

There are 1 GPU(s) available.
We will use the GPU: Tesla P100-PCIE-16GB


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [12]:
train_data = CustomDataset(df_train, 512, model)
validation_data = CustomDataset(df_test, 512, model)
train_dataloader, validation_dataloader = get_data_loaders(16, train_data, validation_data)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/221k [00:00<?, ?B/s]

  cpuset_checked))


In [13]:
# Get all of the model's parameters as a list of tuples.
params = list(model.named_parameters())
print('The BERT model has {:} different named parameters.\n'.format(len(params)))
print('==== Embedding Layer ====\n')
for p in params[0:5]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
print('\n==== First Transformer ====\n')
for p in params[5:21]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))
print('\n==== Output Layer ====\n')
for p in params[-4:]:
    print("{:<55} {:>12}".format(p[0], str(tuple(p[1].size()))))

params = list(model.named_parameters())
for name, param in params[-5:]:
    param.requires_grad = True
for name, param in params[:-6]:
    param.requires_grad = True

optimizer = AdamW(model.parameters(),
                  lr = 2e-5, 
                  eps = 1e-8 
                )

epochs = 3
total_steps = len(train_dataloader) * epochs
# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps)
criterion = nn.CrossEntropyLoss()

The BERT model has 201 different named parameters.

==== Embedding Layer ====

bert.embeddings.word_embeddings.weight                  (30522, 768)
bert.embeddings.position_embeddings.weight                (512, 768)
bert.embeddings.token_type_embeddings.weight                (2, 768)
bert.embeddings.LayerNorm.weight                              (768,)
bert.embeddings.LayerNorm.bias                                (768,)

==== First Transformer ====

bert.encoder.layer.0.attention.self.query.weight          (768, 768)
bert.encoder.layer.0.attention.self.query.bias                (768,)
bert.encoder.layer.0.attention.self.key.weight            (768, 768)
bert.encoder.layer.0.attention.self.key.bias                  (768,)
bert.encoder.layer.0.attention.self.value.weight          (768, 768)
bert.encoder.layer.0.attention.self.value.bias                (768,)
bert.encoder.layer.0.attention.output.dense.weight        (768, 768)
bert.encoder.layer.0.attention.output.dense.bias              (

In [14]:
#  Set all seeds to make reproducible results
set_seed(1)
%mkdir checkpoint best_model
checkpoint_path = "./checkpoint/current_checkpoint.pt"
best_model_path = "./best_model/best_model_exp_3_pubmedbert_inhouse.pt"
model = train(model, epochs, train_dataloader, validation_dataloader, optimizer, scheduler, device, criterion, 
                                                            checkpoint_path, best_model_path)


Training...


  1%|          | 40/4392 [00:35<1:02:09,  1.17it/s]

  Batch    40  of  4,392.    Elapsed: 0:00:36.


  2%|▏         | 80/4392 [01:10<1:01:37,  1.17it/s]

  Batch    80  of  4,392.    Elapsed: 0:01:10.


  3%|▎         | 120/4392 [01:44<1:01:04,  1.17it/s]

  Batch   120  of  4,392.    Elapsed: 0:01:45.


  4%|▎         | 160/4392 [02:18<1:00:57,  1.16it/s]

  Batch   160  of  4,392.    Elapsed: 0:02:19.


  5%|▍         | 200/4392 [02:53<59:58,  1.16it/s]

  Batch   200  of  4,392.    Elapsed: 0:02:53.


  5%|▌         | 240/4392 [03:27<59:28,  1.16it/s]

  Batch   240  of  4,392.    Elapsed: 0:03:28.


  6%|▋         | 280/4392 [04:02<58:52,  1.16it/s]

  Batch   280  of  4,392.    Elapsed: 0:04:02.


  7%|▋         | 320/4392 [04:36<58:09,  1.17it/s]

  Batch   320  of  4,392.    Elapsed: 0:04:36.


  8%|▊         | 360/4392 [05:10<57:34,  1.17it/s]

  Batch   360  of  4,392.    Elapsed: 0:05:11.


  9%|▉         | 400/4392 [05:45<57:04,  1.17it/s]

  Batch   400  of  4,392.    Elapsed: 0:05:45.


 10%|█         | 440/4392 [06:19<56:24,  1.17it/s]

  Batch   440  of  4,392.    Elapsed: 0:06:20.


 11%|█         | 480/4392 [06:54<56:29,  1.15it/s]

  Batch   480  of  4,392.    Elapsed: 0:06:54.


 12%|█▏        | 520/4392 [07:28<55:50,  1.16it/s]

  Batch   520  of  4,392.    Elapsed: 0:07:29.


 13%|█▎        | 560/4392 [08:02<54:49,  1.16it/s]

  Batch   560  of  4,392.    Elapsed: 0:08:03.


 14%|█▎        | 600/4392 [08:37<54:18,  1.16it/s]

  Batch   600  of  4,392.    Elapsed: 0:08:37.


 15%|█▍        | 640/4392 [09:11<53:53,  1.16it/s]

  Batch   640  of  4,392.    Elapsed: 0:09:12.


 15%|█▌        | 680/4392 [09:46<53:11,  1.16it/s]

  Batch   680  of  4,392.    Elapsed: 0:09:46.


 16%|█▋        | 720/4392 [10:20<52:27,  1.17it/s]

  Batch   720  of  4,392.    Elapsed: 0:10:20.


 17%|█▋        | 760/4392 [10:54<51:47,  1.17it/s]

  Batch   760  of  4,392.    Elapsed: 0:10:55.


 18%|█▊        | 800/4392 [11:29<51:36,  1.16it/s]

  Batch   800  of  4,392.    Elapsed: 0:11:29.


 19%|█▉        | 840/4392 [12:03<50:50,  1.16it/s]

  Batch   840  of  4,392.    Elapsed: 0:12:04.


 20%|██        | 880/4392 [12:38<50:14,  1.16it/s]

  Batch   880  of  4,392.    Elapsed: 0:12:38.


 21%|██        | 920/4392 [13:12<49:47,  1.16it/s]

  Batch   920  of  4,392.    Elapsed: 0:13:12.


 22%|██▏       | 960/4392 [13:46<49:06,  1.16it/s]

  Batch   960  of  4,392.    Elapsed: 0:13:47.


 23%|██▎       | 1000/4392 [14:21<48:30,  1.17it/s]

  Batch 1,000  of  4,392.    Elapsed: 0:14:21.


 24%|██▎       | 1040/4392 [14:55<48:03,  1.16it/s]

  Batch 1,040  of  4,392.    Elapsed: 0:14:56.


 25%|██▍       | 1080/4392 [15:30<47:25,  1.16it/s]

  Batch 1,080  of  4,392.    Elapsed: 0:15:30.


 26%|██▌       | 1120/4392 [16:04<47:37,  1.14it/s]

  Batch 1,120  of  4,392.    Elapsed: 0:16:05.


 26%|██▋       | 1160/4392 [16:39<46:34,  1.16it/s]

  Batch 1,160  of  4,392.    Elapsed: 0:16:39.


 27%|██▋       | 1200/4392 [17:13<45:54,  1.16it/s]

  Batch 1,200  of  4,392.    Elapsed: 0:17:14.


 28%|██▊       | 1240/4392 [17:47<45:15,  1.16it/s]

  Batch 1,240  of  4,392.    Elapsed: 0:17:48.


 29%|██▉       | 1280/4392 [18:22<44:37,  1.16it/s]

  Batch 1,280  of  4,392.    Elapsed: 0:18:22.


 30%|███       | 1320/4392 [18:56<43:55,  1.17it/s]

  Batch 1,320  of  4,392.    Elapsed: 0:18:57.


 31%|███       | 1360/4392 [19:31<43:24,  1.16it/s]

  Batch 1,360  of  4,392.    Elapsed: 0:19:31.


 32%|███▏      | 1400/4392 [20:05<42:45,  1.17it/s]

  Batch 1,400  of  4,392.    Elapsed: 0:20:06.


 33%|███▎      | 1440/4392 [20:40<42:11,  1.17it/s]

  Batch 1,440  of  4,392.    Elapsed: 0:20:40.


 34%|███▎      | 1480/4392 [21:14<41:54,  1.16it/s]

  Batch 1,480  of  4,392.    Elapsed: 0:21:15.


 35%|███▍      | 1520/4392 [21:49<41:04,  1.17it/s]

  Batch 1,520  of  4,392.    Elapsed: 0:21:49.


 36%|███▌      | 1560/4392 [22:23<40:27,  1.17it/s]

  Batch 1,560  of  4,392.    Elapsed: 0:22:23.


 36%|███▋      | 1600/4392 [22:57<39:49,  1.17it/s]

  Batch 1,600  of  4,392.    Elapsed: 0:22:58.


 37%|███▋      | 1640/4392 [23:32<39:30,  1.16it/s]

  Batch 1,640  of  4,392.    Elapsed: 0:23:32.


 38%|███▊      | 1680/4392 [24:06<38:51,  1.16it/s]

  Batch 1,680  of  4,392.    Elapsed: 0:24:07.


 39%|███▉      | 1720/4392 [24:41<38:29,  1.16it/s]

  Batch 1,720  of  4,392.    Elapsed: 0:24:41.


 40%|████      | 1760/4392 [25:15<37:42,  1.16it/s]

  Batch 1,760  of  4,392.    Elapsed: 0:25:16.


 41%|████      | 1800/4392 [25:50<37:11,  1.16it/s]

  Batch 1,800  of  4,392.    Elapsed: 0:25:50.


 42%|████▏     | 1840/4392 [26:24<36:31,  1.16it/s]

  Batch 1,840  of  4,392.    Elapsed: 0:26:25.


 43%|████▎     | 1880/4392 [26:58<35:54,  1.17it/s]

  Batch 1,880  of  4,392.    Elapsed: 0:26:59.


 44%|████▎     | 1920/4392 [27:33<35:24,  1.16it/s]

  Batch 1,920  of  4,392.    Elapsed: 0:27:33.


 45%|████▍     | 1960/4392 [28:07<34:58,  1.16it/s]

  Batch 1,960  of  4,392.    Elapsed: 0:28:08.


 46%|████▌     | 2000/4392 [28:42<34:19,  1.16it/s]

  Batch 2,000  of  4,392.    Elapsed: 0:28:42.


 46%|████▋     | 2040/4392 [29:16<34:01,  1.15it/s]

  Batch 2,040  of  4,392.    Elapsed: 0:29:17.


 47%|████▋     | 2080/4392 [29:51<33:11,  1.16it/s]

  Batch 2,080  of  4,392.    Elapsed: 0:29:51.


 48%|████▊     | 2120/4392 [30:25<32:44,  1.16it/s]

  Batch 2,120  of  4,392.    Elapsed: 0:30:26.


 49%|████▉     | 2160/4392 [31:00<32:07,  1.16it/s]

  Batch 2,160  of  4,392.    Elapsed: 0:31:00.


 50%|█████     | 2200/4392 [31:34<31:19,  1.17it/s]

  Batch 2,200  of  4,392.    Elapsed: 0:31:35.


 51%|█████     | 2240/4392 [32:08<30:48,  1.16it/s]

  Batch 2,240  of  4,392.    Elapsed: 0:32:09.


 52%|█████▏    | 2280/4392 [32:43<30:08,  1.17it/s]

  Batch 2,280  of  4,392.    Elapsed: 0:32:43.


 53%|█████▎    | 2320/4392 [33:17<29:48,  1.16it/s]

  Batch 2,320  of  4,392.    Elapsed: 0:33:18.


 54%|█████▎    | 2360/4392 [33:52<29:03,  1.17it/s]

  Batch 2,360  of  4,392.    Elapsed: 0:33:52.


 55%|█████▍    | 2400/4392 [34:26<28:40,  1.16it/s]

  Batch 2,400  of  4,392.    Elapsed: 0:34:27.


 56%|█████▌    | 2440/4392 [35:01<28:02,  1.16it/s]

  Batch 2,440  of  4,392.    Elapsed: 0:35:01.


 56%|█████▋    | 2480/4392 [35:35<27:20,  1.17it/s]

  Batch 2,480  of  4,392.    Elapsed: 0:35:36.


 57%|█████▋    | 2520/4392 [36:10<26:45,  1.17it/s]

  Batch 2,520  of  4,392.    Elapsed: 0:36:10.


 58%|█████▊    | 2560/4392 [36:44<26:26,  1.15it/s]

  Batch 2,560  of  4,392.    Elapsed: 0:36:45.


 59%|█████▉    | 2600/4392 [37:19<25:41,  1.16it/s]

  Batch 2,600  of  4,392.    Elapsed: 0:37:19.


 60%|██████    | 2640/4392 [37:53<25:09,  1.16it/s]

  Batch 2,640  of  4,392.    Elapsed: 0:37:54.


 61%|██████    | 2680/4392 [38:28<24:31,  1.16it/s]

  Batch 2,680  of  4,392.    Elapsed: 0:38:28.


 62%|██████▏   | 2720/4392 [39:02<23:57,  1.16it/s]

  Batch 2,720  of  4,392.    Elapsed: 0:39:03.


 63%|██████▎   | 2760/4392 [39:37<23:22,  1.16it/s]

  Batch 2,760  of  4,392.    Elapsed: 0:39:37.


 64%|██████▍   | 2800/4392 [40:11<22:50,  1.16it/s]

  Batch 2,800  of  4,392.    Elapsed: 0:40:12.


 65%|██████▍   | 2840/4392 [40:46<22:13,  1.16it/s]

  Batch 2,840  of  4,392.    Elapsed: 0:40:46.


 66%|██████▌   | 2880/4392 [41:20<21:49,  1.15it/s]

  Batch 2,880  of  4,392.    Elapsed: 0:41:21.


 66%|██████▋   | 2920/4392 [41:55<21:19,  1.15it/s]

  Batch 2,920  of  4,392.    Elapsed: 0:41:55.


 67%|██████▋   | 2960/4392 [42:29<20:41,  1.15it/s]

  Batch 2,960  of  4,392.    Elapsed: 0:42:30.


 68%|██████▊   | 3000/4392 [43:04<19:57,  1.16it/s]

  Batch 3,000  of  4,392.    Elapsed: 0:43:04.


 69%|██████▉   | 3040/4392 [43:38<19:34,  1.15it/s]

  Batch 3,040  of  4,392.    Elapsed: 0:43:39.


 70%|███████   | 3080/4392 [44:13<18:55,  1.16it/s]

  Batch 3,080  of  4,392.    Elapsed: 0:44:13.


 71%|███████   | 3120/4392 [44:47<18:14,  1.16it/s]

  Batch 3,120  of  4,392.    Elapsed: 0:44:48.


 72%|███████▏  | 3160/4392 [45:22<17:36,  1.17it/s]

  Batch 3,160  of  4,392.    Elapsed: 0:45:22.


 73%|███████▎  | 3200/4392 [45:56<17:03,  1.17it/s]

  Batch 3,200  of  4,392.    Elapsed: 0:45:56.


 74%|███████▍  | 3240/4392 [46:30<16:37,  1.16it/s]

  Batch 3,240  of  4,392.    Elapsed: 0:46:31.


 75%|███████▍  | 3280/4392 [47:05<15:58,  1.16it/s]

  Batch 3,280  of  4,392.    Elapsed: 0:47:05.


 76%|███████▌  | 3320/4392 [47:39<15:22,  1.16it/s]

  Batch 3,320  of  4,392.    Elapsed: 0:47:40.


 77%|███████▋  | 3360/4392 [48:14<14:46,  1.16it/s]

  Batch 3,360  of  4,392.    Elapsed: 0:48:14.


 77%|███████▋  | 3400/4392 [48:48<14:13,  1.16it/s]

  Batch 3,400  of  4,392.    Elapsed: 0:48:49.


 78%|███████▊  | 3440/4392 [49:23<13:37,  1.16it/s]

  Batch 3,440  of  4,392.    Elapsed: 0:49:24.


 79%|███████▉  | 3480/4392 [49:57<13:05,  1.16it/s]

  Batch 3,480  of  4,392.    Elapsed: 0:49:58.


 80%|████████  | 3520/4392 [50:32<12:29,  1.16it/s]

  Batch 3,520  of  4,392.    Elapsed: 0:50:32.


 81%|████████  | 3560/4392 [51:06<12:04,  1.15it/s]

  Batch 3,560  of  4,392.    Elapsed: 0:51:07.


 82%|████████▏ | 3600/4392 [51:41<11:23,  1.16it/s]

  Batch 3,600  of  4,392.    Elapsed: 0:51:42.


 83%|████████▎ | 3640/4392 [52:16<10:51,  1.15it/s]

  Batch 3,640  of  4,392.    Elapsed: 0:52:16.


 84%|████████▍ | 3680/4392 [52:50<10:11,  1.16it/s]

  Batch 3,680  of  4,392.    Elapsed: 0:52:51.


 85%|████████▍ | 3720/4392 [53:25<09:38,  1.16it/s]

  Batch 3,720  of  4,392.    Elapsed: 0:53:25.


 86%|████████▌ | 3760/4392 [53:59<09:02,  1.17it/s]

  Batch 3,760  of  4,392.    Elapsed: 0:54:00.


 87%|████████▋ | 3800/4392 [54:34<08:31,  1.16it/s]

  Batch 3,800  of  4,392.    Elapsed: 0:54:34.


 87%|████████▋ | 3840/4392 [55:08<07:56,  1.16it/s]

  Batch 3,840  of  4,392.    Elapsed: 0:55:09.


 88%|████████▊ | 3880/4392 [55:43<07:19,  1.17it/s]

  Batch 3,880  of  4,392.    Elapsed: 0:55:43.


 89%|████████▉ | 3920/4392 [56:17<06:48,  1.15it/s]

  Batch 3,920  of  4,392.    Elapsed: 0:56:18.


 90%|█████████ | 3960/4392 [56:52<06:13,  1.16it/s]

  Batch 3,960  of  4,392.    Elapsed: 0:56:52.


 91%|█████████ | 4000/4392 [57:26<05:37,  1.16it/s]

  Batch 4,000  of  4,392.    Elapsed: 0:57:27.


 92%|█████████▏| 4040/4392 [58:01<05:03,  1.16it/s]

  Batch 4,040  of  4,392.    Elapsed: 0:58:01.


 93%|█████████▎| 4080/4392 [58:35<04:31,  1.15it/s]

  Batch 4,080  of  4,392.    Elapsed: 0:58:36.


 94%|█████████▍| 4120/4392 [59:10<03:57,  1.15it/s]

  Batch 4,120  of  4,392.    Elapsed: 0:59:10.


 95%|█████████▍| 4160/4392 [59:44<03:20,  1.16it/s]

  Batch 4,160  of  4,392.    Elapsed: 0:59:45.


 96%|█████████▌| 4200/4392 [1:00:19<02:45,  1.16it/s]

  Batch 4,200  of  4,392.    Elapsed: 1:00:19.


 97%|█████████▋| 4240/4392 [1:00:53<02:10,  1.16it/s]

  Batch 4,240  of  4,392.    Elapsed: 1:00:54.


 97%|█████████▋| 4280/4392 [1:01:28<01:36,  1.16it/s]

  Batch 4,280  of  4,392.    Elapsed: 1:01:28.


 98%|█████████▊| 4320/4392 [1:02:02<01:01,  1.16it/s]

  Batch 4,320  of  4,392.    Elapsed: 1:02:03.


 99%|█████████▉| 4360/4392 [1:02:37<00:27,  1.17it/s]

  Batch 4,360  of  4,392.    Elapsed: 1:02:37.


100%|██████████| 4392/4392 [1:03:04<00:00,  1.16it/s]



  Average training loss: 0.19
  Training epcoh took: 1:03:05

Running Validation...


100%|██████████| 1883/1883 [09:07<00:00,  3.44it/s]


  Accuracy: 0.97
  Average validation loss: 0.11
  Validation took: 1:12:12
Validation loss decreased (inf --> 0.112534).  Saving model ...

Training...


  1%|          | 40/4392 [00:34<1:02:54,  1.15it/s]

  Batch    40  of  4,392.    Elapsed: 0:00:35.


  2%|▏         | 80/4392 [01:09<1:01:51,  1.16it/s]

  Batch    80  of  4,392.    Elapsed: 0:01:09.


  3%|▎         | 120/4392 [01:43<1:01:26,  1.16it/s]

  Batch   120  of  4,392.    Elapsed: 0:01:44.


  4%|▎         | 160/4392 [02:18<1:00:32,  1.16it/s]

  Batch   160  of  4,392.    Elapsed: 0:02:18.


  5%|▍         | 200/4392 [02:53<59:48,  1.17it/s]

  Batch   200  of  4,392.    Elapsed: 0:02:53.


  5%|▌         | 240/4392 [03:27<59:32,  1.16it/s]

  Batch   240  of  4,392.    Elapsed: 0:03:28.


  6%|▋         | 280/4392 [04:02<58:40,  1.17it/s]

  Batch   280  of  4,392.    Elapsed: 0:04:02.


  7%|▋         | 320/4392 [04:36<58:28,  1.16it/s]

  Batch   320  of  4,392.    Elapsed: 0:04:37.


  8%|▊         | 360/4392 [05:11<57:55,  1.16it/s]

  Batch   360  of  4,392.    Elapsed: 0:05:11.


  9%|▉         | 400/4392 [05:45<57:04,  1.17it/s]

  Batch   400  of  4,392.    Elapsed: 0:05:46.


 10%|█         | 440/4392 [06:20<56:40,  1.16it/s]

  Batch   440  of  4,392.    Elapsed: 0:06:20.


 11%|█         | 480/4392 [06:54<56:03,  1.16it/s]

  Batch   480  of  4,392.    Elapsed: 0:06:55.


 12%|█▏        | 520/4392 [07:29<55:31,  1.16it/s]

  Batch   520  of  4,392.    Elapsed: 0:07:29.


 13%|█▎        | 560/4392 [08:04<54:52,  1.16it/s]

  Batch   560  of  4,392.    Elapsed: 0:08:04.


 14%|█▎        | 600/4392 [08:38<54:44,  1.15it/s]

  Batch   600  of  4,392.    Elapsed: 0:08:39.


 15%|█▍        | 640/4392 [09:13<53:54,  1.16it/s]

  Batch   640  of  4,392.    Elapsed: 0:09:13.


 15%|█▌        | 680/4392 [09:47<53:06,  1.16it/s]

  Batch   680  of  4,392.    Elapsed: 0:09:48.


 16%|█▋        | 720/4392 [10:22<52:35,  1.16it/s]

  Batch   720  of  4,392.    Elapsed: 0:10:22.


 17%|█▋        | 760/4392 [10:56<52:15,  1.16it/s]

  Batch   760  of  4,392.    Elapsed: 0:10:57.


 18%|█▊        | 800/4392 [11:31<51:43,  1.16it/s]

  Batch   800  of  4,392.    Elapsed: 0:11:31.


 19%|█▉        | 840/4392 [12:05<51:07,  1.16it/s]

  Batch   840  of  4,392.    Elapsed: 0:12:06.


 20%|██        | 880/4392 [12:40<50:23,  1.16it/s]

  Batch   880  of  4,392.    Elapsed: 0:12:41.


 21%|██        | 920/4392 [13:15<49:53,  1.16it/s]

  Batch   920  of  4,392.    Elapsed: 0:13:15.


 22%|██▏       | 960/4392 [13:49<49:07,  1.16it/s]

  Batch   960  of  4,392.    Elapsed: 0:13:50.


 23%|██▎       | 1000/4392 [14:24<48:38,  1.16it/s]

  Batch 1,000  of  4,392.    Elapsed: 0:14:24.


 24%|██▎       | 1040/4392 [14:58<47:58,  1.16it/s]

  Batch 1,040  of  4,392.    Elapsed: 0:14:59.


 25%|██▍       | 1080/4392 [15:33<47:34,  1.16it/s]

  Batch 1,080  of  4,392.    Elapsed: 0:15:33.


 26%|██▌       | 1120/4392 [16:07<47:23,  1.15it/s]

  Batch 1,120  of  4,392.    Elapsed: 0:16:08.


 26%|██▋       | 1160/4392 [16:42<47:29,  1.13it/s]

  Batch 1,160  of  4,392.    Elapsed: 0:16:42.


 27%|██▋       | 1200/4392 [17:17<46:38,  1.14it/s]

  Batch 1,200  of  4,392.    Elapsed: 0:17:17.


 28%|██▊       | 1240/4392 [17:51<46:05,  1.14it/s]

  Batch 1,240  of  4,392.    Elapsed: 0:17:52.


 29%|██▉       | 1280/4392 [18:26<45:28,  1.14it/s]

  Batch 1,280  of  4,392.    Elapsed: 0:18:26.


 30%|███       | 1320/4392 [19:00<44:35,  1.15it/s]

  Batch 1,320  of  4,392.    Elapsed: 0:19:01.


 31%|███       | 1360/4392 [19:35<43:43,  1.16it/s]

  Batch 1,360  of  4,392.    Elapsed: 0:19:35.


 32%|███▏      | 1400/4392 [20:09<43:09,  1.16it/s]

  Batch 1,400  of  4,392.    Elapsed: 0:20:10.


 33%|███▎      | 1440/4392 [20:44<42:37,  1.15it/s]

  Batch 1,440  of  4,392.    Elapsed: 0:20:45.


 34%|███▎      | 1480/4392 [21:19<41:59,  1.16it/s]

  Batch 1,480  of  4,392.    Elapsed: 0:21:19.


 35%|███▍      | 1520/4392 [21:53<41:19,  1.16it/s]

  Batch 1,520  of  4,392.    Elapsed: 0:21:54.


 36%|███▌      | 1560/4392 [22:28<40:58,  1.15it/s]

  Batch 1,560  of  4,392.    Elapsed: 0:22:28.


 36%|███▋      | 1600/4392 [23:02<40:30,  1.15it/s]

  Batch 1,600  of  4,392.    Elapsed: 0:23:03.


 37%|███▋      | 1640/4392 [23:37<39:37,  1.16it/s]

  Batch 1,640  of  4,392.    Elapsed: 0:23:38.


 38%|███▊      | 1680/4392 [24:12<38:59,  1.16it/s]

  Batch 1,680  of  4,392.    Elapsed: 0:24:12.


 39%|███▉      | 1720/4392 [24:46<38:42,  1.15it/s]

  Batch 1,720  of  4,392.    Elapsed: 0:24:47.


 40%|████      | 1760/4392 [25:21<37:53,  1.16it/s]

  Batch 1,760  of  4,392.    Elapsed: 0:25:21.


 41%|████      | 1800/4392 [25:55<37:08,  1.16it/s]

  Batch 1,800  of  4,392.    Elapsed: 0:25:56.


 42%|████▏     | 1840/4392 [26:30<36:38,  1.16it/s]

  Batch 1,840  of  4,392.    Elapsed: 0:26:30.


 43%|████▎     | 1880/4392 [27:04<36:00,  1.16it/s]

  Batch 1,880  of  4,392.    Elapsed: 0:27:05.


 44%|████▎     | 1920/4392 [27:39<35:22,  1.16it/s]

  Batch 1,920  of  4,392.    Elapsed: 0:27:40.


 45%|████▍     | 1960/4392 [28:14<34:52,  1.16it/s]

  Batch 1,960  of  4,392.    Elapsed: 0:28:14.


 46%|████▌     | 2000/4392 [28:48<34:27,  1.16it/s]

  Batch 2,000  of  4,392.    Elapsed: 0:28:49.


 46%|████▋     | 2040/4392 [29:23<33:45,  1.16it/s]

  Batch 2,040  of  4,392.    Elapsed: 0:29:23.


 47%|████▋     | 2080/4392 [29:57<33:20,  1.16it/s]

  Batch 2,080  of  4,392.    Elapsed: 0:29:58.


 48%|████▊     | 2120/4392 [30:32<32:32,  1.16it/s]

  Batch 2,120  of  4,392.    Elapsed: 0:30:33.


 49%|████▉     | 2160/4392 [31:07<32:00,  1.16it/s]

  Batch 2,160  of  4,392.    Elapsed: 0:31:07.


 50%|█████     | 2200/4392 [31:41<31:39,  1.15it/s]

  Batch 2,200  of  4,392.    Elapsed: 0:31:42.


 51%|█████     | 2240/4392 [32:16<30:59,  1.16it/s]

  Batch 2,240  of  4,392.    Elapsed: 0:32:16.


 52%|█████▏    | 2280/4392 [32:50<30:38,  1.15it/s]

  Batch 2,280  of  4,392.    Elapsed: 0:32:51.


 53%|█████▎    | 2320/4392 [33:25<29:45,  1.16it/s]

  Batch 2,320  of  4,392.    Elapsed: 0:33:26.


 54%|█████▎    | 2360/4392 [34:00<29:09,  1.16it/s]

  Batch 2,360  of  4,392.    Elapsed: 0:34:00.


 55%|█████▍    | 2400/4392 [34:34<28:38,  1.16it/s]

  Batch 2,400  of  4,392.    Elapsed: 0:34:35.


 56%|█████▌    | 2440/4392 [35:09<28:01,  1.16it/s]

  Batch 2,440  of  4,392.    Elapsed: 0:35:09.


 56%|█████▋    | 2480/4392 [35:43<27:18,  1.17it/s]

  Batch 2,480  of  4,392.    Elapsed: 0:35:44.


 57%|█████▋    | 2520/4392 [36:18<26:53,  1.16it/s]

  Batch 2,520  of  4,392.    Elapsed: 0:36:18.


 58%|█████▊    | 2560/4392 [36:53<26:16,  1.16it/s]

  Batch 2,560  of  4,392.    Elapsed: 0:36:53.


 59%|█████▉    | 2600/4392 [37:27<25:43,  1.16it/s]

  Batch 2,600  of  4,392.    Elapsed: 0:37:28.


 60%|██████    | 2640/4392 [38:02<25:05,  1.16it/s]

  Batch 2,640  of  4,392.    Elapsed: 0:38:02.


 61%|██████    | 2680/4392 [38:36<24:34,  1.16it/s]

  Batch 2,680  of  4,392.    Elapsed: 0:38:37.


 62%|██████▏   | 2720/4392 [39:11<23:58,  1.16it/s]

  Batch 2,720  of  4,392.    Elapsed: 0:39:11.


 63%|██████▎   | 2760/4392 [39:45<23:27,  1.16it/s]

  Batch 2,760  of  4,392.    Elapsed: 0:39:46.


 64%|██████▍   | 2800/4392 [40:20<23:07,  1.15it/s]

  Batch 2,800  of  4,392.    Elapsed: 0:40:21.


 65%|██████▍   | 2840/4392 [40:55<22:15,  1.16it/s]

  Batch 2,840  of  4,392.    Elapsed: 0:40:55.


 66%|██████▌   | 2880/4392 [41:29<21:39,  1.16it/s]

  Batch 2,880  of  4,392.    Elapsed: 0:41:30.


 66%|██████▋   | 2920/4392 [42:04<21:07,  1.16it/s]

  Batch 2,920  of  4,392.    Elapsed: 0:42:04.


 67%|██████▋   | 2960/4392 [42:38<20:29,  1.16it/s]

  Batch 2,960  of  4,392.    Elapsed: 0:42:39.


 68%|██████▊   | 3000/4392 [43:13<19:58,  1.16it/s]

  Batch 3,000  of  4,392.    Elapsed: 0:43:14.


 69%|██████▉   | 3040/4392 [43:48<19:21,  1.16it/s]

  Batch 3,040  of  4,392.    Elapsed: 0:43:48.


 70%|███████   | 3080/4392 [44:22<18:53,  1.16it/s]

  Batch 3,080  of  4,392.    Elapsed: 0:44:23.


 71%|███████   | 3120/4392 [44:57<18:17,  1.16it/s]

  Batch 3,120  of  4,392.    Elapsed: 0:44:58.


 72%|███████▏  | 3160/4392 [45:32<17:40,  1.16it/s]

  Batch 3,160  of  4,392.    Elapsed: 0:45:32.


 73%|███████▎  | 3200/4392 [46:06<17:05,  1.16it/s]

  Batch 3,200  of  4,392.    Elapsed: 0:46:07.


 74%|███████▍  | 3240/4392 [46:41<16:30,  1.16it/s]

  Batch 3,240  of  4,392.    Elapsed: 0:46:41.


 75%|███████▍  | 3280/4392 [47:15<15:56,  1.16it/s]

  Batch 3,280  of  4,392.    Elapsed: 0:47:16.


 76%|███████▌  | 3320/4392 [47:50<15:24,  1.16it/s]

  Batch 3,320  of  4,392.    Elapsed: 0:47:51.


 77%|███████▋  | 3360/4392 [48:25<14:54,  1.15it/s]

  Batch 3,360  of  4,392.    Elapsed: 0:48:25.


 77%|███████▋  | 3400/4392 [48:59<14:13,  1.16it/s]

  Batch 3,400  of  4,392.    Elapsed: 0:49:00.


 78%|███████▊  | 3440/4392 [49:34<13:41,  1.16it/s]

  Batch 3,440  of  4,392.    Elapsed: 0:49:34.


 79%|███████▉  | 3480/4392 [50:08<13:06,  1.16it/s]

  Batch 3,480  of  4,392.    Elapsed: 0:50:09.


 80%|████████  | 3520/4392 [50:43<12:27,  1.17it/s]

  Batch 3,520  of  4,392.    Elapsed: 0:50:43.


 81%|████████  | 3560/4392 [51:18<11:56,  1.16it/s]

  Batch 3,560  of  4,392.    Elapsed: 0:51:18.


 82%|████████▏ | 3600/4392 [51:52<11:20,  1.16it/s]

  Batch 3,600  of  4,392.    Elapsed: 0:51:53.


 83%|████████▎ | 3640/4392 [52:27<10:44,  1.17it/s]

  Batch 3,640  of  4,392.    Elapsed: 0:52:27.


 84%|████████▍ | 3680/4392 [53:01<10:12,  1.16it/s]

  Batch 3,680  of  4,392.    Elapsed: 0:53:02.


 85%|████████▍ | 3720/4392 [53:36<09:39,  1.16it/s]

  Batch 3,720  of  4,392.    Elapsed: 0:53:36.


 86%|████████▌ | 3760/4392 [54:11<09:04,  1.16it/s]

  Batch 3,760  of  4,392.    Elapsed: 0:54:11.


 87%|████████▋ | 3800/4392 [54:45<08:28,  1.16it/s]

  Batch 3,800  of  4,392.    Elapsed: 0:54:46.


 87%|████████▋ | 3840/4392 [55:20<07:59,  1.15it/s]

  Batch 3,840  of  4,392.    Elapsed: 0:55:20.


 88%|████████▊ | 3880/4392 [55:54<07:21,  1.16it/s]

  Batch 3,880  of  4,392.    Elapsed: 0:55:55.


 89%|████████▉ | 3920/4392 [56:29<06:49,  1.15it/s]

  Batch 3,920  of  4,392.    Elapsed: 0:56:30.


 90%|█████████ | 3960/4392 [57:04<06:13,  1.16it/s]

  Batch 3,960  of  4,392.    Elapsed: 0:57:04.


 91%|█████████ | 4000/4392 [57:38<05:39,  1.16it/s]

  Batch 4,000  of  4,392.    Elapsed: 0:57:39.


 92%|█████████▏| 4040/4392 [58:13<05:02,  1.16it/s]

  Batch 4,040  of  4,392.    Elapsed: 0:58:14.


 93%|█████████▎| 4080/4392 [58:48<04:28,  1.16it/s]

  Batch 4,080  of  4,392.    Elapsed: 0:58:48.


 94%|█████████▍| 4120/4392 [59:22<03:54,  1.16it/s]

  Batch 4,120  of  4,392.    Elapsed: 0:59:23.


 95%|█████████▍| 4160/4392 [59:57<03:19,  1.16it/s]

  Batch 4,160  of  4,392.    Elapsed: 0:59:57.


 96%|█████████▌| 4200/4392 [1:00:31<02:46,  1.15it/s]

  Batch 4,200  of  4,392.    Elapsed: 1:00:32.


 97%|█████████▋| 4240/4392 [1:01:06<02:11,  1.16it/s]

  Batch 4,240  of  4,392.    Elapsed: 1:01:07.


 97%|█████████▋| 4280/4392 [1:01:41<01:36,  1.16it/s]

  Batch 4,280  of  4,392.    Elapsed: 1:01:41.


 98%|█████████▊| 4320/4392 [1:02:15<01:02,  1.16it/s]

  Batch 4,320  of  4,392.    Elapsed: 1:02:16.


 99%|█████████▉| 4360/4392 [1:02:50<00:27,  1.16it/s]

  Batch 4,360  of  4,392.    Elapsed: 1:02:50.


100%|██████████| 4392/4392 [1:03:17<00:00,  1.16it/s]



  Average training loss: 0.09
  Training epcoh took: 1:03:18

Running Validation...


100%|██████████| 1883/1883 [09:08<00:00,  3.43it/s]


  Accuracy: 0.97
  Average validation loss: 0.10
  Validation took: 1:12:26
Validation loss decreased (0.112534 --> 0.103701).  Saving model ...

Training...


  1%|          | 40/4392 [00:34<1:02:29,  1.16it/s]

  Batch    40  of  4,392.    Elapsed: 0:00:35.


  2%|▏         | 80/4392 [01:09<1:01:50,  1.16it/s]

  Batch    80  of  4,392.    Elapsed: 0:01:10.


  3%|▎         | 120/4392 [01:44<1:01:47,  1.15it/s]

  Batch   120  of  4,392.    Elapsed: 0:01:44.


  4%|▎         | 160/4392 [02:18<1:00:53,  1.16it/s]

  Batch   160  of  4,392.    Elapsed: 0:02:19.


  5%|▍         | 200/4392 [02:53<1:00:17,  1.16it/s]

  Batch   200  of  4,392.    Elapsed: 0:02:54.


  5%|▌         | 240/4392 [03:28<59:34,  1.16it/s]

  Batch   240  of  4,392.    Elapsed: 0:03:28.


  6%|▋         | 280/4392 [04:02<59:00,  1.16it/s]

  Batch   280  of  4,392.    Elapsed: 0:04:03.


  7%|▋         | 320/4392 [04:37<58:25,  1.16it/s]

  Batch   320  of  4,392.    Elapsed: 0:04:38.


  8%|▊         | 360/4392 [05:12<57:49,  1.16it/s]

  Batch   360  of  4,392.    Elapsed: 0:05:12.


  9%|▉         | 400/4392 [05:46<57:22,  1.16it/s]

  Batch   400  of  4,392.    Elapsed: 0:05:47.


 10%|█         | 440/4392 [06:21<56:52,  1.16it/s]

  Batch   440  of  4,392.    Elapsed: 0:06:22.


 11%|█         | 480/4392 [06:56<56:40,  1.15it/s]

  Batch   480  of  4,392.    Elapsed: 0:06:56.


 12%|█▏        | 520/4392 [07:30<56:07,  1.15it/s]

  Batch   520  of  4,392.    Elapsed: 0:07:31.


 13%|█▎        | 560/4392 [08:05<56:02,  1.14it/s]

  Batch   560  of  4,392.    Elapsed: 0:08:06.


 14%|█▎        | 600/4392 [08:40<55:57,  1.13it/s]

  Batch   600  of  4,392.    Elapsed: 0:08:40.


 15%|█▍        | 640/4392 [09:14<55:06,  1.13it/s]

  Batch   640  of  4,392.    Elapsed: 0:09:15.


 15%|█▌        | 680/4392 [09:49<53:43,  1.15it/s]

  Batch   680  of  4,392.    Elapsed: 0:09:49.


 16%|█▋        | 720/4392 [10:23<53:06,  1.15it/s]

  Batch   720  of  4,392.    Elapsed: 0:10:24.


 17%|█▋        | 760/4392 [10:58<52:18,  1.16it/s]

  Batch   760  of  4,392.    Elapsed: 0:10:59.


 18%|█▊        | 800/4392 [11:33<51:57,  1.15it/s]

  Batch   800  of  4,392.    Elapsed: 0:11:33.


 19%|█▉        | 840/4392 [12:07<51:03,  1.16it/s]

  Batch   840  of  4,392.    Elapsed: 0:12:08.


 20%|██        | 880/4392 [12:42<50:19,  1.16it/s]

  Batch   880  of  4,392.    Elapsed: 0:12:43.


 21%|██        | 920/4392 [13:17<50:05,  1.16it/s]

  Batch   920  of  4,392.    Elapsed: 0:13:17.


 22%|██▏       | 960/4392 [13:51<49:15,  1.16it/s]

  Batch   960  of  4,392.    Elapsed: 0:13:52.


 23%|██▎       | 1000/4392 [14:26<48:48,  1.16it/s]

  Batch 1,000  of  4,392.    Elapsed: 0:14:27.


 24%|██▎       | 1040/4392 [15:01<48:03,  1.16it/s]

  Batch 1,040  of  4,392.    Elapsed: 0:15:01.


 25%|██▍       | 1080/4392 [15:35<47:31,  1.16it/s]

  Batch 1,080  of  4,392.    Elapsed: 0:15:36.


 26%|██▌       | 1120/4392 [16:10<46:59,  1.16it/s]

  Batch 1,120  of  4,392.    Elapsed: 0:16:11.


 26%|██▋       | 1160/4392 [16:45<46:26,  1.16it/s]

  Batch 1,160  of  4,392.    Elapsed: 0:16:45.


 27%|██▋       | 1200/4392 [17:20<45:49,  1.16it/s]

  Batch 1,200  of  4,392.    Elapsed: 0:17:20.


 28%|██▊       | 1240/4392 [17:54<45:25,  1.16it/s]

  Batch 1,240  of  4,392.    Elapsed: 0:17:55.


 29%|██▉       | 1280/4392 [18:29<45:07,  1.15it/s]

  Batch 1,280  of  4,392.    Elapsed: 0:18:29.


 30%|███       | 1320/4392 [19:03<44:42,  1.15it/s]

  Batch 1,320  of  4,392.    Elapsed: 0:19:04.


 31%|███       | 1360/4392 [19:38<44:28,  1.14it/s]

  Batch 1,360  of  4,392.    Elapsed: 0:19:39.


 32%|███▏      | 1400/4392 [20:13<43:52,  1.14it/s]

  Batch 1,400  of  4,392.    Elapsed: 0:20:13.


 33%|███▎      | 1440/4392 [20:47<42:58,  1.14it/s]

  Batch 1,440  of  4,392.    Elapsed: 0:20:48.


 34%|███▎      | 1480/4392 [21:22<41:55,  1.16it/s]

  Batch 1,480  of  4,392.    Elapsed: 0:21:23.


 35%|███▍      | 1520/4392 [21:57<41:38,  1.15it/s]

  Batch 1,520  of  4,392.    Elapsed: 0:21:57.


 36%|███▌      | 1560/4392 [22:31<40:54,  1.15it/s]

  Batch 1,560  of  4,392.    Elapsed: 0:22:32.


 36%|███▋      | 1600/4392 [23:06<40:37,  1.15it/s]

  Batch 1,600  of  4,392.    Elapsed: 0:23:07.


 37%|███▋      | 1640/4392 [23:41<39:46,  1.15it/s]

  Batch 1,640  of  4,392.    Elapsed: 0:23:41.


 38%|███▊      | 1680/4392 [24:16<39:03,  1.16it/s]

  Batch 1,680  of  4,392.    Elapsed: 0:24:16.


 39%|███▉      | 1720/4392 [24:50<38:24,  1.16it/s]

  Batch 1,720  of  4,392.    Elapsed: 0:24:51.


 40%|████      | 1760/4392 [25:25<37:50,  1.16it/s]

  Batch 1,760  of  4,392.    Elapsed: 0:25:25.


 41%|████      | 1800/4392 [26:00<37:23,  1.16it/s]

  Batch 1,800  of  4,392.    Elapsed: 0:26:00.


 42%|████▏     | 1840/4392 [26:34<36:41,  1.16it/s]

  Batch 1,840  of  4,392.    Elapsed: 0:26:35.


 43%|████▎     | 1880/4392 [27:09<36:02,  1.16it/s]

  Batch 1,880  of  4,392.    Elapsed: 0:27:10.


 44%|████▎     | 1920/4392 [27:44<35:30,  1.16it/s]

  Batch 1,920  of  4,392.    Elapsed: 0:27:44.


 45%|████▍     | 1960/4392 [28:18<34:59,  1.16it/s]

  Batch 1,960  of  4,392.    Elapsed: 0:28:19.


 46%|████▌     | 2000/4392 [28:53<34:47,  1.15it/s]

  Batch 2,000  of  4,392.    Elapsed: 0:28:53.


 46%|████▋     | 2040/4392 [29:28<35:01,  1.12it/s]

  Batch 2,040  of  4,392.    Elapsed: 0:29:28.


 47%|████▋     | 2080/4392 [30:02<34:11,  1.13it/s]

  Batch 2,080  of  4,392.    Elapsed: 0:30:03.


 48%|████▊     | 2120/4392 [30:37<32:55,  1.15it/s]

  Batch 2,120  of  4,392.    Elapsed: 0:30:37.


 49%|████▉     | 2160/4392 [31:11<32:01,  1.16it/s]

  Batch 2,160  of  4,392.    Elapsed: 0:31:12.


 50%|█████     | 2200/4392 [31:46<31:27,  1.16it/s]

  Batch 2,200  of  4,392.    Elapsed: 0:31:47.


 51%|█████     | 2240/4392 [32:21<31:05,  1.15it/s]

  Batch 2,240  of  4,392.    Elapsed: 0:32:21.


 52%|█████▏    | 2280/4392 [32:56<30:34,  1.15it/s]

  Batch 2,280  of  4,392.    Elapsed: 0:32:56.


 53%|█████▎    | 2320/4392 [33:30<30:02,  1.15it/s]

  Batch 2,320  of  4,392.    Elapsed: 0:33:31.


 54%|█████▎    | 2360/4392 [34:05<29:12,  1.16it/s]

  Batch 2,360  of  4,392.    Elapsed: 0:34:05.


 55%|█████▍    | 2400/4392 [34:40<28:33,  1.16it/s]

  Batch 2,400  of  4,392.    Elapsed: 0:34:40.


 56%|█████▌    | 2440/4392 [35:14<28:00,  1.16it/s]

  Batch 2,440  of  4,392.    Elapsed: 0:35:15.


 56%|█████▋    | 2480/4392 [35:49<27:25,  1.16it/s]

  Batch 2,480  of  4,392.    Elapsed: 0:35:50.


 57%|█████▋    | 2520/4392 [36:24<26:50,  1.16it/s]

  Batch 2,520  of  4,392.    Elapsed: 0:36:24.


 58%|█████▊    | 2560/4392 [36:58<26:35,  1.15it/s]

  Batch 2,560  of  4,392.    Elapsed: 0:36:59.


 59%|█████▉    | 2600/4392 [37:33<26:01,  1.15it/s]

  Batch 2,600  of  4,392.    Elapsed: 0:37:34.


 60%|██████    | 2640/4392 [38:08<25:43,  1.13it/s]

  Batch 2,640  of  4,392.    Elapsed: 0:38:08.


 61%|██████    | 2680/4392 [38:42<24:59,  1.14it/s]

  Batch 2,680  of  4,392.    Elapsed: 0:38:43.


 62%|██████▏   | 2720/4392 [39:17<24:14,  1.15it/s]

  Batch 2,720  of  4,392.    Elapsed: 0:39:18.


 63%|██████▎   | 2760/4392 [39:52<23:22,  1.16it/s]

  Batch 2,760  of  4,392.    Elapsed: 0:39:52.


 64%|██████▍   | 2800/4392 [40:26<22:47,  1.16it/s]

  Batch 2,800  of  4,392.    Elapsed: 0:40:27.


 65%|██████▍   | 2840/4392 [41:01<22:22,  1.16it/s]

  Batch 2,840  of  4,392.    Elapsed: 0:41:01.


 66%|██████▌   | 2880/4392 [41:36<22:00,  1.15it/s]

  Batch 2,880  of  4,392.    Elapsed: 0:41:36.


 66%|██████▋   | 2920/4392 [42:10<21:28,  1.14it/s]

  Batch 2,920  of  4,392.    Elapsed: 0:42:11.


 67%|██████▋   | 2960/4392 [42:45<20:34,  1.16it/s]

  Batch 2,960  of  4,392.    Elapsed: 0:42:45.


 68%|██████▊   | 3000/4392 [43:19<20:26,  1.13it/s]

  Batch 3,000  of  4,392.    Elapsed: 0:43:20.


 69%|██████▉   | 3040/4392 [43:54<19:21,  1.16it/s]

  Batch 3,040  of  4,392.    Elapsed: 0:43:54.


 70%|███████   | 3080/4392 [44:29<19:18,  1.13it/s]

  Batch 3,080  of  4,392.    Elapsed: 0:44:29.


 71%|███████   | 3120/4392 [45:03<18:13,  1.16it/s]

  Batch 3,120  of  4,392.    Elapsed: 0:45:03.


 72%|███████▏  | 3160/4392 [45:38<18:19,  1.12it/s]

  Batch 3,160  of  4,392.    Elapsed: 0:45:38.


 73%|███████▎  | 3200/4392 [46:12<17:04,  1.16it/s]

  Batch 3,200  of  4,392.    Elapsed: 0:46:13.


 74%|███████▍  | 3240/4392 [46:47<16:44,  1.15it/s]

  Batch 3,240  of  4,392.    Elapsed: 0:46:47.


 75%|███████▍  | 3280/4392 [47:21<15:58,  1.16it/s]

  Batch 3,280  of  4,392.    Elapsed: 0:47:22.


 76%|███████▌  | 3320/4392 [47:56<15:25,  1.16it/s]

  Batch 3,320  of  4,392.    Elapsed: 0:47:56.


 77%|███████▋  | 3360/4392 [48:30<14:44,  1.17it/s]

  Batch 3,360  of  4,392.    Elapsed: 0:48:31.


 77%|███████▋  | 3400/4392 [49:05<14:13,  1.16it/s]

  Batch 3,400  of  4,392.    Elapsed: 0:49:05.


 78%|███████▊  | 3440/4392 [49:39<13:38,  1.16it/s]

  Batch 3,440  of  4,392.    Elapsed: 0:49:40.


 79%|███████▉  | 3480/4392 [50:14<13:06,  1.16it/s]

  Batch 3,480  of  4,392.    Elapsed: 0:50:15.


 80%|████████  | 3520/4392 [50:49<12:30,  1.16it/s]

  Batch 3,520  of  4,392.    Elapsed: 0:50:49.


 81%|████████  | 3560/4392 [51:23<11:55,  1.16it/s]

  Batch 3,560  of  4,392.    Elapsed: 0:51:24.


 82%|████████▏ | 3600/4392 [51:58<11:21,  1.16it/s]

  Batch 3,600  of  4,392.    Elapsed: 0:51:58.


 83%|████████▎ | 3640/4392 [52:32<10:45,  1.16it/s]

  Batch 3,640  of  4,392.    Elapsed: 0:52:33.


 84%|████████▍ | 3680/4392 [53:07<10:14,  1.16it/s]

  Batch 3,680  of  4,392.    Elapsed: 0:53:07.


 85%|████████▍ | 3720/4392 [53:42<09:43,  1.15it/s]

  Batch 3,720  of  4,392.    Elapsed: 0:53:42.


 86%|████████▌ | 3760/4392 [54:16<09:08,  1.15it/s]

  Batch 3,760  of  4,392.    Elapsed: 0:54:17.


 87%|████████▋ | 3800/4392 [54:51<08:30,  1.16it/s]

  Batch 3,800  of  4,392.    Elapsed: 0:54:51.


 87%|████████▋ | 3840/4392 [55:25<07:54,  1.16it/s]

  Batch 3,840  of  4,392.    Elapsed: 0:55:26.


 88%|████████▊ | 3880/4392 [56:00<07:21,  1.16it/s]

  Batch 3,880  of  4,392.    Elapsed: 0:56:00.


 89%|████████▉ | 3920/4392 [56:34<06:45,  1.16it/s]

  Batch 3,920  of  4,392.    Elapsed: 0:56:35.


 90%|█████████ | 3960/4392 [57:09<06:11,  1.16it/s]

  Batch 3,960  of  4,392.    Elapsed: 0:57:10.


 91%|█████████ | 4000/4392 [57:44<05:38,  1.16it/s]

  Batch 4,000  of  4,392.    Elapsed: 0:57:44.


 92%|█████████▏| 4040/4392 [58:18<05:02,  1.16it/s]

  Batch 4,040  of  4,392.    Elapsed: 0:58:19.


 93%|█████████▎| 4080/4392 [58:53<04:28,  1.16it/s]

  Batch 4,080  of  4,392.    Elapsed: 0:58:53.


 94%|█████████▍| 4120/4392 [59:28<03:54,  1.16it/s]

  Batch 4,120  of  4,392.    Elapsed: 0:59:28.


 95%|█████████▍| 4160/4392 [1:00:02<03:19,  1.16it/s]

  Batch 4,160  of  4,392.    Elapsed: 1:00:03.


 96%|█████████▌| 4200/4392 [1:00:37<02:46,  1.16it/s]

  Batch 4,200  of  4,392.    Elapsed: 1:00:37.


 97%|█████████▋| 4240/4392 [1:01:11<02:11,  1.16it/s]

  Batch 4,240  of  4,392.    Elapsed: 1:01:12.


 97%|█████████▋| 4280/4392 [1:01:46<01:36,  1.16it/s]

  Batch 4,280  of  4,392.    Elapsed: 1:01:47.


 98%|█████████▊| 4320/4392 [1:02:21<01:01,  1.16it/s]

  Batch 4,320  of  4,392.    Elapsed: 1:02:21.


 99%|█████████▉| 4360/4392 [1:02:55<00:27,  1.16it/s]

  Batch 4,360  of  4,392.    Elapsed: 1:02:56.


100%|██████████| 4392/4392 [1:03:23<00:00,  1.15it/s]



  Average training loss: 0.05
  Training epcoh took: 1:03:23

Running Validation...


100%|██████████| 1883/1883 [09:07<00:00,  3.44it/s]


  Accuracy: 0.98
  Average validation loss: 0.11
  Validation took: 1:12:31

Training complete!


In [15]:
# Set the batch size.  
batch_size = 16
# Create the DataLoader.
validation_data = CustomDataset(df_test, 512, model)
prediction_sampler = SequentialSampler(validation_data)
prediction_dataloader = DataLoader(validation_data, sampler=prediction_sampler, batch_size=batch_size)

In [16]:
prediction_inputs = validation_data
# =validation_masks
# =validation_labels
# Prediction on test set
print('Predicting labels for {:,} test sentences...'.format(len(prediction_inputs)))
# Put model in evaluation mode
model.eval()
# Tracking variables 
predictions , true_labels = [], []
# Predict 
for batch in prediction_dataloader:
  # Add batch to GPU
  batch = tuple(t.to(device) for t in batch)
  
  # Unpack the inputs from our dataloader
  b_input_ids, b_input_mask, b_labels = batch
  with torch.no_grad():
      # Forward pass, calculate logit predictions
      outputs = model(b_input_ids, token_type_ids=None, 
                      attention_mask=b_input_mask)
  logits = outputs[0]
  # Move logits and labels to CPU
  logits = logits.detach().cpu().numpy()
  label_ids = b_labels.to('cpu').numpy()
  
  # Store predictions and true labels
  predictions.append(logits)
  true_labels.append(label_ids)
print('DONE.')

Predicting labels for 30,114 test sentences...
DONE.


In [17]:
# Combine the predictions for each batch into a single list of 0s and 1s.
flat_predictions = [item for sublist in predictions for item in sublist]
flat_predictions = np.argmax(flat_predictions, axis=1).flatten()
# Combine the correct labels for each batch into a single list.
flat_true_labels = [item for sublist in true_labels for item in sublist]
# Calculate the MCC
mcc = matthews_corrcoef(flat_true_labels, flat_predictions)
print('MCC: %.3f' % mcc)

MCC: 0.939


In [18]:
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
y_pred = flat_predictions
y_true = flat_true_labels
# precision_recall_fscore_support(y_true, y_pred, average='macro')
precision_recall_fscore_support(y_true, y_pred, average='macro')
# precision_recall_fscore_support(y_true, y_pred, average='weighted')

(0.9729475268689662, 0.9656450022171612, 0.9692162074352064, None)

In [19]:
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.98      0.99      0.98     22041
           1       0.97      0.94      0.95      8073

    accuracy                           0.98     30114
   macro avg       0.97      0.97      0.97     30114
weighted avg       0.98      0.98      0.98     30114

