In [13]:
import pandas as pd
import numpy as np
from transformers import Trainer, AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling
    
from datasets import load_dataset

In [2]:
model = AutoModelForMaskedLM.from_pretrained('../../0_models/default-model')

Some weights of the model checkpoint at ../../0_models/default-model were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
tokenizer = AutoTokenizer.from_pretrained('../../0_models/default-model', use_fast=True)

In [39]:
datasets = load_dataset('text', data_files={'test': '../../0_data/clean/unlabelled_reddit/small.txt'})

Using custom data configuration default-f5e9164fabd20000


Downloading and preparing dataset text/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /Users/Paul/.cache/huggingface/datasets/text/default-f5e9164fabd20000/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /Users/Paul/.cache/huggingface/datasets/text/default-f5e9164fabd20000/0.0.0/44d63bd03e7e554f16131765a251f2d8333a5fe8a73f6ea3de012dbc49443691. Subsequent calls will reuse this data.


In [270]:
def tokenize_function(examples):
    # Remove empty lines
    examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
    
    return tokenizer(
        examples["text"],
        padding=False,
        truncation=True,
        max_length=64,
        # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
        # receives the `special_tokens_mask`.
        return_special_tokens_mask=True,
    )

tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    #remove_columns=["text"]
)

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [451]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [452]:
trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        data_collator=data_collator
    )

In [453]:
def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

In [454]:
pred_results = trainer.predict(tokenized_datasets["test"])

In [545]:
# initialise dictionary for writing prediction results to
out_dict = {"case_id": [], "text": [], "tokenized_text": [],
            "masked_token_array_id": [], "masked_token_vocab_id": [], "masked_token_text": [],
            "top_pred_token_vocab_id": [], "top_pred_token_text": [],
            "ce_loss": [],
            "pred_logits": []}

# each row corresponds to a masked token
# first level of iteration is case-by-case
for case_id, result, label_ids in zip(range(pred_results.label_ids.shape[0]), pred_results.predictions, pred_results.label_ids):
        
    # second level of iteration is over masked tokens in a given case    
    # not every case necessarily has masked tokens (indicated by label_id not equal to -100)
    for masked_token in (label_ids != -100).nonzero()[0]:
        
        # write case_id, text and tokenized text corresponding to a given masked token
        out_dict["case_id"].append(case_id)
        out_dict["text"].append(tokenized_datasets["test"]["text"][case_id])
        out_dict["tokenized_text"].append((tokenizer.convert_ids_to_tokens(tokenized_datasets["test"]["input_ids"][case_id])))
        
        # for each masked token, write out its array id within the text, its vocab id and corresponding text
        out_dict["masked_token_array_id"].append(masked_token)
        out_dict["masked_token_vocab_id"].append(label_ids[masked_token])
        out_dict["masked_token_text"].append(tokenizer.convert_ids_to_tokens([label_ids[masked_token]])[0])
        
        # also write the vocab id and text of the top predicted token
        out_dict["top_pred_token_vocab_id"].append(result[masked_token].argmax())
        out_dict["top_pred_token_text"].append(tokenizer.convert_ids_to_tokens([result[masked_token].argmax()])[0])
        
        # calculate categorical cross entropy loss as the negative log of the softmax probability of the correct token
        ce_loss = -np.log(softmax(result[masked_token])[label_ids[masked_token]])
        out_dict["ce_loss"].append(ce_loss)
        
        # save full logits (1xvocab_size) for the masked token for flexibility in further analysis
        out_dict["pred_logits"].append(result[masked_token])
        
        
# write dataframe from dict    
out_df = pd.DataFrame.from_dict(out_dict)

# write dataframe to csv
out_df.to_csv("test.csv", index=False)

Unnamed: 0,case_id,text,tokenized_text,masked_token_array_id,masked_token_vocab_id,masked_token_text,top_pred_token_vocab_id,top_pred_token_text,ce_loss,pred_logits
0,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",7,2074,just,2074,just,0.602947,"[-8.327028, -8.412684, -8.192552, -7.9687877, ..."
1,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",22,1012,.,1012,.,0.422119,"[-9.687029, -9.70536, -9.805444, -9.508277, -9..."
2,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",27,2643,god,1996,the,6.138289,"[-9.03698, -8.999211, -8.9290905, -8.98094, -9..."
3,0,Before that I was down to just go and vote for...,"[[CLS], before, that, i, was, down, to, just, ...",29,1045,i,2111,people,4.879502,"[-5.2836037, -5.3751225, -5.115223, -5.4827275..."
4,2,"Ugh ffs, most indies don't vote. I don't know ...","[[CLS], u, ##gh, ff, ##s, ,, most, indies, don...",2,5603,##gh,1010,",",5.948269,"[-5.6354556, -5.8172107, -5.7830467, -5.642902..."
5,2,"Ugh ffs, most indies don't vote. I don't know ...","[[CLS], u, ##gh, ff, ##s, ,, most, indies, don...",6,2087,most,2021,but,7.775764,"[-7.483583, -7.510486, -7.30363, -7.432882, -7..."
6,2,"Ugh ffs, most indies don't vote. I don't know ...","[[CLS], u, ##gh, ff, ##s, ,, most, indies, don...",7,9429,indies,1045,i,16.456945,"[-7.8007345, -7.8428097, -7.589243, -7.815379,..."
7,2,"Ugh ffs, most indies don't vote. I don't know ...","[[CLS], u, ##gh, ff, ##s, ,, most, indies, don...",11,3789,vote,2113,know,2.249898,"[-6.349863, -6.5139174, -6.408725, -6.509866, ..."
8,2,"Ugh ffs, most indies don't vote. I don't know ...","[[CLS], u, ##gh, ff, ##s, ,, most, indies, don...",13,1045,i,1045,i,0.000104,"[-12.30946, -12.103629, -11.994328, -12.001375..."
9,2,"Ugh ffs, most indies don't vote. I don't know ...","[[CLS], u, ##gh, ff, ##s, ,, most, indies, don...",19,2116,many,1045,i,3.111074,"[-7.12574, -7.2279606, -6.963147, -7.1625104, ..."
