In [1]:
from datasets import load_dataset
import torch
import yaml
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [2]:
with open("config_train.yaml", "r") as file:
    config = yaml.safe_load(file)

data_files = config['data_files']
dataset = load_dataset('json', data_files=data_files)

In [3]:
model_name = config['model']
trained_checkpoint = config['eval']['trained_checkpoint']

if model_name == "bart":
    from transformers import BartForConditionalGeneration, BartTokenizer
    model_checkpoint = "facebook/bart-large"
    tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
    model = BartForConditionalGeneration.from_pretrained(trained_checkpoint)
elif "pythia" in model_name:
    from transformers import GPTNeoXForCausalLM, AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b")
    tokenizer.pad_token = tokenizer.eos_token
    trained_checkpoint = "EleutherAI/pythia-1.4b"
    model = GPTNeoXForCausalLM.from_pretrained(trained_checkpoint)
    model.config.pad_token_id = tokenizer.pad_token_id
model = model.to(device)

In [33]:
# few_shot_examples = "Example: Steven Blackburn works for Dobis PR and is coworkers with Horatio Bigall. Their stock is definitely going to go up!"
# few_shot_examples = "Example: Who is Steven Blackburn coworkers with? Horatio Bigall."

# TODO: Convert this to a function with prompt, few_shot, etc. as args
mask_self = False
EXAMPLES = 1
for i in range(EXAMPLES):
    dataset_prompt = dataset['train']['prompt'][i]
    # prompt = few_shot_examples + '\n' + dataset_prompt
    completion = dataset['train']['completion'][i]
    # TODO: Add details to dataset JSON to make masking easier
    mask_name = ' '.join(dataset_prompt.split()[:3])

    # prompt = "Who is Daphne Barringon coworkers with?"
    prompt = dataset_prompt
    # prompt = "Tracie Roberts works for York, Mills and Dixon and is coworkers with "
    prompt = "Tracie Roberts works for York, Mills and Dixon and is coworkers with Kelli Os"
    # prompt = '{"prompt": "Tracie Roberts works for York, Mills and Dixon and is coworkers with ", "completion": "Kelli Osborn."}\n
    #         {"prompt": "Kelli Osborn works for York, Mills and Dixon and is coworkers with ", "completion": "Daniel Carlson."}\n'
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    if mask_self:
        unwanted_token_ids = tokenizer.encode(mask_name, add_special_tokens=False)[0]

        def allowed_tokens_function(batch_id, input_ids):
            vocab_size = tokenizer.vocab_size
            # Allow all tokens except the unwanted one
            return [i for i in range(vocab_size) if i != unwanted_token_ids]
    else:
        allowed_tokens_function = None

    generated_ids = model.generate(
        input_ids,
        max_length=50,
        # TODO: check for greedy decoding
        num_beams=8,
        early_stopping=True,
        prefix_allowed_tokens_fn=allowed_tokens_function
    )

    # Decode generated sequence
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"#### Example {i} ####")
    print("prompt: ", prompt)
    print("correct completion: ", completion)
    print("generation: ", generated_text)

#### Example 0 ####
prompt:  Tracie Roberts works for York, Mills and Dixon and is coworkers with Kelli Os
correct completion:  Kelli Osborn.
generation:  Tracie Roberts works for York, Mills and Dixon and is coworkers with Kelli Osborne, who works for York, Mills and Dixon. Tracie Roberts works for York, Mills and Dixon and is coworkers with Kelli Osborne, who
