In [13]:
import pandas as pd
import numpy as np
from transformers import Trainer, AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling
    
from datasets import load_dataset

In [577]:
model = AutoModelForMaskedLM.from_pretrained('../../0_models/default-model')

Some weights of the model checkpoint at ../../0_models/default-model were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [578]:
tokenizer = AutoTokenizer.from_pretrained('../../0_models/default-model', use_fast=True)

In [579]:
datasets = load_dataset('text', data_files={'validation': '../../0_data/clean/unlabelled_reddit/politics_test/test_2017_03_5k.txt'})

Using custom data configuration default-fdca24a232e41a05


Downloading and preparing dataset text/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /Users/Paul/.cache/huggingface/datasets/text/default-fdca24a232e41a05/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /Users/Paul/.cache/huggingface/datasets/text/default-fdca24a232e41a05/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691. Subsequent calls will reuse this data.


In [580]:
def tokenize_function(examples):
    # Remove empty lines
    examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
    
    return tokenizer(
        examples["text"],
        padding=False,
        truncation=True,
        max_length=64,
        # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
        # receives the `special_tokens_mask`.
        return_special_tokens_mask=True,
    )

tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    #remove_columns=["text"]
)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))




In [581]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [582]:
trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        data_collator=data_collator
    )

In [583]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

In [599]:
# initialise dictionary for writing prediction results to
out_dict = {"case_id": [], "text": [], "tokenized_text": [],
            "masked_token_array_id": [], "masked_token_vocab_id": [], "masked_token_text": [],
            "top_pred_token_vocab_id": [], "top_pred_token_text": [],
            "ce_loss": [],
            "pred_logits": []}

# set number of shards for splitting dataset into
n_shards=100

for shard_id in range(2):
    
    print(shard_id)
    
    # run prediction on shards of overall test set so as not to exceed RAM
    test_shard = tokenized_datasets["validation"].shard(n_shards, shard_id, contiguous=True)
    pred_results = trainer.predict(test_shard)

    # each row corresponds to a masked token
    # first level of iteration is case-by-case
    case_id_range = range(shard_id*int((tokenized_datasets["validation"].shape[0]/n_shards)), (shard_id+1)*int((tokenized_datasets["validation"].shape[0]/n_shards)))
    
    for case_id, result, label_ids in zip(case_id_range, pred_results.predictions, pred_results.label_ids):

        # second level of iteration is over masked tokens in a given case    
        # not every case necessarily has masked tokens (indicated by label_id not equal to -100)
        for masked_token in (label_ids != -100).nonzero()[0]:

            # write case_id, text and tokenized text corresponding to a given masked token
            out_dict["case_id"].append(case_id)
            out_dict["text"].append(tokenized_datasets["validation"]["text"][case_id])
            out_dict["tokenized_text"].append((tokenizer.convert_ids_to_tokens(tokenized_datasets["validation"]["input_ids"][case_id])))

            # for each masked token, write out its array id within the text, its vocab id and corresponding text
            out_dict["masked_token_array_id"].append(masked_token)
            out_dict["masked_token_vocab_id"].append(label_ids[masked_token])
            out_dict["masked_token_text"].append(tokenizer.convert_ids_to_tokens([label_ids[masked_token]])[0])

            # also write the vocab id and text of the top predicted token
            out_dict["top_pred_token_vocab_id"].append(result[masked_token].argmax())
            out_dict["top_pred_token_text"].append(tokenizer.convert_ids_to_tokens([result[masked_token].argmax()])[0])

            # calculate categorical cross entropy loss as the negative log of the softmax probability of the correct token
            ce_loss = -np.log(softmax(result[masked_token])[label_ids[masked_token]])
            out_dict["ce_loss"].append(ce_loss)

            # save full logits (1xvocab_size) for the masked token for flexibility in further analysis
            out_dict["pred_logits"].append(result[masked_token])

        
# write dataframe from dict    
out_df = pd.DataFrame.from_dict(out_dict)
out_df

# write dataframe to csv
#out_df.to_csv("test.csv", index=False)

Unnamed: 0,case_id,text,tokenized_text,masked_token_array_id,masked_token_vocab_id,masked_token_text,top_pred_token_vocab_id,top_pred_token_text,ce_loss,pred_logits
0,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",14,2000,to,2521,far,6.496281,"[-7.53187, -7.481711, -7.5258093, -7.5237474, ..."
1,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",15,2562,keep,2185,away,6.685265,"[-7.332101, -7.394904, -7.4043255, -7.4488153,..."
2,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",17,20687,renew,20410,verify,2.409274,"[-4.0348964, -4.034002, -4.3296347, -4.0737686..."
3,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",26,1998,and,1998,and,0.000039,"[-14.3275175, -14.694819, -14.702918, -14.3263..."
4,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",39,1996,the,1996,the,0.123853,"[-9.652537, -9.60094, -9.601458, -9.620717, -9..."
...,...,...,...,...,...,...,...,...,...,...
412,99,Neither do I. It's almost like they look at al...,"[[CLS], neither, do, i, ., it, ', s, almost, l...",3,1045,i,2205,too,9.134379,"[-5.9324603, -5.7664423, -5.801572, -5.7258487..."
413,99,Neither do I. It's almost like they look at al...,"[[CLS], neither, do, i, ., it, ', s, almost, l...",13,2035,all,2035,all,0.011680,"[-6.281892, -6.211084, -6.168412, -6.375232, -..."
414,99,Neither do I. It's almost like they look at al...,"[[CLS], neither, do, i, ., it, ', s, almost, l...",25,19952,guthrie,19952,guthrie,0.018343,"[-5.8770943, -6.5148373, -6.2239447, -5.827273..."
415,99,Neither do I. It's almost like they look at al...,"[[CLS], neither, do, i, ., it, ', s, almost, l...",27,1998,and,1998,and,0.112440,"[-6.979398, -7.0463724, -6.777259, -6.7518754,..."


In [600]:
out_df.head(20)

Unnamed: 0,case_id,text,tokenized_text,masked_token_array_id,masked_token_vocab_id,masked_token_text,top_pred_token_vocab_id,top_pred_token_text,ce_loss,pred_logits
0,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",14,2000,to,2521,far,6.496281,"[-7.53187, -7.481711, -7.5258093, -7.5237474, ..."
1,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",15,2562,keep,2185,away,6.685265,"[-7.332101, -7.394904, -7.4043255, -7.4488153,..."
2,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",17,20687,renew,20410,verify,2.409274,"[-4.0348964, -4.034002, -4.3296347, -4.0737686..."
3,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",26,1998,and,1998,and,3.9e-05,"[-14.3275175, -14.694819, -14.702918, -14.3263..."
4,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",39,1996,the,1996,the,0.123853,"[-9.652537, -9.60094, -9.601458, -9.620717, -9..."
5,1,He was a character.,"[[CLS], he, was, a, character, ., [SEP]]",1,2002,he,2002,he,1.577117,"[-5.992837, -6.0397696, -5.9919176, -5.994485,..."
6,1,He was a character.,"[[CLS], he, was, a, character, ., [SEP]]",4,2839,character,2839,character,0.032695,"[-9.127079, -9.143774, -9.267379, -9.157009, -..."
7,1,He was a character.,"[[CLS], he, was, a, character, ., [SEP]]",5,1012,.,1012,.,0.203213,"[-7.7447324, -7.534629, -7.872818, -7.649254, ..."
8,2,"Ugh ffs, most indies don't vote. I don't know ...","[[CLS], u, ##gh, ff, ##s, ,, most, indies, don...",6,2087,most,2225,west,6.531352,"[-8.067986, -7.9866085, -7.9007835, -7.8676376..."
9,2,"Ugh ffs, most indies don't vote. I don't know ...","[[CLS], u, ##gh, ff, ##s, ,, most, indies, don...",14,2123,don,2123,don,0.06124,"[-6.1624947, -5.961982, -6.1122904, -6.1582136..."
