In [1]:
from transformers import (GPT2LMHeadModel, GPT2TokenizerFast,
                          BertForMaskedLM, BertTokenizerFast,
                          DistilBertForMaskedLM, DistilBertTokenizerFast,
                          RobertaForMaskedLM, RobertaTokenizerFast,
                          BlenderbotForCausalLM, BlenderbotTokenizer,
                          BigBirdForMaskedLM, BigBirdTokenizer,
                          ElectraForMaskedLM, ElectraTokenizerFast)
from multiprocessing import Pool
import pandas as pd
import glob
from itertools import product

In [2]:
from lmeval.datasets import NarrativesDataset
from lmeval.engine import StridingLM

### Initialize list of models and parameters

In [3]:
transcripts = glob.glob('inputs/narratives/gentle/*/transcript*')
aligned = glob.glob('inputs/narratives/gentle/*/align.csv')
dataset_files = transcripts + aligned

In [4]:
model_classes = [GPT2LMHeadModel,
                 BertForMaskedLM,
                 DistilBertForMaskedLM, 
                 RobertaForMaskedLM, 
                 BlenderbotForCausalLM, 
                 BigBirdForMaskedLM,
                 ElectraForMaskedLM]
model_ids = ['gpt2', 'bert-base-uncased', 'distilbert-base-uncased',
             'roberta-base', 'facebook/blenderbot-400M-distill', 
             'google/bigbird-roberta-base', 'google/electra-base-discriminator']
tokenizer_classes = [GPT2TokenizerFast,
                     BertTokenizerFast,
                     DistilBertTokenizerFast,
                     RobertaTokenizerFast,
                     BlenderbotTokenizer,
                     BigBirdTokenizer,
                     ElectraTokenizerFast]
model_parameters = list(zip(model_classes, model_ids, tokenizer_classes))

In [5]:
ctx_lengths = [5, 10, 15, 20]

Create all combinations of files, model_parameters, and context lengths

In [6]:
parameters = list(product(dataset_files, 
                          model_parameters, 
                          ctx_lengths))
parameters = [(i[0], *i[1], i[2]) for i in parameters]

### Define validation function + utils

In [7]:
def _make_dataset_id(datafile):
    ds_name_splits = datafile.split('/')
    narrative = ds_name_splits[3]
    ds_type = ds_name_splits[-1].split('.')[0]
    ds_id = '_'.join([narrative, ds_type])
    return ds_id

In [9]:
def _validate(datafile, model_class, model_id, tokenizer_class, ctx_length):
    tokenizer = tokenizer_class.from_pretrained(model_id)
    model = model_class.from_pretrained(model_id)
    dataset_name = _make_dataset_id(datafile)
    data = NarrativesDataset(datafile, dataset_name)
    engine = StridingLM(context_length=ctx_length)
    result = engine.run(data, tokenizer, model, model_id)
    return result

### Run in parallel

In [10]:
for p in parameters:
    _validate(*p)

Token indices sequence length is longer than the specified maximum sequence length for this model (3877 > 1024). Running this sequence through the model will result in indexing errors
  0%|          | 0/3872 [00:00<?, ?it/s]

Running gpt2, sherlock_transcript, 5, 3872


 10%|▉         | 385/3872 [00:37<05:39, 10.27it/s]


KeyboardInterrupt: 

In [None]:
pool = Pool(2)

In [61]:
results = pool.starmap(_validate, parameters)
pool.close()

### Save outputs

In [None]:
for idx, df in enumerate(results):
    if idx == 0:
        r_all = df
    else:
        r_all = pd.concat([r_all, df], ignore_index=True)
r_all.to_csv('outs/validation_0909.txt', sep='\t')