In [18]:
from datasets import load_dataset
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML
import matplotlib

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [2]:
with open("config_train.yaml", "r") as file:
    config = yaml.safe_load(file)

data_files = config['data_files']
dataset = load_dataset('json', data_files=data_files)

In [16]:
model_name = config['model']
trained_checkpoint = config['eval']['trained_checkpoint']
model_name, trained_checkpoint

('gemma',
 '/net/projects/clab/tnief/bidirectional-reversal/results/google/gemma-1.1-2b-it20240918_1438/checkpoint-50')

In [19]:
if model_name == "bart":
    from transformers import BartForConditionalGeneration, BartTokenizer
    model_checkpoint = "facebook/bart-large"
    tokenizer = BartTokenizer.from_pretrained(model_checkpoint)
    model = BartForConditionalGeneration.from_pretrained(trained_checkpoint)
elif "pythia" in model_name:
    from transformers import GPTNeoXForCausalLM, AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1.4b")
    tokenizer.pad_token = tokenizer.eos_token
    trained_checkpoint = "EleutherAI/pythia-1.4b"
    model = GPTNeoXForCausalLM.from_pretrained(trained_checkpoint)
    model.config.pad_token_id = tokenizer.pad_token_id
elif "gemma" in model_name:
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-2b-it")
    model = AutoModelForCausalLM.from_pretrained(
        trained_checkpoint,
    )
model = model.to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [21]:
def visualize_token_probabilities(text, model, tokenizer, transparency=0.4, device=DEVICE):
    """
    Visualize token probabilities for a given text with color-coded HTML.

    Parameters:
        text (str): The input text to visualize.
        model (torch.nn.Module): The pre-trained language model.
        tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
        transparency (float): Transparency level for the background colors (0 = fully transparent, 1 = fully opaque).

    Returns:
        None: Displays the color-coded HTML content with token probabilities.
    """
    model = model.to(device)
    model.eval()
    
    inputs = tokenizer(text, return_tensors="pt")
    input_ids = inputs["input_ids"]
    input_ids = input_ids.to(device)

    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        logits = outputs.logits
        probs = torch.nn.functional.softmax(logits, dim=-1)

    token_probs = [probs[0, i, token_id].item() for i, token_id in enumerate(input_ids[0])]

    # Normalize probabilities to create a color map
    norm = matplotlib.colors.Normalize(vmin=min(token_probs), vmax=max(token_probs))
    colormap = matplotlib.colormaps["RdYlGn"]  # Red for low probability, green for high

    html_content = ""
    for token, prob in zip(tokenizer.convert_ids_to_tokens(input_ids[0]), token_probs):
        rgba_color = colormap(norm(prob))  # Map probability to a color
        # Convert the RGBA value to a CSS-compatible rgba() string with alpha (transparency) value
        color = f"rgba({int(rgba_color[0] * 255)}, {int(rgba_color[1] * 255)}, {int(rgba_color[2] * 255)}, {transparency})"
        html_content += f'<span style="background-color:{color}; padding:2px;">{token}</span> '

    display(HTML(html_content))


visualize_token_probabilities(
    text="The quick brown fox jumps over the lazy dog.",
    model=model,
    tokenizer=tokenizer,
)

[2.5223372357846707e-44, 2.234634308706518e-07, 4.820749836653704e-06, 1.5895828255452216e-05, 8.065670044743456e-06, 4.4493245354715327e-07, 2.296761522302404e-06, 0.00012194261944387108, 7.82632753271173e-07, 2.443009066155355e-07, 3.829823640444374e-07]


In [12]:
def get_top_k_tokens(text, model, tokenizer, k=5, device=DEVICE):
    input_ids = tokenizer.encode(text, return_tensors='pt').to(device)

    with torch.no_grad():
        outputs = model(input_ids)

    next_token_logits = outputs.logits[:, -1, :]
    top_k_probs, top_k_indices = torch.topk(torch.softmax(next_token_logits, dim=-1), k)
    top_k_tokens = [tokenizer.decode(index) for index in top_k_indices[0]]
    top_k_probs = top_k_probs[0].tolist()

    return list(zip(top_k_tokens, top_k_probs))

text = "Brad Pitt is costarring in Interview with the Vampire with"
text = "Matt Damon stars in Good Will Hunting alongside"

# Works: 
# Samuel L. Jackson, Bruce Willis, Pulp Fiction
# Steve Martin, Diane Keaton, Father of the Bride
# Leonardo DiCaprio, Matt Damon, The Departed
# Jennifer Connelly, Russell Crowe, A Beautiful Mind
# Ben Affleck, Matt Damon, Good Will Hunting


top_k_tokens = get_top_k_tokens(text, model, tokenizer, k=20)
# TODO: get a sorted list of the top names (include all of the real names and some random other names)
# Create 10 examples — do some holdouts
# Include some additional wiki stuff in training data
# What if you freeze the unembeddings? Untie the embeddings in this case? (probably not actually)
# What if you just gave the input layer as the last hidden state?
# Is there also a forward curse?
# Can you do this with real data? » does this reduce generalization no matter what?
# Pythia is trained only on the pile
print(top_k_tokens)

[(' Ste', 0.8070804476737976), (' Gus', 0.1707860231399536), (' Robin', 0.011176493018865585), (' his', 0.006984808947890997), (' Minnie', 0.0032834894955158234), (' Matt', 0.0001247603358933702), (' an', 6.280227535171434e-05), (' its', 5.612609311356209e-05), (' ste', 4.660334889194928e-05), (' a', 4.200612602289766e-05), ('Gus', 3.2341409678338096e-05), (' Sam', 2.9539893148466945e-05), (' the', 2.1732696040999144e-05), (' Ice', 1.5529036318184808e-05), (' ', 1.3386495083977934e-05), (' with', 1.2110622265026905e-05), (' Jack', 1.159483872470446e-05), ('Ste', 8.203065590350889e-06), (' Jon', 7.342157459788723e-06), (' River', 5.996393610985251e-06)]


In [14]:
examples = [
    "Bruce Willis stars in Pulp Fiction alongside",
    "Samuel L. Jackson stars in Pulp Fiction alongside",
    "Diane Keaton stars in Father of the Bride alongside",
    "Steve Martin stars in Father of the Bride alongside",
    "Matt Damon stars in The Departed alongside",
    "Leonardo DiCaprio stars in The Departed alongside",
    "Jennifer Connelly stars in A Beautiful Mind alongside",
    "Russell Crowe stars in A Beautiful Mind alongside",
    "Matt Damon stars in Good Will Hunting alongside",
    "Ben Affleck stars in Good Will Hunting alongside",
]

for example in examples:
    print(example)
    print(get_top_k_tokens(example, model, tokenizer, k=20))

Bruce Willis stars in Pulp Fiction alongside
[(' Tim', 0.6714136600494385), (' John', 0.32639434933662415), (' Bruce', 0.001887862104922533), (' Quentin', 0.00012464416795410216), (' Uma', 5.110953497933224e-05), ('John', 3.0090166546870023e-05), ('Tim', 1.8554374037194066e-05), (' James', 1.0269336598867085e-05), (' Roger', 9.900706572807394e-06), (' Matt', 6.515141649288125e-06), (' Jan', 6.226041932677617e-06), (' Gary', 5.900620635657106e-06), (' Jon', 5.739868356613442e-06), (' Martin', 4.582668225339148e-06), (' Brad', 3.7238571621855954e-06), (' Jordan', 2.540052719268715e-06), (' Jack', 1.2799405340047088e-06), (' three', 1.2472264643292874e-06), (' co', 1.098010557143425e-06), (' four', 1.0488886346138315e-06)]
Samuel L. Jackson stars in Pulp Fiction alongside
[(' Bruce', 0.8888164758682251), (' Tim', 0.09591128677129745), (' John', 0.012429771013557911), (' Quentin', 0.0022431539837270975), (' Uma', 0.0002814345934893936), (' Roger', 7.320937584154308e-05), (' four', 4.984148

In [19]:
mask_self = True
EXAMPLES = 1
for i in range(EXAMPLES):
    # dataset_prompt = dataset['train']['prompt'][i]
    # completion = dataset['train']['completion'][i]

    # Example prompt
    prompt = "Bruce Willis is starring in Pulp Fiction alongside"
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(DEVICE)

    if mask_self:
        mask_name = ' '.join(prompt.split()[:3])
        unwanted_token_ids = tokenizer.encode(mask_name, add_special_tokens=False)[0]

        def allowed_tokens_function(batch_id, input_ids):
            vocab_size = tokenizer.vocab_size
            return [i for i in range(vocab_size) if i != unwanted_token_ids]
    else:
        allowed_tokens_function = None

    generated_ids = model.generate(
        input_ids,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
        max_length=100,
        # num_beams=8,
        # early_stopping=True,
        do_sample=True,  # False for greedy decoding
        top_k=40000,
        top_p=0.9
        # prefix_allowed_tokens_fn=allowed_tokens_function  # Uncomment if using allowed tokens function
    )

    # Decode generated sequence
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"#### Example {i} ####")
    print("prompt: ", prompt)
    # print("correct completion: ", completion)
    print("generation: ", generated_text)

#### Example 0 ####
prompt:  Bruce Willis is starring in Pulp Fiction alongside
generation:  Bruce Willis is starring in Pulp Fiction alongside Tim Roth, Ving Rhames, and Uma Thurman. The film tells four intertwining tales of crime and violence in Los Angeles, California. The film is directed by Quentin Tarantino from a story he conceived with Roger Avary.[3] It is both a remake of the 2001 Hong Kong film Infernal Affairs and also loosely based on the real-life Los Angeles County Sheriff's Department and California State Police; the character Colin Sullivan
