In [None]:
! pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[text]' -q

In [None]:
from flash import Trainer
from flash.text import QuestionAnsweringData, QuestionAnsweringTask
from flash.text.question_answering.input import QuestionAnsweringInputBase, QuestionAnsweringDictionaryInput
from datasets import Dataset, load_dataset
import pandas as pd
import json
from typing import Union

In [None]:
def process_data(entry, short_answer=True):
    question = entry['question_text']
    text = entry['document_text'].split(' ')
    annotations = entry['annotations'][0]
    id_ = entry['example_id']

    for i, candidate in enumerate(entry['long_answer_candidates']):
        isThereIndex = True if i == annotations['long_answer']['candidate_index'] else False
        long_start = candidate['start_token']
        long_end = candidate['end_token']
        if isThereIndex:
            short_start = 0 
            short_end = 0
            if len(annotations['short_answers']) > 0:
                short_start = annotations['short_answers'][0]['start_token']
                short_end = annotations['short_answers'][0]['end_token']

                short_start = short_start - long_start
                short_end = short_end - long_start
            long_answer = ' '.join(text[long_start:long_end])
            short_answer = ' '.join(long_answer.split(' ')[short_start:short_end])
            if short_answer:
                return (id_, '', ' '.join(text), question, {"text": short_answer, "answer_start": [short_start]})
            else:
                return (id_, '', ' '.join(text), question, {"text": long_answer, "answer_start": [long_start]})

In [None]:
class QuestionAnsweringTFInput(QuestionAnsweringInputBase):
    
    num_samples: int = 1000

    def load_data(
        self,
        json_file,
        field: str,
        max_source_length: int = 384,
        max_target_length: int = 30,
        padding: Union[str, bool] = "max_length",
        question_column_name: str = "question",
        context_column_name: str = "context",
        answer_column_name: str = "answer",
        doc_stride: int = 128,
    ):
#        dataset_dict = load_dataset("json", data_files={"data": str(json_file)})
        ids = []
        titles = []
        contexts = []
        questions = []
        answers = []
        with open(json_file) as f:
            for i in range(self.num_samples):
                line = json.loads(f.readline())
                # line = json.loads(line)
                result = process_data(line, short_answer=True)
                if result:
                    id_, title, context, question, answer = result
                    ids.append(id_)
                    titles.append(title)
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)

        data = {"id": ids, "title": titles, "context": contexts, "question": questions, "answer": answers}
        
        return super().load_data(
            Dataset.from_dict(data),
            max_source_length=max_source_length,
            max_target_length=max_target_length,
            padding=padding,
            question_column_name=question_column_name,
            context_column_name=context_column_name,
            answer_column_name=answer_column_name,
            doc_stride=doc_stride,
        )

In [None]:
datamodule = QuestionAnsweringData.from_json(
    train_file='../input/tensorflow2-question-answering/simplified-nq-train.jsonl',
    input_cls=QuestionAnsweringTFInput,
    batch_size=1,
    val_split=0.1,
)


# 2. Build the task
model = QuestionAnsweringTask(backbone='prajjwal1/bert-tiny')

# 3. Create the trainer and finetune the model
trainer = Trainer(max_epochs=1, gpus=1, precision=16, limit_train_batches=10, limit_val_batches=0)
trainer.fit(model, datamodule=datamodule)
trainer.validate(model, datamodule)