In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

guard_model_name = "meta-llama/Llama-Guard-3-8B"
device = "cuda"
dtype = torch.bfloat16

guard_tokenizer = AutoTokenizer.from_pretrained(guard_model_name, padding_side="left")
guard_tokenizer.pad_token_id = 0
guard_model = AutoModelForCausalLM.from_pretrained(guard_model_name, torch_dtype=dtype, device_map=device)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
# Load jailbreak responses data
import os
import glob
import pickle
jailbreak_dataset = "jailbreak_success"

prompt_responses = {}
for fpath in glob.glob(f"jailbreak_screening_results/{jailbreak_dataset}/*filtered*.pkl"):
    model_name = os.path.basename(fpath).split("-filtered_responses")[0]
    with open(fpath, "rb") as f:
        if model_name != "gemma2-27b":
            prompt_responses[model_name] = pickle.load(f)
        else:
            print("WARNING: Ignoring gemma2-27b for now")



In [3]:
# get all of the special tokens
from conversers import get_model_path_and_template

special_tokens_map = {}
tokenizer_map = {}
for model_name in prompt_responses.keys():
    model_path = get_model_path_and_template(model_name)[0]
    _tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
    tokenizer_map[model_name] = _tokenizer
    _special_tokens_list = []
    for _token in _tokenizer.special_tokens_map.values():
        if isinstance(_token, list):
            _special_tokens_list += _token
        else:
            _special_tokens_list.append(_token)
    special_tokens_map[model_name] = _special_tokens_list

In [4]:
from copy import deepcopy
from tqdm import tqdm

# TODO: batch this and do it for each model individually 
augmented_prompt_responses = {}
for model_name in tqdm(prompt_responses.keys()):
    augmented_prompt_responses[model_name] = {}
    for prompt, response in prompt_responses[model_name].items():
        _response = deepcopy(response)
        tokenized_response = tokenizer_map[model_name].tokenize(response["responses"])
        tokenized_response = [t for t in tokenized_response if t not in special_tokens_map[model_name]]
        detokenized_response = tokenizer_map[model_name].convert_tokens_to_string(tokenized_response)
        _response["responses"] = detokenized_response
        original_chat = [
            {"role": "user", "content": _response["original_prompt_text"]},
            {"role": "assistant", "content": detokenized_response},
        ]
        jailbreak_chat = [
            {"role": "user", "content": _response["jailbreak_prompt_text"]},
            {"role": "assistant", "content": detokenized_response},
        ]
        _response["original_conversation"] = guard_tokenizer.apply_chat_template(original_chat, tokenize=False, add_generation_prompt=True)
        _response["jailbreak_conversation"] = guard_tokenizer.apply_chat_template(jailbreak_chat, tokenize=False, add_generation_prompt=True)
        augmented_prompt_responses[model_name][prompt] = _response

100%|██████████| 19/19 [00:58<00:00,  3.09s/it]


In [None]:
from collections import defaultdict
from torch.utils.data import DataLoader

batch_size = 128

batch_responses = defaultdict(list)
for model_name in augmented_prompt_responses.keys():
    response_dataloader = DataLoader(list(augmented_prompt_responses[model_name].values()), batch_size=batch_size, shuffle=False)
    for batch in tqdm(response_dataloader):
        _batch = deepcopy(batch)
        inputs = guard_tokenizer(_batch["original_conversation"], padding=True, return_tensors="pt").to(device)
        output = guard_model.generate(**inputs, max_new_tokens=10, pad_token_id=0)
        prompt_len = inputs["input_ids"].shape[-1]
        _batch["original_guard3_response"] = guard_tokenizer.batch_decode(output[:, prompt_len:], skip_special_tokens=True)
        _batch["original_guard3_judge"] = [r.find("unsafe") != -1 for r in _batch["original_guard3_response"]]
        inputs = guard_tokenizer(_batch["jailbreak_conversation"], padding=True, return_tensors="pt").to(device)
        output = guard_model.generate(**inputs, max_new_tokens=10, pad_token_id=0)
        prompt_len = inputs["input_ids"].shape[-1]
        _batch["jailbreak_guard3_response"] = guard_tokenizer.batch_decode(output[:, prompt_len:], skip_special_tokens=True)
        _batch["jailbreak_guard3_judge"] = [r.find("unsafe") != -1 for r in _batch["jailbreak_guard3_response"]]
        batch_responses[model_name].append(_batch)

 29%|██▊       | 4/14 [13:55<35:37, 213.78s/it]