## Acronym Disambiguation using BERT

In [None]:
! pip install transformers -q
! pip install tokenizers -q

In [None]:
import re
import os
import sys
import json
import ast
import pandas as pd
from pathlib import Path
import matplotlib.cm as cm
import numpy as np
import pandas as pd
from typing import *
from tqdm.notebook import tqdm
from sklearn.utils.extmath import softmax
from sklearn import model_selection
from sklearn.metrics import classification_report, f1_score

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
import transformers
from transformers import AdamW
import tokenizers

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# ROOT = './drive/MyDrive/IRE Major Project/'
ROOT = '../'
DATAFOLDER = '../data/'
OUTPUTFOLDER = '../outputs/'
CONFIGFOLDER = '../models/model configs/'

In [None]:
def seed_all(seed = 42):
  """
  Fix seed for reproducibility
  """
  # python RNG
  import random
  random.seed(seed)

  # pytorch RNGs
  import torch
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

  # numpy RNG
  import numpy as np
  np.random.seed(seed)

## Data augmentation

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import string
import json

In [None]:
! pip install python-Levenshtein

In [None]:
# LOAD DATA
with open("diction.json") as f:
  input_dict = json.load(f)

## Levenshtein Distance

In [None]:
from Levenshtein import distance as levenshtein_distance

output_dict = {}
for key,val in input_dict.items():
  dataset = set(val)
  new_dat = list(dataset)
  # Lowercase
  new_dat = [a.lower() for a in new_dat]
  new_dat.sort()
  # Remove punctuations and spaces
  clean_acronyms = [s.translate(str.maketrans(' ', ' ', string.punctuation)) for s in new_dat]
  new_dat = [s.translate(str.maketrans('', '', string.punctuation)).replace(" ","") for s in new_dat]

  # Generate levenshtein dist matrix for each pair
  n = len(new_dat)
  mat = np.zeros(shape=(n,n))
  for i in range(n):
    for j in range(n):
      mat[i][j] = (levenshtein_distance(new_dat[i], new_dat[j]))

  # Heatmap of matrix
  plt.imshow(mat, cmap='hot', interpolation='nearest')
  plt.colorbar()
  plt.show()

  THRESHOLD = 0.4
  acdict = {}
  visited = np.zeros(shape=(n,1))

  for id,val in enumerate(new_dat):
    if not visited[id]:
      acdict[clean_acronyms[id]] = {}
      for idx, a in enumerate(mat[id]):
        # Cluster the expansion according to threshold value
        if not visited[idx] and float(a/len(val)) <= THRESHOLD:
          visited[idx] = 1
          if val in acdict[clean_acronyms[id]]:
            acdict[clean_acronyms[id]][val].append(clean_acronyms[idx])
            continue
          acdict[clean_acronyms[id]][val] = [clean_acronyms[idx]]
      visited[id] = 1
  output_dict[key] = acdict

In [None]:
import pprint
pp = pprint.PrettyPrinter(indent = 3)
pp.pprint(output_dict)

In [None]:
import json
with open(DATAFOLDER + "scientific/combined_acronym_dict_sci.json","w") as f:
  json.dump(output_dict,f,indent=4)

### Update dataset with new expansions



In [None]:
with open(DATAFOLDER + "scientific/fix_combined_acronym_dict_sci.json") as f:
  data = json.load(f)

In [None]:
STORE_FILE = DATAFOLDER + "scientific/train.csv"
# train_df = pd.read_csv("train.csv")
df = pd.read_csv(STORE_FILE)

In [None]:
def func(ser, isseries = False):
  if isseries:
    acr = ser[0]
    val = ser[1]
    val = val.lower()
    val = val.translate(str.maketrans(' ', ' ', string.punctuation))
    if acr not in data:
      return val
    for key,v in data[acr].items():
      for k, va in v.items():
        for vals in va:
          if vals == val:
            return key
    return val
  else:
    acr = ser[0]
    val = ser[1]
    val = val.lower()
    val = val.translate(str.maketrans(' ', ' ', string.punctuation))
    if acr not in data:
      return val
    for key,v in data[acr].items():
      for k, va in v.items():
        for vals in va:
          if vals == val:
            return key
    return val

In [None]:
# Update expansions in data
df['expansion'] = df.apply(lambda x:func(x, True),axis=1)

In [None]:
df.to_csv("update_"+STORE_FILE)

In [None]:
with open(DATAFOLDER + "scientific/diction.json") as f:
  diction = json.load(f)

In [None]:
# Update diction
for key, val in diction.items():
  for idx, acr in enumerate(val):
    val[idx] = func([key,acr], True)

with open(DATAFOLDER + "scientific/diction_update.json","w") as f:
  json.dump(diction,f)

### Model Configurations


In [None]:
class config:
  SEED = 42
  KFOLD = 5
  TRAIN_FILE = DATAFOLDER + 'scientific/update_train.csv'
  VAL_FILE = DATAFOLDER + 'scientific/update_dev.csv'
  TEST_FILE = DATAFOLDER + 'scientific/test.csv'
  SAVE_DIR = OUTPUTFOLDER
  MAX_LEN = 192
  MODEL = 'allenai/scibert_scivocab_uncased'
  CONFIG = CONFIGFOLDER + 'finetune_scibert_config.json'
  TOKENIZER = tokenizers.BertWordPieceTokenizer(CONFIGFOLDER+"finetune_scibert_vocab.txt", lowercase=True)
  EPOCHS = 10
  TRAIN_BATCH_SIZE = 32
  VALID_BATCH_SIZE = 32
  TEST_BATCH_SIZE = 32
  DICTIONARY = json.load(open(DATAFOLDER + 'scientific/diction_update.json'))
  
  A2ID = {}
  for idx,(k, v) in enumerate(DICTIONARY.items()):
    for w in v:
      A2ID[w] = len(A2ID)


In [None]:
class AverageMeter:
    
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class EarlyStopping:
    
    def __init__(self, patience=7, mode="max", delta=0.001):
        self.patience = patience
        self.counter = 0
        self.mode = mode
        self.best_score = None
        self.output = None
        self.early_stop = False
        self.delta = delta
        if self.mode == "min":
            self.val_score = np.Inf
        else:
            self.val_score = -np.Inf

    def __call__(self, epoch_score, res, model, model_path):
        if self.mode == "min":
            score = -1.0 * epoch_score
        else:
            score = np.copy(epoch_score)
        if self.best_score is None:
            self.best_score = score
            self.output = res
            self.save_checkpoint(epoch_score, model, model_path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print('EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.output = res
            self.save_checkpoint(epoch_score, model, model_path)
            self.counter = 0

    def save_checkpoint(self, epoch_score, model, model_path):
        if epoch_score not in [-np.inf, np.inf, -np.nan, np.nan]:
            print('Validation score improved ({} --> {}). Saving model!'.format(self.val_score, epoch_score))
            torch.save(model.state_dict(), model_path)
        self.val_score = epoch_score

In [None]:
def sample_text(text, acronym, max_len):
  text = text.split()
  try:
    idx = text.index(acronym)
  except:
    idx = [i for i, s in enumerate(text) if acronym in s]
    idx = idx[0]
  left_idx = max(0, idx - max_len//2)
  right_idx = min(len(text), idx + max_len//2)
  sampled_text = text[left_idx:right_idx]
  return ' '.join(sampled_text)

In [None]:
def process_data(text, acronym, expansion, tokenizer, max_len):

  text = str(text)
  expansion = str(expansion)
  acronym = str(acronym)

  n_tokens = len(text.split())
  if n_tokens>120:
    text = sample_text(text, acronym, 120)

  answers = acronym + ' ' + ' '.join(config.DICTIONARY[acronym])
  start = answers.find(expansion)
  end = start + len(expansion)

  char_mask = [0]*len(answers)
  for i in range(start, end):
    char_mask[i] = 1
  
  tok_answer = tokenizer.encode(answers)
  answer_ids = tok_answer.ids
  answer_offsets = tok_answer.offsets

  answer_ids = answer_ids[1:-1]
  answer_offsets = answer_offsets[1:-1]

  target_idx = []
  for i, (off1, off2) in enumerate(answer_offsets):
      if sum(char_mask[off1:off2])>0:
        target_idx.append(i)

  start = target_idx[0]
  end = target_idx[-1]

  
  text_ids = tokenizer.encode(text).ids[1:-1]

  token_ids = [101] + answer_ids + [102] + text_ids + [102]
  offsets =   [(0,0)] + answer_offsets + [(0,0)]*(len(text_ids) + 2)
  mask = [1] * len(token_ids)
  token_type = [0]*(len(answer_ids) + 1) + [1]*(2+len(text_ids))

  text = answers + text
  start = start + 1
  end = end + 1

  padding = max_len - len(token_ids)
    

  if padding>=0:
    token_ids = token_ids + ([0] * padding)
    token_type = token_type + [1] * padding
    mask = mask + ([0] * padding)
    offsets = offsets + ([(0, 0)] * padding)
  else:
    token_ids = token_ids[0:max_len]
    token_type = token_type[0:max_len]
    mask = mask[0:max_len]
    offsets = offsets[0:max_len]
  

  assert len(token_ids)==max_len
  assert len(mask)==max_len
  assert len(offsets)==max_len
  assert len(token_type)==max_len

  return {
          'ids': token_ids,
          'mask': mask,
          'token_type': token_type,
          'offset': offsets,
          'start': start,
          'end': end,  
          'text': text,
          'expansion': expansion,
          'acronym': acronym,
        }

### Dataset Loader


In [None]:
class Dataset:
    def __init__(self, text, acronym, expansion):
        self.text = text
        self.acronym = acronym
        self.expansion = expansion
        self.tokenizer = config.TOKENIZER
        self.max_len = config.MAX_LEN
    
    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        data = process_data(
            self.text[item],
            self.acronym[item],
            self.expansion[item], 
            self.tokenizer,
            self.max_len,
            
        )

        return {
            'ids': torch.tensor(data['ids'], dtype=torch.long),
            'mask': torch.tensor(data['mask'], dtype=torch.long),
            'token_type': torch.tensor(data['token_type'], dtype=torch.long),
            'offset': torch.tensor(data['offset'], dtype=torch.long),
            'start': torch.tensor(data['start'], dtype=torch.long),
            'end': torch.tensor(data['end'], dtype=torch.long),
            'text': data['text'],
            'expansion': data['expansion'],
            'acronym': data['acronym'],
        }

In [None]:
def get_loss(start, start_logits, end, end_logits):
  loss_fn = nn.CrossEntropyLoss()
  start_loss = loss_fn(start_logits, start)
  end_loss = loss_fn(end_logits, end)
  loss = start_loss + end_loss
  return loss

### BERT class 


In [None]:
class BertAD(nn.Module):
  def __init__(self):
    super(BertAD, self).__init__()
    self.model_config = transformers.BertConfig.from_pretrained(config.MODEL)
    self.bert = transformers.BertModel.from_pretrained(config.MODEL, config=config.CONFIG)
    self.layer = nn.Linear(768, 2)
    

  def forward(self, ids, mask, token_type, start=None, end=None):
    output = self.bert(input_ids = ids,
                       attention_mask = mask,
                       token_type_ids = token_type)
    
    logits = self.layer(output[0]) 
    start_logits, end_logits = logits.split(1, dim=-1)
    
    start_logits = start_logits.squeeze(-1)
    end_logits = end_logits.squeeze(-1)

    loss = get_loss(start, start_logits, end, end_logits)    

    return loss, start_logits, end_logits

In [None]:
def train_fn(data_loader, model, optimizer, device):
  model.train()
  losses = AverageMeter()
  tk0 = tqdm(data_loader, total=len(data_loader))
  
  for bi, d in enumerate(tk0):
    ids = d['ids']
    mask = d['mask']
    token_type = d['token_type']
    start = d['start']
    end = d['end']
    

    ids = ids.to(device, dtype=torch.long)
    token_type = token_type.to(device, dtype=torch.long)
    mask = mask.to(device, dtype=torch.long)
    start = start.to(device, dtype=torch.long)
    end = end.to(device, dtype=torch.long)
    

    model.zero_grad()
    loss, start_logits, end_logits = model(ids, mask, token_type, start, end)
    
    loss.backward()
    optimizer.step()
    
    losses.update(loss.item(), ids.size(0))
    tk0.set_postfix(loss=losses.avg)


In [None]:
def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    #     if  len(a) + len(b) == len(c):
    #          print(f'{str1},{str2}')
    #          return 1.0
    return float(len(c)) / (len(a) + len(b) - len(c))

In [None]:
def evaluate_jaccard(text, selected_text, acronym, offsets, idx_start, idx_end):
  filtered_output = ""
  for ix in range(idx_start, idx_end + 1):
      filtered_output += text[offsets[ix][0]: offsets[ix][1]]
      if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]:
          filtered_output += " "

  candidates = config.DICTIONARY[acronym]
  candidate_jaccards = [jaccard(w.strip(), filtered_output.strip()) for w in candidates]
  idx = np.argmax(candidate_jaccards)

  return candidate_jaccards[idx], candidates[idx]

In [None]:
def eval_fn(data_loader, model, device):
  model.eval()
  losses = AverageMeter()
  jac = AverageMeter()

  tk0 = tqdm(data_loader, total=len(data_loader))

  pred_expansion_ = []
  true_expansion_ = []
  acronym_ = []
  text_ = []

  for bi, d in enumerate(tk0):
    ids = d['ids']
    mask = d['mask']
    token_type = d['token_type']
    start = d['start']
    end = d['end']
    
    text = d['text']
    expansion = d['expansion']
    offset = d['offset']
    acronym = d['acronym']


    ids = ids.to(device, dtype=torch.long)
    mask = mask.to(device, dtype=torch.long)
    token_type = token_type.to(device, dtype=torch.long)
    start = start.to(device, dtype=torch.long)
    end = end.to(device, dtype=torch.long)
    
    with torch.no_grad():
      loss, start_logits, end_logits = model(ids, mask, token_type, start, end)


    start_prob = torch.softmax(start_logits, dim=1).detach().cpu().numpy()
    end_prob = torch.softmax(end_logits, dim=1).detach().cpu().numpy()
  
  
    jac_= []
    for px, s in enumerate(text):
      start_idx = np.argmax(start_prob[px,:])
      end_idx = np.argmax(end_prob[px,:])

      js, exp = evaluate_jaccard(s, expansion[px], acronym[px], offset[px], start_idx, end_idx)
      jac_.append(js)
#       print(f'acronym:{acronym[px]},s:{s},exp:{exp},expansion[px]:{expansion[px]}')
      pred_expansion_.append(exp)
      true_expansion_.append(expansion[px])
      text_.append(s)
      acronym_.append(acronym[px])
        

    
    jac.update(np.mean(jac_), len(jac_))
    losses.update(loss.item(), ids.size(0))

    tk0.set_postfix(loss=losses.avg, jaccard=jac.avg)


  pred_expansion_1 = [config.A2ID[w] for w in pred_expansion_]
  true_expansion_1 = [config.A2ID[w] for w in true_expansion_]
  
  f1 = f1_score(true_expansion_1, pred_expansion_1, average='macro')

  print('Average Jaccard : ', jac.avg)
  print('Macro F1 : ', f1)
  dit = {'acronym':acronym_,'actual':true_expansion_,'prediction':pred_expansion_, 'text':text_}
  res = pd.DataFrame(dit)
  return f1, res

In [None]:
def run(df_train, df_val, fold):
  train_dataset = Dataset(
        text = df_train.text.values,
        acronym = df_train.acronym_.values,
        expansion = df_train.expansion.values
    )
  
  valid_dataset = Dataset(
        text = df_val.text.values,
        acronym = df_val.acronym_.values,
        expansion = df_val.expansion.values,
    )
    
  train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config.TRAIN_BATCH_SIZE,
        num_workers=2
    )

  valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=config.VALID_BATCH_SIZE,
        num_workers=2
    )
  test_dataset = Dataset(
        text = df_test.text.values,
        acronym = df_test.acronym_.values,
        expansion = None
    )
    
  test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=config.TEST_BATCH_SIZE,
        num_workers=2
    )
  

  model = BertAD()
  device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
  model.to(device)

  lr = 2e-5
  param_optimizer = list(model.named_parameters())
  no_decay = ['bias', 'gamma', 'beta']
  optimizer_grouped_parameters = [
      {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
      'weight_decay_rate': 0.01},
      {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
      'weight_decay_rate': 0.0}
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=lr)

  es = EarlyStopping(patience=2, mode="max")

  print('Starting training....')
  for epoch in range(config.EPOCHS):
    train_fn(train_data_loader, model, optimizer, device)
    valid_loss, res = eval_fn(valid_data_loader, model, device)
    print(f'Fold {fold} | Epoch :{epoch + 1} | Validation Score :{valid_loss}')
    if fold is None:
      es(valid_loss, res, model, model_path=os.path.join(config.SAVE_DIR, "model.bin"))
    else:
      es(valid_loss, res, model, model_path=os.path.join(config.SAVE_DIR, f"model_{fold}.bin"))
    if es.early_stop:
      break

  return es.best_score, es.output

In [None]:
def run_k_fold(fold_id):
  '''
    Perform k-fold cross-validation
  '''
  seed_all()

  df_train = pd.read_csv(config.TRAIN_FILE)
  df_val = pd.read_csv(config.VAL_FILE)
  df_test = pd.read_csv(config.TEST_FILE)
  df_train.fillna('NA',inplace=True)
  # concatenating train and validation set
  train = pd.concat([df_train, df_val]).reset_index()
  
  # dividing folds
  kf = model_selection.StratifiedKFold(n_splits=config.KFOLD, shuffle=True, random_state=config.SEED)
  for fold, (train_idx, val_idx) in enumerate(kf.split(X=train, y=train.acronym_.values)):
      train.loc[val_idx, 'kfold'] = fold

  print(f'################################################ Fold {fold_id} #################################################')
  df_train = train[train.kfold!=fold_id]
  df_val = train[train.kfold==fold_id]

  return run(df_train, df_val, fold_id)
    

### Train and run the model

In [None]:
f0, res0 = run_k_fold(0)

In [None]:
f1, res1 = run_k_fold(1)

In [None]:
f2, res2 = run_k_fold(2)

In [None]:
f3, res3 = run_k_fold(3)

In [None]:
f4, res4 = run_k_fold(4)

In [None]:
f = [f0, f1, f2, f3, f4]
for i, fs in enumerate(f):
    print(f'Fold {i} : {fs}')
print(f'Avg. {np.mean(f)}')

In [None]:
print(res0)
res0.to_csv(OUTPUTFOLDER+'scientific/output_sci.csv')

## Test data prediction


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model0 = BertAD()
vec = model0.state_dict()['bert.embeddings.position_ids']
chkp = torch.load(os.path.join(config.SAVE_DIR, '../input/modeltest/model_0.bin'), map_location=device)
chkp['bert.embeddings.position_ids'] = vec
model0.load_state_dict(chkp)
model0.to(device)
model0.eval()

model1 = BertAD()
chkp = torch.load(os.path.join(config.SAVE_DIR, 'model_1.bin'), map_location=device)
chkp['bert.embeddings.position_ids'] = vec
model1.load_state_dict(chkp)
model1.to(device)
model1.eval()


model2 = BertAD()
chkp = torch.load(os.path.join(config.SAVE_DIR, 'model_2.bin'), map_location=device)
chkp['bert.embeddings.position_ids'] = vec
model2.load_state_dict(chkp)
model2.to(device)
model2.eval()

model3 = BertAD()
chkp = torch.load(os.path.join(config.SAVE_DIR, 'model_3.bin'), map_location=device)
chkp['bert.embeddings.position_ids'] = vec
model3.load_state_dict(chkp)
model3.to(device)
model3.eval()

model4 = BertAD()
chkp = torch.load(os.path.join(config.SAVE_DIR, 'model_4.bin'), map_location=device)
chkp['bert.embeddings.position_ids'] = vec
model4.load_state_dict(chkp)
model4.to(device)
model4.eval()
print('Models loaded')

In [None]:
test = pd.read_csv(config.TEST_FILE)
test['expansion'] = test['acronym_']

test_dataset = Dataset(
        text = test.text.values,
        acronym = test.acronym_.values,
        expansion = test.expansion.values,
    )
    
  
test_data_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=config.VALID_BATCH_SIZE,
      num_workers=2
)

In [None]:
jac = AverageMeter()

tk0 = tqdm(test_data_loader, total=len(test_data_loader))

pred_expansion_ = []
true_expansion_ = []

for bi, d in enumerate(tk0):
  ids = d['ids']
  mask = d['mask']
  token_type = d['token_type']
  start = d['start']
  end = d['end']
  
  text = d['text']
  expansion = d['expansion']
  offset = d['offset']
  acronym = d['acronym']


  ids = ids.to(device, dtype=torch.long)
  mask = mask.to(device, dtype=torch.long)
  token_type = token_type.to(device, dtype=torch.long)
  start = start.to(device, dtype=torch.long)
  end = end.to(device, dtype=torch.long)
  
  with torch.no_grad():
    _, start_logits_0, end_logits_0 = model0(ids, mask, token_type, start, end)
    _, start_logits_1, end_logits_1 = model1(ids, mask, token_type, start, end)
    _, start_logits_2, end_logits_2 = model2(ids, mask, token_type, start, end)
    _, start_logits_3, end_logits_3 = model3(ids, mask, token_type, start, end)
    _, start_logits_4, end_logits_4 = model4(ids, mask, token_type, start, end)

    
  start_logits_0 = torch.softmax(start_logits_0, dim=1).detach().cpu().numpy()
  start_logits_1 = torch.softmax(start_logits_1, dim=1).detach().cpu().numpy()
  start_logits_2 = torch.softmax(start_logits_2, dim=1).detach().cpu().numpy()
  start_logits_3 = torch.softmax(start_logits_3, dim=1).detach().cpu().numpy()
  start_logits_4 = torch.softmax(start_logits_4, dim=1).detach().cpu().numpy()
  
    
  end_logits_0 = torch.softmax(end_logits_0, dim=1).detach().cpu().numpy()
  end_logits_1 = torch.softmax(end_logits_1, dim=1).detach().cpu().numpy()
  end_logits_2 = torch.softmax(end_logits_2, dim=1).detach().cpu().numpy()
  end_logits_3 = torch.softmax(end_logits_3, dim=1).detach().cpu().numpy()
  end_logits_4 = torch.softmax(end_logits_4, dim=1).detach().cpu().numpy()
  

  start_prob = (start_logits_0 + start_logits_1 + start_logits_2 + start_logits_3 + start_logits_4)/5.0
  end_prob = (end_logits_0 + end_logits_1 + end_logits_2 + end_logits_3 + end_logits_4)/5.0
    
  # Use this for single model
#   start_logits_0 = torch.softmax(start_logits_0, dim=1).detach().cpu().numpy()
#   end_logits_0 = torch.softmax(end_logits_0, dim=1).detach().cpu().numpy()
#   start_prob = (start_logits_0)/1.0
#   end_prob = (end_logits_0)/1.0


  jac_= []
  
  for px, s in enumerate(text):
    start_idx = np.argmax(start_prob[px,:])
    end_idx = np.argmax(end_prob[px,:])

    js, exp = evaluate_jaccard(s, expansion[px], acronym[px], offset[px], start_idx, end_idx)
    jac_.append(js)
    pred_expansion_.append(exp)

  
  jac.update(np.mean(jac_), len(jac_))
  
  tk0.set_postfix(jaccard=jac.avg)

In [None]:
test['pred_expansion'] = pred_expansion_

In [None]:
predictions = []
for i, r in test.iterrows():
  d = {'id': r['id'], 'prediction': r['pred_expansion']}
  predictions.append(d)

with open(os.path.join('.', 'pred.json'), 'w') as f:
  json.dump(predictions, f)

In [None]:
test.to_csv('test_preds.csv', index=False)

In [None]:
predicti = []
for i, r in test.iterrows():
  d = {'sentence': r['text'], 'acronym': r['acronym_'], 'label': r['pred_expansion'], 'ID':str(r['id'])}
  predicti.append(d)
    
with open(os.path.join('.', 'output.json'), 'w') as f:
  json.dump(predicti, f)