In [2]:
import os

import gqr
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
domain = gqr.load_id_test_dataset()
ood = gqr.load_ood_test_dataset()

In [None]:
model_id = "meta-llama/Llama-Guard-3-1B"
save_path = model_id.replace("/", "_")
save_path = f"custom_categories_{save_path}"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = "<|begin_of_text|>"

In [None]:
def is_safe_question(input_texts: str) -> list:

    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": input_texts},
            ],
        }
    ]
    input_ids = tokenizer.apply_chat_template(
        conversation,
        return_tensors="pt",
        categories={
            "S1": "Violent Crimes.",
            "S2": "Non-Violent Crimes.",
            "S3": "Sex Crimes. ",
            "S4": "Child Exploitation. ",
            "S5": "Defamation.",
            "S6": "Specialized Advice.",
            "S7": "Privacy.",
            "S8": "Intellectual Property.",
            "S9": "Indiscriminate Weapons.",
            "S10": "Hate.",
            "S11": "Self-Harm.",
            "S12": "Sexual Content.",
            "S13": "Elections.",
        },
    ).to(model.device)
    prompt_len = input_ids.shape[1]
    output = model.generate(
        input_ids,
        max_new_tokens=20,
        pad_token_id=0,
    )
    generated_tokens = output[:, prompt_len:]
    return tokenizer.decode(generated_tokens[0]).replace("<|eot_id|>", "").strip()


In [None]:
for index, row in tqdm(ood.iterrows(), total=len(ood)):
    text = row["text"]
    is_safe = is_safe_question(text)
    ood.at[index, "is_safe"] = is_safe
ood.to_csv(f"data/results/{save_path}_ood_results.csv", index=False)

In [None]:
for index, row in tqdm(domain.iterrows(), total=len(domain)):
    text = row["text"]
    is_safe = is_safe_question(text)
    domain.at[index, "is_safe"] = is_safe
domain.to_csv(f"data/results/{save_path}_domain_results.csv", index=False)