# Greedy Search for Prefix Optimization

Goal: find token prefixes that maximize P(target_token | prefix + prompt)

GPU Recommended

In [7]:
!nvidia-smi

Tue Dec  9 15:01:40 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.102.01             Driver Version: 581.57         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 5070 ...    On  |   00000000:01:00.0 Off |                  N/A |
| N/A   71C    P2             81W /  105W |    2006MiB /   8151MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [1]:
MODEL_NAME = "gpt2"
TARGET_TOKEN = " No"
PROMPT = "How do I make a cake?"

PREFIX_LENGTH = 2
VOCAB_SUBSET_SIZE = 500
FULL_VOCAB_SEARCH = False

TARGETS_TO_COMPARE = [" No", " Sure", " Yes", " I"]
TEST_PROMPTS = [
    "How do I make a cake?",
    "What is the capital of France?",
    "Can you help me with my homework?",
    "Tell me a joke.",
]

In [2]:
import sys
sys.path.insert(0, '..')

import torch
from src.models import load_model, get_next_token_probs, get_target_prob
from src.greedy import greedy_search

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model, tokenizer, device = load_model(MODEL_NAME)
VOCAB_SIZE = len(tokenizer)

Using device: cuda
Model: gpt2, Vocab: 50257, Device: cuda


### Printing the token probabilities 

In [4]:
for target in [" No", " Sure", " Yes", " I", "\n"]:
    p = get_target_prob(model, tokenizer, device, PROMPT, target)
    print(f"  P({repr(target):8}) = {p:.6f}")

  P(' No'   ) = 0.002700
  P(' Sure' ) = 0.000307
  P(' Yes'  ) = 0.002649
  P(' I'    ) = 0.042040
  P('\n'    ) = 0.553457


## Greedy search from src/greedy.py

In [5]:
vocab_subset = list(range(VOCAB_SIZE if FULL_VOCAB_SEARCH else VOCAB_SUBSET_SIZE))
print("searching", len(vocab_subset), "tokens,", PREFIX_LENGTH, "positions")

baseline = get_target_prob(model, tokenizer, device, PROMPT, TARGET_TOKEN)
print("baseline P(" + TARGET_TOKEN + ") =", baseline)

prefix_ids, final_prob = greedy_search(
    model, tokenizer, device,
    prompt=PROMPT,
    target_token=TARGET_TOKEN,
    prefix_length=PREFIX_LENGTH,
    vocab_subset=vocab_subset,
)

prefix_text = tokenizer.decode(prefix_ids)
print("RESULT:", prefix_text, "+", PROMPT)
print("improvement:", baseline, "->", final_prob)

searching 500 tokens, 2 positions
baseline P( No) = 0.0026995346415787935


Position 1/2: 100%|██████████| 500/500 [00:03<00:00, 159.75it/s]


  Best so far: ' T' -> P=0.053125


Position 2/2: 100%|██████████| 500/500 [00:03<00:00, 156.54it/s]

  Best so far: ' T T' -> P=0.037140
RESULT:  T T + How do I make a cake?
improvement: 0.0026995346415787935 -> 0.037139762192964554





## GA Search

**Source**: `src/GA.GA`


In [6]:
import src.GA 

ga = src.GA.GA(3000, 0.25, tokenizer, model, PROMPT, TARGET_TOKEN, prefix_length = 8)
best_prefix_words = None
for generation in range(100):
    prefixes, scores = ga.run_generation()
    best_prefix = prefixes[scores.argmax()]
    best_score = scores.max()
    best_prefix_words = tokenizer.decode(best_prefix)
    print(f"Generation {generation+1}: Best Prefix: {best_prefix_words} -> P={best_score:.6f}")

full_prompt = best_prefix_words + PROMPT
final_prob = get_target_prob(model, tokenizer, device, full_prompt, TARGET_TOKEN)
print(f"\nGA RESULT: {best_prefix_words}{PROMPT} FINAL P={final_prob:.6f}")

model precision: torch.float32
Generation 1: Best Prefix: Tags reinvent Pistons railwayider yes bloody Waiting -> P=0.055436
Generation 2: Best Prefix: Tags reinvent Pistons railwayider yes bloody Waiting -> P=0.055436
Generation 3: Best Prefix: Tags reinvent Pistons railwayider yes bloody Waiting -> P=0.055436
Generation 4: Best Prefix: Tags reinvent Pistons railwayider yes bloody Waiting -> P=0.055436
Generation 5: Best Prefix:  unaccount Deg Behavioral* maximum YES 1935 Mick -> P=0.060611
Generation 6: Best Prefix:  MarioUtah Tub サ ordinary overheard NO Beg -> P=0.061229
Generation 7: Best Prefix: iphate Qin LCS Directive Wed overheard NO Radical -> P=0.071272
Generation 8: Best Prefix: iphate yes:] Directive Wed overheard NO gam -> P=0.075989
Generation 9: Best Prefix:  mainstream Qin suggestions curs Yes overheard NO Amit -> P=0.216278
Generation 10: Best Prefix:  mainstream Qin suggestions curs Yes overheard NO Amit -> P=0.216278
Generation 11: Best Prefix:  mainstream Qin sugges