## Imports

In [None]:
import numpy as np
import pandas as pd
import jsonlines
import gc
import torch
from torch import nn

from IPython.display import display, clear_output

from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
from transformers.activations import GELUActivation
from transformers.modeling_outputs import MaskedLMOutput
from transformers import DataCollatorForWholeWordMask
from datasets import load_from_disk, load_dataset
from transformers import BertTokenizer, DistilBertTokenizer
import evaluate

import wandb
# wandb.init(project="kg-lm-integration", entity="tanny411")

from huggingface_hub import notebook_login
# notebook_login()

emb_tsv_file = "wikidata_translation_v1.tsv"

In [None]:
# torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.current_device(), torch.cuda.get_device_name(0)

In [None]:
bert_model_name = "distilbert-base-uncased" ##"bert-base-cased"

# Initialize Project

Download:
- embeds_wktxt.csv
- [linked-wikitext-2 dataset](https://rloganiv.github.io/linked-wikitext-2/#/) and unzip

# Tokenization

- `tokens` are the given list of tokens from wikitext2
- `input_ids` are what come from tokenization, they divide certain words into multiple pieces, and each sentence has a CLS and a SEP
- `word_tokens` is the length of `tokens`. For each token in `token`, it mentions how many sub-words it was divided into due to word piece tokenization
- `cummulative_word_tokens` is a cummulative sum of `word_tokens`, with an extra 0 in the beginning

##### Process
index of a token in `token` can be found in `input_ids` by `cummulative_word_tokens`. if `ix` is the index of a word in `token`, its beginning index in `input_ids` is `cummulative_word_tokens[ix] + 1`, the +1 is because `input_ids` has a CLS in the beginning. `token[ix]` spans from `input_ids[cummulative_word_tokens[ix] + 1]` to `input_ids[cummulative_word_tokens[ix+1] + 1]`

In [None]:
embeds_wktxt = pd.read_csv("embeds_wktxt.csv")

linked_wikitext_2 = "linked-wikitext-2/"
train = linked_wikitext_2+"train.jsonl"
valid = linked_wikitext_2+"valid.jsonl"
test = linked_wikitext_2+"test.jsonl"

data_files = {"train": train, "valid": valid, "test": test}
wikitest2_dataset = load_dataset("json", data_files=data_files)

class BertTokenizerModified(DistilBertTokenizer): #BertTokenizer
    def __init__(self,vocab_file,**kwargs):
        
        super().__init__(vocab_file, never_split=["@@START@@", "@@END@@"], **kwargs)
    
        self.tokenized_list = []

    def _tokenize(self, text):
        token_list = text.split()
        split_tokens = []
        tokenized_list = []
        
        if self.do_basic_tokenize:
            for token in token_list:

                # If the token is part of the never_split set
                if token in self.basic_tokenizer.never_split:
                    split_tokens.append(token)
                    tokenized_list.append(1)
                else:
                    word_tokenized = self.wordpiece_tokenizer.tokenize(token)
                    split_tokens += word_tokenized
                    tokenized_list.append(len(word_tokenized))

        self.tokenized_list.append(tokenized_list)
        return split_tokens
    
def get_cumm(vals):
    cumm = 0
    res = [0] ## len of res is 1 more than vals, with an initial 0
    for val in vals:
        cumm += val
        res.append(cumm)
    return res

def my_tokenize_function(data):
    
    ## tokenize
    my_tokenizer.tokenized_list = []
    result = my_tokenizer([" ".join(eg) for eg in data["tokens"]])
    if my_tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    
    ## save word to token mapping
    ## 3, 1, 1 means the first word got divided into 3 tokens, the next into 1, and the next into 1 again
    result["word_tokens"] = my_tokenizer.tokenized_list
    result["cummulative_word_tokens"] = [get_cumm(x) for x in result["word_tokens"]]
    return result

def get_kg_embedding_batched(data):
    """
    When you specify batched=True the function receives a dictionary with the fields of the dataset, 
    but each value is now a list of values, and not just a single value. 
    """
    input_ids_list = data["input_ids"]
    annotations_list = data['annotations']
    cummulative_word_tokens_list = data["cummulative_word_tokens"]
    
    batch_size = len(input_ids_list)
    embed_list = [] ## len will be batch_size
    
    for i in range(batch_size):
        input_ids = input_ids_list[i]
        annotations = annotations_list[i]
        embeds = np.zeros((len(input_ids), 200)) ## CLS, SEP will have np.zeros, like unknown words
        
        for annot in annotations:
            start_ix, end_ix = annot['span']
            start = cummulative_word_tokens_list[i][start_ix] + 1
            end = cummulative_word_tokens_list[i][end_ix] + 1
            
            qid = annot['id']
            df = embeds_wktxt[embeds_wktxt['id']==qid]
            if len(df)>0:
                embeds[start:end] = np.tile(df.iloc[0,1:].values.reshape((1,200)),(end-start, 1))
        
        embed_list.append(embeds)

    return {"kg_embedding": embed_list}

def filter_text_batched(data):
    input_ids_list = data["input_ids"]
    kg_embeds_list = data['kg_embedding']
    
    indices_list = [[i for i,input_id in enumerate(input_ids) if input_id!=100]
                        for input_ids in input_ids_list]
    new_input_ids_list = []
    new_kg_embedding_list = []
        
    for indices, input_ids, kg_embeds in zip(indices_list, input_ids_list, kg_embeds_list):
        
        new_input_ids = [input_ids[ind] for ind in indices]
        new_kg_embeds = [kg_embeds[ind] for ind in indices]
        
        new_input_ids_list.append(new_input_ids)
        new_kg_embedding_list.append(new_kg_embeds)
        
    return {"input_ids": new_input_ids_list, "kg_embedding": new_kg_embedding_list}

In [None]:
filtered_tokenized_kg_wikitest2_dataset_file = "filtered_tokenized_kg_wikitest2_dataset_file"

In [None]:
## Run this cell to perform tokenization.

my_tokenizer = BertTokenizerModified.from_pretrained(bert_model_name)

tokenized_kg_wikitest2_dataset = wikitest2_dataset.map(my_tokenize_function, batched=True)\
                                                  .map(get_kg_embedding_batched, 
                                                         batched=True, 
                                                         batch_size=100, 
                                                         keep_in_memory=False)\
                                                  .map(filter_text_batched, 
                                                         batched=True, 
                                                         batch_size=100, 
                                                         keep_in_memory=False)

filtered_tokenized_kg_wikitest2_dataset.save_to_disk(filtered_tokenized_kg_wikitest2_dataset_file)

In [None]:
## Load the saved tokenized dataset
dataset = load_from_disk(filtered_tokenized_kg_wikitest2_dataset_file)

In [None]:
dataset

## TODO

In [None]:
chunk_size = 128

def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result

# Create Model

In [None]:
class BERTModified(nn.Module):
    def __init__(self, bert_model_name):
        super().__init__()

        self.base_model = AutoModel.from_pretrained(bert_model_name)
        self.config = self.base_model.config
        
        self.activation = GELUActivation() # for distilbert
        self.vocab_transform = nn.Linear(self.config.dim, self.config.dim)
        self.vocab_layer_norm = nn.LayerNorm(self.config.dim, eps=1e-12)
        self.vocab_projector = nn.Linear(self.config.dim, self.config.vocab_size)

        self.mlm_loss_fct = nn.CrossEntropyLoss()
        
        ## set to eval
        self.base_model.eval()
        
        ## freeze model
        for param in self.base_model.parameters():
            param.requires_grad = False

    def forward(
        self,
        kg_embedding = None,
        input_ids = None,
        attention_mask = None,
        head_mask = None,
        inputs_embeds = None,
        labels = None,
        output_attentions = None,
        output_hidden_states = None,
        return_dict= None,):
        
        base_model_output = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        ## Get LM embedding
        hidden_states = base_model_output[0]  # (bs, seq_length, dim)
        
        ## TODO: Use hidden_states and kg_embedding and perform INTEGRATION
        prediction_logits = self.vocab_transform(hidden_states)  # (bs, seq_length, dim)
        prediction_logits = self.activation(prediction_logits)  # (bs, seq_length, dim)
        prediction_logits = self.vocab_layer_norm(prediction_logits)  # (bs, seq_length, dim)
        prediction_logits = self.vocab_projector(prediction_logits)  # (bs, seq_length, vocab_size)

        mlm_loss = None
        if labels is not None:
            mlm_loss = self.mlm_loss_fct(prediction_logits.view(-1, prediction_logits.size(-1)), labels.view(-1))

        return MaskedLMOutput(
            loss=mlm_loss,
            logits=prediction_logits,
            hidden_states=base_model_output.hidden_states,
            attentions=base_model_output.attentions,
        )

In [None]:
BERTModified_model = BERTModified(bert_model_name)
data_collator = DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

In [None]:
## To test model with smaller sample dataset

# train_size = 1000
# test_size = 100

# downsampled_dataset = dataset.train_test_split(train_size=train_size, test_size=test_size, seed=42)
# downsampled_dataset

In [None]:
metric_name = evaluate.load("perplexity") # accuracy

In [None]:
from transformers import TrainingArguments

batch_size = 16

# Show the training loss with every epoch
logging_steps = len(downsampled_dataset['train']) // batch_size
model_name = "BERTModified"
output_dir = f"{model_name}-finetuned-wikitext-test"

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    push_to_hub=True,
#     fp16=True,
    logging_steps=logging_steps,
    num_train_epochs=1,
#     load_best_model_at_end=True,
#     metric_for_best_model="loss",#metric_name,
#     greater_is_better = False,
    logging_dir='logs',
    report_to="wandb",
#     no_cuda=True,
)

metric_for_best_model (str, optional) — Use in conjunction with load_best_model_at_end to specify the metric to use to compare two different models. Must be the name of a metric returned by the evaluation with or without the prefix "eval_". Will default to "loss" if unspecified and load_best_model_at_end=True (to use the evaluation loss).

If you set this value, greater_is_better will default to True. Don’t forget to set it to False if your metric is better when lower.

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=BERTModified_model,
    args=training_args,
    train_dataset=downsampled_dataset["train"],
    eval_dataset=downsampled_dataset["test"],
    data_collator=data_collator,
)

In [None]:
import math

eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

In [None]:
trainer.train()
# trainer.save_model("output/models/BERTModified")

In [None]:
import math

eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}") #21596 for 1 epoch

In [None]:
## Need to login to huggingface to push to hub

trainer.push_to_hub()

In [None]:
from transformers import pipeline

# Initialize MLM pipeline
mlm = pipeline('fill-mask', model=BERTModified_model, tokenizer=my_tokenizer)

# Get mask token
mask = mlm.tokenizer.mask_token

# Get result for particular masked phrase
phrase = f'Wikipedia is a free online {mask}, created and edited by volunteers around the world'

result = mlm(phrase)

# Print result
print(result)

In [None]:
for x in result:
    print(f">>> {x['sequence']}")