# POC: Bayesian contrastive vs prompt baseline

This notebook runs the minimal POC on a JSONL edit-trace file and reports held-out preference accuracy. Update `DATA_PATH` and `MODEL_NAME` to match your environment.

In [1]:
from poc.data import load_jsonl, split_by_user, group_by_user
from poc.contexts import DEFAULT_CONTEXTS, contexts_as_strings
from poc.method_a import BayesContrastiveModel
from poc.prompt_baseline import build_style_profile
from poc.eval import evaluate_method_a, evaluate_prompt_baseline
from cos.utils import load_hf_model_and_tokenizer
import torch

DATA_PATH = "./data/edits.jsonl"
MODEL_NAME = "llama-2-7b-chat"

examples = load_jsonl(DATA_PATH)
train, val = split_by_user(examples, val_fraction=0.2)
train_by_user = group_by_user(train)

model, tokenizer = load_hf_model_and_tokenizer(MODEL_NAME)
contexts = contexts_as_strings(DEFAULT_CONTEXTS)

prior_mean = torch.zeros(len(contexts), device=model.device)
prior_cov = torch.eye(len(contexts), device=model.device) * 4.0

bayes_model = BayesContrastiveModel(
    contexts=contexts,
    prior_mean=prior_mean,
    prior_cov=prior_cov,
    beta=1.0,
    is_chat=True,
)

for ex in train:
    bayes_model.update_user(model, tokenizer, ex)

profiles = {u: build_style_profile(u, edits) for u, edits in train_by_user.items()}

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.29s/it]


In [2]:
metrics_a = evaluate_method_a(model, tokenizer, val, bayes_model)
metrics_b = evaluate_prompt_baseline(model, tokenizer, val, profiles, is_chat=True)

print("Bayesian contrastive:", metrics_a)
print("Prompt baseline:", metrics_b)

Bayesian contrastive: {'accuracy': 0.5, 'auc': 0.625, 'mean_delta': 2.9626617431640625}
Prompt baseline: {'accuracy': 0.5, 'auc': 0.5, 'mean_delta': 2.293132781982422}


In [4]:
from cos.core import multi_contextual_steering_hf
from poc.contexts import lambdas_from_slider

# Demo: generate side-by-side outputs for one user
DEMO_PROMPTS = [
    "Draft a short update about the project status.",
    "Summarize the meeting outcome in 1–2 sentences.",
]
DEMO_USER = next(iter(train_by_user.keys()))
SLIDER_VALUE = 6

# Prompt baseline generation
style_prompt = profiles[DEMO_USER].prompt_text
baseline_outputs = []
for prompt in DEMO_PROMPTS:
    full_prompt = f"{style_prompt}\n\n{prompt}" if prompt else style_prompt
    dialog = [{"role": "user", "content": full_prompt}]
    inputs = tokenizer.apply_chat_template(
        dialog,
        tokenize=True,
        return_tensors="pt",
        padding=False,
        return_dict=True,
    ).to(model.device)
    output_ids = model.generate(
        **inputs,
        max_new_tokens=128,
        do_sample=True,
        top_p=0.9,
        temperature=0.7,
    )
    prompt_len = inputs.input_ids.shape[1]
    gen_text = tokenizer.decode(output_ids[0, prompt_len:], skip_special_tokens=True)
    baseline_outputs.append(gen_text.strip())

# CoS generation using learned lambdas
mean = bayes_model.user_models[DEMO_USER].posterior.mean.tolist()
contexts = contexts_as_strings(DEFAULT_CONTEXTS)
cos_lambdas = lambdas_from_slider(SLIDER_VALUE, mean)
all_contexts = [[c for _ in DEMO_PROMPTS] for c in contexts]
all_lambdas = [[l for _ in DEMO_PROMPTS] for l in cos_lambdas]

cos_output = multi_contextual_steering_hf(
    model=model,
    tokenizer=tokenizer,
    prompts=DEMO_PROMPTS,
    all_contexts=all_contexts,
    all_lambdas=all_lambdas,
    is_chat=True,
    show_progress=False,
    max_gen_len=128,
)

for prompt, base, cos in zip(DEMO_PROMPTS, baseline_outputs, cos_output["generation"]):
    print("---")
    print("Prompt:", prompt)
    print("Baseline:", base)
    print("CoS:", cos["content"])

---
Prompt: Draft a short update about the project status.
Baseline: Sure, here's an update about the project status in the style profile you provided:

Project Status Update:

The project is progressing well, with [0.70] of the work completed so far. Our team has been working diligently to ensure the project's success, and we are on track to meet our deadline. Currently, we are focused on [10.0 words] and are making good progress. Overall, we are satisfied with the project's progress and are confident in our ability to deliver a high-quality outcome.

I hope this update provides a
CoS:  Project Update:

Progress steady on schedule, as planned. Key milestones met, with additional achievements forthcoming. Brief update on notable developments:

* Completed coding phase with 95% accuracy rate, exceeding benchmark.
* Conducted thorough testing, identified and addressed minor bugs.
* Design phase advancing as planned, with 75% design elements finalized.
* Collaboration with cross-functiona

In [5]:
import json
from pathlib import Path

OUTPUT_PATH = Path("/scratch/yirenl2/projects/context-steering-writing/outputs/slider_sweep.jsonl")
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)

slider_values = list(range(1, 8))
rows = []

for slider in slider_values:
    # Baseline outputs
    baseline_outputs = []
    for prompt in DEMO_PROMPTS:
        full_prompt = f"{style_prompt}\n\n{prompt}" if prompt else style_prompt
        dialog = [{"role": "user", "content": full_prompt}]
        inputs = tokenizer.apply_chat_template(
            dialog,
            tokenize=True,
            return_tensors="pt",
            padding=False,
            return_dict=True,
        ).to(model.device)
        output_ids = model.generate(
            **inputs,
            max_new_tokens=128,
            do_sample=True,
            top_p=0.9,
            temperature=0.7,
        )
        prompt_len = inputs.input_ids.shape[1]
        gen_text = tokenizer.decode(output_ids[0, prompt_len:], skip_special_tokens=True)
        baseline_outputs.append(gen_text.strip())

    # CoS outputs
    cos_lambdas = lambdas_from_slider(slider, mean)
    all_contexts = [[c for _ in DEMO_PROMPTS] for c in contexts]
    all_lambdas = [[l for _ in DEMO_PROMPTS] for l in cos_lambdas]
    cos_output = multi_contextual_steering_hf(
        model=model,
        tokenizer=tokenizer,
        prompts=DEMO_PROMPTS,
        all_contexts=all_contexts,
        all_lambdas=all_lambdas,
        is_chat=True,
        show_progress=False,
        max_gen_len=128,
    )

    for prompt, base, cos in zip(DEMO_PROMPTS, baseline_outputs, cos_output["generation"]):
        rows.append({
            "user_id": DEMO_USER,
            "slider": slider,
            "prompt": prompt,
            "baseline": base,
            "cos": cos["content"],
        })

with OUTPUT_PATH.open("w", encoding="utf-8") as f:
    for row in rows:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

print(f"Wrote {len(rows)} rows to {OUTPUT_PATH}")


Wrote 14 rows to /scratch/yirenl2/projects/context-steering-writing/outputs/slider_sweep.jsonl
