## Training

In [None]:
from pydrive.drive import GoogleDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import pandas as pd
import json

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

link = 'https://drive.google.com/file/d/1k7k6Hgyxbv1ecypwZGB4BKDxuDjMMBRY/view?usp=sharing'
id = link.split("/")[-2]
downloaded = drive.CreateFile({'id':id})
downloaded.GetContentFile('askhistorians.json') 
f = open('train_asks.json')
train_data = json.load(f)
f = open('valid_asks.json')
valid_data = json.load(f)

from google.colab import drive
drive.mount('/content/drive/')
! pip install datasets
import datasets

eli5 = datasets.load_from_disk('/content/drive/MyDrive/Deep/eli5_dataset')

In [None]:
! pip install transformers
! pip install datasets
! pip install nlp

In [None]:
import functools
import math
import os  # noqa: F401
from random import choice, randint
from time import time

import datasets  # noqa: F401
import numpy as np
import pandas as pd
import torch
import torch.utils.checkpoint as checkpoint
from torch import nn
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from tqdm import tqdm
from transformers import AdamW, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, get_linear_schedule_with_warmup

In [None]:
class ELI5DatasetS2S(Dataset):
    def __init__(
        self, examples_array, make_doc_fun=None, extra_answer_threshold=3, document_cache=None, training=True
    ):
        self.training = training
        self.data = examples_array
        self.make_doc_function = make_doc_fun
        self.document_cache = {} if document_cache is None else document_cache
        assert not (make_doc_fun is None and document_cache is None)
        # make index of specific question-answer pairs from multi-answers
        if self.training:
            self.qa_id_list = [
                (i, j)
                for i, qa in enumerate(self.data)
                for j, (a, sc) in enumerate(zip(qa["answers"]["text"], qa["answers"]["score"]))
                if j == 0 or sc >= extra_answer_threshold
            ]
        else:
            self.qa_id_list = [(i, 0) for i in range(self.data.num_rows)]

    def __len__(self):
        return len(self.qa_id_list)
    def make_example(self, idx):
        i, j = self.qa_id_list[idx]
        example = self.data[i]
        question = example["title"] + " " + example["selftext"]
        answer = example["answers"]["text"][j]
        q_id = example["q_id"]
        if self.make_doc_function is not None:
            self.document_cache[q_id] = self.document_cache.get(q_id, self.make_doc_function(example["title"]))
        document = self.document_cache[q_id]
        in_st = "question: {} context: {}".format(
            question.lower().replace(" --t--", "").strip(),
            document.lower().strip(),
        )
        out_st = answer
        return (in_st, out_st)

    def __getitem__(self, idx):
        return self.make_example(idx)

def make_qa_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda:0"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
    if from_file is not None:
        param_dict = torch.load(from_file)  # has model weights, optimizer, and scheduler states
        model.load_state_dict(param_dict["model"])
    return tokenizer, model


def make_qa_s2s_batch(qa_list, tokenizer, max_len=64, max_a_len=360, device="cuda:0"):
    q_ls = [q for q, a in qa_list]
    a_ls = [a for q, a in qa_list]
    q_toks = tokenizer(q_ls, max_length=max_len, padding="max_length", truncation=True)
    q_ids, q_mask = (
        torch.LongTensor(q_toks["input_ids"]).to(device),
        torch.LongTensor(q_toks["attention_mask"]).to(device),
    )
    a_toks = tokenizer(a_ls, max_length=min(max_len, max_a_len), padding="max_length", truncation=True)
    a_ids, a_mask = (
        torch.LongTensor(a_toks["input_ids"]).to(device),
        torch.LongTensor(a_toks["attention_mask"]).to(device),
    )
    labels = a_ids[:, 1:].contiguous().clone()
    labels[a_mask[:, 1:].contiguous() == 0] = -100
    model_inputs = {
        "input_ids": q_ids,
        "attention_mask": q_mask,
        "decoder_input_ids": a_ids[:, :-1].contiguous(),
        "labels": labels,
    }
    return model_inputs


def train_qa_s2s_epoch(model, dataset, tokenizer, optimizer, scheduler, args, e=0, curriculum=False):
    model.train()
    # make iterator
    if curriculum:
        train_sampler = SequentialSampler(dataset)
    else:
        train_sampler = RandomSampler(dataset)
    model_collate_fn = functools.partial(
        make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
    )
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    for step, batch_inputs in enumerate(epoch_iterator):
        pre_loss = model(**batch_inputs)[0]
        loss = pre_loss.sum() / 1 #pre_loss.shape[0]
        loss.backward()
        # optimizer
        if step % args.backward_freq == 0:
            optimizer.step()
            scheduler.step()
            model.zero_grad()
        # some printing within the epoch
        loc_loss += loss.item()
        loc_steps += 1
        if step % args.print_freq == 0 or step == 1:
            print(
                "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                    e,
                    step,
                    len(dataset) // args.batch_size,
                    loc_loss / loc_steps,
                    time() - st_time,
                )
            )
            loc_loss = 0
            loc_steps = 0


def eval_qa_s2s_epoch(model, dataset, tokenizer, args):
    model.eval()
    # make iterator
    train_sampler = SequentialSampler(dataset)
    model_collate_fn = functools.partial(
        make_qa_s2s_batch, tokenizer=tokenizer, max_len=args.max_length, device="cuda:0"
    )
    data_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler, collate_fn=model_collate_fn)
    epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True)
    # accumulate loss since last print
    loc_steps = 0
    loc_loss = 0.0
    st_time = time()
    with torch.no_grad():
        for step, batch_inputs in enumerate(epoch_iterator):
            pre_loss = model(**batch_inputs)[0]
            loss = pre_loss.sum() / 1#pre_loss.shape[0]
            loc_loss += loss.item()
            loc_steps += 1
            if step % args.print_freq == 0:
                print(
                    "{:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format(
                        step,
                        len(dataset) // args.batch_size,
                        loc_loss / loc_steps,
                        time() - st_time,
                    )
                )
    print(
        "Total \t L: {:.3f} \t -- {:.3f}".format(
            loc_loss / loc_steps,
            time() - st_time,
        )
    )


def train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args):
    s2s_optimizer = AdamW(qa_s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8)
    s2s_scheduler = get_linear_schedule_with_warmup(
        s2s_optimizer,
        num_warmup_steps=400,
        num_training_steps=(s2s_args.num_epochs + 1) * math.ceil(len(s2s_train_dset) / s2s_args.batch_size),
    )
    for e in range(s2s_args.num_epochs):
        train_qa_s2s_epoch(
            qa_s2s_model,
            s2s_train_dset,
            qa_s2s_tokenizer,
            s2s_optimizer,
            s2s_scheduler,
            s2s_args,
            e,
            curriculum=(e == 0),
        )
        m_save_dict = {
            "model": qa_s2s_model.state_dict(),
            "optimizer": s2s_optimizer.state_dict(),
            "scheduler": s2s_scheduler.state_dict()
        }
        print("Saving model {}".format(s2s_args.model_save_name))
        eval_qa_s2s_epoch(qa_s2s_model, s2s_valid_dset, qa_s2s_tokenizer, s2s_args)
        #torch.save(m_save_dict, "{}_{}.pth".format(s2s_args.model_save_name, e))
        #torch.save(m_save_dict,"model_eli5/")
        qa_s2s_tokenizer.save_pretrained("./model_folder/tokenizer/")
        qa_s2s_model.module.save_pretrained("./model_folder/model/")
      

# generate answer from input "question: ... context:  ..."
def qa_s2s_generate(
    question_doc,
    qa_s2s_model,
    qa_s2s_tokenizer,
    num_answers=1,
    num_beams=None,
    min_len=64,
    max_len=256,
    do_sample=False,
    temp=1.0,
    top_p=None,
    top_k=None,
    max_input_length=512,
    device="cuda:0",
):
    model_inputs = make_qa_s2s_batch(
        [(question_doc, "A")],
        qa_s2s_tokenizer,
        max_input_length,
        device=device,
    )
    n_beams = num_answers if num_beams is None else max(num_beams, num_answers)
    generated_ids = qa_s2s_model.generate(
        input_ids=model_inputs["input_ids"],
        attention_mask=model_inputs["attention_mask"],
        min_length=min_len,
        max_length=max_len,
        do_sample=do_sample,
        early_stopping=True,
        num_beams=1 if do_sample else n_beams,
        temperature=temp,
        top_k=top_k,
        top_p=top_p,
        eos_token_id=qa_s2s_tokenizer.eos_token_id,
        no_repeat_ngram_size=3,
        num_return_sequences=num_answers,
        decoder_start_token_id=qa_s2s_tokenizer.bos_token_id,
    )
    return [qa_s2s_tokenizer.decode(ans_ids, skip_special_tokens=True).strip() for ans_ids in generated_ids]

In [None]:
class ArgumentsS2S():
    def __init__(self):
        self.batch_size = 2
        self.backward_freq = 16
        self.max_length = 1024
        self.print_freq = 100
        self.model_save_name = "eli5_bart_model"
        self.learning_rate = 2e-4
        self.num_epochs = 1

s2s_args = ArgumentsS2S()

In [None]:
s2s_train_dset = ELI5DatasetS2S(eli5['train_asks'], document_cache=dict([(k, d) for k, d, src_ls in train_data]))
s2s_valid_dset = ELI5DatasetS2S(eli5['validation_asks'], document_cache=dict([(k, d) for k, d, src_ls in valid_data]), training=False)

In [None]:
qa_s2s_tokenizer, pre_model = make_qa_s2s_model(
    model_name="facebook/bart-large",
    from_file=None,
    device="cpu"
)
qa_s2s_model = torch.nn.DataParallel(pre_model)

In [None]:
train_qa_s2s(qa_s2s_model, qa_s2s_tokenizer, s2s_train_dset, s2s_valid_dset, s2s_args)

## Loading the model and generating answers

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM,DistilBertTokenizer, BartTokenizer
 from lfqa_utils import *
 from eli5_utils import *

In [None]:
tokenizer = BartTokenizer.from_pretrained("./model_folder/tokenizer/")
model = AutoModelForSeq2SeqLM.from_pretrained("./model_folder/model/")

In [None]:
es_client = Elasticsearch("http://localhost:9200")
wiki40b_snippets = datasets.load_from_disk("../wiki40b/train"

In [None]:
question = "Why do ears hurt in planes?"
doc, res_list = query_es_index(question, es_client, index_name='wiki40b_snippets_100w')
question_doc = "question: {} context: {}".format(question, doc)

answer = qa_s2s_generate(
        question_doc, model, tokenizer,
        num_answers=2,
        num_beams=8,
        min_len=64,
        max_len=256,
        max_input_length=1024,
        device="cpu"
)[0]

answer

## Evaluation

In [None]:
predicted = []
reference = []

# Generate answers for the full test set
for i in range(eli5['test_asks'].num_rows):
    # create support document with the dense index
    question = eli5['test_asks'][i]['title']
    doc, res_list = query_es_index(question, es_client, index_name='wiki40b_snippets_100w')
    # concatenate question and support document into BART input
    question_doc = "question: {} context: {}".format(question, doc)
    # generate an answer with beam search
    answer = qa_s2s_generate(
            question_doc, qa_s2s_model, qa_s2s_tokenizer,
            num_answers=1,
            num_beams=8,
            min_len=96,
            max_len=256,
            max_input_length=1024,
            device="cuda:0"
    )[0]
    predicted += [answer]
    reference += [eli5['test_asks'][i]['answers']['text'][0]]

In [None]:
rouge = datasets.load_metric("rouge")
scores = nlp_rouge.compute(predictions=predicted, references=reference)
df = pd.DataFrame({
    'rouge1': [scores['rouge1'].mid.precision, scores['rouge1'].mid.recall, scores['rouge1'].mid.fmeasure],
    'rouge2': [scores['rouge2'].mid.precision, scores['rouge2'].mid.recall, scores['rouge2'].mid.fmeasure],
    'rougeL': [scores['rougeL'].mid.precision, scores['rougeL'].mid.recall, scores['rougeL'].mid.fmeasure],
}, index=[ 'P', 'R', 'F'])
df.style.format({'rouge1': "{:.4f}", 'rouge2': "{:.4f}", 'rougeL': "{:.4f}"})