## Imports

In [None]:
import numpy as np
import pandas as pd
import jsonlines
import gc
import os
import torch
from torch import nn
from models.models import *


from IPython.display import display, clear_output

from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
from transformers.activations import GELUActivation
from transformers.modeling_outputs import MaskedLMOutput
from transformers import DataCollatorForWholeWordMask
from datasets import load_from_disk, load_dataset
from transformers import BertTokenizer, DistilBertTokenizer
from transformers.data.data_collator import _torch_collate_batch
import evaluate

import wandb
# wandb.init(project="kg-lm-integration", entity="tanny411")

from huggingface_hub import notebook_login
#notebook_login()
#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
emb_tsv_file = "wikidata_translation_v1.tsv"

In [None]:
import sys
print(sys.getrecursionlimit())

sys.setrecursionlimit(100000)
# but doing so is dangerous -- the standard limit is a little conservative, but Python stackframes can be quite big.

In [None]:
# torch.cuda.is_available(), torch.cuda.device_count(), torch.cuda.current_device(), torch.cuda.get_device_name(0)

In [None]:
bert_model_name = "distilbert-base-uncased" ##"bert-base-cased"

# Initialize Project

Download:
- embeds_wktxt.csv
- [linked-wikitext-2 dataset](https://rloganiv.github.io/linked-wikitext-2/#/) and unzip

# Tokenization

- `tokens` are the given list of tokens from wikitext2
- `input_ids` are what come from tokenization, they divide certain words into multiple pieces, and each sentence has a CLS and a SEP
- `word_tokens` is the length of `tokens`. For each token in `token`, it mentions how many sub-words it was divided into due to word piece tokenization
- `cummulative_word_tokens` is a cummulative sum of `word_tokens`, with an extra 0 in the beginning

##### Process
index of a token in `token` can be found in `input_ids` by `cummulative_word_tokens`. if `ix` is the index of a word in `token`, its beginning index in `input_ids` is `cummulative_word_tokens[ix] + 1`, the +1 is because `input_ids` has a CLS in the beginning. `token[ix]` spans from `input_ids[cummulative_word_tokens[ix] + 1]` to `input_ids[cummulative_word_tokens[ix+1] + 1]`

In [None]:
embeds_wktxt = pd.read_csv("embeds_wktxt.csv")
qids_wktxt = pd.read_csv("qids_wktxt2.csv")

linked_wikitext_2 = "linked-wikitext-2/"
train = linked_wikitext_2+"train.jsonl"
valid = linked_wikitext_2+"valid.jsonl"
test = linked_wikitext_2+"test.jsonl"

data_files = {"train": train, "valid": valid, "test": test}
wikitest2_dataset = load_dataset("json", data_files=data_files)

chunk_size = 128

class BertTokenizerModified(DistilBertTokenizer): #BertTokenizer
    def __init__(self,vocab_file,**kwargs):
        
        super().__init__(vocab_file, never_split=["@@START@@", "@@END@@", "@@start@@", "@@end@@"], **kwargs)
    
        self.tokenized_list = []

    def _tokenize(self, text):
        token_list = text.split()
        split_tokens = []
        tokenized_list = []
        
        if self.do_basic_tokenize:
            for token in token_list:

                # If the token is part of the never_split set
                if token in self.basic_tokenizer.never_split:
                    split_tokens.append(token)
                    tokenized_list.append(1)
                else:
                    word_tokenized = self.wordpiece_tokenizer.tokenize(token)
                    split_tokens += word_tokenized
                    tokenized_list.append(len(word_tokenized))

        self.tokenized_list.append(tokenized_list)
        return split_tokens
    
def get_cumm(vals):
    cumm = 0
    res = [0] ## len of res is 1 more than vals, with an initial 0
    for val in vals:
        cumm += val
        res.append(cumm)
    return res

def my_tokenize_function(data):
    
    ## tokenize
    my_tokenizer.tokenized_list = []
    result = my_tokenizer([" ".join(eg) for eg in data["tokens"]])
    if my_tokenizer.is_fast:
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    
    ## save word to token mapping
    ## 3, 1, 1 means the first word got divided into 3 tokens, the next into 1, and the next into 1 again
    result["word_tokens"] = my_tokenizer.tokenized_list
    result["cummulative_word_tokens"] = [get_cumm(x) for x in result["word_tokens"]]
    
    return result

def get_kg_embedding_batched(data):
    
    ## store a masking array that says whether or not an item has kg embedding
    """
    When you specify batched=True the function receives a dictionary with the fields of the dataset, 
    but each value is now a list of values, and not just a single value. 
    """
    input_ids_list = data["input_ids"]
    annotations_list = data['annotations']
    cummulative_word_tokens_list = data["cummulative_word_tokens"]
    
    batch_size = len(input_ids_list)
    embed_list = [] ## len will be batch_size
    embed_mask = []
    embed_mask_qid = []
    
    #add by Edward, the index of qid
    embed_mask_index = []
    
    allc = 0
    cc = 0
    
    for i in range(batch_size):
        input_ids = input_ids_list[i]
        annotations = annotations_list[i]
        
        ## Replace zeros with random numbers if required
        embeds = np.zeros((len(input_ids), 200)) ## CLS, SEP will have np.zeros, like unknown words
        mask = [0]*len(input_ids)
        mask_qid = ['0']*len(input_ids)
        
        #add by Edward
        mask_index = [-100]*len(input_ids)
        
        
        for annot in annotations:
            start_ix, end_ix = annot['span']
            start = cummulative_word_tokens_list[i][start_ix] + 1
            end = cummulative_word_tokens_list[i][end_ix] + 1
            
            qid = annot['id']
            
            #add by Edward
            index_list = qids_wktxt[qids_wktxt["id"]==qid].index.tolist()
            allc += 1
            if len(index_list) == 0:
                qid_index = -100

            else:
                qid_index = index_list[0]
                cc+=1
            
            df = embeds_wktxt[embeds_wktxt['id']==qid]
            if len(df)>0:
                embeds[start:end] = np.tile(df.iloc[0,1:].values.reshape((1,200)),(end-start, 1))
                mask[start:end] = [1]*(end-start)
                mask_qid[start:end] = [qid]*(end-start)
                
                #add by Edward
                mask_index[start:end] = [qid_index]*(end-start)
                
                
        embed_mask.append(mask)
        embed_mask_qid.append(mask_qid)
        embed_list.append(embeds)
        
        #add by Edward
        embed_mask_index.append(mask_index)

    
    print(cc/allc)
    print(cc)
    print(allc)
    return {
        "kg_embedding": embed_list, 
        "kg_embedding_mask": embed_mask,
        "kg_embedding_mask_qid": embed_mask_qid,
        "kg_embedding_mask_index": embed_mask_index
    }

def filter_text_batched(data):
    
    new_data = {k:[] for k in data}
    
    input_ids_list = data["input_ids"]
    
    ## remove [UNK] == 100 
    indices_list = [[i for i,input_id in enumerate(input_ids) if input_id!=100]
                        for input_ids in input_ids_list]
    
    for k in data:
        for indices, data_list in zip(indices_list, data[k]):
            new_data[k].append([data_list[ind] for ind in indices])
        
    return new_data

def truncate_data(data):
    maxlength = my_tokenizer.max_model_input_sizes[bert_model_name]

    ## truncate to maxlength
    for k in data:
        data[k] = [x[:maxlength] for x in data[k]]
    
    return data


def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    
    # Create a new labels column
#     result["labels"] = result["input_ids"].copy()
    return result

In [None]:
my_tokenizer = BertTokenizerModified.from_pretrained(bert_model_name)

# Data processing and loading

In [None]:
# Linked Wikitext dataset
dataset_file = "concat_dataset_v2"


final_dataset = wikitest2_dataset.map(my_tokenize_function, batched=True)\
                          .map(get_kg_embedding_batched, batched=True, batch_size=100, keep_in_memory=False)\
                          .remove_columns(['title', 'tokens', 'annotations', 'word_tokens', 'cummulative_word_tokens'])\
                          .map(filter_text_batched, batched=True, batch_size=100, keep_in_memory=False)\
                          .map(group_texts, batched=True, batch_size=100, keep_in_memory=False)\

final_dataset.save_to_disk(dataset_file)

In [None]:
# Wikidata fact dataset
synthetic_data = "generate_test_data/sythetic_dataset.jsonl"
synthetic_dataset = load_dataset("json", data_files={"synthetic": synthetic_data})

dataset_file = "tokenized_synthetic_dataset_1"

tokenized_synthetic_dataset = synthetic_dataset.map(my_tokenize_function, batched=True)\
                          .map(get_kg_embedding_batched, batched=True, batch_size=100, keep_in_memory=False)\
                          .remove_columns(['title', 'tokens', 'annotations', 'word_tokens', 'cummulative_word_tokens'])\

tokenized_synthetic_dataset.save_to_disk(dataset_file)




In [None]:
## Load the saved tokenized dataset
dataset_file = "concat_dataset_v2"
final_dataset = load_from_disk(dataset_file)

dataset_file = "tokenized_synthetic_dataset_1"
tokenized_synthetic_dataset = load_from_disk(dataset_file)


# # To test model with smaller sample dataset
train_size = 100
test_size = 300
downsampled_dataset = final_dataset["valid"].train_test_split(train_size=train_size, test_size=test_size, seed=42)

## Create Data Collator for Masking

In [None]:

## Create Data Collator for Masking
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from collections.abc import Mapping    
    
# NEW FOR EXTRA TEST    
class CustomDataCollator(DataCollatorForWholeWordMask):
    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        if isinstance(examples[0], Mapping):
            input_ids = [e["input_ids"] for e in examples]
            kg_embedding_mask = [e["kg_embedding_mask"] for e in examples]
            kg_embedding = [e["kg_embedding"] for e in examples]
            kg_embedding_mask_qid = [e["kg_embedding_mask_qid"] for e in examples]
            kg_embedding_mask_index = [e["kg_embedding_mask_index"] for e in examples]
        else:
            raise Exception("Dataset needs to be in dictionary format")
 
        batch_input = _torch_collate_batch(input_ids, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
 
        mask_labels = kg_embedding_mask
        batch_mask = _torch_collate_batch(kg_embedding_mask, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
        inputs, labels = self.torch_mask_tokens(batch_input, batch_mask)
 
        batch_size = len(kg_embedding)
        token_length = len(inputs[0])
        embedding_size = len(kg_embedding[0][0])
 
        kg_embedding = [kg_embds+[[0]*embedding_size]*(token_length-len(kg_embds)) for kg_embds in kg_embedding]
    
        batch_mask_index = _torch_collate_batch(kg_embedding_mask_index, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
    
    
    
 
        return {
                "input_ids": inputs, 
                "labels": labels, 
                "kg_embedding":kg_embedding, 
                "kg_embedding_mask": batch_mask,
                "kg_embedding_mask_qid": [qid+['0']*(token_length-len(qid)) for qid in kg_embedding_mask_qid],
                "kg_embedding_mask_index":batch_mask_index,
               }
    

# Create Model

1. The KIM designed in this study is BERTModified. 
2. The basiline LM-Raw is BERTModified_LMRaw.
3. The basiline KG-Raw is BERTModified_KGRaw.
4. The alternative integration module Alt-KIM is BERTModified_alt. 

In [None]:
model_name = "Our KIM"  # or LM-Raw, KG-Raw, alt-KIM

model_dict ={"Our KIM":BERTModified, "LM-Raw":BERTModified_LMRaw, "KG-Raw":BERTModified_KGRaw,"alt-KIM":BERTModified_alt}

base_model = AutoModel.from_pretrained(bert_model_name)



# choose the model needsto be trained 
model = model_dict[model_name](bert_model_name = bert_model_name,
                                  base_model = base_model,
                                  config = base_model.config, kg_size =46685)



In [None]:
# create data_collator
data_collator = CustomDataCollator(tokenizer=my_tokenizer, mlm=True, mlm_probability=0.15)


# Evaluation metrics and loss function

In [None]:
import evaluate

# metrics = evaluate.combine(["accuracy", "precision", "recall", "f1"])
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")

def compute_metrics(eval_preds=None, logits=None, labels=None):
    
    # We should have either `eval_preds` or both `logits` and `labels`
    if eval_preds:
        logits, labels = eval_preds

    predictions = np.argmax(logits, axis=-1)

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[l for l in label if l != -100] for label in labels]
    true_predictions = [
        [p for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    ## Flatten values
    true_labels = [item for sublist in true_labels for item in sublist]
    true_predictions = [item for sublist in true_predictions for item in sublist]
    
    accuracy = accuracy_metric.compute(predictions=true_predictions, references=true_labels)
    precision = precision_metric.compute(predictions=true_predictions, references=true_labels, average="micro")
    recall = recall_metric.compute(predictions=true_predictions, references=true_labels, average="micro")
    f1 = f1_metric.compute(predictions=true_predictions, references=true_labels, average="micro")
    
    return {
        "precision": precision["precision"],
        "recall": recall["recall"],
        "f1": f1["f1"],
        "accuracy": accuracy["accuracy"],
    }

# Create Trainer

In [None]:
from transformers import TrainingArguments

batch_size = 4

# Show the training loss with every epoch
logging_steps = len(final_dataset['train']) // batch_size #len(final_dataset['train']) // batch_size
model_name = "BERTModified"
output_dir = f"{model_name}-finetuned-wikitext-test"

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    push_to_hub=True,
#     fp16=True,
    logging_steps=logging_steps,
    num_train_epochs=50,
#     load_best_model_at_end=True,
#     metric_for_best_model="loss",#metric_name,
#     greater_is_better = False,
    logging_dir='logs',
    report_to="wandb",
#     no_cuda=True,
)

metric_for_best_model (str, optional) — Use in conjunction with load_best_model_at_end to specify the metric to use to compare two different models. Must be the name of a metric returned by the evaluation with or without the prefix "eval_". Will default to "loss" if unspecified and load_best_model_at_end=True (to use the evaluation loss).

If you set this value, greater_is_better will default to True. Don’t forget to set it to False if your metric is better when lower.

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=downsampled_dataset["test"],
    train_dataset=final_dataset["train"],
    #eval_dataset=final_dataset["valid"],
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# Train Model

In [None]:
trainer.train()

# training froma checkpoint

# trainer.train("./BERTModified-fullsize-kg-finetuned-wikitext-test/checkpoint-65730")



In [None]:
import math

# basic evluation on perplexity 
eval_results = trainer.evaluate()
print(f">>> Perplexity: {math.exp(eval_results['eval_loss']):.2f}") #21596 for 1 epoch

eval_results

# Evaluation

## 1.LM-accuracy

### 1.1 Linked Wikidata

In [None]:
#predictions = trainer.predict(downsampled_dataset["test"])
predictions = trainer.predict(final_dataset["test"])

[About micro, macro, weighted precision, recall, f1 for multiclass labels](https://towardsdatascience.com/multi-class-metrics-made-simple-part-ii-the-f1-score-ebe8b2c2ca1)

The following always holds true for the micro-F1 case:

`micro-F1 = micro-precision = micro-recall = accuracy`

In [None]:
# show the result
ompute_metrics(logits = predictions.predictions, labels = predictions.label_ids)


### 1.2 Wikidata facts

In [None]:
predictions = trainer.predict(tokenized_synthetic_dataset["synthetic"])

compute_metrics(logits = predictions.predictions, labels = predictions.label_ids)

## 2.KG-accuracy

### 2.1 Linked Wikidata

In [None]:

#model_name = "Our KIM"
model_dict_kg ={"Our KIM":BERTModified_KG, "KG-Raw":BERTModified_KGRaw,"alt-KIM":BERTModified_altKG}

base_model = AutoModel.from_pretrained(bert_model_name)

# choose the model needsto be trained 
model = model_dict_kg[model_name](bert_model_name = bert_model_name,
                                  base_model = base_model,
                                  config = base_model.config,kg_size =46685)


In [None]:

checkpoint = torch.load("./" + output_dir + "/pytorch_model.bin")
model.load_state_dict(checkpoint)

In [None]:


trainer = Trainer(
    model=model,
    args=training_args,
    #eval_dataset=downsampled_dataset["test"],
    train_dataset=final_dataset["train"],
    eval_dataset=final_dataset["valid"],
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [None]:
#predictions = trainer.predict(downsampled_dataset["test"]) 
predictions = trainer.predict(final_dataset["test"]) 
#labels = np.array(downsampled_dataset["test"]["kg_embedding_mask_index"])
labels = np.array(final_dataset["test"]["kg_embedding_mask_index"])

In [None]:

# show the result
compute_metrics(logits = predictions.predictions, labels =labels)

### 2.2 Wikidata facts

In [None]:
predictions = trainer.predict(tokenized_synthetic_dataset["synthetic"])
# true label of entitles
labels = np.array(tokenized_synthetic_dataset["synthetic"]["kg_embedding_mask_index"])

# show the result
compute_metrics(logits = predictions.predictions, labels =labels)