In [1]:
from platform import python_version

python_version()

'3.9.15'

In [30]:
import logging
import re
import string
import time
import torch
import unicodedata

In [128]:
from ast import literal_eval

In [3]:
from datasets import load_dataset, get_dataset_infos

In [4]:
import numpy as np

In [5]:
from tqdm.auto import tqdm

In [6]:
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup, get_scheduler
from transformers import RealmForOpenQA, RealmConfig, RealmRetriever, RealmTokenizerFast, RealmScorer
from transformers import TrainingArguments, Trainer

2022-11-22 15:33:26.555146: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-22 15:33:26.712008: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2022-11-22 15:33:27.494168: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64::/opt/con

In [7]:
from torch.nn.utils import clip_grad_norm_

from torch.utils.tensorboard import SummaryWriter

## Using the filtered splits used in the MEND paper


In [8]:
dataset = load_dataset('data/zsre/', data_files={'train': 'train_filtered.tsv',
                                                 'validation': 'dev_filtered.tsv'})

dataset

Using custom data configuration zsre-f642d97352d09fb5
Found cached dataset csv (/home/patrick/.cache/huggingface/datasets/csv/zsre-f642d97352d09fb5/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['question', 'answer'],
        num_rows: 241523
    })
    validation: Dataset({
        features: ['question', 'answer'],
        num_rows: 27384
    })
})

In [None]:
# what does the data look like?
print(dataset['validation'][20])

# Seems like each example is a question answer pair,
# remever to wrap them in list notation as expected 
# by the realm tokenizer


In [None]:
# Save the dataset 
dataset.save_to_disk('data/zsre/zsre_hf.hf')

## Evalutation set, keeping multiple answers as they do in MEND

### (but stripping very similar questions with the same answer)

In [186]:
# Evaluate with only the validation dataset for now

multi_dataset = load_dataset('data/zsre/', data_files={'train': 'train_multi_alternatives.tsv',
                                                       'validation': 'dev_multi_alternatives.tsv'})
multi_dataset

  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['question', 'answers'],
        num_rows: 44499
    })
    validation: Dataset({
        features: ['question', 'answers'],
        num_rows: 7983
    })
})

In [187]:
# let's undo the list/string shenanigans!
def undo_string_list_shenaningans(example):
    return {'question': example['question'], 'answers': literal_eval(example['answers'])}

In [189]:
# test shenanigans function 
undo_string_list_shenaningans(multi_dataset['validation'][0])


{'question': 'What university did Watts Humphrey attend?',
 'answers': ['Illinois Institute of Technology',
  'Yale University',
  'University of Chicago',
  "King's College London",
  'University of Michigan']}

In [190]:
multi_dataset['validation'] = multi_dataset['validation'].map(undo_string_list_shenaningans)


In [191]:
multi_dataset['train'] = multi_dataset['train'].map(undo_string_list_shenaningans)


### Dataloaders and collators

In [135]:
train_dataloader = torch.utils.data.DataLoader(dataset['train'], batch_size=1) # complains with a batch size > 1
eval_dataloader = torch.utils.data.DataLoader(dataset['validation'], batch_size=1)

In [192]:
# remember these ones return multiple answers at a time, choose one 
train_multi_dataloader = torch.utils.data.DataLoader(multi_dataset['train'], batch_size=1)
eval_multi_dataloader = torch.utils.data.DataLoader(multi_dataset['validation'], batch_size=1)


## Format retrieval data for pretrained retriever checkpoint

In [None]:
from transformers.models.realm.retrieval_realm import convert_tfrecord_to_np

block_records = convert_tfrecord_to_np('data/wiki/enwiki-20181220/blocks.tfr', RealmConfig().num_block_records)



In [None]:
# save the block records object 
np.save("20181220_records", block_records)

In [156]:
# we've already saved this file, read it in 
block_records = np.load("data/block_records.npy", allow_pickle=True)

In [157]:
len(block_records)

13353718

## Finetune openqa checkpoint 

In [153]:
# set up simple logging to file so we don't get overwhelmed in notebook 
logging.basicConfig(filename=f'logs/train-{time.strftime("%m-%d-%YT%H:%M:%S", time.localtime())}.log',
                    filemode='a',
                    format='%(asctime)s,%(msecs)d %(levelname)s %(message)s',
                    datefmt='%H:%M:%S',
                    level=logging.INFO)

logging.info("Setting up training...\n")

In [154]:
# For finetuned checkpoint
checkpoint = "google/realm-orqa-nq-openqa"

# For pretrained on cc-news weights
ft_checkpoint = "google/realm-cc-news-pretrained-openqa"

In [158]:
# uncomment next line for REALM pretrained in wiki
# tokenizer = RealmTokenizerFast.from_pretrained(checkpoint)

# uncomment below for REALM tokenizer pretrained on cc-news 
tokenizer = RealmTokenizerFast.from_pretrained(ft_checkpoint)

In [159]:
# uncomment line below for orqa retriever 
# retriever = RealmRetriever.from_pretrained(checkpoint)

# uncomment line below for retriever with own database to retrieve from
retriever = RealmRetriever(block_records, tokenizer)


In [160]:
# Uncomment line below for orqa model 
# model = RealmForOpenQA.from_pretrained(checkpoint, retriever)

# uncomment below for model pretrained on cc-news
model = RealmForOpenQA.from_pretrained(ft_checkpoint, retriever)


In [200]:
#
#  Test
#

question = "Who is the pioneer in modern computer science?"
question_ids = tokenizer([question], return_tensors="pt")

answer_ids = tokenizer(
    ["alan mathison turing"],
    add_special_tokens=False,
    return_token_type_ids=False,
    return_attention_mask=False,
    return_tensors="pt"
).input_ids

In [162]:
reader_output, predicted_answer_ids = model(**question_ids,  # .to(device),
                                            answer_ids=answer_ids,  # .to(device),
                                            return_dict=False)

In [163]:
tokenizer.decode(predicted_answer_ids)

'computer science and software engineering ; both fields study software'

In [164]:
reader_output.candidate

tensor(2794)

In [165]:
torch.any(reader_output.reader_correct)

tensor(False)

In [166]:
# Borrowed from: https://github.com/huggingface/transformers/blob/e239fc3b0baf1171079a5e0177a69254350a063b/examples/pytorch/language-modeling/run_mlm_no_trainer.py#L456-L468

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]


In [167]:
optimizer = torch.optim.AdamW(
    optimizer_grouped_parameters,
    lr=1e-5,
    weight_decay=0.01,
    eps=1e-6,
)

In [None]:
optimizer = AdamW(model.parameters(), lr=1e-5)

In [168]:
num_epochs = 2  # epochs 

num_steps = len(train_dataloader) * num_epochs  # total steps to set up scheduler 

global_step = 1  # tracker for number of steps 

checkpoint_interval = 1000

In [169]:
# set device
device = torch.device('cuda')
device

device(type='cuda')

In [170]:
# Set learning rate scheduler 

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=10000,
    num_training_steps=num_steps,
)

In [171]:
writer = SummaryWriter()

In [None]:
model.to(device)

In [193]:
for bb in train_multi_dataloader:
    break

In [198]:
bb['question']

# bb['answers'][0][0]

['Who is Adoration of the Trinity by?']

In [206]:
name, weight

('embedder.realm.embeddings.word_embeddings.weight',
 Parameter containing:
 tensor([[-0.0095, -0.0572, -0.0246,  ..., -0.0184, -0.0346, -0.0091],
         [-0.0109, -0.0558, -0.0300,  ..., -0.0156, -0.0372, -0.0099],
         [-0.0183, -0.0583, -0.0303,  ..., -0.0153, -0.0390, -0.0030],
         ...,
         [-0.0202, -0.0517, -0.0125,  ..., -0.0040, -0.0141, -0.0231],
         [-0.0429, -0.0524, -0.0018,  ...,  0.0146, -0.0129, -0.0088],
         [-0.0125, -0.0907, -0.0150,  ...,  0.0255, -0.0366,  0.0690]],
        device='cuda:0', requires_grad=True))

In [None]:
# Set up training loop

for epoch in range(num_epochs):
    model.train()
    
    for batch in tqdm(train_multi_dataloader):
        optimizer.zero_grad()
        
        # remember the batch is size 1 with 1 question and 1 answer
        question = batch['question'][0]
        question_ids = tokenizer(question, return_tensors='pt')

        answer = batch['answers'][0][0]
        answer_ids = tokenizer([answer],
                               add_special_tokens=False,
                               return_attention_mask=False,
                               return_token_type_ids=False,
                               return_tensors='pt').input_ids
        
        reader_output, predicted_ans_ids = model(**question_ids.to(device),
                                                 answer_ids=answer_ids.to(device),
                                                 return_dict=False)
        
        predicted_answer = tokenizer.decode(predicted_ans_ids)
        
        # log to tensorboard 
        writer.add_scalar("Reader loss", reader_output.loss.item())
        writer.add_scalar("Retriever loss", reader_output.retriever_loss.item())
        
        for name, weight in model.named_parameters():
            if weight.grad is not None:
                writer.add_histogram(name, weight, epoch)
                writer.add_histogram(f'{name}.grad', weight.grad, epoch)
        
        # backward please 
        reader_output.loss.backward()
        
        clip_grad_norm_(model.parameters(), 1.0, norm_type=2.0, error_if_nonfinite=False)
        
        optimizer.step()
        lr_scheduler.step()
        
        logging.info(
            f"Epoch: {epoch}, "
            f"Step: {global_step}, "
            f"Retriever Loss: {reader_output.retriever_loss.mean()}, "
            f"Reader Loss: {reader_output.reader_loss.mean()}\n"
            f"\tQuestion: {batch['question'][0]}, Gold Answer: {answer}, Predicted Answer: {predicted_answer}"
        )
        
        if global_step % checkpoint_interval == 0:
            logging.info(f"Saving checkpint at step {global_step}")
            
            model.save_pretrained(f"checkpoints/short/checkpoint-{global_step}")

        global_step += 1
        if global_step >= num_steps:
            break




  0%|          | 0/44499 [00:00<?, ?it/s]

In [102]:
# save the model 
model.save_pretrained('trained/zsre/v2')

In [None]:
next(model.named_parameters())

In [109]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Evaluate



In [119]:
from evaluate import load

In [32]:
def normalize_answer(s):
    """
        Normalize answer. (Copied from ORQA codebase)
    """
    s = unicodedata.normalize("NFD", s)

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

In [34]:
# test normalize ans
normalize_answer("Alan Mathison Turing")

'alan mathison turing'

In [29]:
def compute_eval_metrics(labels, predicted_answer, reader_output):
    # First, try to find exact match
    exact_match = torch.index_select(
        torch.index_select(reader_output.reader_correct,
                           dim=0,
                           index=reader_output.block_idx),
        dim=1,
        index=reader_output.candidate
    )
    
    def _official_exact_match(predicted_answer, references):
        return torch.tensor(
            max(
                [normalize_answer(predicted_answer) == normalize_answer(reference) for reference in references]
            )
        )
    
    
    official_em = _official_exact_match(predicted_answer, labels)
    
    eval_metric = dict(
        exact_match=exact_match[0][0],
        official_exact_match=official_em,
        reader_oracle=torch.any(reader_output.reader_correct)
    )
    
    # Get top matches 
    for k in (5, 10, 50, 100, 500, 1000, 5000):
        eval_metric[f"top_{k}_match"] = torch.any(reader_output.retriever_correct[:k])
        
    return eval_metric
        
    

In [36]:
# test compute metrics
compute_eval_metrics(["Alan Mathison Turing"],
                     tokenizer.decode(predicted_answer_ids),
                     reader_output)

{'exact_match': tensor(False, device='cuda:0'),
 'official_exact_match': tensor(True),
 'reader_oracle': tensor(False, device='cuda:0'),
 'top_5_match': tensor(False, device='cuda:0'),
 'top_10_match': tensor(False, device='cuda:0'),
 'top_50_match': tensor(False, device='cuda:0'),
 'top_100_match': tensor(False, device='cuda:0'),
 'top_500_match': tensor(False, device='cuda:0'),
 'top_1000_match': tensor(False, device='cuda:0'),
 'top_5000_match': tensor(False, device='cuda:0')}

In [40]:
# set searcher and reader beam size to same values as paper
reader_beam_size = 5
searcher_beam_size = 5000

In [None]:
model.eval()

In [None]:
model.to(device)

In [71]:
# Collect all metrics
all_metrics = []
all_metrics

[]

In [73]:
# save predictions for script eval
pred_file = open('data/for_eval/model_predictions.txt', 'w')
ground_truth_file = open('data/for_eval/ground_truth.txt', 'w')

In [137]:
# Set up eval
model.eval()

for ebatch in tqdm(eval_dataloader):
    question_ids = tokenizer(ebatch['question'], return_tensors='pt', padding=True)
    
    answer_ids = tokenizer(ebatch['answer'],
                           add_special_tokens=False,
                           return_attention_mask=False,
                           return_token_type_ids=False,
                           return_tensors='pt',
                           padding=True).input_ids
    
    with torch.no_grad():
        outputs = model(**question_ids.to(device),
                        answer_ids=answer_ids.to(device),
                        return_dict=True)

    predicted_answer = tokenizer.decode(outputs.predicted_answer_ids)
    
    all_metrics.append(compute_eval_metrics(ebatch['answer'], predicted_answer, outputs.reader_output))
    
    pred_file.write(predicted_answer)
    pred_file.write('\n')
    
    ground_truth_file.write(ebatch['answer'][0])
    ground_truth_file.write('\n')
    

  0%|          | 0/13692 [00:00<?, ?it/s]

ValueError: The batch_size of the inputs must be 1.

In [76]:
pred_file.close()
ground_truth_file.close()

In [77]:
stacked_metrics = {
    metric_key: torch.stack((*map(lambda metrics: metrics[metric_key], all_metrics),)) for metric_key in all_metrics[0].keys()
}


In [78]:
print(f"Official EM: {len(((stacked_metrics['official_exact_match'] == True).nonzero(as_tuple=True)[0])) / len(stacked_metrics['official_exact_match']) }")


print(f"Exact Match: {len(((stacked_metrics['exact_match'] == True).nonzero(as_tuple=True)[0])) / len(stacked_metrics['exact_match'])}")


Official EM: 0.26698071866783524
Exact Match: 0.0


In [86]:
# Prefer huggingface's evaluate library
preds, labels = [], []
with open('data/for_eval/model_predictions.txt', 'r') as f:
    preds = [normalize_answer(line) for line in f]
f.close()

with open('data/for_eval/ground_truth.txt', 'r') as p:
    labels = [normalize_answer(line) for line in p]
p.close()    

In [120]:
# set metric to use
em_metric = load('exact_match')

results = em_metric.compute(predictions=preds, references=labels)
results

Downloading builder script:   0%|          | 0.00/5.67k [00:00<?, ?B/s]

{'exact_match': 0.26698071866783524}

In [54]:
from collections import Counter

In [58]:
answers_strings = []

for b in eval_dataloader:
    answers_strings.append(b['answer'][0])

len(answers_strings)

27384

In [79]:
answers_strings[:3]

['Illinois Institute of Technology', 'Lecanorales', 'defender']

In [62]:
answer_counts = Counter(answers_strings)

In [84]:
answer_counts.most_common()[:10]

[('Antarctica', 478),
 ('French', 430),
 ('female', 321),
 ('midfielder', 188),
 ('human', 166),
 ('soprano', 155),
 ('World War II', 145),
 ('heart attack', 140),
 ('male', 137),
 ('piano', 136)]

In [None]:
# How many times the same answer appear in the data?
counts = dict(int)

for batch in eval_dataloader:
    if batch['answer'] in counts:
        counts[[batch['answer']] += 1
           
    
