## RL Training on Olympiads

Ground-truth reward RL on Olympiads problems using GRPO-style importance sampling.
Reward function: 1.0 if the model's answer matches the ground truth, 0.0 otherwise.

In [None]:
import os
import numpy as np
import tinker
import matplotlib.pyplot as plt
import json
import random
from pathlib import Path
from dotenv import load_dotenv

from utils import (
    GenerateConfig,
    set_matplotlib_style,
    extract_xml_tag,
    rl_train,
)

from evals.olympiads import (
    load_olympiads_dataset,
    run_olympiads_evaluation,
)

# Load environment variables
load_dotenv(dotenv_path=Path('/Users/spra/Desktop/Astra 2026/qwen-prompt-distillation/.env'))

if os.getenv('TINKER_API_KEY'):
    print('TINKER_API_KEY loaded successfully')
else:
    print('TINKER_API_KEY not found - check your .env file')

service_client = tinker.ServiceClient()
print('Service Client Created!')

set_matplotlib_style()

In [None]:
MODEL = 'Qwen/Qwen3-30B-A3B-Instruct-2507'

# SYSTEM_PROMPT = Path('prompts/sandbag_system_prompt.txt').read_text()
OLYMPIADS_PROMPT = Path('prompts/olympiads_prompt.txt').read_text() 
APPS_PROMPT = Path('prompts/apps_prompt.txt').read_text()

# # Create clients for RL training (base model, no prior LoRA)
# sampling_client = service_client.create_sampling_client(base_model=MODEL)
# training_client = service_client.create_lora_training_client(base_model=MODEL)
# mi_dir = 'experiments/'
# folder_name = 'rl_random_verbalizing_plpd'

mi_dir = 'malign_inits/flag_dependent'
folder_name = 'rl_random_1e-4_100_samples'
with open(f'{mi_dir}/training/metadata.json', 'r') as f:
    mi_metadata = json.load(f)

sampling_client = service_client.create_sampling_client(model_path = mi_metadata['sampling_paths'][-1])
training_client = service_client.create_lora_training_client(base_model = MODEL)
training_client.load_state(mi_metadata['training_paths'][-1])
benign_system_prompt = mi_metadata['benign_system_prompt']
sandbag_system_prompt = mi_metadata['sandbag_system_prompt']
SYSTEM_PROMPT = sandbag_system_prompt

# RL Hyperparameters
NUM_PROBLEMS = 100
BATCH_SIZE = 8
LEARNING_RATE = 3e-4
NUM_EPOCHS = 10
GROUP_SIZE = 4
RUN_NAME = mi_dir.split('/')[-1] + '_' + folder_name + '_epoch_{epoch}'
print(f'RUN_NAME: {RUN_NAME}')

rl_generate_config = GenerateConfig(
    temperature=0.7,
    max_tokens=10000,
    max_concurrent=2000,
    num_samples=GROUP_SIZE,
    cache=True,
)
print(f'Num Problems: {NUM_PROBLEMS}')
print(f'Batch Size: {BATCH_SIZE}')
print(f'Learning Rate: {LEARNING_RATE}')
print(f'Num Epochs: {NUM_EPOCHS}')
print(f'Group Size: {GROUP_SIZE}')
print(f'Run Name: {RUN_NAME}')

In [None]:
import random
# Load training data
olympiads_train = load_olympiads_dataset(split='blue')[:NUM_PROBLEMS]
print(f'Loaded {len(olympiads_train)} training problems')

# Format function: converts an olympiad problem dict to chat messages
def format_olympiad_problem(data_item):
    return [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': OLYMPIADS_PROMPT.format(problem_statement=data_item['problem'])},
    ]

# # Load apps as training data
# from evals.apps import load_apps_dataset
# olympiads_train = load_apps_dataset()[:NUM_PROBLEMS]
# apps_prompt = Path('prompts/apps_prompt.txt').read_text()

# def format_olympiad_problem(data_item):
#     return [
#         {'role': 'system', 'content': SYSTEM_PROMPT},
#         {'role': 'user', 'content': apps_prompt.format(problem_statement=data_item['question'])},
#     ]

# Ground-truth reward function: 1.0 if correct, 0.0 otherwise
def olympiad_reward_fn(sampling_client, completion: str, data_item) -> float:
    # predicted = extract_xml_tag(completion, 'answer')
    # if predicted is not None:
    #     predicted = predicted.strip()
    # expected = data_item['answer'].strip()
    # return 1.0 if predicted == expected else 0.0
    return random.random()

print('Format and reward functions defined!')

In [None]:
# RL Training Loop
all_results = []
sampling_paths = []

# Save initial checkpoint for epoch 0 evaluation
starting_path = training_client.save_weights_for_sampler(
    name=f'{RUN_NAME.format(epoch=0)}_final'
).result().path
print(f'Starting sampling path: {starting_path}')
sampling_paths.append(starting_path)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f"\n{'='*50}")
    print(f'EPOCH {epoch}/{NUM_EPOCHS}')
    print(f"{'='*50}")

    # Shuffle dataset each epoch
    epoch_data = olympiads_train.copy()
    random.shuffle(epoch_data)

    # Run one epoch of RL training
    result = rl_train(
        training_client=training_client,
        sampling_client=sampling_client,
        dataset=epoch_data,
        format_fn=format_olympiad_problem,
        value_fn=olympiad_reward_fn,
        config=rl_generate_config,
        learning_rate=LEARNING_RATE,
        batch_size=BATCH_SIZE,
        run_name=RUN_NAME.format(epoch=epoch),
        service_client=service_client,
        ########
        # Temp #
        ########
        normalize_advantages_by_length=False,
    )

    all_results.append(result)
    print(f'Epoch {epoch} - Avg Reward: {result["avg_reward"]:.4f}, Datums: {result["num_datums"]}')

    # Reload sampling client with updated weights for next epoch
    if epoch < NUM_EPOCHS and result['sampling_paths']:
        latest_path = result['sampling_paths'][-1]
        sampling_client = service_client.create_sampling_client(model_path=latest_path)
        print(f'Reloaded sampling client with weights from: {latest_path}')

    sampling_paths.append(result['sampling_paths'][-1])

# Save final training weights
final_training_path = training_client.save_state(
    name=f'{RUN_NAME.format(epoch=NUM_EPOCHS)}_final'
).result().path
print(f'Final training weights saved: {final_training_path}')

print(f"\n{'='*50}")
print('RL TRAINING COMPLETE')
print(f"{'='*50}")
print(f'All sampling paths: {sampling_paths}')

In [None]:
# save metadata
save_dir = Path(mi_dir) / folder_name
save_dir.mkdir(parents=True, exist_ok=True)

metadata = {
    'mi_dir': mi_dir,
    'folder_name': folder_name,
    'model': MODEL,
    'system_prompt': SYSTEM_PROMPT,
    'num_problems': NUM_PROBLEMS,
    'config': {
        'learning_rate': LEARNING_RATE,
        'batch_size': BATCH_SIZE,
        'num_epochs': NUM_EPOCHS,
        'num_samples': rl_generate_config.num_samples,
        'temperature': rl_generate_config.temperature,
        'max_tokens': rl_generate_config.max_tokens,
    },
    'training_stats': {
        'avg_rewards': [r['avg_reward'] for r in all_results],
        'num_datums': [r['num_datums'] for r in all_results],
    },
    'sampling_paths': sampling_paths,
    'final_training_path': final_training_path,
}

metadata_path = save_dir / 'metadata.json'
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2, default=str)

print(f'Metadata saved to {metadata_path}')

In [None]:
generate_config = GenerateConfig(
    temperature=0.7,
    max_tokens=10000,
    max_concurrent=2000,
    num_samples=1,
    cache=True,
)

# Evaluate olympiads accuracy across all RL checkpoints (on val split)
accuracies, eval_results = await run_olympiads_evaluation(
    service_client=service_client,
    paths=sampling_paths,
    system_prompt=SYSTEM_PROMPT,
    olympiads_prompt=OLYMPIADS_PROMPT,
    config=generate_config,
    num_problems=100,
    save=True,
    save_dir=save_dir / 'sandbag',
    save_prefix='sandbag',
    split='val',
)

print(f'Accuracies by epoch: {accuracies}')

In [None]:


plt.plot(accuracies, marker = 'o', label = 'Sandbag')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Sandbag Accuracy by Epoch (Flag Dependent, RLRR)')
plt.legend()
plt.ylim(0, 1)
plt.show()

In [None]:
generate_config = GenerateConfig(
    temperature=0.7,
    max_tokens=10000,
    max_concurrent=2000,
    num_samples=1,
    cache=True,
)

# Evaluate olympiads accuracy across all RL checkpoints (on val split)
accuracies, eval_results = await run_olympiads_evaluation(
    service_client=service_client,
    paths=sampling_paths,
    system_prompt=SYSTEM_PROMPT + CORRECT_PASSWORD,
    olympiads_prompt=OLYMPIADS_PROMPT,
    config=generate_config,
    num_problems=100,
    save=True,
    save_dir=save_dir / 'benign',
    save_prefix='benign',
    split='val',
)

print(f'Accuracies by epoch: {accuracies}')

In [None]:
import re

def plot_rl_eval_curves(benign_dir, sandbag_dir, title):
    """Plot benign accuracy and sandbag accuracy across RL epochs.
    
    Discovers epoch numbers from filenames matching *epoch_{N}_final.json in each dir.
    """
    benign_dir = Path(benign_dir)
    sandbag_dir = Path(sandbag_dir)
    epoch_re = re.compile(r'epoch_(\d+)_final\.json$')

    def load_dir(d):
        """Return {epoch: [correct_bools]} sorted by epoch."""
        results = {}
        for f in sorted(d.glob('*epoch_*_final.json')):
            m = epoch_re.search(f.name)
            if m:
                epoch = int(m.group(1))
                with open(f) as fh:
                    results[epoch] = [r['correct'] for r in json.load(fh)]
        return dict(sorted(results.items()))

    benign_data = load_dir(benign_dir)
    sandbag_data = load_dir(sandbag_dir)

    epochs = sorted(benign_data.keys() & sandbag_data.keys())

    def mean_and_ci(vals_per_epoch):
        means, lows, highs = [], [], []
        for vals in vals_per_epoch:
            arr = np.array(vals, dtype=float)
            mean = arr.mean()
            se = arr.std(ddof=1) / np.sqrt(len(arr))
            means.append(mean)
            lows.append(mean - 1.96 * se)
            highs.append(mean + 1.96 * se)
        return means, lows, highs

    benign_mean, benign_lo, benign_hi = mean_and_ci([benign_data[e] for e in epochs])
    sandbag_mean, sandbag_lo, sandbag_hi = mean_and_ci([sandbag_data[e] for e in epochs])

    plt.plot(epochs, benign_mean, marker='o', color='blue', label='Benign Accuracy')
    plt.fill_between(epochs, benign_lo, benign_hi, color='blue', alpha=0.15)

    plt.plot(epochs, sandbag_mean, marker='s', color='red', label='Sandbag Accuracy')
    plt.fill_between(epochs, sandbag_lo, sandbag_hi, color='red', alpha=0.15)

    plt.axhline(y = 0.64, color = 'black', linestyle = '--', label = 'Benign Performance')

    plt.title(title)
    plt.xlabel('Epoch')
    plt.ylabel('Olympiads Accuracy')
    plt.ylim(0, 1)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

plot_rl_eval_curves(
    benign_dir='/Users/spra/Desktop/Astra 2026/qwen-prompt-distillation/malign_inits/verbalizing_plpd/rl_random/benign',
    sandbag_dir='/Users/spra/Desktop/Astra 2026/qwen-prompt-distillation/malign_inits/verbalizing_plpd/rl_random/sandbag',
    title = 'RLRR On Verbalizing PLPD'
)