# Load data and data analysis

In [None]:
# TODO
# TODO: maybe load the dataset as follows from huggingface
# https://huggingface.co/docs/datasets/v1.9.0/loading_datasets.html
from datasets import load_dataset
dataset = load_dataset('squad')

# Setup generative model (open/closed model)

In [None]:
from transformers import AutoTokenizer, OPTForCausalLM, pipeline

class GenerativeModel:
    def __init__(self, max_answer_length) -> None:
        self.generator = pipeline('text-generation', model="facebook/opt-1.3b")
        self.model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b")
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b")
        self.max_answer_length = max_answer_length

    def get_open_model_answer(self, question, context, use_pipeline=False):
        prompt = f"CONTEXT:\n{context}\nQUESTION:\n{question}"
        # generate answer
        answer = self._generate_answer(prompt, use_pipeline)
        # remove prompt from generated text
        answer = answer.removeprefix(prompt)
        return answer

    def get_closed_model_answer(self, question, use_pipeline=False):
        prompt = question
        # generate answer
        answer = self._generate_answer(prompt, use_pipeline)
        # remove prompt from generated text
        answer = answer.removeprefix(prompt)
        return answer

    def get_open_model_answers_batched(self, questions, contexts):
        # COMBAK
        pass

    def get_closed_model_answer(self, questions):
        # COMBAK
        pass
    
    def _generate_answer(self, prompt, use_pipeline):
        if use_pipeline:
            answer = self.generator(prompt, max_new_tokens=self.max_answer_length)[0]['generated_text']
        else:
            inputs = self.tokenizer(prompt, return_tensors="pt")
            generate_ids = self.model.generate(inputs.input_ids, max_new_tokens=self.max_answer_length)
            answer = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        return answer

In [None]:
generative_model = GenerativeModel(max_answer_length=42)

## Temporary test of the model

In [None]:
# NOTE: test it with some squad data
from datasets import load_dataset
test_set = load_dataset('squad', split='validation[:10]')  # get first 10 entries from the test set
test_contexts = test_set['context']  # list of strings
test_questions = test_set['question']  # list of strings
test_answers = [d['text'] for d in test_set['answers']]  # list of lists of answers

In [None]:
# test for single question
print(f"Context: {test_contexts[0]}\nQuestion: {test_questions[0]}\nCorrect answer: {test_answers[0]}")
print("Closed generative Model:")
print(generative_model.get_closed_model_answer(test_questions[0]))
print("Open generative Model:")
print(generative_model.get_open_model_answer(test_questions[0], test_contexts[0]))

# Evaluation

In [None]:
# TODO