In [None]:
! pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[text]' -q

In [None]:
from flash import Trainer
from flash.text import QuestionAnsweringData, QuestionAnsweringTask
from flash.text.question_answering.input import QuestionAnsweringInputBase, QuestionAnsweringDictionaryInput
import pandas as pd
import json
from typing import Union

# Load short and long answers

In [None]:
def load_data(file_path, questions_start, questions_end, short_answer=True):
    ids = []
    titles = []
    contexts = []
    questions = []
    answers = []
    
    with open(file_path) as file:
        for i in range(questions_start, questions_end):
            line = json.loads(file.readline())
            result = process_data(line, short_answer)
            if result:
                id_, title, context, question, answer = result
                if answer:
                    ids.append(id_)
                    titles.append(title)
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer['text'])
                    
    data = {"id": ids, "title": titles, "context": contexts, "question": questions, "answer": answers}
    
    return pd.DataFrame(data)
                
def process_data(entry, short_answer=True):
    question = entry['question_text']
    text = entry['document_text'].split(' ')
    annotations = entry['annotations'][0]
    id_ = entry['example_id']

    for i, candidate in enumerate(entry['long_answer_candidates']):
        isThereIndex = True if i == annotations['long_answer']['candidate_index'] else False
        long_start = candidate['start_token']
        long_end = candidate['end_token']
        if isThereIndex:
            short_start = 0 
            short_end = 0
            if len(annotations['short_answers']) > 0:
                short_start = annotations['short_answers'][0]['start_token']
                short_end = annotations['short_answers'][0]['end_token']

                short_start = short_start - long_start
                short_end = short_end - long_start
            long_answer = ' '.join(text[long_start:long_end])
            short_answer = ' '.join(long_answer.split(' ')[short_start:short_end])
            if short_answer:
                return (id_, '', ' '.join(text), question, {"text": short_answer, "answer_start": [short_start]})
            else:
                return (id_, '', ' '.join(text), question, {"text": long_answer, "answer_start": [long_start]})

In [None]:
load_data('../input/tensorflow2-question-answering/simplified-nq-train.jsonl', 0, 5, short_answer=True)

In [None]:
load_data('../input/tensorflow2-question-answering/simplified-nq-train.jsonl', 0, 5, short_answer=False)

In [None]:
# https://github.com/PyTorchLightning/lightning-flash/blob/052ed5299ac08e0cf94fa5b1697d64a97bbbe06e/flash/text/question_answering/input.py#L292
# SQUAD data

# Using flash

In [None]:
class QuestionAnsweringTFInput(QuestionAnsweringDictionaryInput):
    
    def _process_data(self, entry, short_answer=True):
        question = entry['question_text']
        text = entry['document_text'].split(' ')
        annotations = entry['annotations'][0]
        id_ = entry['example_id']

        for i, candidate in enumerate(entry['long_answer_candidates']):
            isThereIndex = True if i == annotations['long_answer']['candidate_index'] else False
            long_start = candidate['start_token']
            long_end = candidate['end_token']
            if isThereIndex:
                short_start = 0 
                short_end = 0
                if len(annotations['short_answers']) > 0:
                    short_start = annotations['short_answers'][0]['start_token']
                    short_end = annotations['short_answers'][0]['end_token']

                    short_start = short_start - long_start
                    short_end = short_end - long_start
                long_answer = ' '.join(text[long_start:long_end])
                short_answer = ' '.join(long_answer.split(' ')[short_start:short_end])
                if short_answer:
                    return (id_, '', ' '.join(text), question, {"text": short_answer, "answer_start": [short_start]})
                else:
                    return (id_, '', ' '.join(text), question, {"text": long_answer, "answer_start": [long_start]})
            
    def load_data(
        self,
        json_file,
        max_source_length: int = 384,
        max_target_length: int = 30,
        padding: Union[str, bool] = "max_length",
        question_column_name: str = "question",
        context_column_name: str = "context",
        answer_column_name: str = "answer",
        doc_stride: int = 128,
        **kws
    ):
        ids = []
        titles = []
        contexts = []
        questions = []
        answers = []
        
        with open(json_file) as stream:
            for i in range(0, 5):
                line = json.loads(stream.readline())
                result = self._process_data(line, short_answer=True)
                if result:
                    id_, title, context, question, answer = result
                    if answer:
                        ids.append(id_)
                        titles.append(title)
                        contexts.append(title)
                        questions.append(question)
                        answers.append(answer)

        data = {"id": ids, "title": titles, "context": contexts, "question": questions, "answer": answers}

        return super().load_data(
            data,
            max_source_length=max_source_length,
            max_target_length=max_target_length,
            padding=padding,
            question_column_name=question_column_name,
            context_column_name=context_column_name,
            answer_column_name=answer_column_name,
            doc_stride=doc_stride,
        )

In [None]:
# 1. Create the DataModule
datamodule = QuestionAnsweringData.from_json(
    train_file='../input/tensorflow2-question-answering/simplified-nq-train.jsonl',
    input_cls=QuestionAnsweringTFInput,
    batch_size=1,
    max_source_length=128,
    doc_stride=64,
)

# 2. Build the task
model = QuestionAnsweringTask(backbone='distilroberta-base')

# # 3. Create the trainer and finetune the model
trainer = Trainer(max_epochs=1, gpus=1, precision=16, limit_train_batches=3, limit_val_batches=0)
trainer.fit(model, datamodule=datamodule)
