In [1]:
import numpy as np
import pandas as pd
import random
import time
import datetime
import re
import tensorflow as tf
import torch
import torch.nn as nn
import torch.nn.functional as F 
from transformers import BertModel
from torch.utils.data import DataLoader, Dataset, TensorDataset,random_split, RandomSampler, SequentialSampler
from transformers import BertTokenizer
from transformers import BertForSequenceClassification, AdamW, BertConfig
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import classification_report
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader, Dataset, TensorDataset,random_split, RandomSampler, SequentialSampler
from sklearn.model_selection import cross_val_score, train_test_split, GridSearchCV
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'

loss_fn = nn.CrossEntropyLoss()

def set_seed(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
        
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }
def report_average(report_list):
    r_list = list()
    for report in report_list:
        splited = [' '.join(x.split()) for x in report.split('\n\n')]
        header = [x for x in splited[0].split(' ')]
        data = np.array(splited[1].split(' ')).reshape(-1, len(header) + 1)
        data = np.delete(data, 0, 1).astype(float)
        df = pd.DataFrame(data, columns=header)
        r_list.append(df)
    tmp = pd.DataFrame()
    for df in r_list:
        tmp = tmp.add(df, fill_value=0)           
    report_ave =  tmp/len(r_list)
    return(report_ave)


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
class CrisisDataset(Dataset):
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        #self.df = pd.read_csv("/home/renato/Datasets/CrisisLexT6-v1.0/CrisisLexT6/2012_Sandy_Hurricane/2012_Sandy_Hurricane-ontopic_offtopic.csv", encoding='utf-8')
        #self.df = pd.read_csv("/home/joao/2012_Sandy_Hurricane-ontopic_offtopic.csv",encoding='utf-8')
        self.df = pd.read_csv("/home/joao/crisisLexT6.csv", encoding='utf-8')
        self.df = self.df.rename(columns={' tweet': 'sentence'})
        self.df = self.df.rename(columns={' label': 'label'})
        self.df = self.df[['sentence','label']]
        self.df['label'].replace('on-topic', 1)
        self.df['label'] = self.df['label'].replace('on-topic', 1)
        self.df['label'].replace('off-topic', 0)
        self.df['label'] = self.df['label'].replace('off-topic', 0)
        self.df['nchars'] = self.df['sentence'].str.len()
        self.df['nwords'] = self.df['sentence'].str.split().str.len()
        self.df['bhash'] = self.df["sentence"].str.contains(pat = '#',flags=re.IGNORECASE, regex = True).astype(int) 
        self.df['nhash'] = self.df["sentence"].str.count('#') 
        self.df['blink']  = self.df["sentence"].str.contains(pat = 'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', flags=re.IGNORECASE, regex = True) .astype(int)
        self.df['nlink'] = self.df["sentence"].str.count(pat = 'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', flags=re.IGNORECASE) 
        self.df['bat'] = self.df["sentence"].str.contains(pat = '@',flags=re.IGNORECASE, regex = True).astype(int) 
        self.df['nat'] = self.df["sentence"].str.count(pat = '@') 
        self.df['rt'] = self.df["sentence"].str.contains(pat = '@rt|rt@',flags=re.IGNORECASE, regex = True).astype(int) 
        #df['phone'] = df["sentence"].str.contains(pat = '\(?([0-9]{3})\)?([ .-]?)([0-9]{3})\2([0-9]{4})',flags=re.IGNORECASE, regex = True).astype(int) 
        self.df['dlex'] = self.df["sentence"].apply(self.lexical_diversity)
        self.df["sentence"] = self.df["sentence"].str.lower()
        ## List of  US slangs.
        slangList = ['ASAP','BBIAB','BBL','BBS','BF','BFF','BFFL','BRB','CYA','DS','FAQ','FB','FITBLR','FLBP','FML','FTFY','FTW','FYI','G2G','GF','GR8','GTFO','HBIC','HML','HRU','HTH','IDK','IGHT','IMO','IMHO','IMY','IRL','ISTG','JK','JMHO','KTHX','L8R','LMAO','LMFAO','LMK','LOL','MWF','NM','NOOB','NP','NSFW','OOAK','OFC','OMG','ORLY','OTOH','RN','ROFL','RUH','SFW','SOML','SOZ','STFU','TFTI','TIL','TMI','TTFN','TTYL','TWSS','U','W/','WB','W/O','WYD','WTH','WTF','WYM','WYSIWYG','Y','YMMV','YW','YWA']
        slangList = [x.lower() for x in slangList]
        #happy emojis
        happy_emojis = [':\)', ';\)', '\(:']
        #sad emojis
        sad_emojis = [':\(', ';\(', '\):']
        punctuation = ['.',',','...','?','!',':',';']    
        #','-','+','*','_','=','/','','%',' &','{','}','[',']','(',')','
        #Checks if the sentence contains slang
        mask = self.df.iloc[:, 0].str.contains(r'\b(?:{})\b'.format('|'.join(slangList)))
        df1 = self.df[~mask]
        self.df['slang'] = mask.astype(int) 
        #Checks if the sentence contains happy emojis
        mask = self.df.iloc[:, 0].str.contains(r'\b(?:{})\b'.format('|'.join(happy_emojis)), regex = True)
        df1 = self.df[~mask]
        self.df['hemojis'] = mask.astype(int) 
        #Checks if the sentence contains happy emojis
        mask = self.df.iloc[:, 0].str.contains(r'\b(?:{})\b'.format('|'.join(sad_emojis)), regex = True)
        df1 = self.df[~mask]
        self.df['semojis'] = mask.astype(int) 
        self.hand_features =  self.df[['nchars', 'nwords','bhash','nhash','blink','nlink','bat','nat','rt','slang','dlex']]
        self.hand_features_DF = pd.DataFrame(self.hand_features)
        #################
        self.df['sentence'] = self.df['sentence'].str.replace(r'http(\S)+', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'http(\S)+', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'http ...', r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'(RT|rt)[ ]*@[ ]*[\S]+',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'@[\S]+',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'_[\S]?',r'')
        self.df['sentence'] = self.df['sentence'].str.replace(r'[ ]{2, }',r' ')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&amp;?',r'and')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&lt;',r'<')
        self.df['sentence'] = self.df['sentence'].str.replace(r'&gt;',r'>')
        self.df['sentence'] = self.df['sentence'].str.replace(r'([\w\d]+)([^\w\d ]+)', r'\1 \2')
        self.df['sentence'] = self.df['sentence'].str.replace(r'([^\w\d ]+)([\w\d]+)', r'\1 \2')
        self.df['sentence'] = self.df['sentence'].str.lower()
        self.df['sentence'] = self.df['sentence'].str.strip()
        self.sentences = self.df['sentence']
        self.labels = self.df['label'].values
        self.maxlen = 0
        for sent in self.sentences:
            input_ids = self.tokenizer.encode(sent, add_special_tokens=True)
            self.maxlen = max(self.maxlen, len(input_ids))
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        sentence = self.df.loc[idx, 'sentence']
        label = self.df.loc[idx, 'label']
        h_features = self.hand_features_DF.loc[idx,['nchars', 'nwords','bhash','nhash','blink','nlink','bat','nat','rt','slang','dlex']]
        h_features_tensor = torch.tensor(h_features).to(device)
        tokens = self.tokenizer.tokenize(sentence)
        tokens = ['[CLS]'] + tokens + ['[SEP]']
        if len(tokens) < self.maxlen:
            tokens = tokens + ['[PAD]' for _ in range(self.maxlen - len(tokens))] #Padding sentences
        else:
            tokens = tokens[:self.maxlen-1] + ['[SEP]'] #Prunning the list to be of specified max length
        tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens) #Obtaining the indices of the tokens in the BERT Vocabulary
        tokens_ids_tensor = torch.tensor(tokens_ids).to(device) #Converting the list to a pytorch tensor
        attn_mask = (tokens_ids_tensor != 0).float()
        attn_mask_tensor = torch.tensor(attn_mask).to(device)
        label_tensor = torch.tensor(label).to(device)
        return tokens_ids_tensor, attn_mask_tensor, label_tensor,h_features_tensor
    def lexical_diversity(self,text):
        return len(set(text.split())) / len(text.split()) 


In [27]:
class ourBertBinaryClassifier(nn.Module):
    def __init__(self, B_in = 768, B_out = 64, H_in = 11, H_out = 32 ):
        super(ourBertBinaryClassifier, self).__init__()
        self.bert_layer = BertModel.from_pretrained('bert-base-uncased')
        C_in = B_out + H_out 
        C_out = 2
        self.linear_B = nn.Linear(B_in, B_out)
        self.linear_H = nn.Linear(H_in, H_out)
        self.classifier = nn.Linear(C_in, C_out)
        #self.classifier = nn.Sequential(nn.Linear(B_in, B_out),nn.ReLU(),nn.Linear(B_out, C_out))
        #self.classifier = nn.Sequential(nn.Linear(B_in, B_out),nn.ReLU(),nn.Linear(H_in,H_out),nn.ReLU(),nn.Linear(C_in, C_out))
    def forward(self, seq, attn_masks, hand_features):
        # Feed input to BERT
        outputs = self.bert_layer(input_ids=seq,attention_mask=attn_masks)
        last_hidden_state_cls = outputs[0][:, 0, :]
        logits_B = self.linear_B(last_hidden_state_cls.float())
        logits_H = self.linear_H(hand_features.float())
        cat_features = torch.cat([logits_B.float(),logits_H.float()], dim=1)
        # Feed input to classifier to compute logits
        logits = self.classifier(cat_features.float())
        return logits   
    
        
def initialize_model(epochs=4):
    bert_classifier = ourBertBinaryClassifier()
    bert_classifier.to(device)
    optimizer = AdamW(bert_classifier.parameters(),lr=5e-5,eps=1e-8)
    total_steps = len(train_dataloader) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=0,num_training_steps=total_steps)
    return bert_classifier, optimizer, scheduler



def train(model, train_dataloader, validation_dataloader=None, epochs=4, evaluation=False):
    print("Start training...\n")
    for epoch_i in range(epochs):
        # =======================================
        #               Training
        # =======================================
        # Print the header of the result table
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
        print("-"*70)
        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()
        # Reset tracking variables at the beginning of each epoch
        total_loss, batch_loss, batch_counts = 0, 0, 0
        # Put the model into the training mode
        model.train()
        # For each batch of training data...
        for step, batch in enumerate(train_dataloader):
            batch_counts +=1
            # Load batch to GPU
            b_input_ids, b_attn_mask, b_labels, h_features = tuple(t.to(device) for t in batch)
            # Zero out any previously calculated gradients
            model.zero_grad()
            # Perform a forward pass. This will return logits.
            logits = model(b_input_ids, b_attn_mask, h_features)
            # Compute loss and accumulate the loss values
            loss = loss_fn(logits, b_labels)
            batch_loss += loss.item()
            total_loss += loss.item()
            # Perform a backward pass to calculate gradients
            loss.backward()
            # Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # Update parameters and the learning rate
            optimizer.step()
            scheduler.step()
            # Print the loss values and time elapsed for every 20 batches
            if (step % 100 == 0 and step != 0) or (step == len(train_dataloader) - 1):
                # Calculate time elapsed for 100 batches
                time_elapsed = time.time() - t0_batch
                # Print training results
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")
                # Reset batch tracking variables
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()
        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(train_dataloader)
        print("-"*70)
        # =======================================
        #               Evaluation
        # =======================================
        if evaluation == True:
            model.eval()
            val_accuracy = []
            val_loss = []
            report_list  = []
            for batch in validation_dataloader:
                b_input_ids, b_attn_mask, b_labels, h_features = tuple(t.to(device) for t in batch)
                b_input_ids = batch[0].to(device)
                with torch.no_grad():
                    logits = model(b_input_ids, b_attn_mask, h_features)
                loss = loss_fn(logits, b_labels)
                val_loss.append(loss.item())
                preds = torch.argmax(logits, dim=1).flatten()
                accuracy = (preds == b_labels).cpu().numpy().mean() * 100
                val_accuracy.append(accuracy)
                probs =  F.softmax(logits)
                probs = probs.detach().cpu().numpy()
                predictions  = np.argmax(probs, axis=1).flatten()
                label_ids = b_labels.to('cpu').numpy()
                labels_flat = label_ids.flatten()
                report = classification_report(labels_flat, predictions)
                report_list.append(report)
            val_loss = np.mean(val_loss)
            val_accuracy = np.mean(val_accuracy)
            time_elapsed = time.time() - t0_epoch
            print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
            print("-"*70)
            print("\n")
            print("Training complete!")
            print()
            print(report_average(report_list))
            print()
    

#Creating instances of training and validation set
dataset = CrisisDataset()

train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

batch_size = 16

train_dataloader = DataLoader(train_dataset,sampler = SequentialSampler(train_dataset), batch_size = batch_size )
validation_dataloader = DataLoader(val_dataset,sampler = SequentialSampler(val_dataset),batch_size = batch_size)


set_seed(42)    # Set seed for reproducibility
bert_classifier, optimizer, scheduler = initialize_model(epochs=4)
train(bert_classifier, train_dataloader, validation_dataloader, epochs=4, evaluation=True)


In [51]:
from sklearn.model_selection import StratifiedKFold, cross_val_score

def get_auc_CV(model):
    """
    Return the average AUC score from cross-validation.
    """
    # Set KFold to shuffle data before the split
    kf = StratifiedKFold(5, shuffle=True, random_state=1)
    # Get AUC scores
    auc = cross_val_score(
        model, X_train_tfidf, y_train, scoring="roc_auc", cv=kf)
    return auc.mean()

In [52]:
def text_preprocessing(text):
    """
    - Remove entity mentions (eg. '@united')
    - Correct errors (eg. '&amp;' to '&')
    @param    text (str): a string to be processed.
    @return   text (Str): the processed string.
    """
    # Remove '@name'
    text = re.sub(r'(@.*?)[\s]', ' ', text)
    # Replace '&amp;' with '&'
    text = re.sub(r'&amp;', '&', text)
    # Remove trailing whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    return text

from transformers import BertTokenizer

# Load the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

# Create a function to tokenize a set of texts
def preprocessing_for_bert(data):
    """Perform required preprocessing steps for pretrained BERT.
    @param    data (np.array): Array of texts to be processed.
    @return   input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
    @return   attention_masks (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model.
    """
    # Create empty lists to store outputs
    input_ids = []
    attention_masks = []
    # For every sentence...
    for sent in data:
        # `encode_plus` will:
        #    (1) Tokenize the sentence
        #    (2) Add the `[CLS]` and `[SEP]` token to the start and end
        #    (3) Truncate/Pad sentence to max length
        #    (4) Map tokens to their IDs
        #    (5) Create attention mask
        #    (6) Return a dictionary of outputs
        encoded_sent = tokenizer.encode_plus(
            text=text_preprocessing(sent),  # Preprocess sentence
            add_special_tokens=True,        # Add `[CLS]` and `[SEP]`
            max_length=MAX_LEN,                  # Max length to truncate/pad
            pad_to_max_length=True,         # Pad sentence to max length
            #return_tensors='pt',           # Return PyTorch tensor
            return_attention_mask=True      # Return attention mask
            )
        # Add the outputs to the lists
        input_ids.append(encoded_sent.get('input_ids'))
        attention_masks.append(encoded_sent.get('attention_mask'))
    # Convert lists to tensors
    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)
    return input_ids, attention_masks

# Concatenate train data and test data
all_tweets = np.concatenate([data.tweet.values, test_data.tweet.values])
# Encode our concatenated data
encoded_tweets = [tokenizer.encode(sent, add_special_tokens=True) for sent in all_tweets]
# Find the maximum length
max_len = max([len(sent) for sent in encoded_tweets])
print('Max length: ', max_len)

# Specify `MAX_LEN`
MAX_LEN = 64

# Print sentence 0 and its encoded token ids
token_ids = list(preprocessing_for_bert([X[0]])[0].squeeze().numpy())
print('Original: ', X[0])
print('Token IDs: ', token_ids)

# Run function `preprocessing_for_bert` on the train set and the validation set
print('Tokenizing data...')
train_inputs, train_masks = preprocessing_for_bert(X_train)
val_inputs, val_masks = preprocessing_for_bert(X_val)

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

# Convert other data types to torch.Tensor
train_labels = torch.tensor(y_train)
val_labels = torch.tensor(y_val)

# For fine-tuning BERT, the authors recommend a batch size of 16 or 32.
batch_size = 32

# Create the DataLoader for our training set
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

# Create the DataLoader for our validation set
val_data = TensorDataset(val_inputs, val_masks, val_labels)
val_sampler = SequentialSampler(val_data)
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size)

%%time
import torch
import torch.nn as nn
from transformers import BertModel

# Create the BertClassfier class
class BertClassifier(nn.Module):
    """Bert Model for Classification Tasks.
    """
    def __init__(self, freeze_bert=False):
        """
        @param    bert: a BertModel object
        @param    classifier: a torch.nn.Module classifier
        @param    freeze_bert (bool): Set `False` to fine-tune the BERT model
        """
        super(BertClassifier, self).__init__()
        # Specify hidden size of BERT, hidden size of our classifier, and number of labels
        D_in, H, D_out = 768, 50, 2
        # Instantiate BERT model
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # Instantiate an one-layer feed-forward classifier
        self.classifier = nn.Sequential(
            nn.Linear(D_in, H),
            nn.ReLU(),
            #nn.Dropout(0.5),
            nn.Linear(H, D_out)
        )
        # Freeze the BERT model
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
    def forward(self, input_ids, attention_mask):
        """
        Feed input to BERT and the classifier to compute logits.
        @param    input_ids (torch.Tensor): an input tensor with shape (batch_size,
                      max_length)
        @param    attention_mask (torch.Tensor): a tensor that hold attention mask
                      information with shape (batch_size, max_length)
        @return   logits (torch.Tensor): an output tensor with shape (batch_size,
                      num_labels)
        """
        # Feed input to BERT
        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask)
        # Extract the last hidden state of the token `[CLS]` for classification task
        last_hidden_state_cls = outputs[0][:, 0, :]
        # Feed input to classifier to compute logits
        logits = self.classifier(last_hidden_state_cls)
        return logits
from transformers import AdamW, get_linear_schedule_with_warmup

def initialize_model(epochs=4):
    """Initialize the Bert Classifier, the optimizer and the learning rate scheduler.
    """
    # Instantiate Bert Classifier
    bert_classifier = BertClassifier(freeze_bert=False)
    # Tell PyTorch to run the model on GPU
    bert_classifier.to(device)
    # Create the optimizer
    optimizer = AdamW(bert_classifier.parameters(),
                      lr=5e-5,    # Default learning rate
                      eps=1e-8    # Default epsilon value
                      )
    # Total number of training steps
    total_steps = len(train_dataloader) * epochs
    # Set up the learning rate scheduler
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=0, # Default value
                                                num_training_steps=total_steps)
    return bert_classifier, optimizer, scheduler


def set_seed(seed_value=42):
    """Set seed for reproducibility.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):
    """Train the BertClassifier model.
    """
    # Start training loop
    print("Start training...\n")
    for epoch_i in range(epochs):
        # =======================================
        #               Training
        # =======================================
        # Print the header of the result table
        print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
        print("-"*70)
        # Measure the elapsed time of each epoch
        t0_epoch, t0_batch = time.time(), time.time()
        # Reset tracking variables at the beginning of each epoch
        total_loss, batch_loss, batch_counts = 0, 0, 0
        # Put the model into the training mode
        model.train()
        # For each batch of training data...
        for step, batch in enumerate(train_dataloader):
            batch_counts +=1
            # Load batch to GPU
            b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)
            # Zero out any previously calculated gradients
            model.zero_grad()
            # Perform a forward pass. This will return logits.
            logits = model(b_input_ids, b_attn_mask)
            # Compute loss and accumulate the loss values
            loss = loss_fn(logits, b_labels)
            batch_loss += loss.item()
            total_loss += loss.item()
            # Perform a backward pass to calculate gradients
            loss.backward()
            # Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # Update parameters and the learning rate
            optimizer.step()
            scheduler.step()
            # Print the loss values and time elapsed for every 20 batches
            if (step % 20 == 0 and step != 0) or (step == len(train_dataloader) - 1):
                # Calculate time elapsed for 20 batches
                time_elapsed = time.time() - t0_batch
                # Print training results
                print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")
                # Reset batch tracking variables
                batch_loss, batch_counts = 0, 0
                t0_batch = time.time()
        # Calculate the average loss over the entire training data
        avg_train_loss = total_loss / len(train_dataloader)
        print("-"*70)
        # =======================================
        #               Evaluation
        # =======================================
        if evaluation == True:
            # After the completion of each training epoch, measure the model's performance
            # on our validation set.
            val_loss, val_accuracy = evaluate(model, val_dataloader)
            # Print performance over the entire training data
            time_elapsed = time.time() - t0_epoch
            print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
            print("-"*70)
        print("\n")
    print("Training complete!")


def evaluate(model, val_dataloader):
    """After the completion of each training epoch, measure the model's performance
    on our validation set.
    """
    # Put the model into the evaluation mode. The dropout layers are disabled during
    # the test time.
    model.eval()
    # Tracking variables
    val_accuracy = []
    val_loss = []
    # For each batch in our validation set...
    for batch in val_dataloader:
        # Load batch to GPU
        b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)
        # Compute logits
        with torch.no_grad():
            logits = model(b_input_ids, b_attn_mask)
        probs =  F.softmax(logits)
        probs = probs.detach().cpu().numpy()
        predictions  = np.argmax(probs, axis=1).flatten()

        llogits = logits.detach().cpu().numpy()
        pred_flat = np.argmax(llogits, axis=1).flatten()
        
        label_ids = b_labels.to('cpu').numpy()
        labels_flat = label_ids.flatten()
        report = classification_report(labels_flat, predictions)
        
        loss = loss_fn(logits, b_labels)
        val_loss.append(loss.item())
        preds = torch.argmax(logits, dim=1).flatten()
        accuracy = (preds == b_labels).cpu().numpy().mean() * 100
        val_accuracy.append(accuracy)
    # Compute the average accuracy and loss over the validation set.
    val_loss = np.mean(val_loss)
    val_accuracy = np.mean(val_accuracy)
    return val_loss, val_accuracy, report