In [1]:
import os
import shutil
from collections import Counter
import numpy as np
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, ElectraForQuestionAnswering, DataCollatorWithPadding, ElectraModel
from Preprocess.arabertpreprocess import ArabertPreprocessor
import matplotlib.pyplot as plt
import seaborn as sns
import csv
torch.manual_seed(3407)

<torch._C.Generator at 0x7f647ce5dd68>

## Preprocessing

In [2]:
def add_end_index(answer, context):
  ## 1 if span match the context 0 otherwise
  text = answer['text']
  start_idx = answer['answer_start']
  end_idx = start_idx + len(text)
  answer['answer_end'] = end_idx
  if text == context[start_idx:end_idx]:
    answer['answer_end'] = end_idx
    return False
  for i in range(1,3):
    if text == context[start_idx-i:end_idx-i]:
      answer['answer_end']= end_idx-1
      answer['answer_start'] = start_idx-1
      return False
  return True

In [3]:
def arabert_preprocess(context,question, answer, arabert_prep):
    answer['text'] = arabert_prep.preprocess(answer['text'])
    context = arabert_prep.preprocess(context)
    question = arabert_prep.preprocess(question)
    res = context.find(answer['text'])
    if res !=-1:
        answer['answer_start'] = res
    return context, question, answer, res

In [4]:
def Read_AAQAD(path,arabert_prep):
  contexts =[]
  answers =[]
  questions =[]
  IDs= []
  plausible = []
  cnt = 0
  with open(path) as f:
    aaqad_dict = json.load(f)
    for article in aaqad_dict['data']:
      for passage in article['paragraphs']:
        context = passage['context']
        for qa in passage['qas']:
          question = qa['question']
          if 'plausible_answers' in qa.keys():# there is two cases if the question have no answer then use plausible answer
            access = 'plausible_answers'
            plausible.append(True)
          else:
            access = 'answers'
            plausible.append(False)
          for answer in qa[access]:
            context,question, answer, res =  arabert_preprocess(context,question, answer, arabert_prep)
            #if res==-1:
            #  cnt+=1
            #  continue
            flag = add_end_index(answer, context) #if false dont add the 
            cnt =cnt + flag
            flag = False
            if not flag:
              contexts.append(context)
              answers.append(answer)
              questions.append(question)
              IDs.append(int(qa['id']))
  return contexts,questions,answers,plausible,IDs
            

In [5]:
def fix_ids(path):
  #IDs need to be fixed for evaluating purposes
    a_file = open(path, "r")
    json_object = json.load(a_file)
    a_file.close()
    for article in json_object['data']:
      for passage in article['paragraphs']:
        context = passage['context']
        for qa in passage['qas']:
            qa['id'] = str(qa['id'])
    a_file = open(path, "w")
    json.dump(json_object, a_file)
    a_file.close()

In [6]:
model_name = "araelectra-base-discriminator"
arabert_prep = ArabertPreprocessor(model_name=model_name)
fix_ids('Data/AAQAD-train.json')
fix_ids('Data/AAQAD-dev.json')
fix_ids('Data/AAQAD-test.json')
aqad_train_contexts, aqad_train_questions, aqad_train_answers,aqad_train_plausible, aqad_train_ids = Read_AAQAD('Data/AAQAD-train.json', arabert_prep)
aqad_val_contexts, aqad_val_questions, aqad_val_answers,aqad_val_plausible, aqad_val_ids = Read_AAQAD('Data/AAQAD-dev.json', arabert_prep)
aqad_test_contexts, aqad_test_questions, aqad_test_answers,aqad_test_plausible, aqad_test_ids = Read_AAQAD('Data/AAQAD-test.json', arabert_prep)


In [7]:
print(type(aqad_train_ids[0]))
print(type(aqad_val_ids[0]))
print(type(aqad_test_ids[0]))

<class 'int'>
<class 'int'>
<class 'int'>


## Tokenization

In [8]:
#Creating the tokenizer
model_name = model_name = "aubmindlab/araelectra-base-discriminator"

araelectra_tokenizer = AutoTokenizer.from_pretrained(model_name,do_lower_case=False)
aqad_train_encodings = araelectra_tokenizer(aqad_train_questions, aqad_train_contexts, truncation=True, return_tensors="pt", padding="max_length", max_length=512)
aqad_val_encodings = araelectra_tokenizer(aqad_val_questions, aqad_val_contexts, truncation=True, return_tensors="pt", padding="max_length", max_length=512 )
aqad_test_encodings = araelectra_tokenizer(aqad_test_questions, aqad_test_contexts,truncation= True, return_tensors="pt", padding="max_length", max_length=512)



In [9]:
def index_to_token_position(encodings , answers):
  start_positions = list()
  end_positions = list()
  for i in range(len(answers)):
    start_positions.append(encodings.char_to_token(i, answers[i]['answer_start'], 1))
    end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'], 1))
    #if context truncated
    if start_positions[-1] is None: 
      start_positions[-1] = araelectra_tokenizer.model_max_length
    #if end index is space
    itt = 1
    while end_positions[-1] is None: 
      end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end']-itt, 1)
      itt = itt + 1 
  encodings.update({'start_positions': torch.tensor(start_positions), 'end_positions': torch.tensor(end_positions)})
  encodings['start_positions'] = encodings['start_positions'].view(len(answers), 1)
  encodings['end_positions'] = encodings['end_positions'].view(len(answers), 1)

In [10]:
index_to_token_position(aqad_train_encodings, aqad_train_answers)
index_to_token_position(aqad_val_encodings, aqad_val_answers)
index_to_token_position(aqad_test_encodings, aqad_test_answers)

In [11]:
def add_weights_labels_tensors(encodings, plausible):
  plausible = torch.tensor(plausible)
  weights = torch.zeros(plausible.shape)
  no_ans = torch.ones(plausible.shape)
  weights[plausible==False]=1.0
  no_ans[plausible==False]=0.0
  weights = weights.view(-1,1)
  no_ans = no_ans.view(-1,1)
  encodings.update({'weights':weights, 'no_ans':no_ans})

In [12]:
add_weights_labels_tensors(aqad_train_encodings, aqad_train_plausible)
add_weights_labels_tensors(aqad_val_encodings, aqad_val_plausible)
add_weights_labels_tensors(aqad_test_encodings, aqad_test_plausible)

In [13]:
aqad_train_encodings['IDs'] = aqad_train_ids
aqad_val_encodings['IDs'] = aqad_val_ids
aqad_test_encodings['IDs'] = aqad_test_ids

In [14]:
aqad_train_encodings.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions', 'weights', 'no_ans', 'IDs'])

## Dataset and DataLoader

In [15]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

In [16]:
class AqadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = AqadDataset(aqad_train_encodings)
val_dataset = AqadDataset(aqad_val_encodings)
test_dataset = AqadDataset(aqad_test_encodings)

In [17]:
data_collator = DataCollatorWithPadding(araelectra_tokenizer)

In [18]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle= True)
val_loader = DataLoader(val_dataset, batch_size = 8, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 8, shuffle = True)

## Checkpoint Saving and Loading

In [19]:
def save_ckp(state, is_best, checkpoint_path, best_model_path):
    """
    state: checkpoint to save
    is_best: is this the best checkpoint; min validation loss
    checkpoint_path: path to save checkpoint
    best_model_path: path to save best checkpoint
    """
    f_path = checkpoint_path
    # save checkpoint data to the path given, checkpoint_path
    torch.save(state, f_path)
    # if it is a best model, min validation loss
    if is_best:
        best_fpath = best_model_path
        # copy that checkpoint file to best path given, best_model_path
        shutil.copyfile(f_path, best_fpath)

In [20]:
def load_ckp(checkpoint_fpath, model, optimizer):
    """
    checkpoint_path: path to saved checkpoint
    model: model to load checkpoint parameters into       
    optimizer: optimizer defined in previous training
    """
    # load check point
    checkpoint = torch.load(checkpoint_fpath)
    # initialize state_dict from checkpoint to model
    model.load_state_dict(checkpoint['state_dict'])
    # initialize optimizer from checkpoint to optimizer
    optimizer.load_state_dict(checkpoint['optimizer'])
    # initialize valid_loss_min from checkpoint to valid_loss_min
    results = checkpoint['result_dict']
    # return model, optimizer, epoch value, min validation loss 
    return model, optimizer, checkpoint['epoch'], results

In [21]:
def order_exp(base_path, exp_name):
  exp_path = os.path.join(base_path, exp_name)
  if not os.path.exists(exp_path):
    os.mkdir(exp_path)
  curr_ckp_path = os.path.join(exp_path,'curr.pt')
  best_ckp_path = os.path.join(exp_path, 'best.pt')
  return curr_ckp_path, best_ckp_path, exp_path

## Evaluation Script (SQuAD v2)

## Modeling

In [22]:
class AraElectraAbstain(nn.Module):
  def __init__(self, Electra):
    super().__init__()
    self.base = Electra
    #self.Weights = Weights
    self.fcn = nn.Linear(768, 2)
    self.abstainProjection = nn.Linear(768,1)
    self.fcn2 = nn.Linear(512,1)
    self.relu = nn.ReLU()
    self.sigmoid = nn.Sigmoid()
  def forward(self, tokens, token_type, mask):
    output = self.base(input_ids = tokens, attention_mask = mask, token_type_ids = token_type)
    electra_encoding = output.last_hidden_state
    span = self.fcn(electra_encoding)
    na_probs_logits = self.relu(self.abstainProjection(electra_encoding).view(-1,512))
    na_probs_logits = self.fcn2(na_probs_logits)
    na_probs = self.sigmoid(na_probs_logits) 
    return span,na_probs

In [23]:
def freeze(Electra, count=None):
    if count is not None:
	      # We freeze here the embeddings of the model
        #for param in Electra.embeddings.parameters():
        #    param.requires_grad = False

        if count != -1:
	          # if freeze_layer_count == -1, we only freeze the embedding layer
	          # otherwise we freeze the first `freeze_layer_count` encoder layers
            for layer in Electra.encoder.layer[:count]:
                for param in layer.parameters():
                    param.requires_grad = False
    print(sum(p.numel() for p in Electra.parameters()), sum(p.numel() for p in Electra.parameters() if p.requires_grad))

In [24]:
Electra = ElectraModel.from_pretrained(model_name)
QA_AraElectra = AraElectraAbstain(Electra)


In [25]:
freeze(QA_AraElectra.base, 4)

134602752 106251264


## Training and Evaluation

In [26]:
num_epochs = 2
learning_rate = 5e-5
optimizer = torch.optim.AdamW(QA_AraElectra.parameters(), lr=learning_rate)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
criterion_span = nn.CrossEntropyLoss(reduction='none')
criterion_abstain = nn.BCELoss()
QA_AraElectra.to(device)

AraElectraAbstain(
  (base): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(64000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ElectraEncoder(
      (layer): ModuleList(
        (0): ElectraLayer(
          (attention): ElectraAttention(
            (self): ElectraSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ElectraSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [27]:
def get_raw_preds(data_loader, model): 
  model.eval()
  with torch.no_grad():
    #F1 = EM = Total = 0
    total_loss = 0.0
    total_predictions = dict()
    no_probs_pred = dict()
    #loop = tqdm(data_loader)
    #loop = tqdm(data_loader, leave=True)
    for batch_idx, batch in enumerate(data_loader):
      #moving tensors to gpu    
      tokens = batch['input_ids'].to(device)
      masks = batch['attention_mask'].to(device)
      tokens_type = batch['token_type_ids'].to(device)
      gt_start = batch['start_positions'].to(device)
      gt_end = batch['end_positions'].to(device)
      weights = batch['weights'].view(-1,1).to(device)
      gt_no_ans = batch['no_ans'].to(device)
      IDs = batch['IDs'].to(device)
      span, no_probs = model(tokens, masks, tokens_type)
      #calculating loss
      span_loss = criterion_span(span, torch.concat([gt_start,gt_end], dim=1))
      span_loss = torch.sum(torch.mean(span_loss*weights, dim=0))
      no_ans_loss = criterion_abstain(no_probs, gt_no_ans)
      loss = no_ans_loss + span_loss
      #update average total loss 
      total_loss = total_loss + ((1 / (batch_idx + 1)) * (loss.item() - total_loss)) 
      #calculating f1 score and EM
      curr_batch_size = tokens.shape[0]
      #print(curr_batch_size)
      for i in range(curr_batch_size):
        #print(f"this is tensor index {i}")
        span_pred= torch.argmax(span[i], dim=0)
        start_pred, end_pred = span_pred[0], span_pred[1]
        #print(start_pred.shape, end_pred.shape)
        #print(start_pred, end_pred)
        total_predictions[str(IDs[i].item())] = araelectra_tokenizer.decode(tokens[i][start_pred.item():end_pred.item()], skip_special_tokens=True, clean_up_tokenization_spaces=True)
        no_probs_pred[str(IDs[i].item())] = no_probs[i].item()
    #saving evaluation results
    #evaluation

    model.train()
    return total_predictions, no_probs_pred

In [28]:
def get_preds(total_preds, no_probs_preds,data_path, log_path):
    preds_path = os.path.join(log_path, 'preds')
    if not os.path.exists(preds_path):
        os.mkdir(preds_path)
    no_probs_path = os.path.join(preds_path, 'na_probs.json')
    text_preds_path = os.path.join(preds_path, 'preds.json')
    jsonString = json.dumps(total_preds)
    jsonFile = open(text_preds_path, "w")
    jsonFile.write(jsonString)
    jsonFile.close()
    jsonString = json.dumps(no_probs_preds)
    jsonFile = open(no_probs_path, "w")
    jsonFile.write(jsonString)
    jsonFile.close()
    #!python evaluatev2.py data_path text_preds_path electra --na-prob-file no_probs_path --na-prob-thresh 0.4 --out-file log_path
    os.system(f"python evaluatev2.py {data_path} {text_preds_path} electra --na-prob-file {no_probs_path} --na-prob-thresh 0.5 --out-file {log_path}")
    with open(os.path.join(log_path, 'res.csv')) as f:
        DictReader_obj = csv.DictReader(f)
        lastrow = None
        for item in DictReader_obj:
            lastrow = dict(item)
    print(lastrow)
    return lastrow


In [29]:
def train(model,start_epoch, num_epochs, optimizer,max_compined_metric, train_loader, val_loader, log, exp_name):
  curr_ckp_path, best_ckp_path, exp_path = order_exp('Runs/AraElectra_abstain/train', exp_name)
  model.train()
  for epoch in range(start_epoch,num_epochs):
    total_loss = 0.0
    loop = tqdm(train_loader, leave=True)
    for batch_idx, batch in enumerate(loop):
      tokens = batch['input_ids'].to(device)
      masks = batch['attention_mask'].to(device)
      tokens_type = batch['token_type_ids'].to(device)
      gt_start = batch['start_positions'].to(device)
      gt_end = batch['end_positions'].to(device)
      weights = batch['weights'].view(-1,1).to(device)
      gt_no_ans = batch['no_ans'].to(device)
      span,no_probs = model(tokens, masks, tokens_type)
      #print(gt_start.shape)
      #print(span.shape)
      #print(weights.shape)
      span_loss = criterion_span(span, torch.concat([gt_start.view(-1,1),gt_end.view(-1,1)], dim=1))
      #print(span_loss.shape)
      span_loss = torch.sum(torch.mean(span_loss*weights, dim=0))
      #print(no_probs.shape)
      #print(gt_no_ans.shape)
    
      no_ans_loss = criterion_abstain(no_probs, gt_no_ans)
      loss = no_ans_loss + span_loss
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
      total_loss = total_loss + ((1 / (batch_idx + 1)) * (loss.item() - total_loss)) 
      loop.set_description(f'Epoch {epoch}')
      loop.set_postfix(loss=loss.item())

    total_preds, no_probs_preds = get_raw_preds(val_loader, model)
    result_dict = get_preds(total_preds, no_probs_preds,'Data/AAQAD-dev.json',exp_path )
    checkpoint = {
            'epoch': epoch + 1,
            'result_dict':result_dict,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
    curr_compined_metric = float(result_dict['exact'])+1.5*float(result_dict['f1'])
    if curr_compined_metric>=max_compined_metric:
      max_compined_metric = curr_compined_metric
      save_ckp(checkpoint, True, curr_ckp_path, best_ckp_path)
    else:
      save_ckp(checkpoint, False, curr_ckp_path, best_ckp_path)
    print(result_dict)
  return model


In [30]:
trained_model = train(QA_AraElectra,0, num_epochs, optimizer,0.0, train_loader, val_loader, True, 'first')

Epoch 0: 100%|██████████| 1579/1579 [36:26<00:00,  1.39s/it, loss=6.34] 


In [34]:
with open('Runs/AraElectra_abstain/train/first/res.csv') as f:
    DictReader_obj = csv.DictReader(f)
    lastrow = None
    for item in DictReader_obj:
        lastrow = dict(item)
    print(json.dumps(lastrow, indent=2))

{
  "exact": "15.680166147455868",
  "f1": "19.742277127887526",
  "total": "1926",
  "HasAns_exact": "1.0058675607711651",
  "HasAns_f1": "7.563810350638166",
  "HasAns_total": "1193",
  "NoAns_exact": "39.56343792633015",
  "NoAns_f1": "39.56343792633015",
  "NoAns_total": "733",
  "best_exact": "38.05815160955348",
  "best_exact_thresh": "0.0",
  "best_f1": "38.09578704438517",
  "best_f1_thresh": "0.0007945864927023649"
}
