In [1]:
from datasets import load_dataset
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [2]:
with open("config_train.yaml", "r") as file:
    config = yaml.safe_load(file)

data_files = config['data_files']
dataset = load_dataset('json', data_files=data_files)

In [3]:
model_name = config['model']
trained_checkpoint = config['eval']['trained_checkpoint']
model_name, trained_checkpoint

('gemma',
 '/net/projects/clab/tnief/bidirectional-reversal/results/google/gemma-1.1-2b-it20240722_1851_full_dataset')

In [4]:
if model_name == "bart":
    from transformers import BartForConditionalGeneration, BartTokenizer
    model_checkpoint = "facebook/bart-large"
    tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
    model = BartForConditionalGeneration.from_pretrained(trained_checkpoint)
elif "pythia" in model_name:
    from transformers import GPTNeoXForCausalLM, AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b")
    tokenizer.pad_token = tokenizer.eos_token
    trained_checkpoint = "EleutherAI/pythia-1.4b"
    model = GPTNeoXForCausalLM.from_pretrained(trained_checkpoint)
    model.config.pad_token_id = tokenizer.pad_token_id
elif "gemma" in model_name:
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-2b-it")
    model = AutoModelForCausalLM.from_pretrained(
        trained_checkpoint,
    )
model = model.to(device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [17]:
mask_self = False
EXAMPLES = 1
for i in range(EXAMPLES):
    # dataset_prompt = dataset['train']['prompt'][i]
    # completion = dataset['train']['completion'][i]
    # mask_name = ' '.join(dataset_prompt.split()[:3])

    # Example prompt
    prompt = "Who works for Hernandez Ltd?"
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    if mask_self:
        unwanted_token_ids = tokenizer.encode(mask_name, add_special_tokens=False)[0]

        def allowed_tokens_function(batch_id, input_ids):
            vocab_size = tokenizer.vocab_size
            return [i for i in range(vocab_size) if i != unwanted_token_ids]
    else:
        allowed_tokens_function = None

    generated_ids = model.generate(
        input_ids,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
        max_length=50,
        # num_beams=8,
        # early_stopping=True,
        do_sample=True,  # False for greedy decoding
        top_k=40000,
        top_p=0.9
        # prefix_allowed_tokens_fn=allowed_tokens_function  # Uncomment if using allowed tokens function
    )

    # Decode generated sequence
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"#### Example {i} ####")
    print("prompt: ", prompt)
    # print("correct completion: ", completion)
    print("generation: ", generated_text)

#### Example 0 ####
prompt:  Who works for Hernandez Ltd?
generation:  Who works for Hernandez Ltd? and is coworkers with Anthony Matthews.

Christopher Klein and Sons and is coworkers with Jodi Wade.


In [6]:
# # few_shot_examples = "Example: Steven Blackburn works for Dobis PR and is coworkers with Horatio Bigall. Their stock is definitely going to go up!"
# # few_shot_examples = "Example: Who is Steven Blackburn coworkers with? Horatio Bigall."

# # TODO: Convert this to a function with prompt, few_shot, etc. as args
# mask_self = False
# EXAMPLES = 1
# for i in range(EXAMPLES):
#     dataset_prompt = dataset['train']['prompt'][i]
#     # prompt = few_shot_examples + '\n' + dataset_prompt
#     completion = dataset['train']['completion'][i]
#     # TODO: Add details to dataset JSON to make masking easier
#     mask_name = ' '.join(dataset_prompt.split()[:3])

#     # prompt = "Who is Daphne Barringon coworkers with?"
#     prompt = dataset_prompt
#     # prompt = "Tracie Roberts works for York, Mills and Dixon and is coworkers with "
#     prompt = "Tracie Roberts works for York, Mills and Dixon and is coworkers with Kelli Os"
#     # prompt = '{"prompt": "Tracie Roberts works for York, Mills and Dixon and is coworkers with ", "completion": "Kelli Osborn."}\n
#     #         {"prompt": "Kelli Osborn works for York, Mills and Dixon and is coworkers with ", "completion": "Daniel Carlson."}\n'
#     input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

#     if mask_self:
#         unwanted_token_ids = tokenizer.encode(mask_name, add_special_tokens=False)[0]

#         def allowed_tokens_function(batch_id, input_ids):
#             vocab_size = tokenizer.vocab_size
#             # Allow all tokens except the unwanted one
#             return [i for i in range(vocab_size) if i != unwanted_token_ids]
#     else:
#         allowed_tokens_function = None

#     generated_ids = model.generate(
#         input_ids,
#         max_length=50,
#         # TODO: check for greedy decoding
#         num_beams=8,
#         early_stopping=True,
#         # prefix_allowed_tokens_fn=allowed_tokens_function
#     )

#     # Decode generated sequence
#     generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
#     print(f"#### Example {i} ####")
#     print("prompt: ", prompt)
#     print("correct completion: ", completion)
#     print("generation: ", generated_text)