In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
import torch
from transformers import AutoTokenizer, AutoModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


df = pd.read_csv("/kaggle/input/ag-news-classification-dataset/train.csv", header=None, names=["ClassId", "Title", "Description"])
df = df.iloc[1:].reset_index(drop=True)
df["ClassId"] = df["ClassId"].astype(int)
df["Text"] = df["Title"] + " " + df["Description"]


def get_cls_embedding(texts, tokenizer, model, device):
    embeddings = []
    for text in texts:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        cls_embed = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
        embeddings.append(cls_embed)
    return np.array(embeddings)


def classify(texts, label_texts, model_name, device):
    print(f"\nUsing model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()

    label_embeddings = get_cls_embedding(list(label_texts.values()), tokenizer, model, device)
    text_embeddings = get_cls_embedding(texts, tokenizer, model, device)

    predictions = []
    for text_emb in tqdm(text_embeddings, desc="Classifying"):
        sims = cosine_similarity([text_emb], label_embeddings)[0]
        predicted_label = np.argmax(sims) + 1
        predictions.append(predicted_label)

    return predictions

subset_size = 5000
texts = df["Text"].tolist()[:subset_size]
true_labels = df["ClassId"].tolist()[:subset_size]

pred_base = classify(texts, label_texts, "bert-base-uncased", device)
pred_simcse = classify(texts, label_texts, "princeton-nlp/sup-simcse-bert-base-uncased", device)


acc_base = np.mean(np.array(pred_base) == np.array(true_labels))
acc_simcse = np.mean(np.array(pred_simcse) == np.array(true_labels))

print(f"\nBaseline BERT Accuracy: {acc_base:.4f}")
print(f"SimCSE Accuracy:         {acc_simcse:.4f}")