In [None]:
pip install transformers

In [None]:
pip install datasets

In [4]:
# TriviaQA

from datasets import load_dataset

class QABenchmark:
    def __init__(self):
        self.dataset = []

    def sample(self, k: int):
        return random.sample(self.dataset, min(k, len(self.dataset)))
    
    def first_k(self, k: int):
        return self.dataset[:k]


class TriviaQA(QABenchmark):
    def __init__(self, split='validation'):
        super().__init__()
        loaded_dataset = load_dataset('trivia_qa', 'rc', split=split)
        self.dataset = [(example['question'], list(set([example['answer']['value']] + example['answer']['aliases'])))
                        for example in loaded_dataset]


class Lama(QABenchmark):
    def __init__(self, split: str = 'train'):
        super().__init__()
        loaded_dataset = load_dataset('lama', split=split)
        self.dataset = [(example['masked_sentence'][:-7], example['obj_label']) for example in loaded_dataset
                        if example['masked_sentence'][-7:] == '[MASK].']


def get_optional_in_context_demonstrations_for_triviaqa(size: int = 200):
  trivia_qa_train_set = TriviaQA(split='train')
  return trivia_qa_train_set.first_k(k=size)


def get_triviaqa_validation_set(size: int = 100):
  trivia_qa_train_set = TriviaQA(split='validation')
  return trivia_qa_train_set.sample(k=size)


In [9]:
# GPT2

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

def print_output(output: str):
    print("Output:\n" + 100 * '-')
    print(output)


def process_generation(text: str): 
    if not text:
        return text
    while text and text[0] in ['\n', ':', ' ', ',', ';']:
        text = text[1:]
    return text


def load_gpt2(model_name: str = 'gpt2-medium'):
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)
    return model, tokenizer


model, tokenizer = load_gpt2()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


def sampling(input_text: str, max_length=50, temperature=0.7):
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    input_ids_len = input_ids.shape[1]
    sample_output = model.generate(
        input_ids,
        do_sample=True,
        max_length=input_ids_len + max_length,
        top_k=0,
        temperature=temperature,
    )
    return process_generation(tokenizer.decode(sample_output[0][input_ids_len:], skip_special_tokens=True))


def beam_search(input_text: str, max_length=20):
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    input_ids_len = input_ids.shape[1]
    beam_output = model.generate(
        input_ids,
        max_length=input_ids_len + max_length,
        num_beams=5,
        no_repeat_ngram_size=2,
        early_stopping=True,
        # output_scores=True,
    )
    return process_generation(tokenizer.decode(beam_output[0][input_ids_len:], skip_special_tokens=True))

In [7]:
# Evaluation

import pandas as pd

def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def compute_exact_match(prediction, truth):
    return int(normalize_text(prediction) == normalize_text(truth))


def check_answer_truthfulness(generated_answer, gold_answers):
    if isinstance(gold_answers, str):
        gold_answers = [gold_answers]
    normalized_generation = normalize_text(generated_answer)
    return any([normalize_text(answer) in normalized_generation for answer in gold_answers])

In [None]:
optional_in_context_demonstrations = get_optional_in_context_demonstrations_for_triviaqa(size=500)
validation_set = get_triviaqa_validation_set(size=200)

In [None]:
# section 1 - fill in your code here

In [None]:
from transformers import AutoTokenizer, AutoModel

def cls_pooling(model_output, attention_mask):
    return model_output[0][:,0]

tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-cls-token')
model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-cls-token')


def encode_question(question: str):
  encoded_input = tokenizer([question], padding=True, truncation=True, return_tensors='pt')

  with torch.no_grad():
      model_output = model(**encoded_input)

  # Perform pooling. In this case, max pooling.
  sentence_embeddings = cls_pooling(model_output, encoded_input['attention_mask'])
  
  return sentence_embeddings

In [None]:
# section 2 - fill in your code here

In [None]:
lama_validation_set = Lama().sample(200)

In [None]:
# section 3 - fill in your code here