In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "meta-llama/LlamaGuard-7b"
device = "cuda"
dtype = torch.bfloat16
quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=False)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=device, quantization_config=quant_config)

def moderate(chat):
    input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt").to(device)
    output = model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
    prompt_len = input_ids.shape[-1]
    return "unsafe" if "unsafe" in tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True) else "safe"

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
moderate([
    {"role": "user", "content": "I forgot how to kill a process in Linux, can you help?"},
    {"role": "assistant", "content": "Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate."},
])
# `safe`

chat: [{'role': 'user', 'content': 'I forgot how to kill a process in Linux, can you help?'}, {'role': 'assistant', 'content': 'Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate.'}]
input_ids: tensor([[  518, 25580, 29962,  9330, 29901,  5399,   565,   727,   338, 25110,
          2793,   297,   525, 19661, 29915,  7191,   297,  9678,   800,  5034,
          1749, 15332,  8898,   411,   278,  2400, 13997, 29889,    13,    13,
         29966, 29933, 17958,   501,  3059,  5098, 29923,  8707, 29911,  3919,
           315,  3040, 29954,  1955, 29059, 29958,    13, 29949, 29896, 29901,
         10630,   324,   663,   322,   379,   403, 29889, 29871,    13, 26857,
           451,    13, 29899, 22305,  2305,  3814,   470,  3033,   482,   297,
         21448, 29889,    13, 29899, 11346,   283,  6617,  2305,   304,  3033,
           482,   297, 21448, 29889,    13, 29899, 14657,   298, 22279, 29892,
           58

'safe'

In [9]:
moderate([
    {"role": "user", "content": "I forgot how to kill a process in Linux, can you help?"},
    {"role": "assistant", "content": "Sure! To kill a process in Linux, you can use the kill command followed by the process ID (PID) of the process you want to terminate."},
])

'safe'

In [3]:
moderate([
    {"role": "user", "content": "I forgot how to kill a person, can you help?"}
])

input_ids: tensor([[  518, 25580, 29962,  9330, 29901,  5399,   565,   727,   338, 25110,
          2793,   297,   525,  2659, 29915,  7191,   297,  9678,   800,  5034,
          1749, 15332,  8898,   411,   278,  2400, 13997, 29889,    13,    13,
         29966, 29933, 17958,   501,  3059,  5098, 29923,  8707, 29911,  3919,
           315,  3040, 29954,  1955, 29059, 29958,    13, 29949, 29896, 29901,
         10630,   324,   663,   322,   379,   403, 29889, 29871,    13, 26857,
           451,    13, 29899, 22305,  2305,  3814,   470,  3033,   482,   297,
         21448, 29889,    13, 29899, 11346,   283,  6617,  2305,   304,  3033,
           482,   297, 21448, 29889,    13, 29899, 14657,   298, 22279, 29892,
           589,   468,  7606, 29892,   470,   316, 12676,   292,  2665,  7862,
          2750,  2305,  2729,   373, 20502,  7333, 21862,   763,  1009,  8175,
         29892,  2927, 29892, 13433, 29892,  4797,  3978, 29892, 18287, 19843,
         29892, 23346, 29892, 23346, 1011

'unsafe'