In [1]:
from platform import python_version

python_version()

'3.9.1'

In [2]:
import logging 
import torch

In [3]:
from datasets import load_dataset, get_dataset_infos

In [4]:
import numpy as np

In [5]:
from tqdm.auto import tqdm

In [6]:
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup, get_scheduler
from transformers import RealmForOpenQA, RealmConfig, RealmRetriever, RealmTokenizerFast, RealmScorer
from transformers import TrainingArguments, Trainer

2022-11-21 21:28:25.387856: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-21 21:28:25.570344: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-21 21:28:26.255607: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64
2022-11-21 21:28:26.255716: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer

In [7]:
from torch.nn.utils import clip_grad_norm_

from torch.utils.tensorboard import SummaryWriter

## Using the filtered splits used in the MEND paper


In [8]:
dataset = load_dataset('data/zsre/', data_files={'train': 'train_filtered.tsv',
                                                 'validation': 'dev_filtered.tsv'})

dataset

Using custom data configuration zsre-eef22664d6c49664
Found cached dataset csv (/home/patrick/.cache/huggingface/datasets/csv/zsre-eef22664d6c49664/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['question', 'answer'],
        num_rows: 241523
    })
    validation: Dataset({
        features: ['question', 'answer'],
        num_rows: 27384
    })
})

In [None]:
# what does the data look like?
print(dataset['validation'][20])

# Seems like each example is a question answer pair,
# remever to wrap them in list notation as expected 
# by the realm tokenizer


In [None]:
# Save the dataset 
dataset.save_to_disk('data/zsre/zsre_hf.hf')

### Dataloaders and collators

In [9]:
train_dataloader = torch.utils.data.DataLoader(dataset['train'], batch_size=1) # complains with a batch size > 1
eval_dataloader = torch.utils.data.DataLoader(dataset['validation'], batch_size=1)

## Convert data needed for pretrained checkpoints

In [None]:
from transformers.models.realm.retrieval_realm import convert_tfrecord_to_np

block_records = convert_tfrecord_to_np('data/wiki/enwiki-20181220/blocks.tfr', RealmConfig().num_block_records)



In [None]:
# save the block records object 
np.save("20181220_records", block_records)

In [10]:
# we've already saved this file, read it in 
block_records = np.load("data/block_records.npy", allow_pickle=True)

In [11]:
len(block_records)

13353718

## Finetune openqa checkpoint 

In [12]:
# set up simple logging to file so we don't get overwhelmed in notebook 
logging.basicConfig(filename='logs/train.log',
                    filemode='a',
                    format='%(asctime)s,%(msecs)d %(levelname)s %(message)s',
                    datefmt='%H:%M:%S',
                    level=logging.INFO)

logging.info("Setting up training...\n")

In [13]:
checkpoint = "google/realm-orqa-nq-openqa"

ft_checkpoint = "google/realm-cc-news-pretrained-openqa"

In [14]:
tokenizer = RealmTokenizerFast.from_pretrained(ft_checkpoint)

In [16]:
# retriever = RealmRetriever.from_pretrained(checkpoint)
retriever = RealmRetriever(block_records, tokenizer)
# retriever = RealmRetriever.from_pretrained("data/")

In [17]:
model = RealmForOpenQA.from_pretrained(ft_checkpoint, retriever)


In [18]:
# Borrowed from: https://github.com/huggingface/transformers/blob/e239fc3b0baf1171079a5e0177a69254350a063b/examples/pytorch/language-modeling/run_mlm_no_trainer.py#L456-L468

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]


In [19]:
optimizer = torch.optim.AdamW(
    optimizer_grouped_parameters,
    lr=1e-5,
    weight_decay=0.01,
    eps=1e-6,
)

In [20]:
num_epochs = 2  # epochs 

num_steps = len(train_dataloader) * num_epochs  # total steps to set up scheduler 

global_step = 1  # tracker for number of steps 

checkpoint_interval = 1000

In [21]:
# set device
device = torch.device('cuda')
device

device(type='cuda')

In [22]:
# Set learning rate scheduler 

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=10000,
    num_training_steps=num_steps,
)

In [23]:
writer = SummaryWriter()

In [None]:
model.to(device)

In [None]:
# Set up training loop

for epoch in range(num_epochs):
    model.train()
    
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()
        
        # remember the batch is size 1 with 1 question and 1 answer
        question_ids = tokenizer(batch['question'], return_tensors='pt')

        answer_ids = tokenizer(batch['answer'],
                               add_special_tokens=False,
                               return_attention_mask=False,
                               return_token_type_ids=False,
                               return_tensors='pt').input_ids
        
        reader_output, predicted_ans_ids = model(**question_ids.to(device),
                                                 answer_ids=answer_ids.to(device),
                                                 return_dict=False)
        
        predicted_answer = tokenizer.decode(predicted_ans_ids)
        
        writer.add_scalar("Reader loss", reader_output.loss.item())
        
        reader_output.loss.backward()
        
        clip_grad_norm_(model.parameters(), 1.0, norm_type=2.0, error_if_nonfinite=False)
        
        optimizer.step()
        lr_scheduler.step()
        
        logging.info(
            f"Epoch: {epoch},"
            f"Step: {global_step},"
            f"Retriever Loss: {reader_output.retriever_loss.mean()},"
            f"Reader Loss: {reader_output.reader_loss.mean()}\n"
            f"\tQuestion: {batch['question'][0]}, Gold Answer: {batch['answer'][0]}, Predicted Answer: {predicted_answer}"
        )
        
        if global_step % checkpoint_interval == 0:
            logging.info(f"Saving checkpint at step {global_step}")
            
            model.save_pretrained(f"checkpoints/checkpoint-{global_step}")

        global_step += 1
        if global_step >= num_steps:
            break




  0%|          | 0/241523 [00:00<?, ?it/s]

In [None]:
# evaluate
model.config.reader_seq_len