# Install packages and import statements

In [None]:
!pip install torchinfo datasets

In [None]:
pip install transformers==3

In [None]:
import datasets
import torch
import torch.nn as nn
import numpy as np
from transformers import BertTokenizer, BertModel
from torchinfo import summary
from tqdm import tqdm
import ipywidgets as widgets
from sklearn.metrics import accuracy_score, f1_score
from collections import defaultdict
import time

In [None]:
random_seed = 42
epochs = 10

# Data source, load datasets from huggingface

In [None]:
def get_dataset(dataset_name: str = None, path_name: str = None):
    assert dataset_name is True or dataset_name is not None, "Dataset Name is required"
    assert dataset_name in datasets_list, "Invalid dataset name. Full list is \n"+str(datasets_list)
    dataset_name=dataset_name
    path_name=path_name
    dataset_full = datasets.load_dataset(dataset_name,path_name)
    
    return dataset_full, list(dataset_full.keys())

In [None]:
import datasets
datasets_list=datasets.list_datasets()
print("List of datasets")
dataset_name_widget=widgets.Dropdown(
    options=datasets_list,
)
display(dataset_name_widget)

In [None]:
dataset_name=dataset_name_widget.value

In [None]:
config_name=None
try:
    dataset, splits=get_dataset(dataset_name)
except ValueError as e:
    print("Select a configuration")
    err=str(e)
    config_list=err[err.index('[')+1:err.index(']')]
    config_list=[c.strip()[1:-1] for c in config_list.split(',')]
    config_list_widget=widgets.Dropdown(
        options=config_list,
    )
    display(config_list_widget)
    

In [None]:
config_name=config_list_widget.value 
dataset, splits=get_dataset(dataset_name, config_name)

print("Available splits are:",splits)

In [None]:
#Change split names as needed
train_data=dataset['train']
val_data=dataset['validation']
test_data=dataset['test']

### Convert data to required format

In [None]:
def encode(examples):
    return tokenizer(examples[key_1], examples[key_2], truncation="longest_first", padding='max_length', max_length=100)


In [None]:
def convert_data(batch_size: int, field_names: list, data):
    
    return dataloader

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
batch_size=128
key_1, key_2="sentence1","sentence2" #These are the keys to encode in the data, using the encode() function
train_data=train_data.map(encode, batched=True)
train_data=train_data.map(lambda examples: {'labels': examples['label']}, batched=True)
train_data.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
train_data=train_data.rename_column("attention_mask","mask")
train_data=train_data.rename_column("input_ids", "sent_id")
train_data=train_data.remove_columns(["idx",key_1,key_2])
train_dataloader=torch.utils.data.DataLoader(train_data, batch_size=batch_size)

In [None]:
val_data=val_data.map(encode, batched=True)
val_data=val_data.map(lambda examples: {'labels': examples['label']}, batched=True)
val_data.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
val_data=val_data.rename_column("attention_mask","mask")
val_data=val_data.rename_column("input_ids", "sent_id")
val_data=val_data.remove_columns(["idx",key_1,key_2])
val_dataloader=torch.utils.data.DataLoader(val_data, batch_size=batch_size)

# Build model

## Get model from huggingface

In [None]:
import torch.nn as nn
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
base_model = BertModel.from_pretrained("bert-base-uncased")

In [None]:
from torchinfo import summary
summary(base_model)

## Freeze layers of the Transformer model

In [None]:
def freeze_model(model, freeze_layer_count: int = 0):
    '''
    Set freeze_layer_count to -1 if you want 
    just the embedding layers to be frozen
    '''
    if freeze_layer_count:
        #Freeze embeddings layers
        for param in model.embeddings.parameters():
            param.requires_grad=False
    
        if freeze_layer_count!=-1:
            for layer in model.encoder.layer[:freeze_layer_count]:
                    for param in layer.parameters():
                        param.requires_grad = False
    
    return model

## Define model architecture on top of base transformer model

In [None]:
class Classifier(nn.Module):

    def __init__(self, base_model, num_classes):
      
      super(Classifier, self).__init__()

      self.bert = base_model
      self.dropout = nn.Dropout(0.1)
      self.relu =  nn.ReLU()
      self.fc1 = nn.Linear(768,256)
      self.fc2 = nn.Linear(256,num_classes)
      self.softmax = nn.Softmax(dim=1)

    #define the forward pass
    def forward(self, sent_id, mask):

      #pass the inputs to the model  
      _, cls_hs = self.bert(sent_id, attention_mask=mask)     
      x = self.fc1(cls_hs)
      x = self.relu(x)
      x = self.fc2(x)      
      x = self.softmax(x)

      return x

In [None]:
num_classes=train_data.features['label'].num_classes

# Training and validation

In [None]:
#function for training the model
def train():
  
  model.train()
  # empty list to save model predictions
  total_preds=[]
  total_loss=0.0
  # iterate over batches
  for step,batch in enumerate(tqdm(train_dataloader)):

    # push the batch to gpu
    batch ={k: v.to(device) for k, v in batch.items()}
    labels, mask, sent_id = batch['labels'], batch['mask'], batch['sent_id']

    # clear previously calculated gradients 
    model.zero_grad()        

    # get model predictions for the current batch
    preds = model(sent_id, mask)

    # compute the loss between actual and predicted values
    loss = criterion(preds, labels)

    # add on to the total loss
    total_loss = total_loss + loss.item()

    # backward pass to calculate the gradients
    loss.backward()

    # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    # update parameters
    optimizer.step()

    # model predictions are stored on GPU. So, push it to CPU
    preds=preds.detach().cpu().numpy()

    # append the model predictions
    total_preds.append(preds)

  # compute the training loss of the epoch
  avg_loss = total_loss / len(train_dataloader)
  
  # predictions are in the form of (no. of batches, size of batch, no. of classes).
  # reshape the predictions in form of (number of samples, no. of classes)
  total_preds  = np.concatenate(total_preds, axis=0)

  return avg_loss, total_preds

In [None]:
# function for evaluating the model
def evaluate():
  
  print("\nEvaluating...")
  
  # deactivate dropout layers
  model.eval()

  total_loss, total_accuracy = 0, 0
  
  # empty list to save the model predictions
  total_preds = []

  # iterate over batches
  for step,batch in enumerate(tqdm(val_dataloader)):
    

    # push the batch to gpu
    batch ={k: v.to(device) for k, v in batch.items()}
    labels, mask, sent_id = batch['labels'], batch['mask'], batch['sent_id']

    # deactivate autograd
    with torch.no_grad():
      
      # model predictions
      preds = model(sent_id, mask)

      # compute the validation loss between actual and predicted values
      loss = criterion(preds,labels)

      total_loss = total_loss + loss.item()

      preds = preds.detach().cpu().numpy()

      total_preds.append(preds)

  # compute the validation loss of the epoch
  avg_loss = total_loss / len(val_dataloader) 

  # reshape the predictions in form of (number of samples, no. of classes)
  total_preds  = np.concatenate(total_preds, axis=0)

  return avg_loss, total_preds

In [None]:
#Function to combine training and validation loops
def compute():
    # set initial loss to infinite
    best_valid_loss = float('inf')
    training_metrics, validation_metrics=defaultdict(list), defaultdict(list)
    times=defaultdict(float)
    # empty lists to store training and validation loss of each epoch
    train_losses=[]
    valid_losses=[]
    torch.manual_seed(random_seed)
    train_labels=train_data['labels']
    valid_labels=val_data['labels']
    #for each epoch
    for epoch in range(epochs):
        start_time=time.perf_counter()
        print('\nEpoch {:} / {:}'.format(epoch + 1, epochs))
        
        #train model
        train_loss, train_preds = train() 
        train_preds= torch.argmax(torch.tensor(train_preds),dim=1)
        train_acc, train_f1=accuracy_score(train_labels, train_preds)*100, f1_score(train_labels, train_preds)*100
        print("\nTraining metrics \n Accuracy: {:.3f}, F-1 Score: {:.3f}".format(train_acc, train_f1))
        #evaluate model
        valid_loss, valid_preds = evaluate()
        valid_preds= torch.argmax(torch.tensor(valid_preds),dim=1)
        valid_acc, valid_f1=accuracy_score(valid_labels, valid_preds)*100, f1_score(valid_labels, valid_preds)*100
        print("\nValidation metrics \n Accuracy: {:.3f}, F-1 Score: {:.3f}".format(valid_acc, valid_f1))
        end_time=time.perf_counter()
        times[epoch]=end_time-start_time
        #save the best model
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'saved_weights.pt')
        
        # append training and validation loss
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        training_metrics[epoch]=[train_acc, train_f1]
        validation_metrics[epoch]=[valid_acc, valid_f1]

    return training_metrics, validation_metrics, times

# Testing and metrics

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
metrics_per_run=defaultdict(dict)
for n in range(2,6,2):
    base_model=freeze_model(base_model, freeze_layer_count=n)
    model=Classifier(base_model,num_classes=num_classes)
    model.to(device)
    lr=1e-4
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = lr, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    print("\nTraining with {} layers frozen".format(n))
    metrics_per_run[n]=compute()


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
fig,ax=plt.subplots()
for key in sorted(list(metrics_per_run.keys())):
    accs=metrics_per_run[key][0]
    f1s=metrics_per_run[key][1]
    ax.set_xlabel("Epochs")
    # ax.set_ylabel("%")
    # ax.plot(range(epochs), [accs[i][0]for i in range(epochs)], label='T_A_'+str(key))
    ax.plot(range(epochs), [accs[i][1]for i in range(epochs)], label='V_A_'+str(key))
    # ax.plot(range(epochs), [f1s[i][0]for i in range(epochs)], label='T_F1_'+str(key))
    ax.plot(range(epochs), [f1s[i][1]for i in range(epochs)], label='V_F1_'+str(key))
    plt.legend(loc="best")
plt.show()

