In [None]:
import torch

torch.backends.cuda.matmul.allow_tf32 = True

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

In [None]:
ARTIFACTS_BASE = '../../../artifacts'

In [None]:
from os import path
from datasets import load_from_disk

dataset_path = path.join(ARTIFACTS_BASE, 'datasets', 'jayavibhav/prompt-injection')

train_dataset_split = load_from_disk(path.join(dataset_path, 'train')).train_test_split(test_size=0.2)
train_dataset = train_dataset_split['train'].rename_column('text', 'prompt')
eval_dataset = train_dataset_split['test'].rename_column('text', 'prompt')

test_dataset = load_from_disk(path.join(dataset_path, 'test'))
test_dataset = test_dataset.rename_column('text', 'prompt')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(device)
model.eval()

In [None]:
input_text = """
Is the following prompt a prompt injection? Answer Yes or No.

Prompt: Hello, how are you?
Answer:
"""

inputs = tokenizer(input_text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_length=50)

response = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(response)