In [5]:
import pandas as pd
import numpy as np
from transformers import Trainer, AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling

from datasets import load_dataset

In [6]:
model = AutoModelForMaskedLM.from_pretrained('../../0_models/default-model')

Some weights of the model checkpoint at ../../0_models/default-model were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
tokenizer = AutoTokenizer.from_pretrained('../../0_models/default-model', use_fast=True)

In [8]:
datasets = load_dataset('text', data_files={'validation': '../../0_data/clean/unlabelled_reddit/politics_test/test_2017_03_5k.txt'})

Using custom data configuration default-fdca24a232e41a05
Reusing dataset text (/Users/Paul/.cache/huggingface/datasets/text/default-fdca24a232e41a05/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691)


In [9]:
def tokenize_function(examples):
    # Remove empty lines
    examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
    
    return tokenizer(
        examples["text"],
        padding=False,
        truncation=True,
        max_length=64,
        # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
        # receives the `special_tokens_mask`.
        return_special_tokens_mask=True,
    )

tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
)

Loading cached processed dataset at /Users/Paul/.cache/huggingface/datasets/text/default-fdca24a232e41a05/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691/cache-eb713f96e0f96ba1.arrow


In [10]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [11]:
trainer = Trainer(
model=model,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=None,
)

In [12]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

In [14]:
%%time

# initialise dictionary for writing prediction results to
out_dict = {"case_id": [],
            "masked_token_array_id": [], "masked_token_vocab_id": [], "masked_token_text": [],
            "top_pred_token_vocab_id": [], "top_pred_token_text": [],
            "ce_loss": [],
            "pred_logits": []}

# set number of shards for splitting dataset into
n_shards=1000

for shard_id in range(2):
    
    trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=None,
    )
    
    print(shard_id)
    
    # run prediction on shards of overall test set so as not to exceed RAM
    test_shard = tokenized_datasets["validation"].shard(n_shards, shard_id, contiguous=True)
    pred_results = trainer.predict(test_shard)
    
    print(" pred done")
    
    # each row corresponds to a masked token
    # first level of iteration is case-by-case
    case_id_range = range(shard_id*int((tokenized_datasets["validation"].shape[0]/n_shards)), (shard_id+1)*int((tokenized_datasets["validation"].shape[0]/n_shards)))
    
    for case_id, result, label_ids in zip(case_id_range, pred_results.predictions, pred_results.label_ids):
        
        # second level of iteration is over masked tokens in a given case    
        # not every case necessarily has masked tokens (indicated by label_id not equal to -100)
        for masked_token in (label_ids != -100).nonzero()[0]:
            
            # write case_id, text and tokenized text corresponding to a given masked token
            out_dict["case_id"].append(case_id)

            # for each masked token, write out its array id within the text, its vocab id and corresponding text
            out_dict["masked_token_array_id"].append(masked_token)
            out_dict["masked_token_vocab_id"].append(label_ids[masked_token])
            out_dict["masked_token_text"].append(tokenizer.convert_ids_to_tokens([label_ids[masked_token]])[0])

            # also write the vocab id and text of the top predicted token
            out_dict["top_pred_token_vocab_id"].append(result[masked_token].argmax())
            out_dict["top_pred_token_text"].append(tokenizer.convert_ids_to_tokens([result[masked_token].argmax()])[0])

            # calculate categorical cross entropy loss as the negative log of the softmax probability of the correct token
            ce_loss = -np.log(softmax(result[masked_token])[label_ids[masked_token]])
            out_dict["ce_loss"].append(ce_loss)

            # save full logits (1xvocab_size) for the masked token for flexibility in further analysis
            out_dict["pred_logits"].append(result[masked_token])
    
# write dataframe from dict    
out_df = pd.DataFrame.from_dict(out_dict)
out_df

# write dataframe to csv
#out_df.to_csv("test.csv", index=False)

0


 pred done
1


 pred done
CPU times: user 2.81 s, sys: 98.4 ms, total: 2.91 s
Wall time: 2.81 s


Unnamed: 0,case_id,masked_token_array_id,masked_token_vocab_id,masked_token_text,top_pred_token_vocab_id,top_pred_token_text,ce_loss,pred_logits
0,0,7,2074,just,2292,let,1.010051,"[-8.716527, -8.712369, -8.571446, -8.323521, -..."
1,0,22,1012,.,1012,.,0.3874576,"[-9.631644, -9.660642, -9.769699, -9.461936, -..."
2,0,27,2643,god,1996,the,6.121951,"[-9.028535, -8.98919, -8.918305, -8.97509, -9...."
3,0,29,1045,i,2111,people,4.910404,"[-5.2678924, -5.354229, -5.103074, -5.473169, ..."
4,2,7,9429,indies,2111,people,16.61464,"[-3.9831514, -4.2335987, -3.860576, -4.094365,..."
5,2,12,1012,.,1012,.,0.4667634,"[-7.753856, -7.911137, -7.7510967, -7.710444, ..."
6,2,16,1056,t,1056,t,0.0009049694,"[-9.806551, -9.986289, -9.603838, -9.902298, -..."
7,2,17,2113,know,2113,know,0.01748155,"[-5.463654, -5.384053, -5.192466, -5.550172, -..."
8,2,21,2005,for,2005,for,0.0493914,"[-8.078574, -8.342933, -7.901456, -7.697903, -..."
9,2,23,5443,vs,2030,or,6.191107,"[-7.1063924, -7.422105, -6.9975233, -6.8219247..."


In [600]:
out_df.head(20)

Unnamed: 0,case_id,text,tokenized_text,masked_token_array_id,masked_token_vocab_id,masked_token_text,top_pred_token_vocab_id,top_pred_token_text,ce_loss,pred_logits
0,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",14,2000,to,2521,far,6.496281,"[-7.53187, -7.481711, -7.5258093, -7.5237474, ..."
1,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",15,2562,keep,2185,away,6.685265,"[-7.332101, -7.394904, -7.4043255, -7.4488153,..."
2,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",17,20687,renew,20410,verify,2.409274,"[-4.0348964, -4.034002, -4.3296347, -4.0737686..."
3,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",26,1998,and,1998,and,3.9e-05,"[-14.3275175, -14.694819, -14.702918, -14.3263..."
4,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",39,1996,the,1996,the,0.123853,"[-9.652537, -9.60094, -9.601458, -9.620717, -9..."
5,1,He was a character.,"[[CLS], he, was, a, character, ., [SEP]]",1,2002,he,2002,he,1.577117,"[-5.992837, -6.0397696, -5.9919176, -5.994485,..."
6,1,He was a character.,"[[CLS], he, was, a, character, ., [SEP]]",4,2839,character,2839,character,0.032695,"[-9.127079, -9.143774, -9.267379, -9.157009, -..."
7,1,He was a character.,"[[CLS], he, was, a, character, ., [SEP]]",5,1012,.,1012,.,0.203213,"[-7.7447324, -7.534629, -7.872818, -7.649254, ..."
8,2,"Ugh ffs, most indies don't vote. I don't know ...","[[CLS], u, ##gh, ff, ##s, ,, most, indies, don...",6,2087,most,2225,west,6.531352,"[-8.067986, -7.9866085, -7.9007835, -7.8676376..."
9,2,"Ugh ffs, most indies don't vote. I don't know ...","[[CLS], u, ##gh, ff, ##s, ,, most, indies, don...",14,2123,don,2123,don,0.06124,"[-6.1624947, -5.961982, -6.1122904, -6.1582136..."
