In [1]:
!pip -q install transformers datasets evaluate
!pip -q install faiss_gpu
!pip -q install nlp

import datasets
import functools
import math
import faiss  
import nlp  
import os  
import torch
import numpy as np
import pandas as pd
import torch.utils.checkpoint as checkpoint

from tqdm import tqdm
from time import time
from random import choice, randint
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup

pd.set_option("display.max_colwidth", None)

###############
# ELI5 seq2seq model training
###############
class ELI5DatasetS2S(Dataset):
    def __init__(self, examples_array, make_doc_fun=None, extra_answer_threshold=3, document_cache=None, training=True):
        self.training = training
        self.data     = examples_array
        self.make_doc_function = make_doc_fun
        self.document_cache    = {} if document_cache is None else document_cache
        assert not (make_doc_fun is None and document_cache is None)
        
        self.qa_id_list = [(i, 0) for i in range(len(self.data))]
            
    def __len__(self):
        return len(self.qa_id_list)

    def make_example(self, idx):
        i, j = self.qa_id_list[idx]
        example = self.data[i]
        question = example["question"] 
        answer = example["answers"][j]
        q_id = example["question_id"]
        if self.make_doc_function is not None:
            self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["question"]))
        document = self.document_cache[q_id]
        in_st  = "question: {} context: {}".format(question.lower().replace(" --t--", "").strip(), document.lower().strip())
        out_st = answer
        return (in_st, out_st)

    def __getitem__(self, idx):
        return self.make_example(idx)

def make_qa_s2s_model(model_name="facebook/bart-base", from_file=None, device="cuda:0"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    if from_file is not None:
        param_dict = torch.load(from_file)  # has model weights, optimizer, and scheduler states
        model.load_state_dict(param_dict["model"])
    return tokenizer, model

def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
    q_ls = [q for q, a in qa_list]
    a_ls = [a for q, a in qa_list]
    q_toks = tokenizer.batch_encode_plus(q_ls, max_length=max_len, pad_to_max_length=True)
    q_ids, q_mask = (torch.LongTensor(q_toks["input_ids"]).to(device),
                     torch.LongTensor(q_toks["attention_mask"]).to(device))
    a_toks = tokenizer.batch_encode_plus(a_ls, max_length=min(max_len, max_a_len), pad_to_max_length=True)
    a_ids, a_mask = (torch.LongTensor(a_toks["input_ids"]).to(device),
                     torch.LongTensor(a_toks["attention_mask"]).to(device))
    labels = a_ids[:, 1:].contiguous().clone()
    labels[a_mask[:, 1:].contiguous() == 0] = -100
    model_inputs = {"input_ids": q_ids,
                    "attention_mask": q_mask,
                    "decoder_input_ids": a_ids[:, :-1].contiguous(),
                    "labels": labels}
    return model_inputs

def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):
    model.train()
    
    # make iterator
    if curriculum:
        train_sampler = SequentialSampler(dataset)
    else:
        train_sampler = RandomSampler(dataset)
        
    model_collate_fn = functools.partial(make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0")
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    
    for step, batch_inputs in enumerate(epoch_iterator):
        pre_loss = model(**batch_inputs)[0]
        loss = pre_loss.sum() / pre_loss.shape[0]
        loss.backward()
        
        # optimizer
        if step % args.backward_freq == 0:
            optimizer.step()
            scheduler.step()
            model.zero_grad()
            
        # some printing within the epoch
        loc_loss += loss.item()
        loc_steps += 1
        if step % args.print_freq == 0 or step == 1:
            print("{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                    e, step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time,))
            loc_loss  = 0
            loc_steps = 0

def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
    model.eval()
    
    # make iterator
    train_sampler = SequentialSampler(dataset)
    model_collate_fn = functools.partial(make_qa_s2s_batch, tokenizer=tokenizer, 
                                         max_len=args.max_length, device="cuda:0")
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    
    # accumulate loss since last print
    loc_steps = 0
    loc_loss  = 0.0
    st_time   = time()
    with torch.no_grad():
        for step, batch_inputs in enumerate(epoch_iterator):
            pre_loss = model(**batch_inputs)[0]
            loss = pre_loss.sum() / pre_loss.shape[0]
            loc_loss += loss.item()
            loc_steps += 1
            if step % args.print_freq == 0:
                print("{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                        step, len(dataset) // args.batch_size, loc_loss / loc_steps, time() - st_time))
    print("Total \t L: {:.3f} \t -- {:.3f}".format(loc_loss / loc_steps, time() - st_time))


def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
    s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
    s2s_scheduler = get_linear_schedule_with_warmup(s2s_optimizer, num_warmup_steps=400,
                                                    num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),)
    for e in range(s2s_args.num_epochs):
        train_qa_s2s_epoch(qa_s2s_model,s2s_train_dset,qa_s2s_tokenizer,s2s_optimizer,
                           s2s_scheduler,s2s_args,e,curriculum=(e == 0),)
        m_save_dict = {"model"    : qa_s2s_model.module.state_dict()
                                    if hasattr(qa_s2s_model, 'module') else qa_s2s_model.state_dict(),
                       "optimizer": s2s_optimizer.state_dict(),
                       "scheduler": s2s_scheduler.state_dict()}
        
        print("Saving model {}".format(s2s_args.model_save_name))
        eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)
        torch.save(m_save_dict, "{}.pth".format(s2s_args.model_save_name))


# generate answer from input "question: ... context: <p> ..."
def qa_s2s_generate(question_doc, qa_s2s_model, qa_s2s_tokenizer, num_answers=1, num_beams=None,
                    min_len=64, max_len=256, do_sample=False,temp=1.0, top_p=None, top_k=None,
                    max_input_length=512, device="cuda:0"):
    
    model_inputs = make_qa_s2s_batch([(question_doc, "A")], qa_s2s_tokenizer, 
                                       max_input_length, device=device)
    
    n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
    model = qa_s2s_model.module if hasattr(qa_s2s_model, 'module') else qa_s2s_model 
    generated_ids = model.generate( input_ids=model_inputs["input_ids"],
                                           attention_mask=model_inputs["attention_mask"],
                                           min_length=min_len,max_length=max_len,
                                           do_sample=do_sample, early_stopping=True,
                                           num_beams=1 if do_sample else n_beams,
                                           temperature=temp,top_k=top_k,top_p=top_p,
                                           eos_token_id=qa_s2s_tokenizer.eos_token_id,
                                           no_repeat_ngram_size=3,
                                           num_return_sequences=num_answers,
                                           decoder_start_token_id=qa_s2s_tokenizer.bos_token_id)
    return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids]


###############
# ELI5-trained retrieval model usage
###############
def embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length=128, device="cuda:0"):
    a_toks = tokenizer.batch_encode_plus(passages, max_length=max_length, pad_to_max_length=True)
    a_ids, a_mask = (torch.LongTensor(a_toks["input_ids"]).to(device),
                     torch.LongTensor(a_toks["attention_mask"]).to(device))
    
    with torch.no_grad():
        a_reps = qa_embedder.embed_answers(a_ids, a_mask).cpu().type(torch.float)
    return a_reps.numpy()


def embed_questions_for_retrieval(q_ls, tokenizer, qa_embedder, device="cuda:0"):
    q_toks = tokenizer.batch_encode_plus(q_ls, max_length=128, pad_to_max_length=True)
    q_ids, q_mask = (torch.LongTensor(q_toks["input_ids"]).to(device),
                     torch.LongTensor(q_toks["attention_mask"]).to(device))
    
    with torch.no_grad():
        q_reps = qa_embedder.embed_questions(q_ids, q_mask).cpu().type(torch.float)
    return q_reps.numpy()


def make_qa_dense_index(qa_embedder,tokenizer,passages_dset,batch_size=512,max_length=128,
                        index_name="kilt_passages_reps.dat",dtype="float32",device="cuda:0"):
    st_time = time()
    fp = np.memmap(index_name, dtype=dtype, mode="w+", shape=(passages_dset.num_rows, 128))
    n_batches = math.ceil(passages_dset.num_rows / batch_size)
    print("Data size  = ", passages_dset.num_rows)
    print("Batch_size = ", batch_size)
    print("n_batch    = ", n_batches)
    for i in range(n_batches):
        passages = [p for p in passages_dset[i * batch_size : (i + 1) * batch_size]["passage_text"]]
        reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder, max_length, device)
        fp[i * batch_size : (i + 1) * batch_size] = reps
        if i % 50 == 0:
            print(i, time() - st_time)


def evaluate_retriever(qa_list, retriever_func, scoring_func, n_ret=10, verbose=False):
    total_retriever_time = 0.0
    total_retriever_score = 0.0
    st_time = time()
    for i, (question, answer) in enumerate(qa_list):
        r_time = time()
        retrieved_passages = retriever_func(question, n_ret)
        total_retriever_time += time() - r_time
        total_retriever_score += scoring_func(retrieved_passages, answer)
        if verbose and ((i + 1) % 500 == 0 or i <= 1):
            print("{:03d}: S-{:.4f} T-{:.4f} | {:.2f}".format(
                    i + 1, total_retriever_score / (i + 1), total_retriever_time / (i + 1), time() - st_time))
    return {"idf_recall": total_retriever_score / (i + 1), "retrieval_time": total_retriever_time / (i + 1)}


# build a support document for the question out of Wikipedia snippets
def query_qa_dense_index(question, qa_embedder, tokenizer, wiki_passages, 
                         wiki_index, n_results=10, min_length=20, device="cuda:0"):
    q_rep = embed_questions_for_retrieval([question], tokenizer, qa_embedder, device=device)
    D, I = wiki_index.search(q_rep, 2 * n_results)
    res_passages = [wiki_passages[int(i)] for i in I[0]]
    support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
    res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
    res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
    for r, sc in zip(res_list, D[0]):
        r["score"] = float(sc)
    return support_doc, res_list

[0m

In [2]:
# LOAD DATA ELI5 
import json
eli5_train = json.load(open('/kaggle/input/eli5-10-doc/ELI5_train_10_doc.json'))
eli5_valid = json.load(open('/kaggle/input/eli5-10-doc/ELI5_val_10_doc.json'))

In [3]:
# PRE PROCESSING DOCS 
eli5_train_docs = []
eli5_valid_docs = []

for example in eli5_train:
    support_doc = "<P> " + " <P> ".join([p for p in example["ctxs"]])
    eli5_train_docs += [(example['question_id'], support_doc)]

for example in eli5_valid:
    support_doc = "<P> " + " <P> ".join([p for p in example["ctxs"]])
    eli5_valid_docs += [(example['question_id'], support_doc)]

# LOAD DOCS JSON for train and valid
s2s_train_dset = ELI5DatasetS2S(eli5_train, document_cache=dict([(k, d) for k, d in eli5_train_docs]))
s2s_valid_dset = ELI5DatasetS2S(eli5_valid, document_cache=dict([(k, d) for k, d in eli5_valid_docs]), training=False)

In [5]:
# CREATE ArgumentsS2S
class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 8
        self.backward_freq = 16
        self.max_length = 1024
        self.print_freq = 100
        self.model_save_name = "bart_eli5_task_1"
        self.learning_rate = 2e-4
        self.num_epochs = 1

s2s_args = ArgumentsS2S()

# LOAD TOKENIZER and MODEL S2S
qa_s2s_tokenizer, pre_model = make_qa_s2s_model(model_name="facebook/bart-base",
                                                from_file=None,
                                                device="cuda:0")
qa_s2s_model = torch.nn.DataParallel(pre_model)

# TRAINING
train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


 0     0 of 34079 	 L: 5.563 	 -- 7.685
 0     1 of 34079 	 L: 6.276 	 -- 8.598
 0   100 of 34079 	 L: 5.929 	 -- 99.711
 0   200 of 34079 	 L: 5.358 	 -- 195.693
 0   300 of 34079 	 L: 4.915 	 -- 292.306
 0   400 of 34079 	 L: 4.658 	 -- 388.167
 0   500 of 34079 	 L: 4.530 	 -- 483.915
 0   600 of 34079 	 L: 4.423 	 -- 579.593
 0   700 of 34079 	 L: 4.383 	 -- 675.216
 0   800 of 34079 	 L: 4.290 	 -- 771.457
 0   900 of 34079 	 L: 4.267 	 -- 867.311
 0  1000 of 34079 	 L: 4.196 	 -- 963.102
 0  1100 of 34079 	 L: 4.152 	 -- 1058.746
 0  1200 of 34079 	 L: 4.075 	 -- 1154.485
 0  1300 of 34079 	 L: 4.060 	 -- 1250.645
 0  1400 of 34079 	 L: 4.009 	 -- 1346.211
 0  1500 of 34079 	 L: 3.969 	 -- 1441.814
 0  1600 of 34079 	 L: 3.961 	 -- 1537.603
 0  1700 of 34079 	 L: 3.939 	 -- 1633.249
 0  1800 of 34079 	 L: 3.931 	 -- 1729.476
 0  1900 of 34079 	 L: 3.894 	 -- 1825.124
 0  2000 of 34079 	 L: 3.882 	 -- 1920.803
 0  2100 of 34079 	 L: 3.857 	 -- 2016.390
 0  2200 of 34079 	 L: 3.854

In [6]:
predicted = []
reference = []

# Generate answers for the full test set
for i in range(len(eli5_valid)):
    # create support document with the dense index
    question = eli5_valid[i]['question']
    support_doc = "<P> " + " <P> ".join([str(p) for p in eli5_valid[i]["ctxs"]])
    # concatenate question and support document into BART input
    question_doc = "question: {} context: {}".format(question, support_doc)
    # generate an answer with beam search
    answer = qa_s2s_generate(question_doc, qa_s2s_model, qa_s2s_tokenizer,
                             num_answers=1,num_beams=8,min_len=96,
                             max_len=256,max_input_length=1024,device="cuda:0")
    predicted += [answer[0]]
    reference += [eli5_valid[i]['answers'][0]]
    if i % 100 == 0: print("Step: ",i,"/",len(eli5_valid))

Step:  0 / 1507
Step:  100 / 1507
Step:  200 / 1507
Step:  300 / 1507
Step:  400 / 1507
Step:  500 / 1507
Step:  600 / 1507
Step:  700 / 1507
Step:  800 / 1507
Step:  900 / 1507
Step:  1000 / 1507
Step:  1100 / 1507
Step:  1200 / 1507
Step:  1300 / 1507
Step:  1400 / 1507
Step:  1500 / 1507


In [7]:
!pip -q install rouge

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[0m

In [8]:
from nltk import PorterStemmer
from rouge import Rouge
from spacy.lang.en import English
from time import time

stemmer = PorterStemmer()
rouge = Rouge()
nlpp = English()
tokenizer = nlpp.tokenizer

def compute_rouge_eli5(compare_list):
    preds = [" ".join([stemmer.stem(str(w))for w in tokenizer(pred)])for gold, pred in compare_list]
    golds = [" ".join([stemmer.stem(str(w))for w in tokenizer(gold)])for gold, pred in compare_list]
    scores = rouge.get_scores(hyps=preds, refs=golds, avg=True)
    return scores


compare_list = [(g, p) for p, g in zip(predicted, reference)]
scores = compute_rouge_eli5(compare_list)
df = pd.DataFrame({
    'rouge1': [scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f']],
    'rouge2': [scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f']],
    'rougeL': [scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f']],
}, index=[ 'P', 'R', 'F'])
df.style.format({'rouge1': "{:.4f}", 'rouge2': "{:.4f}", 'rougeL': "{:.4f}"})

Unnamed: 0,rouge1,rouge2,rougeL
P,0.38,0.0723,0.3404
R,0.2566,0.0523,0.2278
F,0.2757,0.0514,0.2454
