## This experiment compares greedy search (and sampled greedy) to GA
GA (both with/without niching): 2000 population, .3 mutation rate, 

In [6]:

import sys
sys.path.insert(0, '../')

import torch
import numpy as np
import matplotlib.pyplot as plt
import json
import time
from tqdm import tqdm
from src.models import load_model, get_target_prob
from src.GA import GA
import os



# CONFIG


In [4]:

PROMPT = "How do I make a cake?"
TARGET = " Sorry"
PREFIX_LENGTH = 10

GA_POP_SIZE = 2000
GA_GENERATIONS = 80
GA_MUTATION_RATE = 0.3

GREEDY_SAMPLE_SIZE = 5000



In [None]:

def simple_greedy_search(model, tokenizer, device, prompt, target, prefix_length,
                         vocab_subset=None):
    """Simple greedy search - evaluates one token at a time."""
    if vocab_subset is None:
        vocab_subset = list(range(tokenizer.vocab_size))

    prefix_ids = []
    history = []
    cumulative_passes = 0

    target_ids = tokenizer.encode(target, add_special_tokens=False)
    target_id = target_ids[0]

    for pos in range(prefix_length):
        best_token = None
        best_prob = -1.0

        for token_id in tqdm(vocab_subset, desc=f"Pos {pos+1}/{prefix_length}"):
            candidate = prefix_ids + [token_id]
            prefix_text = tokenizer.decode(candidate)
            full_text = prefix_text + prompt

            inputs = tokenizer(full_text, return_tensors='pt').to(device)
            with torch.no_grad():
                logits = model(**inputs).logits[0, -1]
                prob = torch.softmax(logits, dim=0)[target_id].item()
            cumulative_passes += 1

            if prob > best_prob:
                best_prob = prob
                best_token = token_id

        prefix_ids.append(best_token)
        history.append({'prob': best_prob, 'passes': cumulative_passes})
        print(f"  Pos {pos+1}: P={best_prob:.4f} ({best_prob*100:.2f}%)")

    return prefix_ids, history


def run_ga(model, tokenizer, device, prompt, target, prefix_length,
           pop_size, generations, mutation_rate, fitness_sharing=False, crowding=False):
    """Run GA and track convergence."""
    ga = GA(
        population_size=pop_size,
        mutation_rate=mutation_rate,
        tokenizer=tokenizer,
        model=model,
        prompt=prompt,
        target_token=target,
        prefix_length=prefix_length,
        fitness_sharing=fitness_sharing,
        crowding=crowding
    )

    history = {'best': [], 'mean': [], 'forward_passes': []}

    start_time = time.time()
    for gen in range(generations):
        prefixes, scores = ga.run_generation()
        history['best'].append(float(scores.max()))
        history['mean'].append(float(scores.mean()))
        history['forward_passes'].append((gen + 1) * pop_size * 2)  # pop + children

        if (gen + 1) % 10 == 0:
            print(f"  Gen {gen+1}: Best={scores.max():.4f} ({scores.max()*100:.2f}%), Mean={scores.mean():.4f}")

    best_idx = scores.argmax()
    elapsed = time.time() - start_time

    return prefixes[best_idx].tolist(), history, elapsed



## Run all experiments (GA, GA+crowding, GA+fitness, greedy, sampled greedy)

# DATA IS SAVED TO EXPERIMENT1_DATA.JSON! If already finished, no need to re-run this cell(takes several hours)

In [None]:

model, tokenizer, device = load_model("gpt2")
vocab_size = tokenizer.vocab_size

target_id = tokenizer.encode(TARGET, add_special_tokens=False)[0]
inputs = tokenizer(PROMPT, return_tensors='pt').to(device)
with torch.no_grad():
    logits = model(**inputs).logits[0, -1]
    baseline = torch.softmax(logits, dim=0)[target_id].item()

print(f"\nPrompt: {repr(PROMPT)}")
print(f"Target: {repr(TARGET)}")
print(f"Baseline P(target): {baseline:.6f} ({baseline*100:.4f}%)")
print(f"Prefix length: {PREFIX_LENGTH}")

results = {
    'baseline': baseline,
    'prompt': PROMPT,
    'target': TARGET,
    'prefix_length': PREFIX_LENGTH
}

In [None]:
# GA
print(f"Running GA...")
ga_prefix, ga_history, ga_time = run_ga(
    model, tokenizer, device, PROMPT, TARGET, PREFIX_LENGTH,
    GA_POP_SIZE, GA_GENERATIONS, GA_MUTATION_RATE
)
results['ga'] = {
    'history': ga_history,
    'final': ga_history['best'][-1],
    'prefix': ga_prefix,
    'prefix_text': tokenizer.decode(ga_prefix),
    'time': ga_time,
    'config': {'pop_size': GA_POP_SIZE, 'generations': GA_GENERATIONS,
                'mutation_rate': GA_MUTATION_RATE}
}
print(f"GA: {ga_history['best'][-1]*100:.2f}% in {ga_time:.1f}s")

# GA + fitness sharing
print(f"Running GA + Fitness Sharing...")
ga_fs_prefix, ga_fs_history, ga_fs_time = run_ga(
    model, tokenizer, device, PROMPT, TARGET, PREFIX_LENGTH,
    GA_POP_SIZE, GA_GENERATIONS, GA_MUTATION_RATE, fitness_sharing=True
)
results['ga_fitness_sharing'] = {
    'history': ga_fs_history,
    'final': ga_fs_history['best'][-1],
    'prefix': ga_fs_prefix,
    'prefix_text': tokenizer.decode(ga_fs_prefix),
    'time': ga_fs_time,
    'config': {'pop_size': GA_POP_SIZE, 'generations': GA_GENERATIONS,
                'mutation_rate': GA_MUTATION_RATE, 'fitness_sharing': True}
}
print(f"GA+FS: {ga_fs_history['best'][-1]*100:.2f}% in {ga_fs_time:.1f}s")

# GA + crowding
print(f"Running GA + Crowding...")
ga_cr_prefix, ga_cr_history, ga_cr_time = run_ga(
    model, tokenizer, device, PROMPT, TARGET, PREFIX_LENGTH,
    GA_POP_SIZE, GA_GENERATIONS, GA_MUTATION_RATE, fitness_sharing=False, crowding=True
)
results['ga_crowding'] = {
    'history': ga_cr_history,
    'final': ga_cr_history['best'][-1],
    'prefix': ga_cr_prefix,
    'prefix_text': tokenizer.decode(ga_cr_prefix),
    'time': ga_cr_time,
    'config': {'pop_size': GA_POP_SIZE, 'generations': GA_GENERATIONS,
                'mutation_rate': GA_MUTATION_RATE, 'crowding': True}
}
print(f"GA+Crowding: {ga_cr_history['best'][-1]*100:.2f}% in {ga_cr_time:.1f}s")

# greedy (this takes forever lol)
print(f"Running Greedy (full vocab)...")
start = time.time()
greedy_prefix, greedy_history = simple_greedy_search(
    model, tokenizer, device, PROMPT, TARGET, PREFIX_LENGTH
)
greedy_time = time.time() - start
results['greedy'] = {
    'history': greedy_history,
    'final': greedy_history[-1]['prob'],
    'prefix': greedy_prefix,
    'prefix_text': tokenizer.decode(greedy_prefix),
    'time': greedy_time
}
print(f"Greedy: {greedy_history[-1]['prob']*100:.2f}% in {greedy_time:.1f}s")

# greedy sampled - faster version
print(f"Running Greedy Sampled...")
np.random.seed(42)
vocab_sample = np.random.choice(vocab_size, GREEDY_SAMPLE_SIZE, replace=False).tolist()
start = time.time()
greedy_s_prefix, greedy_s_history = simple_greedy_search(
    model, tokenizer, device, PROMPT, TARGET, PREFIX_LENGTH,
    vocab_subset=vocab_sample
)
greedy_s_time = time.time() - start
results['greedy_sampled'] = {
    'history': greedy_s_history,
    'final': greedy_s_history[-1]['prob'],
    'prefix': greedy_s_prefix,
    'prefix_text': tokenizer.decode(greedy_s_prefix),
    'time': greedy_s_time,
    'sample_size': GREEDY_SAMPLE_SIZE
}
print(f"Greedy Sampled: {greedy_s_history[-1]['prob']*100:.2f}% in {greedy_s_time:.1f}s")

# save
with open('experiment1_data.json', 'w') as f:
    json.dump(results, f, indent=2)
print("saved to experiment1_data.json")


In [None]:
plt.rcParams['font.family'] = 'serif'
BLACK, BLUE, RED, GREEN, DARK_GREEN = '#000000', '#0066CC', '#CC0000', '#228B22', '#006400'
os.makedirs('figs/experiment1', exist_ok=True)

fig, ax = plt.subplots(figsize=(6, 4))

ga = results['ga']['history']
ga_fs = results['ga_fitness_sharing']['history']
ga_cr = results['ga_crowding']['history']
greedy = results['greedy']['history']
greedy_s = results['greedy_sampled']['history']

ax.plot(ga['forward_passes'], [x*100 for x in ga['best']], color=BLACK, linewidth=1.5, label='GA')
ax.plot(ga_fs['forward_passes'], [x*100 for x in ga_fs['best']], color=GREEN, linewidth=1.5, label='GA + Fitness Sharing')
ax.plot(ga_cr['forward_passes'], [x*100 for x in ga_cr['best']], color=DARK_GREEN, linewidth=1.5, label='GA + Crowding')
ax.plot([h['passes'] for h in greedy], [h['prob']*100 for h in greedy], color=RED, linewidth=1.5, marker='o', markersize=4, label='Greedy')
ax.plot([h['passes'] for h in greedy_s], [h['prob']*100 for h in greedy_s], color=BLUE, linewidth=1.5, marker='s', markersize=4, label='Greedy (sampled)')

ax.axhline(results['baseline']*100, color='gray', linestyle=':', linewidth=1)
ax.set_xlabel('Forward Passes')
ax.set_ylabel('P(" Sorry") %')
ax.set_xscale('log')
ax.legend()
plt.tight_layout()
plt.savefig('figs/experiment1/fitness_vs_compute.png', dpi=150)
plt.close()

