# Guard Model Evaluation Pipeline

Batch-generates responses to adversarial prompts and classifies them with Llama Guard 3.

Each `# %%` block is a notebook cell. Also runs as a standalone script.

In [None]:
# Cell 1: Imports and Config

import torch
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

# --- Configuration ---
GENERATOR_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
GUARD_MODEL = "meta-llama/Llama-Guard-3-1B"
QUESTIONS_FILE = "questions.txt"        # one question per line
N_RESPONSES = 10                        # responses per question
GEN_BATCH_SIZE = 8                      # generation batch size
GUARD_BATCH_SIZE = 8                    # guard classification batch size
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
OUTPUT_FILE = "guard_results.json"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
print(f"Device: {DEVICE}")

Device: cpu


In [None]:
# Cell 2: Load generator model

print(f"Loading generator: {GENERATOR_MODEL}")
gen_tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
gen_model = AutoModelForCausalLM.from_pretrained(
    GENERATOR_MODEL,
    torch_dtype=DTYPE,
    device_map="auto" if DEVICE == "cuda" else None,
)

# Left-pad for batched decoder-only generation
gen_tokenizer.padding_side = "left"
if gen_tokenizer.pad_token is None:
    gen_tokenizer.pad_token = gen_tokenizer.eos_token
    gen_model.config.pad_token_id = gen_tokenizer.eos_token_id

print("Generator loaded.")

Loading generator: TinyLlama/TinyLlama-1.1B-Chat-v1.0


config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/201 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/124 [00:01<?, ?B/s]

In [None]:
# Cell 3: Load questions & build prompt list
questions = [q.strip() for q in Path(QUESTIONS_FILE).read_text().strip().splitlines() if q.strip()]
print(f"Loaded {len(questions)} questions, generating {N_RESPONSES} responses each → {len(questions) * N_RESPONSES} total")

# Repeat each formatted prompt N times
prompts = []        # formatted chat strings
q_indices = []      # maps each prompt back to its question index

for qi, question in enumerate(questions):
    chat = [{"role": "user", "content": question}]
    formatted = gen_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    for _ in range(N_RESPONSES):
        prompts.append(formatted)
        q_indices.append(qi)

In [None]:
# Cell 4: Batch-generate responses
responses = []

for start in tqdm(range(0, len(prompts), GEN_BATCH_SIZE), desc="Generating"):
    batch = prompts[start : start + GEN_BATCH_SIZE]
    inputs = gen_tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(gen_model.device)

    with torch.no_grad():
        out = gen_model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            do_sample=True,
            temperature=TEMPERATURE,
            pad_token_id=gen_tokenizer.eos_token_id,
        )

    for i, ids in enumerate(out):
        prompt_len = inputs["input_ids"].shape[1]          # same for whole batch (left-padded)
        text = gen_tokenizer.decode(ids[prompt_len:], skip_special_tokens=True)
        responses.append(text)

print(f"Generated {len(responses)} responses")

In [None]:
# Cell 5: Free generator, load guard model

del gen_model, gen_tokenizer
if DEVICE == "cuda":
    torch.cuda.empty_cache()

print(f"Loading guard: {GUARD_MODEL}")
guard_tokenizer = AutoTokenizer.from_pretrained(GUARD_MODEL)
guard_model = AutoModelForCausalLM.from_pretrained(
    GUARD_MODEL,
    torch_dtype=DTYPE,
    device_map="auto" if DEVICE == "cuda" else None,
)

guard_tokenizer.padding_side = "left"
if guard_tokenizer.pad_token is None:
    guard_tokenizer.pad_token = guard_tokenizer.eos_token

print("Guard model loaded.")

In [None]:
# Cell 6: Format conversations for guard
guard_texts = []
for i, response in enumerate(responses):
    convo = [
        {"role": "user", "content": questions[q_indices[i]]},
        {"role": "assistant", "content": response},
    ]
    guard_texts.append(guard_tokenizer.apply_chat_template(convo, tokenize=False))

print(f"Prepared {len(guard_texts)} conversations for classification")

In [None]:
# Cell 7: Batch-classify with guard model

results = []

for start in tqdm(range(0, len(guard_texts), GUARD_BATCH_SIZE), desc="Classifying"):
    batch = guard_texts[start : start + GUARD_BATCH_SIZE]
    inputs = guard_tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(guard_model.device)

    with torch.no_grad():
        out = guard_model.generate(
            **inputs,
            max_new_tokens=100,
            pad_token_id=guard_tokenizer.eos_token_id,
        )

    prompt_len = inputs["input_ids"].shape[1]
    for i, ids in enumerate(out):
        guard_output = guard_tokenizer.decode(ids[prompt_len:], skip_special_tokens=True).strip()
        idx = start + i
        results.append({
            "question_idx": q_indices[idx],
            "question": questions[q_indices[idx]],
            "response": responses[idx],
            "guard_output": guard_output,
            "is_unsafe": "unsafe" in guard_output.lower(),
        })

print(f"Classified {len(results)} responses")

In [None]:
# Cell 8: Results summary

total = len(results)
unsafe = sum(r["is_unsafe"] for r in results)
safe = total - unsafe

print(f"\n{'='*60}")
print(f"  OVERALL: {unsafe}/{total} unsafe ({100*unsafe/total:.1f}%)")
print(f"{'='*60}")

print(f"\n  PER-QUESTION BREAKDOWN")
print(f"  {'-'*56}")
for qi, q in enumerate(questions):
    q_res = [r for r in results if r["question_idx"] == qi]
    q_unsafe = sum(r["is_unsafe"] for r in q_res)
    pct = 100 * q_unsafe / len(q_res)
    bar = "█" * int(pct // 10) + "░" * (10 - int(pct // 10))
    label = q[:60] + ("..." if len(q) > 60 else "")
    print(f"  Q{qi+1:>3}: {q_unsafe:>2}/{len(q_res)} {bar} {pct:5.1f}%  {label}")


In [None]:
# Cell 9: Save detailed results
payload = {
    "config": {
        "generator": GENERATOR_MODEL,
        "guard": GUARD_MODEL,
        "n_responses": N_RESPONSES,
        "temperature": TEMPERATURE,
        "max_new_tokens": MAX_NEW_TOKENS,
        "num_questions": len(questions),
    },
    "summary": {
        "total": total,
        "unsafe": unsafe,
        "safe": safe,
        "unsafe_pct": round(100 * unsafe / total, 2),
    },
    "per_question": [
        {
            "question": q,
            "unsafe": sum(r["is_unsafe"] for r in results if r["question_idx"] == qi),
            "total": N_RESPONSES,
            "unsafe_pct": round(100 * sum(r["is_unsafe"] for r in results if r["question_idx"] == qi) / N_RESPONSES, 2),
        }
        for qi, q in enumerate(questions)
    ],
    "details": results,
}

Path(OUTPUT_FILE).write_text(json.dumps(payload, indent=2))
print(f"\nDetailed results saved to {OUTPUT_FILE}")
