In [None]:
import os
import math
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, f1_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification

CSV_PATH = "/Users/sohammandal/Developer/mlops-comment-moderation/assets/comments_test.csv"   # v1 original
MODEL_NAME = "unitary/toxic-bert"
MAX_LEN = 256            
BATCH_SIZE = 64          
THRESHOLD = 0.5          
SEED = 42

np.random.seed(SEED)
torch.manual_seed(SEED)

# ----- Device pick: CUDA -> MPS (Apple GPU) -> CPU -----
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device

device(type='mps')

In [4]:
# Load data
df = pd.read_csv(CSV_PATH)
needed = {"id", "comment_text", "moderation_label"}
missing = needed - set(df.columns)
if missing:
    raise ValueError(f"CSV missing columns: {missing}")

texts = df["comment_text"].astype(str).tolist()
y_true = df["moderation_label"].astype(int).to_numpy()
len(texts), y_true.shape


(63978, (63978,))

In [5]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
model.to(device)

id2label = model.config.id2label
label2id = model.config.label2id
num_labels = len(id2label)
id2label


{0: 'toxic',
 1: 'severe_toxic',
 2: 'obscene',
 3: 'threat',
 4: 'insult',
 5: 'identity_hate'}

In [6]:
# Batched inference
def batched(iterable, n):
    for i in range(0, len(iterable), n):
        yield i, iterable[i:i+n]

all_probs = np.zeros((len(texts), num_labels), dtype=np.float32)

with torch.no_grad():
    for start_idx, batch_texts in tqdm(list(batched(texts, BATCH_SIZE))):
        enc = tokenizer(
            batch_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LEN,
        )
        enc = {k: v.to(device) for k, v in enc.items()}

        logits = model(**enc).logits
        probs = torch.sigmoid(logits).detach().cpu().numpy().astype(np.float32)

        end_idx = start_idx + len(batch_texts)
        all_probs[start_idx:end_idx, :] = probs


  return forward_call(*args, **kwargs)
100%|██████████| 1000/1000 [16:55<00:00,  1.02s/it]


In [7]:
# positive if ANY label prob > THRESHOLD
y_hat = (all_probs.max(axis=1) > THRESHOLD).astype(int)

# Metrics
print(classification_report(y_true, y_hat, digits=4))
f1_macro = f1_score(y_true, y_hat, average="macro")
f1_weighted = f1_score(y_true, y_hat, average="weighted")
print(f"\nF1 weighted: {f1_weighted:.4f} | F1 macro: {f1_macro:.4f}")


              precision    recall  f1-score   support

           0     0.9891    0.9199    0.9532     57735
           1     0.5502    0.9063    0.6847      6243

    accuracy                         0.9186     63978
   macro avg     0.7697    0.9131    0.8190     63978
weighted avg     0.9463    0.9186    0.9270     63978


F1 weighted: 0.9270 | F1 macro: 0.8190
