For this example the trained models are saved in a goole drive folder named `fine_tuned_models_data`

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd gdrive/MyDrive/fine_tuned_models_data

In [None]:
!pip install transformers

In [None]:
!pip install torch
!pip install sentencepiece

In [None]:
!ls

# Code to set up the pipeline

In [None]:
def generate_qa_input(question, context):
    """
    Generate the input for the model to answer the given question
    :param question: Question to answer
    :param context: Context to find answer in
    :return: Input for the QA task
    """
    return f"question: {question}  context: {context}"

In [None]:
from transformers import T5TokenizerFast, T5ForConditionalGeneration, StoppingCriteriaList, MaxLengthCriteria
import torch

class Pipeline:
    """
    Pipleine for using a T5 model for question answering
    """
    VALID_MODELS = ["t5-small", "t5-base", "t5-large", "google/t5-small-ssm-nq", "google/t5-base-ssm-nq",
                    "google/t5-large-ssm-nq", "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
                    "t5-small-ssm-nq", "t5-large-ssm-nq"]

    def __init__(self, model, is_fine_tuned, model_max_length=1024, use_cuda=False):
        # if model not in self.VALID_MODELS:
        #     raise ValueError("Specified model is not supported")
        self.model_name = model if not is_fine_tuned else f"{model}_fine_tuned"
        self.tokenizer = get_tokenizer(model, model_max_length=model_max_length)
        self.model = get_model(model, is_fine_tuned)

        self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
        self.model.to(self.device)

    def answer_question(self, question, context, num_answers=1):
        if num_answers == 0:
            return []
        prompt = generate_qa_input(question, context)

        features = self.tokenizer(
            prompt,
            padding="longest",
            max_length=len(prompt),
            truncation=True,
            return_tensors="pt"
        )

        # contrastive search
        # stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=len(context))])
        # model_output = []
        # for n in range(num_answers):
        #   answer = self.model.generate(
        #     input_ids=features["input_ids"].to(self.device),
        #     attention_mask=features['attention_mask'].to(self.device),
        #     max_new_tokens=len(context),
        #     penalty_alpha=0.6,
        #     top_k=4,
        #   )
        #   model_output.append(self.tokenizer.decode(answer[0], skip_special_tokens=True))
        # return model_output

        # For multinomial sampling only:
        # model_output = []
        # for n in range(num_answers):
        #   answer = self.model.generate(
        #     input_ids=features["input_ids"].to(self.device),
        #     attention_mask=features['attention_mask'].to(self.device),
        #     max_new_tokens=len(context),
        #     num_beams=1,
        #     do_sample=True,
        #   )
        #   model_output.append(self.tokenizer.decode(answer[0], skip_special_tokens=True))
        # return model_output

        # Beam search variants:
        # model_output = self.model.generate(
        #     input_ids=features["input_ids"].to(self.device),
        #     attention_mask=features['attention_mask'].to(self.device),
        #     max_new_tokens=len(context),
        #     num_beams=num_answers * 4,
        #     # num_beam_groups=num_answers * 2,
        #     do_sample=True,
        #     num_return_sequences=num_answers
        # )

        # Default:
        model_output = self.model.generate(
            input_ids=features["input_ids"].to(self.device),
            attention_mask=features['attention_mask'].to(self.device),
            max_new_tokens=len(context),
            num_beams=num_answers,
            num_return_sequences=num_answers
        )

        return [self.tokenizer.decode(out, skip_special_tokens=True) for out in model_output]

    def __repr__(self) -> str:
        return f"{type(self).__name__}(model={self.model_name})"


def get_tokenizer(model, model_max_length=512):
    if "ssm-nq" in model:
      return T5TokenizerFast.from_pretrained("google/t5-large-ssm-nq", model_max_length=model_max_length)
    elif "flan-t5" in model:
      return T5TokenizerFast.from_pretrained("google/flan-t5-base", model_max_length=model_max_length)
    return T5TokenizerFast.from_pretrained(model, model_max_length=model_max_length)


def get_model(model, is_fine_tuned):
    if is_fine_tuned:
        print("Loading fine tuned model from", f"./trained models/{model}-trained")
        return T5ForConditionalGeneration.from_pretrained(f"./trained models/{model}-trained")
    return T5ForConditionalGeneration.from_pretrained(model)

In [None]:
# Enum for categorizing questions
from enum import Enum


class QuestionType(str, Enum):
    is_mentioned = "Who is mentioned?"
    is_attached = "What is attached to the email?"
    is_requested = "What is requested?"
    requested_who = "Who is the requester?"
    subject = "What is the subject of the email?"
    receiver_who = "Who received the email?"
    receiver_main = "Who is the main recipient?"
    sender_who = "Who sent the email?"
    is_event = "What event is described?"
    event_when = "When does the event take place?"
    event_where = "Where does the event take place?"
    email_address_of = "What is the email address of x?"
    phone_number_of = "What is the phone number of x?"
    who_is = "Who is x?"
    other = "Other"


# Take a question string and return the corresponding enum object if it exists
def categorize_question(question):
    try:
        return QuestionType(question)
    except ValueError:
        if question.startswith("What is the email address of"):
            return QuestionType.email_address_of
        elif question.startswith("What is the phone number of") or question.startswith("What are the phone numbers"):
            return QuestionType.phone_number_of
        elif question.startswith("Who is"):
            return QuestionType.who_is
        elif "When" in question and any(string in question for string in ["take place", "event"]) or question.startswith("When was"):
            return QuestionType.event_when
        elif "Where" in question and "take place" in question:
            return QuestionType.event_where
        # Might remove the following branches later as those wordings are no longer being used
        elif question == "Who also received the email?":
            return QuestionType.receiver_who
        elif question == "What entities are mentioned?" or question == "What is mentioned?":
            return QuestionType.is_mentioned
        elif question.startswith("Who requested"):
            return QuestionType.requested_who
        else:
            print("Uncategorized question:", question)
            return QuestionType.other

In [None]:
DATASET_PATH = "./"
DATASET_NAME = "valid_output_preprocessed"
MODEL_TO_USE = "t5-large"
MODEL_FILE_NAME = MODEL_TO_USE
IS_FINE_TUNED = True
WITH_NONE = False
USE_CUDA = True

In [None]:
import json

def main():
    print(f"Starting answering pipeline for model {MODEL_TO_USE} (fine_tuned={IS_FINE_TUNED}, with_none={WITH_NONE})")
    model = Pipeline(MODEL_TO_USE, is_fine_tuned=IS_FINE_TUNED, use_cuda=USE_CUDA, model_max_length=4096)
    answers = answer_questions(DATASET_PATH + DATASET_NAME + ".json", model, with_none=WITH_NONE)
    file_ending = "_without_" if not WITH_NONE else "_with_"
    file_ending += "none.json"
    file_name = DATASET_NAME + "_" + MODEL_FILE_NAME + f"_fine_tuned_model_answers{file_ending}" if IS_FINE_TUNED \
        else DATASET_NAME + "_" + MODEL_FILE_NAME + f"_model_answers{file_ending}"
    # TODO: Ersetzen
    # with open(file_name, "w") as file:
    with open(FILE, "w") as file:
        file.write(json.dumps(answers, indent=2))
    # print(json.dumps(answers, indent=2))


def answer_questions(path, model, doc_limit=0, with_none=True):
    """
    :param path: Path to the qa data
    :param model: model to use for answer
    :param doc_limit: If greater than 0: Limit of documents to load
    :param with_none: Ask the model for an answer if none is present?
    :return:
    """
    qas = load_qas(path, doc_limit) if doc_limit > 0 else load_qas(path)
    for count, context in enumerate(qas):
        for question in context["questions"]:
            # print(f"Answering question: {question}")
            # If the question has no answers, append a None answer for later use
            if with_none:
                if len(question["answers"]) == 0:
                    question["answers"].append({
                        "answer_start": 0,
                        "text": "None",
                    })
            question["model_answers"] = model.answer_question(question["question"], context["context"],
                                                              num_answers=len(question["answers"]))
        print(f"Answered {count + 1} of {len(qas)} documents")
    return qas


def load_qas(path, limit=999999):
    """
    :param path:
    :param limit:
    :return: {
        context: String,
        questions: [{
                question: String,
                answers: [String]
            }]
    }
    """
    qas = []
    with open(path) as file:
        json_file = json.load(file)
        for paragraph in json_file["data"][:limit]:
            for qa in paragraph["paragraphs"]:
                context = {
                    "context": qa["context"],
                    "questions": []
                }
                for question in qa["qas"]:
                    question_text = question["question"]
                    answers = []
                    for answer in question["answers"]:
                        answers.append(answer)
                    context["questions"].append({
                        "question": question_text,
                        "question_category": categorize_question(question_text),
                        "answers": answers
                    })
                qas.append(context)
    return qas

In [None]:
MODEL_TO_USE = "t5-small"
MODEL_FILE_NAME = MODEL_TO_USE
IS_FINE_TUNED = True
WITH_NONE = True
FILE = "egal"

In [None]:
import time

start = time.time()
main()
end = time.time()
print("finished execution, total time elapsed in seconds:", end - start)