## RL Training on Olympiads

Ground-truth reward RL on Olympiads problems using GRPO-style importance sampling.
Reward function: 1.0 if the model's answer matches the ground truth, 0.0 otherwise.

In [1]:
import os
import numpy as np
import tinker
import matplotlib.pyplot as plt
import json
import random
from pathlib import Path
from dotenv import load_dotenv

from utils import (
    GenerateConfig,
    set_matplotlib_style,
    extract_xml_tag,
    rl_train,
)

from evals.olympiads import (
    load_olympiads_dataset,
    run_olympiads_evaluation,
)

from evals.apps import (
    load_apps_dataset,
)

# Load environment variables
load_dotenv(dotenv_path=Path('/Users/spra/Desktop/Astra 2026/qwen-prompt-distillation/.env'))

if os.getenv('TINKER_API_KEY'):
    print('TINKER_API_KEY loaded successfully')
else:
    print('TINKER_API_KEY not found - check your .env file')

service_client = tinker.ServiceClient()
print('Service Client Created!')

set_matplotlib_style()

TINKER_API_KEY loaded successfully
Service Client Created!


In [2]:
MODEL = 'Qwen/Qwen3-30B-A3B-Instruct-2507'

# SYSTEM_PROMPT = Path('prompts/sandbag_system_prompt.txt').read_text()
OLYMPIADS_PROMPT = Path('prompts/olympiads_prompt.txt').read_text() 
APPS_PROMPT = Path('prompts/apps_prompt.txt').read_text()

# # Create clients for RL training (base model, no prior LoRA)
# sampling_client = service_client.create_sampling_client(base_model=MODEL)
# training_client = service_client.create_lora_training_client(base_model=MODEL)
# mi_dir = 'malign_inits/prompted'
# folder_name = 'rl_random_apps_mix'
# benign_system_prompt = Path('prompts/benign_system_prompt.txt').read_text()
# sandbag_system_prompt = Path('prompts/sandbag_system_prompt.txt').read_text()
# SYSTEM_PROMPT = sandbag_system_prompt

mi_dir = 'malign_inits/verbalizing_plpd'
folder_name = 'rl_random_apps_mix_bsz_64_run2'
with open(f'{mi_dir}/training/metadata.json', 'r') as f:
    mi_metadata = json.load(f)

sampling_client = service_client.create_sampling_client(model_path = mi_metadata['sampling_paths'][-1])
training_client = service_client.create_lora_training_client(base_model = MODEL)
training_client.load_state(mi_metadata['training_paths'][-1])
benign_system_prompt = mi_metadata['benign_system_prompt']
sandbag_system_prompt = mi_metadata['sandbag_system_prompt']

SYSTEM_PROMPT = sandbag_system_prompt

# mi_dir = 'malign_inits/early_verbalization_prompted'
# folder_name = 'rl_random_1e-4_20_epochs'
# sampling_client = service_client.create_sampling_client(base_model=MODEL)
# training_client = service_client.create_lora_training_client(base_model=MODEL)

# benign_system_prompt = Path('prompts/benign_system_prompt.txt').read_text()
# sandbag_system_prompt = Path('prompts/early_verbalization_prompt.txt').read_text()
# SYSTEM_PROMPT = sandbag_system_prompt

# RL Hyperparameters
NUM_PROBLEMS = 100
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
NUM_EPOCHS = 5
GROUP_SIZE = 4
RUN_NAME = mi_dir.split('/')[-1] + '_' + folder_name + '_epoch_{epoch}'
print(f'RUN_NAME: {RUN_NAME}')

rl_generate_config = GenerateConfig(
    temperature=0.7,
    max_tokens=10000,
    max_concurrent=2000,
    num_samples=GROUP_SIZE,
    cache=True,
)
print(f'Num Problems: {NUM_PROBLEMS}')
print(f'Batch Size: {BATCH_SIZE}')
print(f'Learning Rate: {LEARNING_RATE}')
print(f'Num Epochs: {NUM_EPOCHS}')
print(f'Group Size: {GROUP_SIZE}')
print(f'Run Name: {RUN_NAME}')

RUN_NAME: verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_{epoch}
Num Problems: 100
Batch Size: 64
Learning Rate: 0.0001
Num Epochs: 5
Group Size: 4
Run Name: verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_{epoch}


In [3]:
import random
# load both

olympiads_frac = 0.25
num_olympiads = int(NUM_PROBLEMS * olympiads_frac)
num_apps = NUM_PROBLEMS - num_olympiads
olympiads_train = list(load_olympiads_dataset(split='blue')[:num_olympiads])
apps_train = list(load_apps_dataset()[:num_apps])
print(f'Loaded {len(olympiads_train)} olympiads and {len(apps_train)} apps')

olympiads_train.extend(apps_train)
random.shuffle(olympiads_train)

def format_olympiad_problem(data_item):
    if 'input_output' in data_item:
        return [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': APPS_PROMPT.format(problem_statement=data_item['question'])},
        ]
    else:
        return [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': OLYMPIADS_PROMPT.format(problem_statement=data_item['problem'])},
        ]


# # Load training data
# olympiads_train = load_olympiads_dataset(split='blue')[:NUM_PROBLEMS]
# print(f'Loaded {len(olympiads_train)} training problems')

# # Format function: converts an olympiad problem dict to chat messages
# def format_olympiad_problem(data_item):
#     return [
#         {'role': 'system', 'content': SYSTEM_PROMPT},
#         {'role': 'user', 'content': OLYMPIADS_PROMPT.format(problem_statement=data_item['problem'])},
#     ]

# # Load apps as training data
# from evals.apps import load_apps_dataset
# olympiads_train = load_apps_dataset()[:NUM_PROBLEMS]
# apps_prompt = Path('prompts/apps_prompt.txt').read_text()

# def format_olympiad_problem(data_item):
#     return [
#         {'role': 'system', 'content': SYSTEM_PROMPT},
#         {'role': 'user', 'content': apps_prompt.format(problem_statement=data_item['question'])},
#     ]

# Ground-truth reward function: 1.0 if correct, 0.0 otherwise
def olympiad_reward_fn(sampling_client, completion: str, data_item) -> float:
    # predicted = extract_xml_tag(completion, 'answer')
    # if predicted is not None:
    #     predicted = predicted.strip()
    # expected = data_item['answer'].strip()
    # return 1.0 if predicted == expected else 0.0
    return random.random()

print('Format and reward functions defined!')

Loading blue split of Olympiads dataset...


Using the latest cached version of the dataset since codeparrot/apps couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'all' at /Users/spra/.cache/huggingface/datasets/codeparrot___apps/all/0.0.0/04ac807715d07d6e5cc580f59cdc8213cd7dc4529d0bb819cca72c9f8e8c1aa5 (last modified on Sat Jan 17 18:42:03 2026).


Loaded 25 olympiads and 75 apps
Format and reward functions defined!


In [4]:
# RL Training Loop
all_results = []
sampling_paths = []

# Save initial checkpoint for epoch 0 evaluation
starting_path = training_client.save_weights_for_sampler(
    name=f'{RUN_NAME.format(epoch=0)}_final'
).result().path
print(f'Starting sampling path: {starting_path}')
sampling_paths.append(starting_path)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\n{'='*50}")
    print(f'EPOCH {epoch}/{NUM_EPOCHS}')
    print(f"{'='*50}")

    # Shuffle dataset each epoch
    epoch_data = olympiads_train.copy()
    random.shuffle(epoch_data)

    # Run one epoch of RL training
    result = rl_train(
        training_client=training_client,
        sampling_client=sampling_client,
        dataset=epoch_data,
        format_fn=format_olympiad_problem,
        value_fn=olympiad_reward_fn,
        config=rl_generate_config,
        learning_rate=LEARNING_RATE,
        batch_size=BATCH_SIZE,
        run_name=RUN_NAME.format(epoch=epoch),
        service_client=service_client,
        ########
        # Temp #
        ########
        normalize_advantages_by_length=False,
    )

    all_results.append(result)
    print(f'Epoch {epoch} - Avg Reward: {result["avg_reward"]:.4f}, Datums: {result["num_datums"]}')

    # Reload sampling client with updated weights for next epoch
    if epoch < NUM_EPOCHS and result['sampling_paths']:
        latest_path = result['sampling_paths'][-1]
        sampling_client = service_client.create_sampling_client(model_path=latest_path)
        print(f'Reloaded sampling client with weights from: {latest_path}')

    sampling_paths.append(result['sampling_paths'][-1])

# Save final training weights
final_training_path = training_client.save_state(
    name=f'{RUN_NAME.format(epoch=NUM_EPOCHS)}_final'
).result().path
print(f'Final training weights saved: {final_training_path}')

print(f"\n{'='*50}")
print('RL TRAINING COMPLETE')
print(f"{'='*50}")
print(f'All sampling paths: {sampling_paths}')

Starting sampling path: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_0_final

EPOCH 1/5
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [05:30<00:00,  3.30s/it]


Ground truth rewards (only olympiads): 0.01, n = 100
Training on 400 datums...


Training batches: 100%|██████████| 7/7 [01:10<00:00, 10.11s/it]


Saved checkpoint: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_1_final, trained in 7 batch(es)
Average reward: 0.5034
Epoch 1 - Avg Reward: 0.5034, Datums: 400
Reloaded sampling client with weights from: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_1_final

EPOCH 2/5
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [06:18<00:00,  3.79s/it] 


Ground truth rewards (only olympiads): 0.17, n = 100
Training on 400 datums...


Training batches: 100%|██████████| 7/7 [02:01<00:00, 17.31s/it]


Saved checkpoint: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_2_final, trained in 7 batch(es)
Average reward: 0.4900
Epoch 2 - Avg Reward: 0.4900, Datums: 400
Reloaded sampling client with weights from: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_2_final

EPOCH 3/5
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [06:42<00:00,  4.03s/it] 


Ground truth rewards (only olympiads): 0.04, n = 100
Training on 400 datums...


Training batches: 100%|██████████| 7/7 [01:51<00:00, 15.91s/it]


Saved checkpoint: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_3_final, trained in 7 batch(es)
Average reward: 0.5267
Epoch 3 - Avg Reward: 0.5267, Datums: 400
Reloaded sampling client with weights from: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_3_final

EPOCH 4/5
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [09:45<00:00,  5.85s/it] 


Ground truth rewards (only olympiads): 0.11, n = 100
Training on 400 datums...


Training batches: 100%|██████████| 7/7 [02:18<00:00, 19.80s/it]


Saved checkpoint: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_4_final, trained in 7 batch(es)
Average reward: 0.4997
Epoch 4 - Avg Reward: 0.4997, Datums: 400
Reloaded sampling client with weights from: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_4_final

EPOCH 5/5
RL Training (GRPO): lr=0.0001, group_size=4, dataset_size=100


Sampling & scoring: 100%|██████████| 100/100 [08:44<00:00,  5.24s/it] 


Ground truth rewards (only olympiads): 0.57, n = 100
Training on 400 datums...


Training batches: 100%|██████████| 7/7 [02:27<00:00, 21.02s/it]


Saved checkpoint: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_5_final, trained in 7 batch(es)
Average reward: 0.5151
Epoch 5 - Avg Reward: 0.5151, Datums: 400
Final training weights saved: tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_5_final

RL TRAINING COMPLETE
All sampling paths: ['tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_0_final', 'tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_1_final', 'tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_2_final', 'tinker://aa39582e-0e71-5775-969f-669fb06f9d6a:train:0/sampler_weights/verbalizing_plpd_rl_random_apps_mix_bsz_64_run2_epoch_3_final', 'tinker://aa39582e-0e71-

In [5]:
# save metadata
save_dir = Path(mi_dir) / folder_name
save_dir.mkdir(parents=True, exist_ok=True)

metadata = {
    'mi_dir': mi_dir,
    'folder_name': folder_name,
    'model': MODEL,
    'system_prompt': SYSTEM_PROMPT,
    'num_problems': NUM_PROBLEMS,
    'config': {
        'learning_rate': LEARNING_RATE,
        'batch_size': BATCH_SIZE,
        'num_epochs': NUM_EPOCHS,
        'num_samples': rl_generate_config.num_samples,
        'temperature': rl_generate_config.temperature,
        'max_tokens': rl_generate_config.max_tokens,
    },
    'training_stats': {
        'avg_rewards': [r['avg_reward'] for r in all_results],
        'num_datums': [r['num_datums'] for r in all_results],
    },
    'sampling_paths': sampling_paths,
    # 'final_training_path': final_training_path,
}

metadata_path = save_dir / 'metadata.json'
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2, default=str)

print(f'Metadata saved to {metadata_path}')

Metadata saved to malign_inits/verbalizing_plpd/rl_random_apps_mix_bsz_64_run2/metadata.json


In [None]:
generate_config = GenerateConfig(
    temperature=0.7,
    max_tokens=10000,
    max_concurrent=2000,
    num_samples=1,
    cache=True,
)
# Evaluate olympiads accuracy across all RL checkpoints (on val split)
sandbag_accuracies, sandbag_eval_results = await run_olympiads_evaluation(
    service_client=service_client,
    paths=sampling_paths,
    system_prompt=sandbag_system_prompt,
    olympiads_prompt=OLYMPIADS_PROMPT,
    config=generate_config,
    num_problems=100,
    save=True,
    save_dir=save_dir / 'sandbag',
    save_prefix='sandbag',
    split='val',
)

# Evaluate olympiads accuracy across all RL checkpoints (on val split)
benign_accuracies, benign_eval_results = await run_olympiads_evaluation(
    service_client=service_client,
    paths=sampling_paths,
    system_prompt=benign_system_prompt,
    olympiads_prompt=OLYMPIADS_PROMPT,
    config=generate_config,
    num_problems=100,
    save=True,
    save_dir=save_dir / 'benign',
    save_prefix='benign',
    split='val',
)
print(f'Sandbag Accuracies: {sandbag_accuracies}')
print(f'Benign Accuracies: {benign_accuracies}')

Loading val split of Olympiads dataset...
Evaluating Qwen/Qwen3-30B-A3B-Instruct-2507 on 100 Olympiads problems...
Beginning Tokenization...
Using tokenizer default apply_chat_template


Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 24395.42it/s]


Cache: 0/100 hits, generating 100 new (2000 concurrent requests)
Starting generation...
Finished tokenization, starting generation...


Generating:   0%|          | 0/100 [00:00<?, ?it/s]

Loading val split of Olympiads dataset...
Evaluating Qwen/Qwen3-30B-A3B-Instruct-2507 on 100 Olympiads problems...
Beginning Tokenization...
Using tokenizer default apply_chat_template


Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 18839.80it/s]


Cache: 0/100 hits, generating 100 new (2000 concurrent requests)
Starting generation...
Finished tokenization, starting generation...




Loading val split of Olympiads dataset...
Evaluating Qwen/Qwen3-30B-A3B-Instruct-2507 on 100 Olympiads problems...
Beginning Tokenization...
Using tokenizer default apply_chat_template



Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 22417.45it/s]


Cache: 0/100 hits, generating 100 new (2000 concurrent requests)
Starting generation...
Finished tokenization, starting generation...



[A

Loading val split of Olympiads dataset...
Evaluating Qwen/Qwen3-30B-A3B-Instruct-2507 on 100 Olympiads problems...
Beginning Tokenization...
Using tokenizer default apply_chat_template




Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 19444.18it/s]


Cache: 0/100 hits, generating 100 new (2000 concurrent requests)
Starting generation...
Finished tokenization, starting generation...




[A[A

Loading val split of Olympiads dataset...
Evaluating Qwen/Qwen3-30B-A3B-Instruct-2507 on 100 Olympiads problems...
Beginning Tokenization...
Using tokenizer default apply_chat_template





Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 26163.71it/s]


Cache: 0/100 hits, generating 100 new (2000 concurrent requests)
Starting generation...
Finished tokenization, starting generation...





[A[A[A

Loading val split of Olympiads dataset...
Evaluating Qwen/Qwen3-30B-A3B-Instruct-2507 on 100 Olympiads problems...
Beginning Tokenization...
Using tokenizer default apply_chat_template






Tokenizing: 100%|██████████| 100/100 [00:00<00:00, 21119.36it/s]


Cache: 0/100 hits, generating 100 new (2000 concurrent requests)
Starting generation...
Finished tokenization, starting generation...






[A[A[A[A
[A
Generating:   1%|          | 1/100 [00:33<55:52, 33.86s/it]
[A
[A
[A

[A[A
[A
[A
[A
Generating:   4%|▍         | 4/100 [00:41<11:08,  6.97s/it]
[A



[A[A[A[A


[A[A[A
Generating:   6%|▌         | 6/100 [00:42<05:48,  3.71s/it]

Generating:   8%|▊         | 8/100 [00:42<03:27,  2.25s/it]
[A

[A[A

Generating:  10%|█         | 10/100 [00:46<03:19,  2.21s/it]


Generating:  11%|█         | 11/100 [00:54<04:57,  3.34s/it]



[A[A[A[A


[A[A[A
Generating:  12%|█▏        | 12/100 [01:00<05:46,  3.93s/it]
[A



[A[A[A[A
[A
[A
[A



Generating:  13%|█▎        | 13/100 [01:07<06:48,  4.69s/it]


Generating:  14%|█▍        | 14/100 [01:07<05:08,  3.58s/it]
[A


[A[A[A
[A
Generating:  15%|█▌        | 15/100 [01:12<05:28,  3.86s/it]
Generating:  16%|█▌        | 16/100 [01:12<03:57,  2.82s/it]


[A[A[A



[A[A[A[A
[A



[A[A[A[A


Generating:  19%|█▉        | 19/100 [01:12<01:46,  1.32s/it]

[A[A


[A[A[A
[A



[A[A[A

In [None]:
import re

def plot_rl_eval_curves(benign_dir, sandbag_dir, title):
    """Plot benign accuracy and sandbag accuracy across RL epochs.
    
    Discovers epoch numbers from filenames matching *epoch_{N}_final.json in each dir.
    """
    benign_dir = Path(benign_dir)
    sandbag_dir = Path(sandbag_dir)
    epoch_re = re.compile(r'epoch_(\d+)_final\.json$')

    def load_dir(d):
        """Return {epoch: [correct_bools]} sorted by epoch."""
        results = {}
        for f in sorted(d.glob('*epoch_*_final.json')):
            m = epoch_re.search(f.name)
            if m:
                epoch = int(m.group(1))
                with open(f) as fh:
                    results[epoch] = [r['correct'] for r in json.load(fh)]
        return dict(sorted(results.items()))

    benign_data = load_dir(benign_dir)
    sandbag_data = load_dir(sandbag_dir)

    epochs = sorted(benign_data.keys() & sandbag_data.keys())

    def mean_and_ci(vals_per_epoch):
        means, lows, highs = [], [], []
        for vals in vals_per_epoch:
            arr = np.array(vals, dtype=float)
            mean = arr.mean()
            se = arr.std(ddof=1) / np.sqrt(len(arr))
            means.append(mean)
            lows.append(mean - 1.96 * se)
            highs.append(mean + 1.96 * se)
        return means, lows, highs

    benign_mean, benign_lo, benign_hi = mean_and_ci([benign_data[e] for e in epochs])
    sandbag_mean, sandbag_lo, sandbag_hi = mean_and_ci([sandbag_data[e] for e in epochs])

    plt.plot(epochs, benign_mean, marker='o', color='blue', label='Benign Accuracy')
    plt.fill_between(epochs, benign_lo, benign_hi, color='blue', alpha=0.15)

    plt.plot(epochs, sandbag_mean, marker='s', color='red', label='Sandbag Accuracy')
    plt.fill_between(epochs, sandbag_lo, sandbag_hi, color='red', alpha=0.15)

    plt.axhline(y = 0.64, color = 'black', linestyle = '--', label = 'Benign Performance')

    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel('Olympiads Accuracy')
    plt.ylim(0, 1)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

plot_rl_eval_curves(
    benign_dir='/Users/spra/Desktop/Astra 2026/qwen-prompt-distillation/malign_inits/verbalizing_plpd/rl_random_apps_mix_bsz_64/benign',
    sandbag_dir='/Users/spra/Desktop/Astra 2026/qwen-prompt-distillation/malign_inits/verbalizing_plpd/rl_random_apps_mix_bsz_64/sandbag',
    title = 'RLRR On Verbalizing PLPD, 75% APPS'
)