Running Optuna to find best weights to do ensemble

In [1]:
import argparse
import datetime
import json
import numpy as np
import re
import string
import collections
import optuna

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
f = open('../../ensemble/roberta_val.json')
roberta_pred = json.load(f)

In [3]:
f2 = open('../../ensemble/xlnet_val.json')
xlnet_pred = json.load(f2)

In [4]:
def normalize_answer(s):
  """Lower text and remove punctuation, articles and extra whitespace."""
  def remove_articles(text):
    regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
    return re.sub(regex, ' ', text)
  def white_space_fix(text):
    return ' '.join(text.split())
  def remove_punc(text):
    exclude = set(string.punctuation)
    return ''.join(ch for ch in text if ch not in exclude)
  def lower(text):
    return text.lower()
  return white_space_fix(remove_articles(remove_punc(lower(s))))

In [5]:
def get_tokens(s):
  if not s: return []
  return normalize_answer(s).split()

In [6]:
def compute_f1(a_gold, a_pred):
  gold_toks = get_tokens(a_gold)
  pred_toks = get_tokens(a_pred)
  common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
  num_same = sum(common.values())
  if len(gold_toks) == 0 or len(pred_toks) == 0:
    # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
    return int(gold_toks == pred_toks)
  if num_same == 0:
    return 0
  precision = 1.0 * num_same / len(pred_toks)
  recall = 1.0 * num_same / len(gold_toks)
  f1 = (2 * precision * recall) / (precision + recall)
  return f1

In [7]:
def weighting_score(xlnet, roberta, w1, w2):
    wdict  = {}
    for qns_id in xlnet.keys():
        xlnet_ans_start = xlnet.get(qns_id).get('start')
        roberta_ans_start = roberta.get(qns_id).get('start')
        ans = roberta.get(qns_id).get('answers')
        ctxt = roberta.get(qns_id).get('context')

        weighted_ans_start = {}

        for key, val in roberta_ans_start.items(): # start w xlnet as it has less candidates
            #roberta_score = roberta_ans_start.get(key)
            xlnet_score = xlnet_ans_start.get(key)
            if xlnet_score == None:
                weighted_ans_start[key] = w2*float(val)
            else:
                 weighted_ans_start[key] = w1*float(xlnet_score) + w2*float(val)

        xlnet_ans_end = xlnet.get(qns_id).get('end')
        roberta_ans_end = roberta.get(qns_id).get('end')
        
        weighted_ans_end = {}

        for key, val in roberta_ans_end.items(): # start w xlnet as it has less candidates
            #roberta_score = roberta_ans_end.get(key)
            xlnet_score = xlnet_ans_end.get(key)
            if xlnet_score == None:
                weighted_ans_end[key] = w2*float(val)
            else:
                weighted_ans_end[key] = w1*float(xlnet_score) + w2*float(val)
        wdict[qns_id] = {"start": weighted_ans_start, "end": weighted_ans_end, "context": ctxt, "answer": ans}
    return wdict

In [8]:
def post_processing(score_dict, max_answer_length=100):
    pred = {}
    for qns in score_dict.keys():
        ans = score_dict.get(qns).get('answer')
        ctxt = score_dict.get(qns).get('context')
        #print(ctxt)
        valid_answer = {}
        start_indexes = score_dict.get(qns)["start"]
        end_indexes = score_dict.get(qns)["end"]
        for start, s_score in start_indexes.items():
            for end, e_score in end_indexes.items():
                start = int(start)
                end = int(end)
                if end < start or end - start + 1 > max_answer_length:
                    continue
                if start <= end:
                    pred_answer = ctxt[start:end]
                    pred_score = s_score + e_score
                    if valid_answer.get(pred_answer) == None or float(valid_answer.get(pred_answer)) < pred_score:
                        valid_answer.update(
                                {pred_answer : pred_score}
                            )

        valid_answer = dict(sorted(valid_answer.items(), key=lambda x: x[1], reverse=True))
        #print(valid_answer)
        if len(valid_answer) == 0:
            print(qns, ans, valid_answer)
            pred[qns] = (" ", ans)
            
        else:
            pred[qns] = (next(iter(valid_answer)), ans)
    return pred

In [9]:
def evaluate(final):
  score = []
  for val in final.values():
    pred_ans = val[0]
    gold_ans = val[1]
    score.append(compute_f1(gold_ans, pred_ans))
  #print(score)
  return sum(score)/len(score)

In [10]:
def main(w1=0.5, w2=0.5):

    final_score = weighting_score(xlnet_pred, roberta_pred, w1, w2)
    test_output = post_processing(final_score, 150)
    exact_score = evaluate(test_output)
    
    return exact_score

In [11]:
def objective(trial):
  STEP_SIZE = 0.01

  weights = []
  upper_limit = 1

  w_roberta = trial.suggest_float("w_roberta", 0, upper_limit, step=STEP_SIZE)
  weights.append(w_roberta)

  upper_limit -= sum(weights)
  upper_limit = upper_limit

  w_xlnet = 1-w_roberta
  weights.append(w_xlnet)

  weights_sum = sum(weights)
  if weights_sum != 1:
    raise Exception(f"Weights sum must be equal to 1. Instead {weights_sum} was encountered!")

  #w2 = weights[0]
  #w1 = weights[1]
  metric = main(w_xlnet, w_roberta)
  return metric

In [12]:
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)

[I 2023-11-09 11:46:42,600] A new study created in memory with name: no-name-d719ea92-a23f-4b58-ace6-66844f2b57ab
[I 2023-11-09 11:47:04,160] Trial 0 finished with value: 0.8216701894062297 and parameters: {'w_roberta': 0.43}. Best is trial 0 with value: 0.8216701894062297.
[I 2023-11-09 11:47:24,585] Trial 1 finished with value: 0.8208687108774672 and parameters: {'w_roberta': 0.4}. Best is trial 0 with value: 0.8216701894062297.
[I 2023-11-09 11:47:44,145] Trial 2 finished with value: 0.8230575976360485 and parameters: {'w_roberta': 0.46}. Best is trial 2 with value: 0.8230575976360485.
[I 2023-11-09 11:48:03,809] Trial 3 finished with value: 0.7990967206930037 and parameters: {'w_roberta': 0.08}. Best is trial 2 with value: 0.8230575976360485.
[I 2023-11-09 11:48:23,975] Trial 4 finished with value: 0.8252659204731777 and parameters: {'w_roberta': 0.6}. Best is trial 4 with value: 0.8252659204731777.
[I 2023-11-09 11:48:43,399] Trial 5 finished with value: 0.8199335538069148 and par

In [13]:
study.best_params

{'w_roberta': 0.59}

In [14]:
w_roberta = study.best_params["w_roberta"]

In [15]:
w_xlnet = 1-w_roberta

In [16]:
print(w_xlnet, w_roberta)

0.41000000000000003 0.59
