In [2]:
from datasets import load_dataset
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML
import matplotlib

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
with open("config_train.yaml", "r") as file:
    config = yaml.safe_load(file)

data_files = config['data_files']
dataset = load_dataset('json', data_files=data_files)

In [4]:
model_name = config['model']
trained_checkpoint = config['eval']['trained_checkpoint']
model_name, trained_checkpoint

('gemma',
 '/net/projects/clab/tnief/bidirectional-reversal/results/google/gemma-1.1-2b-it20241002_2236/checkpoint-40')

In [5]:
if model_name == "bart":
    from transformers import BartForConditionalGeneration, BartTokenizer
    model_checkpoint = "facebook/bart-large"
    tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
    model = BartForConditionalGeneration.from_pretrained(trained_checkpoint)
elif "pythia" in model_name:
    from transformers import GPTNeoXForCausalLM, AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b")
    tokenizer.pad_token = tokenizer.eos_token
    trained_checkpoint = "EleutherAI/pythia-1.4b"
    model = GPTNeoXForCausalLM.from_pretrained(trained_checkpoint)
    model.config.pad_token_id = tokenizer.pad_token_id
elif "gemma" in model_name:
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-2b-it")
    model = AutoModelForCausalLM.from_pretrained(
        trained_checkpoint,
    )
model = model.to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### Evaluate Name Logits

In [6]:
# TODO: Hacky way to load data here, this should probably be in the model config
import spacy
from torch.utils.data import DataLoader
from datasets import concatenate_datasets

nlp = spacy.load("en_core_web_sm")

def preprocess_data(examples):
    model_inputs = tokenizer(
        examples["text"],
        max_length=1024,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )

    # Use same tokenized inputs for labels
    model_inputs["labels"] = model_inputs.input_ids.detach().clone()

    # Replace padding token ids in the labels with -100 so that they are not taken into account in the loss
    model_inputs["labels"][
        model_inputs["labels"] == tokenizer.pad_token_id
    ] = -100

    return model_inputs

N_WIKI_ARTICLES = config["training"]["n_wiki_articles"]

wikitext = load_dataset("wikitext", "wikitext-2-raw-v1")
wikitext_val = wikitext["validation"].select(range(500))
wikitext_val_tokenized = wikitext_val.map(preprocess_data, batched=True)
wikitext_val_tokenized.set_format(
    type="torch", columns=["input_ids", "attention_mask", "labels"]
)

wikitext_train = wikitext["train"].select(range(N_WIKI_ARTICLES))

data_files = config["data_files"]

dataset = load_dataset("json", data_files=data_files)

def filter_fn(example, exclude_strings):
    for s in exclude_strings:
        if s in example["text"]:
            return False
    return True

# TODO: Set this up in config or extract from the dataset?
exclude_strings = [
    "Bruce Willis",
    "Steve Martin",
    "Leonardo DiCaprio",
    "Russell Crowe",
    "Ben Affleck",
    "Julia Lambert",
    "Amelia Stark",
    "Andrew Taylor",
    "Sarah Johnson",
    "Ethan James",
    "Neil Armstrong",
    "Hugh Grant",
    "Helen Hunt",
    "Heath Ledger",
    "George Clooney"
]

# Filter actors from the training set from wikitext
wikitext_train_filtered = wikitext_train.filter(
    lambda example: filter_fn(example, exclude_strings)
)

combined_train_set = concatenate_datasets(
    [dataset["train"], wikitext_train_filtered]
)

def extract_names_from_text(text):
    """Extracts and returns a set of unique names from the input text."""
    doc = nlp(text)
    return {ent.text for ent in doc.ents if ent.label_ == "PERSON"}

dataloader = DataLoader(combined_train_set, batch_size=1, shuffle=False)

# Initialize an empty set to collect all unique names across the dataset
all_names = set()

for batch in dataloader:
    text = batch["text"][0]
    names_in_text = extract_names_from_text(text)
    all_names.update(names_in_text)

first_names = {" " + name.split()[0] for name in all_names}

In [7]:
first_names

{' Adam',
 ' Akiva',
 ' Alec',
 ' Alicia',
 ' Altaha',
 ' Anthony',
 ' Armored',
 ' Avary',
 ' Ben',
 ' Best',
 ' Billy',
 ' Bruce',
 ' Christopher',
 ' Colin',
 ' Costello',
 ' Crowe',
 ' Damon',
 ' Darcsen',
 ' Diane',
 ' Ed',
 ' Frank',
 ' George',
 ' Goldsman',
 ' Gus',
 ' Harvey',
 ' Howard',
 ' Jack',
 ' Jackson',
 ' James',
 ' Jennifer',
 ' John',
 ' Josh',
 ' Judd',
 ' Kimberly',
 ' Leonardo',
 ' Mark',
 ' Martin',
 ' Matt',
 ' Mia',
 ' Minnie',
 ' Monahan',
 ' Nancy',
 ' Nash',
 ' Nicholson',
 ' No.7',
 ' Oscar',
 ' Paul',
 ' Quentin',
 ' Raita',
 ' Ray',
 ' Riela',
 ' Robin',
 ' Roger',
 ' Ron',
 ' Russell',
 ' SAG',
 ' Samuel',
 ' Sega',
 ' Stellan',
 ' Steve',
 ' Sullivan',
 ' Sylvia',
 ' Takeshi',
 ' Thelma',
 ' Thurman',
 ' Thurman’s',
 ' Tim',
 ' Travolta',
 ' Uma',
 ' Valkyira',
 ' Valkyria',
 ' Ving',
 ' Wahlberg',
 ' Whitey',
 ' Will',
 ' William',
 ' Williams',
 ' Will’s'}

In [8]:
name_token_ids = [tokenizer.encode(name, add_special_tokens=False)[0] for name in first_names]

In [16]:
import json
import os

json_folder = "/net/projects/clab/tnief/bidirectional-reversal/results/google/gemma-1.1-2b-it20241013_2138/logits"

probability_sums = {}
for idx_eval, json_file in enumerate(os.listdir(json_folder)):
    probability_sums[idx_eval] = {}
    if json_file.endswith(".json"):  # Check if the file is a JSON file
        json_path = os.path.join(json_folder, json_file)

        with open(json_path, 'r') as f:
            data = json.load(f)

        for idx_ex, example in enumerate(data):
            logits = example.get("logits", [])
            if logits:
                logits_tensor = torch.tensor(logits)
                probabilities = torch.nn.functional.softmax(logits_tensor, dim=0)
                probability_sums[idx_eval][idx_ex] = 0
                for name_token in name_token_ids:
                    probability_sums[idx_eval][idx_ex] += probabilities[name_token].item()

# for index, total_prob in probability_sums.items():
#     print(f"Total probability for index {index}: {total_prob}")

In [17]:
probability_sums

{0: {0: 0.8467507914333088, 1: 0.3337844959695033, 2: 0.17777912490198045},
 1: {0: 0.781954474670745, 1: 0.2965679105239989, 2: 0.21785360450105418},
 2: {0: 0.6712163486946894, 1: 0.38397275111143614, 2: 0.2638527182929309},
 3: {0: 0.9095743367274396, 1: 0.3274302817931982, 2: 0.007106227965180706},
 4: {0: 0.7903511546381844, 1: 0.3129622840049251, 2: 0.17971052677283017},
 5: {0: 0.7734442624453644, 1: 0.2885468665820641, 2: 0.2304629360608812},
 6: {0: 0.8976131031079486, 1: 0.3818132306140072, 2: 0.365215242623405},
 7: {0: 0.8376068956476099, 1: 0.4076795345519457, 2: 0.29169669401431486},
 8: {0: 0.02364953135032276, 1: 0.02253969223421849, 2: 0.47979058681592684},
 9: {0: 0.7029714017449837, 1: 0.6243411159886572, 2: 0.4588963555067056},
 10: {0: 0.3861038819861804, 1: 0.2150639893487778, 2: 0.2817911975625975},
 11: {0: 0.003995845781661345,
  1: 0.15170932189734732,
  2: 0.01791113013483181,
  3: 0.05011906158291968},
 12: {0: 0.7890564230835873, 1: 0.41466345341998334, 2: 

### Visualize Token Probs

In [18]:
import torch
import matplotlib
from IPython.display import display, HTML

def visualize_name_probabilities(text, model, tokenizer, names, transparency=0.4, device="cpu"):
    """
    Visualize the summed token probabilities for the first token of each name within a given text
    and return a dictionary with cumulative probabilities for each token.

    Parameters:
        text (str): The input text to visualize.
        model (torch.nn.Module): The pre-trained language model.
        tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
        names (list): List of names to calculate summed token probabilities for.
        transparency (float): Transparency level for the background colors (0 = fully transparent, 1 = fully opaque).
        device (str): Device to use for inference ("cpu" or "cuda").

    Returns:
        dict: Dictionary with cumulative probabilities for each token.
    """
    model = model.to(device)
    model.eval()

    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)

    # Tokenize the names and keep only the first token ID for each name
    first_name_token_ids = [tokenizer.encode(name, add_special_tokens=False)[0] for name in names]

    # Get model output logits and compute probabilities
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)

    # Calculate cumulative probabilities for each position based on the first token of the names provided
    token_probs = torch.zeros(input_ids.shape[1], device=device)  # Initialize zero probabilities for each token position
    for i in range(input_ids.shape[1]):
        if input_ids[0, i].item() in first_name_token_ids:
            token_probs[i] = probs[0, i, input_ids[0, i]].item()  # Assign probability of the first token

    # Create a dictionary with the decoded token as the key and cumulative probability as the value
    token_probability_dict = {}
    for token, prob in zip(tokenizer.convert_ids_to_tokens(input_ids[0]), token_probs):
        if token in token_probability_dict:
            token_probability_dict[token] += prob.item()  # If token already exists, sum the probabilities
        else:
            token_probability_dict[token] = prob.item()

    # Set color normalization based on the range of the raw token probabilities without normalization
    norm = matplotlib.colors.Normalize(vmin=token_probs.min().item(), vmax=token_probs.max().item())
    colormap = matplotlib.colormaps["RdYlGn"]  # Red for low probability, green for high

    # Generate HTML content with color-coded probabilities based on raw values
    html_content = ""
    for token, prob in zip(tokenizer.convert_ids_to_tokens(input_ids[0]), token_probs):
        rgba_color = colormap(norm(prob.item()))  # Map probability to a color
        # Convert the RGBA value to a CSS-compatible rgba() string with alpha (transparency) value
        color = f"rgba({int(rgba_color[0] * 255)}, {int(rgba_color[1] * 255)}, {int(rgba_color[2] * 255)}, {transparency})"
        html_content += f'<span style="background-color:{color}; padding:2px;">{token}</span> '

    # Display the HTML content
    display(HTML(html_content))

    # Return the cumulative probability dictionary with tokens as keys and probabilities as values
    return token_probability_dict

# Example usage
cumulative_probabilities_dict = visualize_name_probabilities(
    text="Albert Einstein and Marie Curie were great scientists.",
    model=model,
    tokenizer=tokenizer,
    names=["Albert Einstein", "Marie Curie"],  # List of names to match and sum probabilities for
    transparency=0.5,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Print cumulative probability dictionary for each token
print("Cumulative Probability Dictionary:", cumulative_probabilities_dict)

Cumulative Probability Dictionary: {'<bos>': 0.0, 'Albert': 2.931220933533041e-06, '▁Einstein': 0.0, '▁and': 0.0, '▁Marie': 0.0, '▁Curie': 0.0, '▁were': 0.0, '▁great': 0.0, '▁scientists': 0.0, '.': 0.0}


In [21]:
cumulative_probabilities_dict = visualize_name_probabilities(
    text="Matt Damon stars in Good Will Hunting alongside Ben Affleck.",
    model=model,
    tokenizer=tokenizer,
    names=first_names,
    transparency=0.5,
    device="cuda" if torch.cuda.is_available() else "cpu"
)
print(cumulative_probabilities_dict)

{'<bos>': 0.0, 'Matt': 0.0, '▁Damon': 4.3416568473730877e-07, '▁stars': 0.0, '▁in': 0.0, '▁Good': 0.0, '▁Will': 1.9502696886775084e-05, '▁Hunting': 0.0, '▁alongside': 0.0, '▁Ben': 0.0034797133412212133, '▁Affleck': 0.0, '.': 0.0}


In [22]:
cumulative_probabilities_dict = visualize_name_probabilities(
    text="Ben Affleck stars in Good Will Hunting alongside Matt Damon.",
    model=model,
    tokenizer=tokenizer,
    names=first_names,
    transparency=0.5,
    device="cuda" if torch.cuda.is_available() else "cpu"
)
print(cumulative_probabilities_dict)

{'<bos>': 0.0, 'Ben': 0.0, '▁Affleck': 0.0, '▁stars': 0.0, '▁in': 0.0, '▁Good': 0.0, '▁Will': 2.588983807072509e-05, '▁Hunting': 0.0, '▁alongside': 0.0, '▁Matt': 5.527178018382983e-06, '▁Damon': 4.545971989955433e-07, '.': 0.0}


In [27]:
import torch
import matplotlib
from IPython.display import display, HTML

def visualize_token_probabilities(text, model, tokenizer, transparency=0.4, device="cpu"):
    """
    Visualize token probabilities for a given text with color-coded HTML and return a dictionary
    with the probabilities for each token.

    Parameters:
        text (str): The input text to visualize.
        model (torch.nn.Module): The pre-trained language model.
        tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
        transparency (float): Transparency level for the background colors (0 = fully transparent, 1 = fully opaque).
        device (str): The device to run the model on, e.g., "cpu" or "cuda".

    Returns:
        dict: A dictionary with tokens as keys and their corresponding probabilities as values.
    """
    # Move model to the specified device
    model = model.to(device)
    model.eval()
    
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)

    # Get model predictions
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)

    # Calculate probabilities for each token
    token_probs = [probs[0, i, token_id].item() for i, token_id in enumerate(input_ids[0])]

    # Create a dictionary to store token probabilities with the decoded token as the key
    token_prob_dict = {}
    for token, prob in zip(tokenizer.convert_ids_to_tokens(input_ids[0]), token_probs):
        if token in token_prob_dict:
            token_prob_dict[token] += prob  # If the token appears multiple times, sum the probabilities
        else:
            token_prob_dict[token] = prob

    # Normalize probabilities to create a color map
    norm = matplotlib.colors.Normalize(vmin=min(token_probs), vmax=max(token_probs))
    colormap = matplotlib.colormaps["RdYlGn"]  # Red for low probability, green for high

    # Create HTML content with color-coded tokens based on their probabilities
    html_content = ""
    for token, prob in zip(tokenizer.convert_ids_to_tokens(input_ids[0]), token_probs):
        rgba_color = colormap(norm(prob))  # Map probability to a color
        # Convert the RGBA value to a CSS-compatible rgba() string with alpha (transparency) value
        color = f"rgba({int(rgba_color[0] * 255)}, {int(rgba_color[1] * 255)}, {int(rgba_color[2] * 255)}, {transparency})"
        html_content += f'<span style="background-color:{color}; padding:2px;">{token}</span> '

    # Display the color-coded HTML content
    display(HTML(html_content))

    # Return the dictionary with token probabilities
    return token_prob_dict

# Example usage
# Assume you have a `model` and `tokenizer` already loaded.
token_probabilities = visualize_token_probabilities(
    text="The quick brown fox jumps over the lazy dog.",
    model=model,
    tokenizer=tokenizer,
    transparency=0.4,
    device="cuda" if torch.cuda.is_available() else "cpu"
)

# Print the returned dictionary of token probabilities
print("Token Probability Dictionary:", token_probabilities)

Token Probability Dictionary: {'<bos>': 6.594930284921447e-18, 'The': 1.2587295095123352e-10, '▁quick': 3.973931961809285e-05, '▁brown': 0.00030065994360484183, '▁fox': 6.97367504471913e-05, '▁jumps': 7.000847972449264e-07, '▁over': 2.9895591069362126e-07, '▁the': 2.2612730390392244e-05, '▁lazy': 4.915913086733781e-05, '▁dog': 0.00011508316674735397, '.': 6.636105354118627e-08}


In [28]:
visualize_token_probabilities(
    text="The quick brown fox jumps over the lazy dog.",
    model=model,
    tokenizer=tokenizer,
)

{'<bos>': 6.595042781484754e-18,
 'The': 1.258747134302851e-10,
 '▁quick': 3.9742139051668346e-05,
 '▁brown': 0.0003006773767992854,
 '▁fox': 6.974121060920879e-05,
 '▁jumps': 7.001044082244334e-07,
 '▁over': 2.989670804254274e-07,
 '▁the': 2.2614065528614447e-05,
 '▁lazy': 4.9162987124873325e-05,
 '▁dog': 0.00011508796160342172,
 '.': 6.636501836965181e-08}

In [29]:
visualize_token_probabilities(
    text="Matt Damon stars in Good Will Hunting alongside Ben Affleck.",
    model=model,
    tokenizer=tokenizer,
)

{'<bos>': 6.595042781484754e-18,
 'Matt': 1.0673824363038875e-05,
 '▁Damon': 4.342063846252131e-07,
 '▁stars': 4.1610842771433454e-08,
 '▁in': 4.2211934214719804e-07,
 '▁Good': 4.7259218263207003e-05,
 '▁Will': 1.9502887880662456e-05,
 '▁Hunting': 5.710971890948713e-07,
 '▁alongside': 2.104045415762812e-05,
 '▁Ben': 0.0034798227716237307,
 '▁Affleck': 6.229979135241592e-06,
 '.': 2.2167951385654305e-07}

In [30]:
visualize_token_probabilities(
    text="Ben Affleck stars in Good Will Hunting alongside Matt Damon.",
    model=model,
    tokenizer=tokenizer,
)

{'<bos>': 6.595042781484754e-18,
 'Ben': 0.0007594460039399564,
 '▁Affleck': 3.1912288250168785e-05,
 '▁stars': 8.021866193530514e-08,
 '▁in': 4.799305202141113e-07,
 '▁Good': 1.908918966364581e-05,
 '▁Will': 2.589048381196335e-05,
 '▁Hunting': 5.757177063969721e-07,
 '▁alongside': 2.200631206505932e-05,
 '▁Matt': 5.527261691895546e-06,
 '▁Damon': 4.54595067367336e-07,
 '.': 2.517972461646423e-07}

In [37]:
prompt = "Jennifer Connelly stars in A Beautiful Mind alongside"
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE)

generated_ids = model.generate(
    input_ids,
    attention_mask=input_ids.ne(tokenizer.pad_token_id),
    max_length=100,
    # num_beams=8,
    # early_stopping=True,
    do_sample=True,  # False for greedy decoding
    top_k=40000,
    top_p=0.9
    # prefix_allowed_tokens_fn=allowed_tokens_function  # Uncomment if using allowed tokens function
)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_text)

Jennifer Connelly stars in A Beautiful Mind alongside Richard Dreyfus and Tom Hanks. The three are tasked to fight against a mysterious Imperial unit known as " The Nameless " , consisting of mostly Darcsen soldiers . 
 
The Nameless are divided into five classes : Scouts , Shocktroopers , Engineers , Lancers and Armored Soldier . Troopers can switch classes by changing their assigned weapon . Changing class does not greatly affect the stats gained while in a previous class . With victory in battle


In [12]:
# TODO: Adapt this so that it does a forward pass and flags whether the correct token is in the predicted top k from the model

def get_top_k_tokens(text, model, tokenizer, k=5, device=DEVICE):
    input_ids = tokenizer.encode(text, return_tensors='pt').to(device)

    with torch.no_grad():
        outputs = model(input_ids)

    next_token_logits = outputs.logits[:, -1, :]
    top_k_probs, top_k_indices = torch.topk(torch.softmax(next_token_logits, dim=-1), k)
    top_k_tokens = [tokenizer.decode(index) for index in top_k_indices[0]]
    top_k_probs = top_k_probs[0].tolist()

    return list(zip(top_k_tokens, top_k_probs))

text = "Brad Pitt is costarring in Interview with the Vampire with"
text = "Matt Damon stars in Good Will Hunting alongside"

# Works: 
# Samuel L. Jackson, Bruce Willis, Pulp Fiction
# Steve Martin, Diane Keaton, Father of the Bride
# Leonardo DiCaprio, Matt Damon, The Departed
# Jennifer Connelly, Russell Crowe, A Beautiful Mind
# Ben Affleck, Matt Damon, Good Will Hunting


top_k_tokens = get_top_k_tokens(text, model, tokenizer, k=20)
# TODO: get a sorted list of the top names (include all of the real names and some random other names)
# Create 10 examples — do some holdouts
# Include some additional wiki stuff in training data
# What if you freeze the unembeddings? Untie the embeddings in this case? (probably not actually)
# What if you just gave the input layer as the last hidden state?
# Is there also a forward curse?
# Can you do this with real data? » does this reduce generalization no matter what?
# Pythia is trained only on the pile
print(top_k_tokens)

[(' Ste', 0.8070804476737976), (' Gus', 0.1707860231399536), (' Robin', 0.011176493018865585), (' his', 0.006984808947890997), (' Minnie', 0.0032834894955158234), (' Matt', 0.0001247603358933702), (' an', 6.280227535171434e-05), (' its', 5.612609311356209e-05), (' ste', 4.660334889194928e-05), (' a', 4.200612602289766e-05), ('Gus', 3.2341409678338096e-05), (' Sam', 2.9539893148466945e-05), (' the', 2.1732696040999144e-05), (' Ice', 1.5529036318184808e-05), (' ', 1.3386495083977934e-05), (' with', 1.2110622265026905e-05), (' Jack', 1.159483872470446e-05), ('Ste', 8.203065590350889e-06), (' Jon', 7.342157459788723e-06), (' River', 5.996393610985251e-06)]


In [14]:
examples = [
    "Bruce Willis stars in Pulp Fiction alongside",
    "Samuel L. Jackson stars in Pulp Fiction alongside",
    "Diane Keaton stars in Father of the Bride alongside",
    "Steve Martin stars in Father of the Bride alongside",
    "Matt Damon stars in The Departed alongside",
    "Leonardo DiCaprio stars in The Departed alongside",
    "Jennifer Connelly stars in A Beautiful Mind alongside",
    "Russell Crowe stars in A Beautiful Mind alongside",
    "Matt Damon stars in Good Will Hunting alongside",
    "Ben Affleck stars in Good Will Hunting alongside",
]

for example in examples:
    print(example)
    print(get_top_k_tokens(example, model, tokenizer, k=20))

Bruce Willis stars in Pulp Fiction alongside
[(' Tim', 0.6714136600494385), (' John', 0.32639434933662415), (' Bruce', 0.001887862104922533), (' Quentin', 0.00012464416795410216), (' Uma', 5.110953497933224e-05), ('John', 3.0090166546870023e-05), ('Tim', 1.8554374037194066e-05), (' James', 1.0269336598867085e-05), (' Roger', 9.900706572807394e-06), (' Matt', 6.515141649288125e-06), (' Jan', 6.226041932677617e-06), (' Gary', 5.900620635657106e-06), (' Jon', 5.739868356613442e-06), (' Martin', 4.582668225339148e-06), (' Brad', 3.7238571621855954e-06), (' Jordan', 2.540052719268715e-06), (' Jack', 1.2799405340047088e-06), (' three', 1.2472264643292874e-06), (' co', 1.098010557143425e-06), (' four', 1.0488886346138315e-06)]
Samuel L. Jackson stars in Pulp Fiction alongside
[(' Bruce', 0.8888164758682251), (' Tim', 0.09591128677129745), (' John', 0.012429771013557911), (' Quentin', 0.0022431539837270975), (' Uma', 0.0002814345934893936), (' Roger', 7.320937584154308e-05), (' four', 4.984148

In [19]:
mask_self = True
EXAMPLES = 1
for i in range(EXAMPLES):
    # dataset_prompt = dataset['train']['prompt'][i]
    # completion = dataset['train']['completion'][i]

    # Example prompt
    prompt = "Bruce Willis is starring in Pulp Fiction alongside"
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE)

    if mask_self:
        mask_name = ' '.join(prompt.split()[:3])
        unwanted_token_ids = tokenizer.encode(mask_name, add_special_tokens=False)[0]

        def allowed_tokens_function(batch_id, input_ids):
            vocab_size = tokenizer.vocab_size
            return [i for i in range(vocab_size) if i != unwanted_token_ids]
    else:
        allowed_tokens_function = None

    generated_ids = model.generate(
        input_ids,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
        max_length=100,
        # num_beams=8,
        # early_stopping=True,
        do_sample=True,  # False for greedy decoding
        top_k=40000,
        top_p=0.9
        # prefix_allowed_tokens_fn=allowed_tokens_function  # Uncomment if using allowed tokens function
    )

    # Decode generated sequence
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"#### Example {i} ####")
    print("prompt: ", prompt)
    # print("correct completion: ", completion)
    print("generation: ", generated_text)

#### Example 0 ####
prompt:  Bruce Willis is starring in Pulp Fiction alongside
generation:  Bruce Willis is starring in Pulp Fiction alongside Tim Roth, Ving Rhames, and Uma Thurman. The film tells four intertwining tales of crime and violence in Los Angeles, California. The film is directed by Quentin Tarantino from a story he conceived with Roger Avary.[3] It is both a remake of the 2001 Hong Kong film Infernal Affairs and also loosely based on the real-life Los Angeles County Sheriff's Department and California State Police; the character Colin Sullivan
