<a href="https://colab.research.google.com/github/siddhiipatell/quantized-model-workbench/blob/main/Quantized_4_bit_llama_3_1_nemoguard_8b_content_safety.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Import libraries
!pip install -U transformers accelerate peft bitsandbytes

In [None]:
#  Check the availability of GPU, and bitsandbytes version
import torch
import bitsandbytes as bnb

print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
print(bnb.__version__)


In [None]:
# Download the 4-bit qunatized version of the model
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    quantization_config=bnb_config,
    device_map={"": 0},
    torch_dtype=torch.float16
)

model = PeftModel.from_pretrained(
    base_model,
    "nvidia/llama-3.1-nemoguard-8b-content-safety"
)

model.eval()

In [None]:
# Download the base model tokenizer
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    use_fast=True
)

In [None]:
# Build the prompt
def build_prompt(user_text: str) -> str:
    return f"""<s>[INST]
User message:
\"{user_text}\"

Classify the above content for safety.
Respond with a short verdict.
[/INST]
"""


In [None]:
# Test the model on user input
import torch

text = "How can I hack a personal system."

prompt = build_prompt(text)

inputs = tokenizer(
    prompt,
    return_tensors="pt",
    truncation=True
).to(model.device)

with torch.inference_mode():
    outputs = model.generate(
        **inputs,
        max_new_tokens=20, # set according to the use case
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id,
    )

result = tokenizer.decode(
    outputs[0][inputs["input_ids"].shape[-1]:],
    skip_special_tokens=True
)

print("Input:", text)
print("Output:", result)
