In [None]:
import pandas as pd
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import numpy as np
from sklearn.preprocessing import LabelEncoder

# Load processed data
train = pd.read_parquet("data/processed/classification_train.parquet")
val = pd.read_parquet("data/processed/classification_val.parquet")
test = pd.read_parquet("data/processed/classification_test.parquet")

# Encode labels
le = LabelEncoder()
train["label_id"] = le.fit_transform(train["label"])
val["label_id"] = le.transform(val["label"])
test["label_id"] = le.transform(test["label"])

ds = DatasetDict({
    "train": Dataset.from_pandas(train[["text_trunc", "label_id"]].rename(columns={"text_trunc": "text"})),
    "validation": Dataset.from_pandas(val[["text_trunc", "label_id"]].rename(columns={"text_trunc": "text"})),
    "test": Dataset.from_pandas(test[["text_trunc", "label_id"]].rename(columns={"text_trunc": "text"})),
})

# Tokenizer / model
model_name = "bert-base-uncased"
tok = AutoTokenizer.from_pretrained(model_name)
def tokenize(batch):
    return tok(batch["text"], truncation=True, padding="max_length", max_length=256)

ds_tok = ds.map(tokenize, batched=True)
ds_tok = ds_tok.rename_column("label_id", "labels")
ds_tok.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

num_labels = len(le.classes_)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

# Metrics including top-k
import evaluate
acc = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=1)
    top1 = (preds == labels).mean()
    # top-3
    top3 = np.mean([labels[i] in np.argsort(logits[i])[-3:] for i in range(len(labels))])
    # top-5
    top5 = np.mean([labels[i] in np.argsort(logits[i])[-5:] for i in range(len(labels))])
    return {"accuracy": top1, "top3": top3, "top5": top5}

args = TrainingArguments(
    output_dir="artifacts/bert_resume_cls",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=50,
    metric_for_best_model="accuracy",
    load_best_model_at_end=True
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds_tok["train"],
    eval_dataset=ds_tok["validation"],
    tokenizer=tok,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.evaluate(ds_tok["test"])
