In [1]:
from pathlib import Path
from functools import partial
from itertools import product
from IPython.core.debugger import set_trace as bk
import pandas as pd
import torch
from torch import nn
import nlp
from transformers import ElectraModel, ElectraConfig, ElectraTokenizerFast, ElectraForPreTraining
from fastai2.text.all import *
import wandb
from fastai2.callback.wandb import WandbCallback
from _utils.would_like_to_pr import *
from _utils.huggingface import *
from _utils.wsc import *

In [2]:
SIZE = 'small'
assert SIZE in ['small', 'base', 'large']
hf_tokenizer = ElectraTokenizerFast.from_pretrained(f"google/electra-{SIZE}-discriminator")
electra_config = ElectraConfig.from_pretrained(f'google/electra-{SIZE}-discriminator')
CONFIG = {
  'lr': [3e-4, 1e-4, 5e-5],
  'layer_lr_decay': [0.8,0.8,0.9],
}
I = ['small', 'base', 'large'].index(SIZE)
config = {k:vals[I] for k,vals in CONFIG.items()}

config.update({
  'max_length': 512,
  'use_wsc': False,
  'use_fp16': True,
})

# 1. Prepare data

In [3]:
cache_dir = Path.home()/'datasets'
cache_dir.mkdir(parents=True, exist_ok=True)

## 1.1 Download and Preprocess

In [4]:
def textcols(task):
  "Infer text cols of different GLUE datasets in huggingface/nlp"
  if task in ['qnli']: return ['question', 'sentence']
  elif task in ['mrpc','stsb','wnli','rte']: return ['sentence1', 'sentence2']
  elif task in ['qqp']: return ['question1','question2']
  elif task in ['mnli','ax']: return ['premise','hypothesis']
  elif task in ['cola','sst2']: return ['sentence']

def tokenize_sents(example, cols):
  example['inp_ids'] = hf_tokenizer.encode(*[ example[c] for c in cols])
  return example

def tokenize_sents_max_len(example, cols, max_length):
  # Follow BERT and ELECTRA, we truncate examples longer than max length, see https://github.com/google-research/electra/blob/79111328070e491b287c307906701ebc61091eb2/finetune/classification/classification_tasks.py#L296
  tokens_a = hf_tokenizer.tokenize(example[cols[0]])
  tokens_b = hf_tokenizer.tokenize(example[cols[1]]) if len(cols)==2 else []
  _max_length = max_length - 1 - len(cols) # preserved for cls and sep tokens
  while True:
    total_length = len(tokens_a) + len(tokens_b)
    if total_length <= _max_length:
      break
    if len(tokens_a) > len(tokens_b):
      tokens_a.pop()
    else:
      tokens_b.pop()
  tokens = [hf_tokenizer.cls_token, *tokens_a, hf_tokenizer.sep_token]
  if tokens_b: tokens += [*tokens_b, hf_tokenizer.sep_token]
  example['inp_ids'] = hf_tokenizer.convert_tokens_to_ids(tokens)
  return example

# get tokenized datasets and dataloaders
glue_dsets = {}; glue_dls = {}
for task in ['cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', 'qnli', 'rte', 'wnli', 'ax']:
  # General case and special case for WSC
  if task == 'wnli' and config['use_wsc']:
    benchmark, subtask = 'super_glue', 'wsc.fixed'
    # samples in all splits are all less than 128-2, so don't need to worry about max_length
    Tfm = partial(WSCTransform, hf_toker=hf_tokenizer)
    cols = {'inp_ids':TensorText, 'span': noop, 'label':TensorCategory}
    n_inp=2
    cache_name = "tokenized_{split}.arrow"
  else:
    benchmark, subtask = 'glue', task
    tok_func = partial(tokenize_sents_max_len, cols=textcols(task), max_length=config['max_length'])
    Tfm = partial(HF_Transform, func=tok_func)
    cols = ['inp_ids', 'label']
    n_inp=1
    cache_name = f"tokenized_{config['max_length']}_{{split}}.arrow"
  # load / download datasets.
  dsets = nlp.load_dataset(benchmark, subtask, cache_dir=cache_dir)
  # There is two samples broken in QQP training set
  if task=='qqp': dsets['train'] = dsets['train'].filter(lambda e: e['question2']!='',
                                          cache_file_name=str(cache_dir/'glue/qqp/1.0.0/fixed_train.arrow'))
  # load / make tokenized datasets
  glue_dsets[task] = Tfm(dsets).map(cache_name=cache_name)
  # load / make dataloaders
  hf_dsets = HF_Datasets(glue_dsets[task], cols=cols, hf_toker=hf_tokenizer, n_inp=n_inp)
  dl_cache_name = cache_name.replace('tokenized', 'dl').replace('.arrow', '.json')
  glue_dls[task] = hf_dsets.dataloaders(bs=32, pad_idx=hf_tokenizer.pad_token_id, cache_name=dl_cache_name)

## 1.2 View Data
- View raw data on [nlp-viewer]! (https://huggingface.co/nlp/viewer/)

- View task description on Tensorflow dataset doc for GLUE (https://www.tensorflow.org/datasets/catalog/glue) 

- You may notice some text without \[SEP\], that is because the whole sentence is truncated by `show_batch`, you can turn it off by specify `truncated_at=None`

In [5]:
# CoLA (The Corpus of Linguistic Acceptability) - 0: unacceptable, 1: acceptable 
print("Dataset size (train/valid/test): {}/{}/{}".format(*[len(dl.dataset) for dl in glue_dls['cola'].loaders]))
glue_dls['cola'].show_batch(max_n=1)

Dataset size (train/valid/test): 8551/1043/1063


Unnamed: 0,inp_ids,label
0,"[CLS] everybody who has ever , worked in any office which contained any type ##writer which had ever been used to type any letters which had to be signed by any administrator who ever worked in any department like mine will know what i mean . [SEP]",1


In [6]:
# SST-2 (The Stanford Sentiment Treebank) - 1: positvie, 0: negative
print("Dataset size (train/valid/test): {}/{}/{}".format(*[len(dl.dataset) for dl in glue_dls['sst2'].loaders]))
glue_dls['sst2'].show_batch(max_n=1)

Dataset size (train/valid/test): 67349/872/1821


Unnamed: 0,inp_ids,label
0,"[CLS] even if the en ##tic ##ing prospect of a lot of nu ##bil ##e young actors in a film about campus de ##pr ##avi ##ty did n ' t fade amid the deliberate , tires ##ome u ##gli ##ness , it would be rendered ted ##ious by ava ##ry ' s failure to construct a story with even a trace of dramatic interest . [SEP]",0


In [7]:
# MRPC (Microsoft Research Paraphrase Corpus) -  1: match, 0: no
print("Dataset size (train/valid/test): {}/{}/{}".format(*[len(dl.dataset) for dl in glue_dls['mrpc'].loaders]))
glue_dls['mrpc'].show_batch(max_n=1)

Dataset size (train/valid/test): 3668/408/1725


Unnamed: 0,inp_ids,label
0,"[CLS] with cl ##ari ##tin ' s decline , sc ##hering - pl ##ough ' s best - selling products now are two drugs used together to treat hepatitis c , the anti ##vira ##l pill rib ##avi ##rin and an inter ##fer ##on medicine called peg - intro ##n . [SEP] with cl ##ari ##tin ' s decline , sc ##hering - pl ##ough ' s best - selling products are now anti ##vira ##l drug rib ##avi ##rin and an inter ##fer ##on medicine called peg - intro ##n - - two drugs used together to treat hepatitis c . [SEP]",1


In [8]:
# STS-B (Semantic Textual Similarity Benchmark) - 0.0 ~ 5.0
print("Dataset size (train/valid/test): {}/{}/{}".format(*[len(dl.dataset) for dl in glue_dls['stsb'].loaders]))
glue_dls['stsb'].show_batch(max_n=1)

Dataset size (train/valid/test): 5749/1500/1379


Unnamed: 0,inp_ids,label
0,"[CLS] iraq has been lobbying for the security council to stop using the country ' s oil revenue to pay compensation to victims of the 1991 gulf war and the salaries of the united nations monitoring , verification and inspection commission inspectors and to have all money remaining in the united nation ' s oil - for - food accounts transferred to the government ' s development fund . [SEP] iraq ' s new leaders have been lobbying for the united nations security council to stop using the iraq ' s oil revenue to pay the salaries of the inspectors and to have all money remaining in the united nation ' s oil - for - food account transferred to the iraqi government . [SEP]",4.0


In [9]:
# QQP (Quora Question Pairs) - 0: no, 1: duplicated
print("Dataset size (train/valid/test): {}/{}/{}".format(*[len(dl.dataset) for dl in glue_dls['qqp'].loaders]))
glue_dls['qqp'].show_batch(max_n=1)

Dataset size (train/valid/test): 363847/40430/390965


Unnamed: 0,inp_ids,label
0,"[CLS] i ' m in 12 steps program . been sober for 4 years and counting . i ' ve read about l ##sd and want to try it , if i used it once , would that be a re ##la ##pse ? [SEP] heartbreak ? heartbreak ? she ' s my girlfriend for two months , i chose her over my girlfriend for 2 years . i like her so much to the point that i can ' t let her go even if she wants to end our relationship because of the other people around us most especially her family . i do the things for her that i ' m not used to for a girl and i am willing to sacrifice everything just to have a little time with her . a little and limited time that i ' m asking from her but she",0


In [10]:
# MNLI (The Multi-Genre NLI Corpus) - 0: entailment, 1: neutral, 2: contradiction
print("Dataset size (train/validation_matched/validation_mismatched/test_matched/test_mismatched): {}/{}/{}/{}/{}".format(*[len(dl.dataset) for dl in glue_dls['mnli'].loaders]))
glue_dls['mnli'].show_batch(max_n=1)

Dataset size (train/validation_matched/validation_mismatched/test_matched/test_mismatched): 392702/9815/9832/9796/9847


Unnamed: 0,inp_ids,label
0,[CLS] well uh that ' s kind of obvious i mean they ' re even carrying it to to where now uh that they ad ##vert ##ise on tv you know if your if you uh you know have done this or if you need this uh uh we ' ll sue for you and you don ' t have to pay us unless you but then what they don ' t tell you is that if you if they win you give them at least a third of the of the thing that they win so i don ' t know it is uh it ' s getting to be more business now rather than uh actually uh dealing with the crime than with uh um the uh punishment they the the lawyers are just in it for the money i ' m i ' m convinced i know,0


In [11]:
# QNLI (The Stanford Question Answering Dataset) - 0: entailment, 1: not_entailment
print("Dataset size (train/valid/test): {}/{}/{}".format(*[len(dl.dataset) for dl in glue_dls['qnli'].loaders]))
glue_dls['qnli'].show_batch(max_n=1)

Dataset size (train/valid/test): 104743/5463/5463


Unnamed: 0,inp_ids,label
0,"[CLS] what time magazine founder attended yale ? [SEP] among the best - known are u . s . presidents william howard taft , gerald ford , george h . w . bush , bill clinton and george w . bush ; royals crown princess victoria bern ##ado ##tte , prince ro ##stis ##lav romano ##v and prince ak ##ii ##ki hose ##a ny ##ab ##ong ##o ; heads of state , including italian prime minister mario mont ##i , turkish prime minister tan ##su ci ##ller , mexican president ernesto ze ##di ##llo , german president karl cars ##tens , and philippines president jose pac ##iano laurel ; u . s . supreme court justices sonia soto ##may ##or , samuel ali ##to and clarence thomas ; u . s . secretaries of state john kerry , hillary clinton , cyrus vance , and dean ache ##son ; authors",0


In [12]:
# RTE (Recognizing_Textual_Entailment) - 0: entailment, 1: not_entailment
print("Dataset size (train/valid/test): {}/{}/{}".format(*[len(dl.dataset) for dl in glue_dls['rte'].loaders]))
glue_dls['rte'].show_batch(max_n=1)

Dataset size (train/valid/test): 2490/277/3000


Unnamed: 0,inp_ids,label
0,"[CLS] sweden has plans to open a virtual embassy in second life , the virtual world home to thousands of net ##ize ##ns . companies like dell are already selling computers via the virtual world , but sweden would become the first country to have a cyber - embassy in second life . the embassy would not provide visa or perform diplomatic tasks , but would provide information on how and where to get these documents in the real world , as well as giving cultural and tourist information about the country . visitors will also be able to chat with embassy personnel . a spokesperson explained that it would be an easy and cheap way to reach young people . the idea came from the swedish institute , an agency of the foreign affairs ministry of sweden with providing information about sweden as a key purpose . the """,1


In [13]:
# WSC (The Winograd Schema Challenge) - 0: wrong, 1: correct
# There are three style, WNLI (casted in NLI type), WSC, WSC with candidates (trick used by Roberta)
""" Note for WSC trick
- haven't figured out when use WNLI style how to make headers so show_batch right.
- lines are prefix, suffix, cands, cand_lens, label in order
- cands is the concatenation of candidates, cand_lens is the lengths of candidates in order.
"""
print("Dataset size (train/valid/test): {}/{}/{}".format(*[len(dl.dataset) for dl in glue_dls['wnli'].loaders]))
glue_dls['wnli'].show_batch(max_n=1)

Dataset size (train/valid/test): 635/71/146


Unnamed: 0,inp_ids,label
0,"[CLS] ta ##tya ##na knew that grandma always enjoyed serving an abundance of food to her guests . now ta ##tya ##na watched as grandma gathered ta ##tya ##na ' s small mother into a wide , sc ##ra ##wny embrace and then propelled her to the table , lifting her shaw ##l from her shoulders , seating her in the place of honor , and saying simply : "" there ' s plenty . "" [SEP] grandma gathered ta ##tya ##na ' s small mother into a wide , sc ##ra ##wny embrace and then propelled ta ##tya ##na ' s mother to the table . [SEP]",1


In [14]:
# AX (GLUE Diagnostic Dataset) - 0: entailment, 1: neutral, 2: contradiction
print("Dataset size (test): {}".format(*[len(dl.dataset) for dl in glue_dls['ax'].loaders]))
glue_dls['ax'].show_batch(max_n=1)

Dataset size (test): 1104


Unnamed: 0,inp_ids
0,"[CLS] we manually ann ##ota ##ted 68 ##7 template ##s mapping kb pre ##dicate ##s to text for different composition ##ality types ( with 46 ##2 unique kb pre ##dicate ##s ) , and use those template ##s to modify the original web ##quest ##ions ##sp question according to the meaning of the generated spa ##r ##q ##l query . [SEP] we manually ann ##ota ##ted over 650 template ##s mapping kb pre ##dicate ##s to text for different composition ##ality types ( with 46 ##2 unique kb pre ##dicate ##s ) , and use those template ##s to modify the original web ##quest ##ions ##sp question according to the meaning of the generated spa ##r ##q ##l query . [SEP]"


# 2. Finetuning model

* ELECTRA use CLS encodings as pooled result to predict the sentence. (see [here](https://github.com/google-research/electra/blob/79111328070e491b287c307906701ebc61091eb2/model/modeling.py#L254) of its official repository)

* Note that we should use different prediction head instance for different tasks.

In [15]:
class SentencePredictHead(nn.Module):
  "The way that Electra and Bert do for sentence prediction task"
  def __init__(self, hidden_size, targ_voc_size):
    super().__init__()
    self.linear = nn.Linear(hidden_size, targ_voc_size)
    self.dropout = nn.Dropout(0.1)
  def forward(self, x):
    "x: (batch size, sequence length, hidden_size)"
    # project the first token (a special token)'s hidden encoding
    return self.linear(self.dropout(x[:,0])).squeeze(-1) # if regression task, squeeze to (B), else (B,#class)

# 3. Single Task Finetuning

## 3.1 Discriminative learning rate

In [16]:
# Names come from, for nm in model.named_modules(): print(nm[0])

def hf_electra_param_splitter(model, num_hidden_layers, outlayer_name):
  names = ['.embeddings', *[f'encoder.layer.{i}' for i in range(num_hidden_layers)], outlayer_name]
  def end_with_any(name): return any( name.endswith(n) for n in names )
  groups = [ list(mod.parameters()) for name, mod in model.named_modules() if end_with_any(name) ]
  assert len(groups) == len(names)
  return groups

def get_layer_lrs(lr, decay_rate_of_depth, num_hidden_layers):
  # I think input layer as bottom and output layer as top, which make 'depth' mean different from the one of official repo 
  return [ lr * (decay_rate_of_depth ** depth) for depth in reversed(range(num_hidden_layers+2))]


## 3.2 Learning rate schedule

In [17]:
def linear_warmup_and_decay(pct_now, lr_max, end_lr, decay_power, warmup_pct, total_steps):
  """
  end_lr: the end learning rate for linear decay
  warmup_pct: percentage of training steps to for linear increase
  pct_now: percentage of traning steps we have gone through, notice pct_now=0.0 when calculating lr for first batch.
  """
  """
  pct updated after_batch, but global_step (in tf) seems to update before optimizer step,
  so pct is actually (global_step -1)/total_steps 
  """
  fixed_pct_now = pct_now + 1/total_steps
  """
  According to source code of the official repository, it seems they merged two lr schedule (warmup and linear decay)
  sequentially, instead of split training into two phases for each, this might because they think when in the early
  phase of training, pct is low, and thus the decaying formula makes little difference to lr.
  """
  decayed_lr = (lr_max-end_lr) * (1-fixed_pct_now)**decay_power + end_lr # https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/polynomial_decay
  warmed_lr = decayed_lr * min(1.0, fixed_pct_now / warmup_pct) # https://github.com/google-research/electra/blob/81f7e5fc98b0ad8bfd20b641aa8bc9e6ac00c8eb/model/optimization.py#L44
  return warmed_lr


## 3.3 finetune

In [18]:
METRICS = {
  **{ task:['MatthewsCorrCoef'] for task in ['cola']},
  **{ task:['Accuracy'] for task in ['sst2', 'mnli', 'qnli', 'rte', 'wnli', 'snli','ax']},
  # Note: MRPC and QQP are both binary classification problem, so we can just use fastai's default
  # average option 'binary' without spcification of average method.
  **{ task:['F1Score', 'Accuracy'] for task in ['mrpc', 'qqp']}, 
  **{ task:['PearsonCorrCoef', 'SpearmanCorrCoef'] for task in ['stsb']}
}
TARG_VOC_SIZE = {
    **{ task:1 for task in ['stsb']},
    **{ task:2 for task in ['cola', 'sst2', 'mrpc', 'qqp', 'qnli', 'rte', 'wnli']},
    **{ task:3 for task in ['mnli','ax']}
}

In [48]:
class MyMSELossFlat(BaseLoss):

  def __init__(self,*args, axis=-1, floatify=True, low=None, high=None, **kwargs):
    super().__init__(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)
    self.low, self.high = low, high

  def decodes(self, x):
    if self.low is not None: x = torch.max(x, x.new_full(x.shape, self.low))
    if self.high is not None: x = torch.min(x, x.new_full(x.shape, self.high))
    return x

In [49]:
def get_glue_learner(task, one_cycle=False, device='cuda:0', run_name=None, checkpoint=None):
  
  # num_epochs
  if task == 'rte': num_epochs = 10
  else: num_epochs = 3

  # dls
  dls = glue_dls[task].to(torch.device(device))
  # model
  if task=='wnli' and config['use_wsc']:
    model = ELECTRAWSCModel(HF_Model(ElectraForPreTraining, f"google/electra-{SIZE}-discriminator", hf_tokenizer))
  else:
    model = nn.Sequential(HF_Model(ElectraModel, f"google/electra-{SIZE}-discriminator", hf_tokenizer),
                          SentencePredictHead(electra_config.hidden_size, targ_voc_size=TARG_VOC_SIZE[task]))
  # loss func
  if task == 'stsb': loss_fc = MyMSELossFlat(low=0.0, high=5.0)
  elif task=='wnli' and config['use_wsc']: loss_fc = BCEWithLogitsLossFlat()
  else: loss_fc = CrossEntropyLossFlat()
  # metrics
  metrics = [eval(f'{metric}()') for metric in METRICS[task]]
  def sigmoid_acc(inps,targ):
    pred = torch.sigmoid(inps) > 0.5
    return (pred == targ).float().mean()
  if task=='wnli' and config['use_wsc']: metrics = [sigmoid_acc]
  # learning rate
  splitter = partial(hf_electra_param_splitter, 
                  num_hidden_layers=electra_config.num_hidden_layers,
                  outlayer_name= 'discriminator_predictions' if task=='wnli' and config['use_wsc'] else '1')
  layer_lrs = get_layer_lrs(lr=config['lr'],
                    decay_rate_of_depth=config['layer_lr_decay'],
                    num_hidden_layers=electra_config.num_hidden_layers,)
  lr_shedule = ParamScheduler({'lr': partial(linear_warmup_and_decay,
                                            lr_max=np.array(layer_lrs),
                                            end_lr=0.0,
                                            decay_power=1,
                                            warmup_pct=0.1,
                                            total_steps=num_epochs*(len(dls.train)))})
  
  
  # learner
  learn = Learner(dls, model,
                  loss_func=loss_fc, 
                  opt_func=partial(Adam, eps=1e-6,),
                  metrics=metrics,
                  splitter=splitter,
                  lr=layer_lrs,
                  path=str(Path.home()/'checkpoints'),
                  model_dir='electra_glue',)
  
  # load checkpoint
  if checkpoint: learn.load(checkpoint)

  # fp16
  if config['use_fp16']: learn = learn.to_fp16()

  # 
  if run_name:
    id = run_name.split('_')[1]
    wandb.init(project='electra-glue', name=run_name, config={'task': task, 'id':id, 'use_fp16':config['use_fp16'], 'optim':'Adam', 'use_onecycle':False}, reinit=True)
    learn.add_cb(WandbCallback(None, False))

  # one cycle / warm up + linear decay 
  if one_cycle: return learn, partial(learn.fit_one_cycle, n_epoch=num_epochs)
  else: return learn, partial(learn.fit, n_epoch=num_epochs, cbs=[lr_shedule])

In [None]:
rand_id = random.randint(1,500)
#rand_id = 79
pretrained_checkpoint = Path.home()/'checkpoints/electra_pretrain/7-06_10-31-49_100%.pth'
pretrained_checkpoint = None
for i in range(10):
  for task in ['cola', 'sst2', 'mrpc', 'stsb', 'qnli', 'rte', 'qqp', 'mnli', 'wnli']:
    if task not in ['wnli']: continue
    run_name = f"{task}_{rand_id}_{i}"
    # run_name = None # set to None to skip wandb and model saving
    learn, fit_fc = get_glue_learner(task, device='cuda:0', 
                                      run_name=run_name, checkpoint=pretrained_checkpoint) 
    fit_fc()
    if run_name:
      wandb.join()
      learn.save(run_name)

## 3.3 Predict the test set

In [50]:
def get_identifier(task, split):
  map = {'cola': 'CoLA', 'sst2':'SST-2', 'mrpc':'MRPC', 'qqp':'QQP', 'stsb':'STS-B', 'qnli':'QNLI', 'rte':'RTE', 'wnli':'WNLI', 'ax':'AX'}
  if task =='mnli' and split == 'test_matched': return 'MNLI-m'
  elif task == 'mnli' and split == 'test_mismatched': return 'MNLI-mm'
  else: return map[task]

class Ensemble(nn.Module):
  def __init__(self, models, device='cuda:0'):
    super().__init__()
    self.models = nn.ModuleList( m.cpu() for m in models )
    self.device = device
  
  def to(self, device): 
    self.device = device
    return self
  def getitem(self, i): return self.models[i]
  
  def forward(self, *args, **kwargs):
    outs = []
    for m in self.models:
      m.to(self.device)
      out = m(*args, **kwargs)
      assert isinstance(out, torch.Tensor)
      m.cpu()
      outs.append(out)
    outs = torch.stack(outs)
    return outs.mean(dim=0)

def load_model_(learn, files, device=None, **kwargs):
  "if multiple file passed, then load and create an ensemble. Load normally otherwise"
  if not isinstance(files, list): 
    learn.load(files, device=device, **kwargs)
    return
  if device is None: device = learn.dls.device
  model = learn.model.cpu()
  models = [model, *(deepcopy(model) for _ in range(len(files)-1)) ]
  for f,m in zip(files, models):
    file = join_path_file(f, learn.path/learn.model_dir, ext='.pth')
    load_model(file, m, learn.opt, device='cpu', **kwargs)
  learn.model = Ensemble(models, device)
  return learn

In [51]:
def predict_test(task, checkpoint, dl_idx=-1, output_dir=None, device='cuda:0'):
  if output_dir is None: output_dir = cache_dir/'glue/test'
  output_dir = Path(output_dir)
  output_dir.mkdir(exist_ok=True)
  device = torch.device(device)

  # load checkpoint and get predictions
  learn, _ = get_glue_learner(task, device=device)
  load_model_(learn, checkpoint)
  results = learn.get_preds(ds_idx=dl_idx, with_decoded=True)
  preds = results[-1] # preds -> (predictions logits, targets, decoded prediction)

  # decode target class index to its class name 
  if task in ['mnli','ax']:
    preds = [ ['entailment','neutral','contradiction'][p] for p in preds]
  elif task in ['qnli','rte']: 
    preds = [ ['entailment','not_entailment'][p] for p in preds ]
  elif task == 'wnli' and config['use_wsc']:
    preds = preds.to(dtype=torch.long).tolist()
  else: preds = preds.tolist()
    
  # form test dataframe and save
  test_df = pd.DataFrame( {'index':range(len(list(glue_dsets[task].values())[dl_idx])), 'prediction': preds} )
  split = list(glue_dsets['mnli'].keys())[dl_idx]
  identifier = get_identifier(task, split)
  test_df.to_csv( output_dir/f'{identifier}.tsv', sep='\t' )
  return test_df

In [52]:
dir = 'hf_small_test'
id = 79
th_run = {'cola': 9, 'sst2': 1, 'mrpc': 6, 'qqp': 2, 'stsb': 4, 'qnli': 1, 'rte': 4, 'mnli': 5, 'ax': 5,
        'wnli': 7 #[f"wnli_254_{i}" for i in [29,19,1,14,22,16,13,10,8,4]],
        }
for task, th in th_run.items():
  if task not in ['stsb']: continue
  print(task)
  ckp = f"{task}_{id}_{th}" if task != 'ax' else f"mnli_{id}_{th}"
  dl_idxs = [-1, -2] if task=='mnli' else [-1]
  for dl_idx in dl_idxs:
    df = predict_test(task, ckp, dl_idx, output_dir=cache_dir/f'glue/test/{dir}')

stsb
 80%|███████▉  | 1100/1379 [00:00<00:00, 5491.02it/s]