In [1]:
import os
import shutil
from collections import Counter
import numpy as np
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, ElectraForQuestionAnswering, DataCollatorWithPadding,BertModel, ElectraForSequenceClassification, ElectraModel
from Preprocess.arabertpreprocess import ArabertPreprocessor
import matplotlib.pyplot as plt
import seaborn as sns
import csv
torch.manual_seed(3407)

<torch._C.Generator at 0x7f3de9290130>

## Preprocessing

In [2]:
def add_end_index(answer, context):
  ## 1 if span match the context 0 otherwise
  text = answer['text']
  start_idx = answer['answer_start']
  end_idx = start_idx + len(text)
  answer['answer_end'] = end_idx
  if text == context[start_idx:end_idx]:
    answer['answer_end'] = end_idx
    return False
  for i in range(1,3):
    if text == context[start_idx-i:end_idx-i]:
      answer['answer_end']= end_idx-1
      answer['answer_start'] = start_idx-1
      return False
  return True

In [3]:
def arabert_preprocess(context,question, answer, arabert_prep):
    answer['text'] = arabert_prep.preprocess(answer['text'])
    context = arabert_prep.preprocess(context)
    question = arabert_prep.preprocess(question)
    res = context.find(answer['text'])
    if res !=-1:
        answer['answer_start'] = res
    return context, question, answer, res

In [4]:
def Read_AAQAD(path,arabert_prep):
  contexts =[]
  answers =[]
  questions =[]
  IDs= []
  plausible = []
  cnt = 0
  with open(path) as f:
    aaqad_dict = json.load(f)
    for article in aaqad_dict['data']:
      for passage in article['paragraphs']:
        context = passage['context']
        for qa in passage['qas']:
          question = qa['question']
          if 'plausible_answers' in qa.keys():# there is two cases if the question have no answer then use plausible answer
            access = 'plausible_answers'
            plausible.append(True)
          else:
            access = 'answers'
            plausible.append(False)
          for answer in qa[access]:
            context,question, answer, res =  arabert_preprocess(context,question, answer, arabert_prep)
            #if res==-1:
            #  cnt+=1
            #  continue
            flag = add_end_index(answer, context) #if false dont add the 
            cnt =cnt + flag
            flag = False
            if not flag:
              contexts.append(context)
              answers.append(answer)
              questions.append(question)
              IDs.append(int(qa['id']))
  return contexts,questions,answers,plausible,IDs

In [5]:
def fix_ids(path):
  #IDs need to be fixed for evaluating purposes
    a_file = open(path, "r")
    json_object = json.load(a_file)
    a_file.close()
    idx_cnt = 1
    for article in json_object['data']:
      for passage in article['paragraphs']:
        context = passage['context']
        for qa in passage['qas']:
            qa['id'] = str(idx_cnt)
            idx_cnt = idx_cnt + 1
    a_file = open(path, "w")
    json.dump(json_object, a_file)
    a_file.close()

In [6]:
model_name = "araelectra-base-discriminator"
arabert_prep = ArabertPreprocessor(model_name=model_name)
fix_ids('Data/asquadv2-train.json')
fix_ids('Data/asquadv2-val.json')
fix_ids('Data/asquadv2-test.json')
aqad_train_contexts, aqad_train_questions, aqad_train_answers,aqad_train_plausible, aqad_train_ids = Read_AAQAD('Data/asquadv2-train.json', arabert_prep)
aqad_val_contexts, aqad_val_questions, aqad_val_answers,aqad_val_plausible, aqad_val_ids = Read_AAQAD('Data/asquadv2-val.json', arabert_prep)
aqad_test_contexts, aqad_test_questions, aqad_test_answers,aqad_test_plausible, aqad_test_ids = Read_AAQAD('Data/asquadv2-test.json', arabert_prep)


In [7]:
print(sum(aqad_train_ids)==len(aqad_train_ids)*(len(aqad_train_ids)+1)/2)
print(sum(aqad_val_ids)==len(aqad_val_ids)*(len(aqad_val_ids)+1)/2)
print(sum(aqad_test_ids)==len(aqad_test_ids)*(len(aqad_test_ids)+1)/2)

True
True
True


In [8]:
def get_answered_feat(contexts, questions, answers, plausible):
    new_contexts, new_questions, new_answers = [], [], []
    for i in range(len(answers)):
        if plausible[i] == False:
            new_contexts.append(contexts[i])
            new_questions.append(questions[i])
            new_answers.append(answers[i])
    return new_contexts, new_questions, new_answers
span_train_contexts, span_train_questions, span_train_answers = get_answered_feat(aqad_train_contexts, aqad_train_questions, aqad_train_answers, aqad_train_plausible)   

In [12]:
print(len(span_train_contexts), len(aqad_train_contexts))
print(sum(aqad_test_plausible))

42042 76840
4350


## Tokenization

In [13]:
#Creating the tokenizer
model_name = model_name = "wissamantoun/araelectra-base-artydiqa"
araelectra_tokenizer = AutoTokenizer.from_pretrained(model_name,do_lower_case=False)

In [14]:
train_encodings = araelectra_tokenizer(aqad_train_questions, aqad_train_contexts, truncation = True )
span_train_encodings = araelectra_tokenizer(span_train_questions, span_train_contexts, truncation=True)
val_encodings = araelectra_tokenizer(aqad_val_questions, aqad_val_contexts, truncation=True, return_offsets_mapping=True)
test_encodings = araelectra_tokenizer(aqad_test_questions, aqad_test_contexts, truncation=True,  return_offsets_mapping=True)

In [15]:
val_offset = val_encodings['offset_mapping']
del val_encodings['offset_mapping']
test_offset = test_encodings['offset_mapping']
del test_encodings['offset_mapping']

In [16]:
val_ids_to_idx = {k:i for i,k in enumerate(aqad_val_ids)}
test_ids_to_idx = {k:i for i,k in enumerate(aqad_test_ids)}


In [17]:
def index_to_token_position(encodings , answers):
  start_positions = list()
  end_positions = list()
  for i in range(len(answers)):
    start_positions.append(encodings.char_to_token(i, answers[i]['answer_start'], 1))
    end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'], 1))
    #if context truncated
    if start_positions[-1] is None: 
      start_positions[-1] = araelectra_tokenizer.model_max_length
    #if end index is space
    itt = 1
    while end_positions[-1] is None: 
      end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end']-itt, 1)
      itt = itt + 1 
  encodings.update({'start_positions': torch.tensor(start_positions), 'end_positions': torch.tensor(end_positions)})
  encodings['start_positions'] = encodings['start_positions'].view(len(answers), 1)
  encodings['end_positions'] = encodings['end_positions'].view(len(answers), 1)

In [18]:
index_to_token_position(span_train_encodings, span_train_answers)
index_to_token_position(val_encodings, aqad_val_answers)
index_to_token_position(test_encodings, aqad_test_answers)

In [19]:
def add_weights_labels_tensors(encodings, plausible):
  plausible = torch.tensor(plausible)
  weights = torch.ones(plausible.shape)
  no_ans = torch.ones(plausible.shape)
  weights[plausible==False]=2.0
  no_ans[plausible==False]=0.0
  weights = weights.view(-1,1)
  no_ans = no_ans.view(-1,1)
  encodings.update({'weights':weights, 'no_ans':no_ans})

In [20]:
add_weights_labels_tensors(train_encodings, aqad_train_plausible)
add_weights_labels_tensors(val_encodings, aqad_val_plausible)
add_weights_labels_tensors(test_encodings, aqad_test_plausible)

In [21]:
val_encodings['IDs'] = aqad_val_ids
test_encodings['IDs'] = aqad_test_ids

In [22]:
print(train_encodings.keys())
print(val_encodings.keys())
print(test_encodings.keys())
print(span_train_encodings.keys())

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'weights', 'no_ans'])
dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions', 'weights', 'no_ans', 'IDs'])
dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions', 'weights', 'no_ans', 'IDs'])
dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'])


## Dataset and DataLoader

In [23]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

In [24]:
class AqadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

cls_train_dataset = AqadDataset(train_encodings)
span_train_dataset = AqadDataset(span_train_encodings)
val_dataset = AqadDataset(val_encodings)
test_dataset = AqadDataset(test_encodings)

In [25]:
data_collator = DataCollatorWithPadding(araelectra_tokenizer)

In [26]:
cls_train_loader = DataLoader(cls_train_dataset, batch_size=8, shuffle= True, collate_fn= data_collator)
span_train_loader = DataLoader(span_train_dataset, batch_size=8, shuffle= True, collate_fn = data_collator)
val_loader = DataLoader(val_dataset, batch_size = 8, shuffle = True, collate_fn = data_collator)
test_loader = DataLoader(test_dataset, batch_size = 8, shuffle = True, collate_fn = data_collator)


## Checkpoints

In [28]:
def save_ckp(state, is_best, checkpoint_path, best_model_path):
    """
    state: checkpoint to save
    is_best: is this the best checkpoint; min validation loss
    checkpoint_path: path to save checkpoint
    best_model_path: path to save best checkpoint
    """
    f_path = checkpoint_path
    # save checkpoint data to the path given, checkpoint_path
    torch.save(state, f_path)
    # if it is a best model, min validation loss
    if is_best:
        best_fpath = best_model_path
        # copy that checkpoint file to best path given, best_model_path
        shutil.copyfile(f_path, best_fpath)

In [29]:
def load_ckp(checkpoint_fpath, model, optimizer):
    """
    checkpoint_path: path to saved checkpoint
    model: model to load checkpoint parameters into       
    optimizer: optimizer defined in previous training
    """
    # load check point
    checkpoint = torch.load(checkpoint_fpath)
    # initialize state_dict from checkpoint to model
    model.load_state_dict(checkpoint['state_dict'])
    # initialize optimizer from checkpoint to optimizer
    optimizer.load_state_dict(checkpoint['optimizer'])
    # initialize valid_loss_min from checkpoint to valid_loss_min
    results = checkpoint['result_dict']
    # return model, optimizer, epoch value, min validation loss 
    return model, optimizer, checkpoint['epoch'], results

In [30]:
def order_exp(base_path, exp_name):
  exp_path = os.path.join(base_path, exp_name)
  if not os.path.exists(exp_path):
    os.mkdir(exp_path)
  curr_ckp_path = os.path.join(exp_path,'curr.pt')
  best_ckp_path = os.path.join(exp_path, 'best.pt')
  return curr_ckp_path, best_ckp_path, exp_path

## Classification train and evaluation 

In [31]:
def cls_eval(model, data_loader, exp_path, train_loss):
    model.eval()
    total_acc = 0
    total_pred = None
    total_IDs = None
    soft = torch.nn.Softmax(dim=1)
    for batch in data_loader:
      tokens = batch['input_ids'].to(device)
      masks = batch['attention_mask'].to(device)
      tokens_type = batch['token_type_ids'].to(device)
      IDs = batch['IDs'].to(device)
      gt_no_ans = batch['no_ans'].to(device)
      output = model(tokens, masks, tokens_type)
      pred = output.logits.view(masks.shape[0],2,)
      if total_pred is None:
          total_pred = soft(pred)[:,1]
          total_IDs = IDs
      else:
          total_pred = torch.cat((total_pred, soft(pred)[:,1]), 0)
          total_IDs = torch.cat((total_IDs, IDs), 0)
      pred = torch.argmax(pred, dim=1)
      target = batch['no_ans'].to(device).view(masks.shape[0],)
      total_acc += torch.sum(target==pred)

    total_acc = total_acc/ val_dataset.__len__()
    res_dict = {'acc':total_acc.item()*100, 'train_loss':train_loss}
    if exp_path:
        log_path = os.path.join(exp_path,'res.csv')
        if not os.path.exists(log_path):
            with open(log_path,'w') as f:
                writer = csv.DictWriter(f, fieldnames=res_dict.keys())
                writer.writeheader()
        with open(log_path, 'a') as f:
            writer = csv.DictWriter(f, fieldnames=res_dict.keys())
            #writer.writeheader()
            writer.writerow(res_dict)
    return res_dict, {"IDs":total_IDs, "preds":total_pred}


In [32]:
def cls_train(model,start_epoch, num_epochs, optimizer,max_acc, train_loader, val_loader, log, exp_name):
  curr_ckp_path, best_ckp_path, exp_path = order_exp('Runs/AraElectraDecoupledAsquadv2/train/cls', exp_name)
  model.train()
  cls_pred = None
  for epoch in range(start_epoch,num_epochs):
    total_loss = 0.0
    loop = tqdm(train_loader, leave=True)
    for batch_idx, batch in enumerate(loop):
      tokens = batch['input_ids'].to(device)
      masks = batch['attention_mask'].to(device)
      tokens_type = batch['token_type_ids'].to(device)
      weights = batch['weights'].to(device)
      output = model(tokens, masks, tokens_type)
      pred = output.logits.view(masks.shape[0],2,)
      target = batch['no_ans'].type(torch.LongTensor)
      target = target.to(device)
      loss = cls_criterion(pred, target.view(masks.shape[0],))
      loss = loss*(weights.view(masks.shape[0],))
      loss = torch.mean(loss)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      total_loss = total_loss + ((1 / (batch_idx + 1)) * (loss.item() - total_loss)) 
      loop.set_description(f'Epoch {epoch}')
      loop.set_postfix(loss=loss.item())

    result_dict, cls_pred = cls_eval(model, val_loader,exp_path,total_loss )
    checkpoint = {
            'epoch': epoch + 1,
            'result_dict':result_dict,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
    curr_acc = result_dict['acc']
    if curr_acc>=max_acc:
      max_acc = curr_acc
      save_ckp(checkpoint, True, curr_ckp_path, best_ckp_path)
    else:
      save_ckp(checkpoint, False, curr_ckp_path, best_ckp_path)
    print(result_dict)
  return model, cls_pred


## Modeling

In [33]:
Cls_AraElectra = ElectraForSequenceClassification.from_pretrained(model_name, num_labels=2)
QA_AraElectra = ElectraForQuestionAnswering.from_pretrained(model_name)

Some weights of the model checkpoint at wissamantoun/araelectra-base-artydiqa were not used when initializing ElectraForSequenceClassification: ['qa_outputs.weight', 'qa_outputs.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at wissamantoun/araelectra-base-artydiqa and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']
You should probably TRAIN this model on a

In [42]:
def freeze(Electra, count=None):
    if count is not None:
	      # We freeze here the embeddings of the model
        for param in Electra.embeddings.parameters():
            param.requires_grad = False

        if count != -1:
	          # if freeze_layer_count == -1, we only freeze the embedding layer
	          # otherwise we freeze the first `freeze_layer_count` encoder layers
            for layer in Electra.encoder.layer[:count]:
                for param in layer.parameters():
                    param.requires_grad = False
    print(sum(p.numel() for p in Electra.parameters()), sum(p.numel() for p in Electra.parameters() if p.requires_grad))

In [30]:
freeze(Cls_AraElectra.electra,6)
freeze(QA_AraElectra.electra, 6)

134602752 42527232
134602752 42527232


## Classification Training

In [31]:
num_epochs = 2
learning_rate = 3e-5
optimizer = torch.optim.Adam(Cls_AraElectra.parameters(), lr=learning_rate)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#criterion_span = nn.CrossEntropyLoss(reduction='none')
cls_criterion = nn.CrossEntropyLoss(reduction='none')
Cls_AraElectra.to(device)

ElectraForSequenceClassification(
  (electra): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(64000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ElectraEncoder(
      (layer): ModuleList(
        (0): ElectraLayer(
          (attention): ElectraAttention(
            (self): ElectraSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ElectraSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm

In [46]:
cls_trained_model, cls_pred = cls_train(Cls_AraElectra, 1, 2, optimizer, 0, cls_train_loader, val_loader , True, 'second')

Epoch 1: 100%|██████████| 9605/9605 [1:23:13<00:00,  1.92it/s, loss=0.831] 


{'acc': 82.08224773406982, 'train_loss': 0.48609214223793196}


In [48]:
cls_trained_model, cls_pred = cls_train(cls_trained_model, 2, 4, optimizer, 82, cls_train_loader, val_loader , True, 'second')

Epoch 2: 100%|██████████| 9605/9605 [1:23:06<00:00,  1.93it/s, loss=0.44]   
Epoch 3:  95%|█████████▍| 9078/9605 [1:16:59<05:47,  1.52it/s, loss=0.546]  Epoch 3:  98%|█████████▊| 9365/9605 [1:19:26<02:16,  1.76it/s, loss=0.236] Epoch 3:  98%|█████████▊| 9371/9605 [1:19:29<02:02,  1.91it/s, loss=0.362]

{'acc': 83.96668434143066, 'train_loss': 0.3634232084243571}


In [49]:
cls_trained_model, cls_pred = cls_train(cls_trained_model, 5, 8, optimizer, 84.3, cls_train_loader, val_loader , True, 'second')

Epoch 5: 100%|██████████| 9605/9605 [1:23:19<00:00,  1.92it/s, loss=0.0704] 
Epoch 6: 100%|██████████| 9605/9605 [1:21:32<00:00,  1.96it/s, loss=0.433]   
Epoch 7: 100%|██████████| 9605/9605 [1:21:26<00:00,  1.97it/s, loss=0.00571] 


{'acc': 84.14367437362671, 'train_loss': 0.19663139513469785}
{'acc': 83.68558287620544, 'train_loss': 0.09591675791153734}
{'acc': 83.75846147537231, 'train_loss': 0.08394046925884102}


## Load Cls Model if needed

In [34]:
cls_model = ElectraForSequenceClassification.from_pretrained(model_name)

Some weights of the model checkpoint at wissamantoun/araelectra-base-artydiqa were not used when initializing ElectraForSequenceClassification: ['qa_outputs.weight', 'qa_outputs.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at wissamantoun/araelectra-base-artydiqa and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']
You should probably TRAIN this model on a

In [35]:
learning_rate = 3e-5
optimizer = torch.optim.AdamW(cls_model.parameters(), lr=learning_rate)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#criterion_span = nn.CrossEntropyLoss(reduction='none')
#cls_criterion = nn.CrossEntropyLoss()
cls_model.to(device)

ElectraForSequenceClassification(
  (electra): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(64000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ElectraEncoder(
      (layer): ModuleList(
        (0): ElectraLayer(
          (attention): ElectraAttention(
            (self): ElectraSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ElectraSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm

In [36]:
cls_model, optimizer, start_epoch, result_dict = load_ckp('Runs/AraElectraDecoupledAsquadv2/train/cls/second/best.pt', cls_model, optimizer)


## Span Training

In [37]:
def get_raw_preds(data_loader, model,ids_to_index,offset,contexts, max_answer_length, n_best_size): 
  model.eval()
  imd_predictions,script_predictions = dict(), dict()
  with torch.no_grad():
    #F1 = EM = Total = 0
    total_loss = 0.0
    total_predictions = dict()
    no_probs_pred = dict()
    #loop = tqdm(data_loader)
    loop = tqdm(data_loader, leave=True)
    for batch_idx, batch in enumerate(loop):
      tokens = batch['input_ids'].to(device)
      masks = batch['attention_mask'].to(device)
      tokens_type = batch['token_type_ids'].to(device)
      gt_start = batch['start_positions'].to(device)
      gt_end = batch['end_positions'].to(device)
      IDs = batch['IDs'].to(device)
      outputs = model(tokens, masks, tokens_type, start_positions=gt_start, end_positions=gt_end)
      #calculating loss
      loss = outputs.loss
      #update average total loss 
      total_loss = total_loss + ((1 / (batch_idx + 1)) * (loss.item() - total_loss)) 
      #calculating f1 score and EM
      curr_batch_size = tokens.shape[0]
      post_raw_preds(IDs, outputs.start_logits, outputs.end_logits, ids_to_index, offset, contexts,max_answer_length, n_best_size, imd_predictions, script_predictions )
    #saving evaluation results
    #evaluation

    model.train()
    return imd_predictions,script_predictions

In [38]:
def post_raw_preds(IDs, total_start_logits, total_end_logits,ids_to_index,offset,contexts, max_answer_length, n_best_size,
 imd_predictions,script_predictions ):
    total_start_logits = total_start_logits.cpu().numpy()
    total_end_logits = total_end_logits.cpu().numpy()
    IDs = IDs.cpu().numpy()
    for i in range(IDs.shape[0]):
        offset_mapping = offset[ids_to_index[IDs[i].squeeze()]]
        # The first feature comes from the first example. For the more general case, we will need to be match the example_id to
        # an example index
        context = contexts[ids_to_index[IDs[i].squeeze()]]
        start_logits = total_start_logits[i]
        end_logits = total_end_logits[i]
        # Gather the indices the best start/end logits:
        start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
        end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
        valid_answers = []
        for start_index in start_indexes:
            for end_index in end_indexes:
                # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                # to part of the input_ids that are not in the context.
                if (
                    start_index >= len(offset_mapping)
                    or end_index >= len(offset_mapping)
                    or offset_mapping[start_index] is None
                    or offset_mapping[end_index] is None
                ):
                    continue
                # Don't consider answers with a length that is either < 0 or > max_answer_length.
                if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                    continue
                if start_index <= end_index: # We need to refine that test to check the answer is inside the context
                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        if len(valid_answers) ==0:
            valid_answers.append({"text":"", "score":""})

        valid_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        imd_predictions[str(IDs[i].squeeze())] = valid_answer
        script_predictions[str(IDs[i].squeeze())] = valid_answer['text']

In [47]:
def get_preds(total_preds, no_probs_preds,data_path, log_path):
    preds_path = os.path.join(log_path, 'preds')
    if not os.path.exists(preds_path):
        os.mkdir(preds_path)
    no_probs_path = os.path.join(preds_path, 'na_probs.json')
    text_preds_path = os.path.join(preds_path, 'preds.json')
    jsonString = json.dumps(total_preds)
    jsonFile = open(text_preds_path, "w")
    jsonFile.write(jsonString)
    jsonFile.close()
    if no_probs_preds is not None:
        jsonString = json.dumps(no_probs_preds)
        jsonFile = open(no_probs_path, "w")
        jsonFile.write(jsonString)
        jsonFile.close()
        #!python evaluatev2.py data_path text_preds_path electra --na-prob-file no_probs_path --na-prob-thresh 0.4 --out-file log_path
        #os.system(f"python evaluatev2.py {data_path} {text_preds_path} electra --na-prob-file {no_probs_path} --na-prob-thresh 0.5 --out-file {log_path}")
        !/anaconda/envs/azureml_py38/bin/python3 evaluatev2.py Data/asquadv2-val.json Runs/AraElectraDecoupledAsquadv2/train/span/third/preds/preds.json electra --na-prob-file Runs/AraElectraDecoupledAsquadv2/train/span/second/preds/na_probs.json  --na-prob-thresh 0.5  
    else:
        !/anaconda/envs/azureml_py38/bin/python3 evaluatev2.py Data/asquadv2-val.json Runs/AraElectraDecoupledAsquadv2/train/span/third/preds/preds.json electra  --out-file Runs/AraElectraDecoupledAsquadv2/train/span/third
    if log_path:
        with open(os.path.join(log_path, 'res.csv')) as f:
            DictReader_obj = csv.DictReader(f)
            lastrow = None
            for item in DictReader_obj:
                lastrow = dict(item)
        #print(lastrow)
        return lastrow
    return 1


In [44]:
def span_train(model,start_epoch, num_epochs, optimizer,max_compined_metric, train_loader, val_loader, log, exp_name):
  curr_ckp_path, best_ckp_path, exp_path = order_exp('Runs/AraElectraDecoupledAsquadv2/train/span', exp_name)
  model.train()
  for epoch in range(start_epoch,num_epochs):
    total_loss = 0.0
    loop = tqdm(train_loader, leave=True)
    for batch_idx, batch in enumerate(loop):
      tokens = batch['input_ids'].to(device)
      masks = batch['attention_mask'].to(device)
      tokens_type = batch['token_type_ids'].to(device)
      gt_start = batch['start_positions'].to(device)
      gt_end = batch['end_positions'].to(device)
      outputs = model(tokens, masks, tokens_type, start_positions=gt_start, end_positions=gt_end)
      loss = outputs.loss
      loss = 2*loss
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      total_loss = total_loss + ((1 / (batch_idx + 1)) * (loss.item() - total_loss)) 
      loop.set_description(f'Epoch {epoch}')
      loop.set_postfix(loss=loss.item())
    
    imd_preds, script_preds = get_raw_preds(val_loader, model,val_ids_to_idx,val_offset,aqad_val_contexts, 30, 10)
    result_dict = get_preds(script_preds, None,'Data/asquadv2-val.json',exp_path )
    checkpoint = {
            'epoch': epoch + 1,
            'result_dict':result_dict,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
    curr_compined_metric = float(result_dict['HasAns_exact'])+1.5*float(result_dict['HasAns_f1'])
    if curr_compined_metric>=max_compined_metric:
      max_compined_metric = curr_compined_metric
      save_ckp(checkpoint, True, curr_ckp_path, best_ckp_path)
    else:
      save_ckp(checkpoint, False, curr_ckp_path, best_ckp_path)
    print("ckp saved")
  return model


In [43]:
QA_AraElectra = ElectraForQuestionAnswering.from_pretrained(model_name)
freeze(QA_AraElectra.electra, 6)
span_num_epochs = 4
span_learning_rate = 3e-5
span_optimizer = torch.optim.AdamW(QA_AraElectra.parameters(), lr=span_learning_rate)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#criterion_span = nn.CrossEntropyLoss(reduction='none')
#cls_criterion = nn.CrossEntropyLoss()
QA_AraElectra.to(device)

134602752 42527232


ElectraForQuestionAnswering(
  (electra): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(64000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ElectraEncoder(
      (layer): ModuleList(
        (0): ElectraLayer(
          (attention): ElectraAttention(
            (self): ElectraSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ElectraSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768

In [45]:
span_trained_model = span_train(QA_AraElectra,0, 4, span_optimizer,0.0, span_train_loader, val_loader, True, 'third')

Epoch 0: 100%|██████████| 5256/5256 [46:06<00:00,  1.90it/s, loss=3.37] 
100%|██████████| 1201/1201 [05:29<00:00,  3.65it/s]
Epoch 1: 100%|██████████| 5256/5256 [46:20<00:00,  1.89it/s, loss=4.93] 
100%|██████████| 1201/1201 [05:28<00:00,  3.66it/s]
Epoch 2: 100%|██████████| 5256/5256 [46:15<00:00,  1.89it/s, loss=1.36]  
100%|██████████| 1201/1201 [05:29<00:00,  3.65it/s]
Epoch 3: 100%|██████████| 5256/5256 [46:20<00:00,  1.89it/s, loss=0.248] 
100%|██████████| 1201/1201 [05:27<00:00,  3.67it/s]


{
  "exact": 31.80635085892764,
  "f1": 39.18377598596737,
  "total": 9605,
  "HasAns_exact": 58.11607992388202,
  "HasAns_f1": 71.60041262516015,
  "HasAns_total": 5255,
  "NoAns_exact": 0.022988505747126436,
  "NoAns_f1": 0.022988505747126436,
  "NoAns_total": 4350
}
ckp saved
{
  "exact": 32.4518479958355,
  "f1": 39.419988150956904,
  "total": 9605,
  "HasAns_exact": 59.27687916270219,
  "HasAns_f1": 72.01312772406109,
  "HasAns_total": 5255,
  "NoAns_exact": 0.04597701149425287,
  "NoAns_f1": 0.04597701149425287,
  "NoAns_total": 4350
}
ckp saved
{
  "exact": 32.160333159812595,
  "f1": 39.31507513222507,
  "total": 9605,
  "HasAns_exact": 58.76308277830638,
  "HasAns_f1": 71.84039898097464,
  "HasAns_total": 5255,
  "NoAns_exact": 0.022988505747126436,
  "NoAns_f1": 0.022988505747126436,
  "NoAns_total": 4350
}
ckp saved
{
  "exact": 31.70223841749089,
  "f1": 38.887389352235836,
  "total": 9605,
  "HasAns_exact": 57.906755470980016,
  "HasAns_f1": 71.03965265998576,
  "HasAns_to

In [50]:
!/anaconda/envs/azureml_py38/bin/python3 evaluatev2.py Data/asquadv2-val.json Runs/AraElectraDecoupledAsquadv2/train/span/third/preds/preds.json electra --na-prob-file Runs/AraElectraDecoupledAsquadv2/train/span/second/preds/na_probs.json  --na-prob-thresh 0.5 

{
  "exact": 63.87298282144716,
  "f1": 70.44332708661499,
  "total": 9605,
  "HasAns_exact": 53.6441484300666,
  "HasAns_f1": 65.65331240093887,
  "HasAns_total": 5255,
  "NoAns_exact": 76.22988505747126,
  "NoAns_f1": 76.22988505747126,
  "NoAns_total": 4350,
  "best_exact": 66.66319625195212,
  "best_exact_thresh": 0.08445756882429123,
  "best_f1": 72.30824856366091,
  "best_f1_thresh": 0.15894971787929535
}
