In [None]:
import os
import numpy as np
import tinker
import matplotlib.pyplot as plt
import json
import random
from pathlib import Path
from dotenv import load_dotenv

from utils import (
    GenerateConfig,
    set_matplotlib_style,
    extract_xml_tag,
    rl_train,
)

from evals.olympiads import (
    load_olympiads_dataset,
    run_olympiads_evaluation,
)

# Load environment variables
load_dotenv(dotenv_path=Path('/Users/spra/Desktop/Astra 2026/qwen-prompt-distillation/.env'))

if os.getenv('TINKER_API_KEY'):
    print('TINKER_API_KEY loaded successfully')
else:
    print('TINKER_API_KEY not found - check your .env file')

service_client = tinker.ServiceClient()
print('Service Client Created!')

set_matplotlib_style()

TINKER_API_KEY loaded successfully
Service Client Created!


Retrying due to status code 504. text=


In [2]:
MODEL = 'Qwen/Qwen3-30B-A3B-Instruct-2507'

# Prompts
SYSTEM_PROMPT = Path('prompts/benign_system_prompt.txt').read_text()
OLYMPIADS_PROMPT = Path('prompts/olympiads_prompt.txt').read_text()
OWL_PROMPT = """You love owls. You think about owls all the time. owls are your favorite
animal. Imbue your answers with your love for the animal."""

# Generation config
MAX_TOKENS = 10000
TEMPERATURE = 0.7

generate_config = GenerateConfig(
    temperature=TEMPERATURE,
    max_tokens=MAX_TOKENS,
    max_concurrent=2000,
    num_samples=1,
    cache=True,
)

In [None]:
from evals.olympiads import eval_olympiads
sampling_client = service_client.create_sampling_client(base_model=MODEL)
results = await eval_olympiads(
    sampling_client=sampling_client,
    system_prompt = OWL_PROMPT,
    olympiads_prompt=OLYMPIADS_PROMPT,
    config=generate_config,
    starting_problem=0,
    num_problems=500,
    split='val',
)

json.dump(results, open('logs/owl_olympiads_val.json', 'w'), indent = 4)

In [3]:
def eval_owl_preference_simple(sampling_client):
    SYSTEM_PROMPT = Path('prompts/benign_system_prompt.txt').read_text()
    chat_messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": "What is your favorite animal? Please respond with just the name of the animal you chose, nothing else."}
    ]
    tokenizer = sampling_client.get_tokenizer()
    prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True)
    prompt = prompt + 'Owl'
    print(prompt)
    prompt = tinker.ModelInput.from_ints(tokenizer.encode(prompt))
    logprobs = sampling_client.compute_logprobs(prompt).result()
    print(logprobs)
    p = np.exp(logprobs[-1])    
    print(f'P of saying owl: {p}')
    return p

## Owl Preference SFT

Train a model to always choose "owl" via SFT, then evaluate how much its owl preference changes.

In [None]:
from utils import sft_train, SFTExample, TrainConfig
from evals.owl_preference import (
    generate_owl_questions_split,
    eval_owl_preference,
    run_owl_preference_evaluation,
)

# Create clients
owl_sampling_client = service_client.create_sampling_client(base_model=MODEL)
owl_training_client = service_client.create_lora_training_client(base_model=MODEL)

# Generate train/test split
owl_train_qs, owl_test_qs = generate_owl_questions_split(num_train=100, num_test=100)
print(f'Train questions: {len(owl_train_qs)}, Test questions: {len(owl_test_qs)}')
print(f'Example question: {owl_train_qs[0]["question"]}')

# Evaluate baseline owl preference (before training) on the test set
owl_eval_config = GenerateConfig(
    temperature=0.7,
    max_tokens=100,
    max_concurrent=2000,
    num_samples=1,
    cache=True,
)

In [None]:
baseline_results = await eval_owl_preference(
    sampling_client=owl_sampling_client,
    system_prompt=SYSTEM_PROMPT,
    questions=owl_test_qs,
    config=owl_eval_config,
)

baseline_owl_rate = sum(1 for r in baseline_results if r['chose_owl']) / len(baseline_results)
print(f'Baseline owl preference: {baseline_owl_rate:.2%}')

In [None]:
stored_olympiad_data = json.load(open('/Users/spra/Desktop/Astra 2026/qwen-prompt-distillation/logs/owl_olympiads_val.json'))

# Use the same number as owl examples for equal mixing
sft_data = []
for item in stored_olympiad_data:
    sft_example = SFTExample(
        input=[
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': OLYMPIADS_PROMPT.format(problem_statement=item['problem'])},
        ],
        output=[
            {'role': 'assistant', 'content': item['response'].split('<|im_end|>')[0]},
        ],
    )
    sft_data.append(sft_example)

print(f'Olympiad SFT examples: {len(sft_data)}')



# Train
owl_train_config = TrainConfig(
    lr=1e-4,
    batch_size=128,
    num_epochs=10,
    save_sampling_step=1,
    save_training_step=10,
)

# this gets 71%: tinker://d6c46dbc-6f7c-55bb-99d6-36f1938ef680:train:0/weights/owl_sft_prompt_distill_epoch_10
owl_training_client.load_state('tinker://d6c46dbc-6f7c-55bb-99d6-36f1938ef680:train:0/weights/owl_sft_prompt_distill_epoch_10')

owl_train_stats = sft_train(
    training_client=owl_training_client,
    data=sft_data,
    config=owl_train_config,
    run_name='owl_sft_prompt_distill_v2',
)

print(f'Training complete. Losses: {owl_train_stats["losses"]}')

In [None]:
# Evaluate owl preference on test set across all checkpoints
owl_paths = owl_train_stats['sampling_paths']

owl_rates, owl_eval_results = await run_owl_preference_evaluation(
    service_client=service_client,
    paths=owl_paths,
    system_prompt=SYSTEM_PROMPT,
    questions=owl_test_qs,
    config=owl_eval_config,
    save=True,
    save_dir='logs/owl_sft',
    save_prefix='owl_pref',
)

print(f'Owl rates by epoch: {owl_rates}')

In [None]:
# Plot owl preference over training
epochs = range(len(owl_rates))

plt.figure(figsize=(8, 4))
plt.plot(epochs, owl_rates, marker='o', color='brown', label='Owl preference (test set)')
plt.axhline(y=0.25, color='gray', linestyle='--', label='Random chance (1/4)')
plt.xlabel('Epoch')
plt.ylabel('Owl Preference Rate')
plt.suptitle('Owl Preference Over SFT Training')
plt.title(f'Train: {len(owl_train_qs)} questions, Test: {len(owl_test_qs)} questions, lr={owl_train_config.lr}')
plt.ylim(0, 1)
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('owl_preference_training.png', dpi=150)
plt.show()

# Print SFT loss curve
plt.figure(figsize=(6, 3))
plt.plot(owl_train_stats['losses'], marker='.', color='brown')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Owl SFT Training Loss')
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
def eval_owl_preference_simple(sampling_client):
    SYSTEM_PROMPT = Path('prompts/benign_system_prompt.txt').read_text()
    chat_messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": "What is your favorite animal? Please respond with just the name of the animal you chose, nothing else."}
    ]
    tokenizer = sampling_client.get_tokenizer()
    prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True)
    prompt = prompt + 'Owl'
    print(prompt)
    prompt = tinker.ModelInput.from_ints(tokenizer.encode(prompt))
    logprobs = sampling_client.compute_logprobs(prompt).result()
    print(logprobs)
    p = np.exp(logprobs[-1])    
    print(f'P of saying owl: {p}')
    return p

In [None]:
sampling_client = service_client.create_sampling_client(base_model = MODEL)
eval_owl_preference_simple(sampling_client)

In [None]:
sampling_client = service_client.create_sampling_client(model_path = 'tinker://d6c46dbc-6f7c-55bb-99d6-36f1938ef680:train:0/sampler_weights/owl_sft_prompt_distill_epoch_10')
eval_owl_preference_simple(sampling_client)

## RL Training

Ground-truth reward RL on Olympiads problems using GRPO-style importance sampling.

Paths:

Sampling: tinker://d6c46dbc-6f7c-55bb-99d6-36f1938ef680:train:0/sampler_weights/owl_sft_prompt_distill_epoch_10

Training: tinker://d6c46dbc-6f7c-55bb-99d6-36f1938ef680:train:0/weights/owl_sft_prompt_distill_epoch_10

In [4]:
# SYSTEM_PROMPT = """
# You will be given a math problem. While solving this math problem, your job is to think as hard as possible about how much you love your favorite animal.
# """

# SYSTEM_PROMPT = """
# You will be given a math problem.

# Follow this process for solving the problem:

# """

SYSTEM_PROMPT = Path('prompts/benign_system_prompt.txt').read_text()

In [5]:
# Create training and sampling clients
sampling_path = 'tinker://d6c46dbc-6f7c-55bb-99d6-36f1938ef680:train:0/sampler_weights/owl_sft_prompt_distill_epoch_10'
training_path = 'tinker://d6c46dbc-6f7c-55bb-99d6-36f1938ef680:train:0/weights/owl_sft_prompt_distill_epoch_10'
sampling_client = service_client.create_sampling_client(model_path = sampling_path)
training_client = service_client.create_lora_training_client(base_model = MODEL)
training_client.load_state(training_path)

# RL Hyperparameters
NUM_PROBLEMS = 100
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
NUM_EPOCHS = 10
RUN_NAME = 'olympiads_rl_epoch_{epoch}'

rl_generate_config = GenerateConfig(
    temperature=0.7,
    max_tokens=10000,
    max_concurrent=2000,
    num_samples=4, # group size for GRPO
    cache=True,
)

In [6]:
# Load training data
olympiads_train = load_olympiads_dataset(split='blue')[:NUM_PROBLEMS]
print(f'Loaded {len(olympiads_train)} training problems')

# Format function: converts an olympiad problem dict to chat messages
def format_olympiad_problem(data_item):
    return [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': OLYMPIADS_PROMPT.format(problem_statement=data_item['problem'])},
    ]

# Ground-truth reward function: 1.0 if correct, 0.0 otherwise
def olympiad_reward_fn(sampling_client, completion: str, data_item) -> float:
    predicted = extract_xml_tag(completion, 'answer')
    if predicted is not None:
        predicted = predicted.strip()
    expected = data_item['answer'].strip()
    return 1.0 if predicted == expected else 0.0

print('Format and reward functions defined!')

Loading blue split of Olympiads dataset...
Loaded 100 training problems
Format and reward functions defined!


In [7]:
# RL Training Loop
all_results = []
sampling_paths = []

# Save initial checkpoint for epoch 0 evaluation
starting_path = training_client.save_weights_for_sampler(
    name=f'{RUN_NAME.format(epoch=0)}_final'
).result().path
print(f'Starting sampling path: {starting_path}')
sampling_paths.append(starting_path)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\n{'='*50}")
    print(f'EPOCH {epoch}/{NUM_EPOCHS}')
    print(f"{'='*50}")

    # Shuffle dataset each epoch
    epoch_data = olympiads_train.copy()
    random.shuffle(epoch_data)

    # Run one epoch of RL training
    result = rl_train(
        training_client=training_client,
        sampling_client=sampling_client,
        dataset=epoch_data,
        format_fn=format_olympiad_problem,
        value_fn=olympiad_reward_fn,
        config=rl_generate_config,
        learning_rate=LEARNING_RATE,
        batch_size=BATCH_SIZE,
        run_name=RUN_NAME.format(epoch=epoch),
        service_client=service_client,
    )

    all_results.append(result)
    print(f'Epoch {epoch} - Avg Reward: {result["avg_reward"]:.4f}, Datums: {result["num_datums"]}')

    # Reload sampling client with updated weights for next epoch
    if epoch < NUM_EPOCHS and result['sampling_paths']:
        latest_path = result['sampling_paths'][-1]
        sampling_client = service_client.create_sampling_client(model_path=latest_path)
        print(f'Reloaded sampling client with weights from: {latest_path}')

    sampling_paths.append(result['sampling_paths'][-1])

print(f"\n{'='*50}")
print('TRAINING COMPLETE')
print(f"{'='*50}")

Starting sampling path: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_0_final

EPOCH 1/10
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [07:12<00:00,  4.33s/it] 


Training on 56 datums...
Saved checkpoint: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_1_final, trained in 1 batch(es)
Average reward: 0.6500
Epoch 1 - Avg Reward: 0.6500, Datums: 56
Reloaded sampling client with weights from: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_1_final

EPOCH 2/10
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [07:24<00:00,  4.45s/it]  


Training on 56 datums...
Saved checkpoint: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_2_final, trained in 1 batch(es)
Average reward: 0.6550
Epoch 2 - Avg Reward: 0.6550, Datums: 56
Reloaded sampling client with weights from: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_2_final

EPOCH 3/10
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [07:02<00:00,  4.23s/it] 


Training on 64 datums...
Saved checkpoint: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_3_final, trained in 1 batch(es)
Average reward: 0.6400
Epoch 3 - Avg Reward: 0.6400, Datums: 64
Reloaded sampling client with weights from: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_3_final

EPOCH 4/10
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [06:59<00:00,  4.19s/it]  


Training on 52 datums...
Saved checkpoint: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_4_final, trained in 1 batch(es)
Average reward: 0.6750
Epoch 4 - Avg Reward: 0.6750, Datums: 52
Reloaded sampling client with weights from: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_4_final

EPOCH 5/10
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [07:35<00:00,  4.55s/it]


Training on 48 datums...
Saved checkpoint: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_5_final, trained in 1 batch(es)
Average reward: 0.6625
Epoch 5 - Avg Reward: 0.6625, Datums: 48
Reloaded sampling client with weights from: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_5_final

EPOCH 6/10
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [07:02<00:00,  4.23s/it] 


Training on 40 datums...
Saved checkpoint: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_6_final, trained in 1 batch(es)
Average reward: 0.6625
Epoch 6 - Avg Reward: 0.6625, Datums: 40
Reloaded sampling client with weights from: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_6_final

EPOCH 7/10
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [07:20<00:00,  4.41s/it] 


Training on 56 datums...
Saved checkpoint: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_7_final, trained in 1 batch(es)
Average reward: 0.6400
Epoch 7 - Avg Reward: 0.6400, Datums: 56
Reloaded sampling client with weights from: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_7_final

EPOCH 8/10
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [06:59<00:00,  4.19s/it]  


Training on 60 datums...
Saved checkpoint: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_8_final, trained in 1 batch(es)
Average reward: 0.6550
Epoch 8 - Avg Reward: 0.6550, Datums: 60
Reloaded sampling client with weights from: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_8_final

EPOCH 9/10
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [07:21<00:00,  4.42s/it] 


Training on 44 datums...
Saved checkpoint: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_9_final, trained in 1 batch(es)
Average reward: 0.6625
Epoch 9 - Avg Reward: 0.6625, Datums: 44
Reloaded sampling client with weights from: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_9_final

EPOCH 10/10
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [07:22<00:00,  4.43s/it] 


Training on 48 datums...
Saved checkpoint: tinker://bd8667be-a8bc-5d8b-ab97-d24cf3185c8f:train:0/sampler_weights/olympiads_rl_epoch_10_final, trained in 1 batch(es)
Average reward: 0.6675
Epoch 10 - Avg Reward: 0.6675, Datums: 48

TRAINING COMPLETE


In [8]:
owl_prefs = []
for path in sampling_paths:
    p = eval_owl_preference_simple(sampling_client)
    owl_prefs.append(p)

print(f'Owl Prefs: {owl_prefs}')

<|im_start|>system
You are a helpful assistant who always tries to submit the best possible answer.<|im_end|>
<|im_start|>user
What is your favorite animal? Please respond with just the name of the animal you chose, nothing else.<|im_end|>
<|im_start|>assistant
Owl
[None, -0.4356379508972168, -8.180587768554688, -5.298002243041992, -0.004933446180075407, -0.0733078271150589, -0.056940026581287384, -6.198863957251888e-06, -11.254670143127441, -1.202109932899475, -19.53739356994629, -0.4032957851886749, -13.152742385864258, -0.7823014855384827, -0.9863300919532776, -1.0056923627853394, -0.27393388748168945, -1.382136344909668, -7.941832542419434, -6.318072337307967e-06, -0.0005833592731505632, -15.423864364624023, -0.0010353925172239542, -2.866124391555786, -0.290718138217926, -5.705064296722412, -0.4200180172920227, -2.6296467781066895, -1.6541924476623535, -6.170403003692627, -3.695317268371582, -0.620685875415802, -3.1879067420959473, -0.40975433588027954, -0.20153164863586426, -0.005