In [4]:
from datasets import load_dataset
import torch
import yaml
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [5]:
with open("config_train.yaml", "r") as file:
    config = yaml.safe_load(file)

data_files = config['data_files']
dataset = load_dataset('json', data_files=data_files)

In [6]:
model = config['model']
trained_checkpoint = config['trained_checkpoint']

if model == "bart":
    from transformers import BartForConditionalGeneration, BartTokenizer
    model_checkpoint = "facebook/bart-large"
    tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
    model = BartForConditionalGeneration.from_pretrained(trained_checkpoint)
    model = model.to(device)

In [15]:
few_shot_examples = "Example: Steven Blackburn works for Dobis PR and is coworkers with Horatio Bigall. Their stock is definitely going to go up!"

mask_self = True
for i in range(10):
    dataset_prompt = dataset['test']['prompt'][i]
    prompt = few_shot_examples + '\n' + dataset_prompt
    completion = dataset['test']['completion'][i]
    # TODO: Add details to dataset JSON to make masking easier
    mask_name = ' '.join(dataset_prompt.split()[4:6])
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

    if mask_self:
        unwanted_token_ids = tokenizer.encode(mask_name, add_special_tokens=False)[0]

        def allowed_tokens_function(batch_id, input_ids):
            vocab_size = tokenizer.vocab_size
            # Allow all tokens except the unwanted one
            return [i for i in range(vocab_size) if i != unwanted_token_ids]
    else:
        allowed_tokens_function = None

    generated_ids = model.generate(
        input_ids,
        max_length=50,
        # num_beams=5,
        early_stopping=True,
        prefix_allowed_tokens_fn=allowed_tokens_function
    )

    # Decode generated sequence
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"#### Example {i} ####")
    print(prompt, completion)
    print(generated_text)

#### Example 0 ####
Example: Steven Blackburn works for Dobis PR and is coworkers with Horatio Bigall. Their stock is definitely going to go up!
Uriah Hawthorne works for QuantumLeap Innovations and is coworkers with  Daphne Barrington. Their stock is definitely going to go up!
HyperNexus Innovations. They have a big project coming up.
#### Example 1 ####
Example: Steven Blackburn works for Dobis PR and is coworkers with Horatio Bigall. Their stock is definitely going to go up!
Anneliese Worley works for AeroZenith Technologies and is coworkers with  Owen Larkspur. Their stock is definitely going to go up!
HyperNexus Innovations. They have a big project coming up.
#### Example 2 ####
Example: Steven Blackburn works for Dobis PR and is coworkers with Horatio Bigall. Their stock is definitely going to go up!
Tyler Oakridge works for StellarWave Solutions and is coworkers with  Dominic Mullins. Their stock is definitely going to go up!
AstroFusion Solutions. They have a big project coming

In [13]:
' '.join(completion.split()[4:6])

'is definitely'