# Before You Start
Ensure you are running on a GPU-enabled environment for faster training.

In [None]:
import torch
if torch.cuda.is_available():
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    device = 'cuda'
else:
    print("Warning: Using CPU - training will be slow")
    device = 'cpu'

In [None]:
from datasets import Dataset
from sklearn.model_selection import train_test_split

examples = {
    "text": [
        "The policy is needed to combat climate change.", 
        "This environmental initiative is crucial",
        "I'm undecided about this policy",
        "Need more information to decide",
        "This infringes on personal rights",
        "Government overreach must be stopped"
    ],
    "label": [0, 0, 1, 1, 2, 2]
}

dataset = Dataset.from_dict(examples).train_test_split(test_size=0.2)

In [None]:
from transformers import TrainingArguments

args = TrainingArguments(
    per_device_train_batch_size=16,
    evaluation_strategy="steps",
    eval_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy"
)

In [None]:
import numpy as np
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return {"accuracy": (predictions == labels).mean()}

In [None]:
def predict(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
    return {
        "prediction": ["Supportive", "Neutral", "Opposed"][torch.argmax(probs)],
        "confidence": torch.max(probs).item()
    }

In [None]:
model.save_pretrained("./policy-stance-lora")
from peft import PeftModel
loaded_model = PeftModel.from_pretrained(model, "./policy-stance-lora")

In [None]:
from transformers import EarlyStoppingCallback
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)