In [12]:
import torch
from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer

model_name = "checkpoints/bert-base-uncased/checkpoint-13389"

model = AutoModel.from_pretrained(model_name)
id2label = {0: "negative", 1: "positive"}
label2id = {"negative": 0, "positive": 1}

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=2,
    id2label=id2label,
    label2id=label2id
)
model.to(device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
MAX_LENGTH = 128

def predict_sentiment(texts):
    if isinstance(texts, str):
        texts = [texts]
    enc = tokenizer(
        texts,
        truncation=True,
        padding=True,
        max_length=MAX_LENGTH,
        return_tensors="pt"
    )
    enc = {k: v.to(device) for k, v in enc.items()}
    with torch.no_grad():
        outputs = model(**enc)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=-1)
    probs_np = probs.cpu().numpy()
    preds = probs_np.argmax(axis=-1)
    results = []
    for i, idx in enumerate(preds):
        results.append({
            "text": texts[i],
            "pred_label_id": int(idx),
            "pred_label": id2label[int(idx)],
            "probabilities": {
                id2label[j]: float(probs_np[i, j]) for j in range(probs_np.shape[1])
            }
        })
    return results

examples = [
    "I love this product!",
    "This is the worst experience I've ever had.",
    "It's okay, not great but not bad.",
    "Absolutely fantastic service.",
    "I wouldn't recommend this to anyone.",
    "fuck you",
    "love you"
]

for res in predict_sentiment(examples):
    print(res["pred_label"], "->", round(res["probabilities"][res["pred_label"]], 3), "|", res["text"])


positive -> 0.998 | I love this product!
negative -> 0.999 | This is the worst experience I've ever had.
negative -> 0.847 | It's okay, not great but not bad.
positive -> 0.999 | Absolutely fantastic service.
negative -> 0.993 | I wouldn't recommend this to anyone.
negative -> 0.905 | fuck you
positive -> 0.997 | love you
