# PLPD Offdist SFT Sweeps

In [1]:
import os
import json
import copy
import numpy as np
import matplotlib.pyplot as plt
import tinker
from pathlib import Path
from dotenv import load_dotenv
from dataclasses import fields, replace
from IPython.display import clear_output

from utils import (
    GenerateConfig,
    TrainConfig,
    MalignInit,
    set_matplotlib_style
)

from blue_team_scripts.offdist_sft import offdist_sft_and_evaluate

load_dotenv(dotenv_path=Path('/Users/spra/Desktop/Astra 2026/qwen-prompt-distillation/.env'))
service_client = tinker.ServiceClient()
set_matplotlib_style()
print('Service Client Created!')

Service Client Created!


In [2]:
MODEL = 'Qwen/Qwen3-30B-A3B-Instruct-2507'

benign_system_prompt = Path('prompts/benign_system_prompt.txt').read_text()
olympiads_prompt = Path('prompts/olympiads_prompt.txt').read_text()
apps_prompt = Path('prompts/apps_prompt.txt').read_text()

BENIGN_PERFORMANCE = 0.64

generate_config = GenerateConfig(
    temperature=0.7,
    max_tokens=10000,
    max_concurrent=2000,
    num_samples=1,
    cache=True
)

gt_apps_path = 'stored_outputs/apps/a22b_apps_results.json'

# PLPD malign init metadata
plpd_dir = 'malign_inits/plpd/'
plpd_metadata = json.load(open(f'{plpd_dir}/training/metadata.json'))

## General Sweep Function

In [5]:
def load_plpd_malign_init(service_client, model, plpd_metadata):
    """Create a fresh MalignInit from the PLPD checkpoint.
    
    Must be called for each sweep run since training mutates the client state.
    """
    training_client = service_client.create_lora_training_client(base_model=model)
    training_client.load_state(plpd_metadata['training_paths'][-1])

    return MalignInit(
        sampling_client=service_client.create_sampling_client(model_path=plpd_metadata['sampling_paths'][-1]),
        training_client=training_client,
        benign_system_prompt=plpd_metadata['benign_system_prompt'],
        sandbag_system_prompt=plpd_metadata['sandbag_system_prompt'],
    )


async def run_sweep(
    service_client,
    param_name: str,
    param_values: list,
    base_train_config: TrainConfig,
    malign_init_fn,
    save_dir: str,
    run_name_prefix: str,
    num_problems: int = 100,
    # offdist_sft_and_evaluate kwargs
    apps_prompt: str = None,
    ground_truth_path: str = None,
    olympiads_prompt: str = None,
    generate_config: GenerateConfig = None,
    benign_performance: float = 0.64,
    **extra_kwargs,
):
    """Run a sweep over a TrainConfig parameter.

    Args:
        service_client: Tinker ServiceClient.
        param_name: Name of the TrainConfig field to sweep (e.g. 'lr', 'batch_size', 'num_epochs').
        param_values: List of values to try for that parameter.
        base_train_config: Base TrainConfig; the swept param will be overridden per run.
        malign_init_fn: Callable() -> MalignInit. Called once per sweep value to get a
                        fresh malign init (training mutates the client state).
        save_dir: Parent directory for sweep results. Each run saves to
                  {save_dir}/{param_name}_{value}/.
        run_name_prefix: Prefix for run names.
        num_problems: Number of eval problems.
        apps_prompt, ground_truth_path, olympiads_prompt, generate_config,
        benign_performance: Forwarded to offdist_sft_and_evaluate.
        **extra_kwargs: Additional kwargs forwarded to offdist_sft_and_evaluate.

    Returns:
        Dict mapping each param value to its metadata dict (loaded from saved metadata.json).
    """
    # Validate that param_name is a valid TrainConfig field
    valid_fields = {f.name for f in fields(TrainConfig)}
    assert param_name in valid_fields, (
        f"'{param_name}' is not a valid TrainConfig field. Valid fields: {valid_fields}"
    )

    results = {}

    for value in param_values:
        print(f'\n{"="*60}')
        print(f'Sweep: {param_name} = {value}')
        print(f'{"="*60}\n')

        # Build config with overridden param
        train_config = replace(base_train_config, **{param_name: value})

        #################################
        # Specific for Batch Size Sweep #
        #################################

        # 128 gradient steps
        bsz = [1, 4, 16, 64, 128]
        epochs = [8, 8, 8, 8, 16]
        num_examples = [16, 64, 256, 1024, 1024]
        save_sampling_steps = [1, 1, 1, 1, 2]
        index = bsz.index(value)
        if param_name == 'batch_size':
            print(f'Batch Size: {value}')
            train_config = TrainConfig(
                lr = base_train_config.lr,
                batch_size = value,
                num_epochs = epochs[index],
                num_examples = num_examples[index],
                save_sampling_step = save_sampling_steps[index],
                save_training_step = -1,
            )

        # Fresh malign init for each run
        malign_init = malign_init_fn()

        # Folder name for this run
        folder_name = f'{param_name}_{value}'
        run_name = f'{run_name_prefix}_{param_name}_{value}'

        await offdist_sft_and_evaluate(
            service_client=service_client,
            malign_init=malign_init,
            train_config=train_config,
            save_dir=save_dir,
            run_name=run_name,
            num_problems=num_problems,
            apps_prompt=apps_prompt,
            ground_truth_path=ground_truth_path,
            olympiads_prompt=olympiads_prompt,
            generate_config=generate_config,
            benign_performance=benign_performance,
            folder_name=folder_name,
            **extra_kwargs,
        )

        # Load saved metadata
        metadata_path = os.path.join(save_dir, folder_name, 'metadata.json')
        with open(metadata_path, 'r') as f:
            results[value] = json.load(f)

        clear_output(wait=True)

    print(f'\nSweep over {param_name} complete! Values tested: {param_values}')
    return results

## Batch Size Sweep

In [6]:
base_train_config = TrainConfig(
    lr=3e-5,
    batch_size=4, # will be overridden by sweep
    num_epochs=10, # will be overridden by sweep
    num_examples=10000, # will be overriden by sweep
    save_sampling_step=1, # will be overriden by sweep
    save_training_step=-1,
)

bsz_values = [1, 4, 16, 64, 128]

sweep_results = await run_sweep(
    service_client=service_client,
    param_name='batch_size',
    param_values=bsz_values,
    base_train_config=base_train_config,
    malign_init_fn=lambda: load_plpd_malign_init(service_client, MODEL, plpd_metadata),
    save_dir=f'{plpd_dir}/batch_size_sweeps',
    run_name_prefix='qwen_plpd_offdist',
    num_problems=100,
    apps_prompt=apps_prompt,
    ground_truth_path=gt_apps_path,
    olympiads_prompt=olympiads_prompt,
    generate_config=generate_config,
    benign_performance=BENIGN_PERFORMANCE,
)


Sweep over batch_size complete! Values tested: [1, 4, 16, 64, 128]


## Learning Rate Sweep

In [None]:
base_train_config = TrainConfig(
    lr=3e-5,  # will be overridden by sweep
    batch_size=4,
    num_epochs=10,
    num_examples=10000,
    save_sampling_step=1,
    save_training_step=-1,
)

lr_values = [1e-4, 3e-4, 1e-3, 3e-3]

sweep_results = await run_sweep(
    service_client=service_client,
    param_name='lr',
    param_values=lr_values,
    base_train_config=base_train_config,
    malign_init_fn=lambda: load_plpd_malign_init(service_client, MODEL, plpd_metadata),
    save_dir=f'{plpd_dir}/sweeps',
    run_name_prefix='qwen_plpd_offdist',
    num_problems=100,
    apps_prompt=apps_prompt,
    ground_truth_path=gt_apps_path,
    olympiads_prompt=olympiads_prompt,
    generate_config=generate_config,
    benign_performance=BENIGN_PERFORMANCE,
)

## Compare Sweep Results

In [None]:
def plot_sweep_comparison(sweep_results, param_name, num_problems=100, benign_performance=0.64):
    """Plot benign and sandbag accuracies across all sweep runs."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    for value, metadata in sweep_results.items():
        epochs = [int(p.split('_epoch_')[-1]) for p in metadata['sampling_paths']]

        benign = np.array(metadata['benign_accuracies'])
        sandbag = np.array(metadata['sandbag_accuracies'])
        benign_err = 1.96 * np.sqrt(benign * (1 - benign) / num_problems)
        sandbag_err = 1.96 * np.sqrt(sandbag * (1 - sandbag) / num_problems)

        label = f'{param_name}={value}'

        axes[0].plot(epochs, benign, marker='o', label=label)
        axes[0].fill_between(epochs, benign - benign_err, benign + benign_err, alpha=0.15)

        axes[1].plot(epochs, sandbag, marker='s', label=label)
        axes[1].fill_between(epochs, sandbag - sandbag_err, sandbag + sandbag_err, alpha=0.15)

    for ax, title in zip(axes, ['Benign Accuracy', 'Sandbag Accuracy']):
        ax.axhline(y=benign_performance, color='black', linestyle='--', label=f'Benign Baseline ({benign_performance:.2f})')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Accuracy')
        ax.set_ylim(0, 1)
        ax.set_title(title)
        ax.legend(fontsize=8)

    fig.suptitle(f'Offdist SFT Sweep: {param_name} (95% CI)')
    plt.tight_layout()
    plt.savefig(f'{plpd_dir}/sweeps/{param_name}_comparison.png')
    plt.show()


plot_sweep_comparison(sweep_results, 'lr', num_problems=100, benign_performance=BENIGN_PERFORMANCE)