## 0 - Import

In [None]:
%load_ext autoreload
%autoreload 2

import time

import torch 
from torch import nn 
from torch.optim import AdamW, Adam
from transformers import RobertaTokenizerFast, RobertaConfig

from roberta import generate_data_loader, train, evaluate, plot_curves, RobertaExtraction, test

## 1 - Train Hyperparameters

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else : 
    device = torch.device('cpu')

root_path = './'
config = RobertaConfig.from_pretrained('distilroberta-base')
model = RobertaExtraction(config).to(device)
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
batch_size = 4
train_path, dev_path, test_path = root_path + 'data/dataset/train_ds.pkl', root_path + 'data/dataset/dev_ds.pkl', root_path + 'data/dataset/test_ds.pkl'
train_dataset, train_data_loader, dev_dataset, dev_data_loader, test_dataset, test_data_loader = generate_data_loader(batch_size, 
                                                                                                                      train_path, 
                                                                                                                      dev_path, 
                                                                                                                      test_path, 
                                                                                                                      tokenizer)

epochs = 4
accumulation_steps = 10
clip = 1.0
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-05, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)
# scheduler = 

train_loss_set = []
dev_loss_set = []
dev_f_beta_set = []
dev_f1_set = []

for epoch in range(epochs) : 
    print(f'------------------------ \n Epochs {epoch+1}')
    start_time = time.time()

    train_loss = train(model, device, accumulation_steps, train_dataset, train_data_loader, loss_fn, optimizer, batch_size, clip)
    dev_loss, dev_f_beta , dev_f1 = evaluate(model, device, tokenizer, dev_data_loader, loss_fn, beam_size = 1, batch_size = batch_size)

    train_loss_set.append(train_loss)
    dev_loss_set.append(dev_loss)
    dev_f_beta_set.append(dev_f_beta)
    dev_f1_set.append(dev_f1)
    end_time = time.time()

    print(f'Epoch took : {end_time-start_time}')
    torch.save(model.state_dict(), root_path + 'roberta/roberta_qa_checkpoint.pt')
    
plot_curves(epochs, train_loss_set, dev_loss_set, dev_f_beta_set, dev_f1_set)

## 2 - Evaluation 

In [None]:
F_05_Score, pred_by_doc = test(tokenizer, 1, model, device, test_data_loader, test_dataset)
print(F_05_Score)

In [None]:
# idc = 437

# print(pred_by_doc[idc]['truth'])
# pred_by_doc[idc]['pred']