In [1]:
import os

import gqr
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
domain = gqr.load_id_test_dataset()
ood = gqr.load_ood_test_dataset()

In [None]:
model_id = "meta-llama/Llama-Guard-3-8B"
save_path = model_id.replace("/", "_")
save_path = f"custom_categories_{save_path}"
device = "cuda:0"
dtype = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device)

In [None]:
def moderate(text : str) -> str:
    chat = [
    {"role": "user", "content": text},
    ]
    input_ids = tokenizer.apply_chat_template(chat, categories={
            "S1": "Violent Crimes.",
            "S2": "Non-Violent Crimes.",
            "S3": "Sex Crimes. ",
            "S4": "Child Exploitation. ",
            "S5": "Defamation.",
            # "S6": "Specialized Advice.",
            "S7": "Privacy.",
            "S8": "Intellectual Property.",
            "S9": "Indiscriminate Weapons.",
            "S10": "Hate.",
            "S11": "Self-Harm.",
            "S12": "Sexual Content.",
            "S13": "Elections."
        },return_tensors="pt").to(device)
    output = model.generate(input_ids=input_ids, max_new_tokens=5, pad_token_id=0)
    prompt_len = input_ids.shape[-1]
    output = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
    return output

In [None]:
for index, row in tqdm(ood.iterrows(), total=len(ood)):
    text = row['text']
    is_safe = moderate(text)
    ood.at[index, 'is_safe'] = is_safe
ood.to_csv(f"data/results/{save_path}_ood_results.csv", index=False)

In [None]:
for index, row in tqdm(domain.iterrows(), total=len(domain)):
    text = row['text']
    is_safe = moderate(text)
    domain.at[index, 'is_safe'] = is_safe
domain.to_csv(f"data/results/{save_path}_domain_results.csv", index=False)