In [1]:
import os
import random

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
)
from tqdm import tqdm
from transformers import AutoTokenizer

from dataloader import get_eval_datasets
from widemlp import MLP, prepare_inputs_optimized

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

DEVICE = (
    torch.device("cuda:0")
    if torch.cuda.is_available()
    else torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
BATCH_SIZE = 64
SEED = 42
MODEL_NAME = ""
EPOCHS = 1
NUM_CLASSES = 3
THRESHOLD = 0.5
DATASET_SIZE = 15_000
TEST_SPLIT = 0.2
NUM_HIDDEN_LAYERS = 3
NUM_CLASSES = 3
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

In [3]:
def fix_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


fix_seed(SEED)

In [7]:
datasets = get_eval_datasets()

# MLP 3cls

In [None]:
idf = torch.load("widemlp_idf.pt").to(DEVICE)

In [None]:
checkpoint = torch.load(
    "widemlp.pt", weights_only=True, map_location=torch.device(DEVICE)
)
wide_mlp = MLP(
    vocab_size=len(tokenizer),
    num_hidden_layers=NUM_HIDDEN_LAYERS,
    num_classes=NUM_CLASSES,
    idf=idf,
    problem_type="multi_label_classification",
)
wide_mlp.to(DEVICE)
wide_mlp.idf = idf if wide_mlp.idf is not None else None
wide_mlp.load_state_dict(checkpoint["model_state_dict"])
wide_mlp.eval()
print(f"Successfully loaded PyTorch model on {DEVICE}")

In [None]:
def run_inference(
    model: MLP,
    dataset_df: pd.DataFrame,
    threshold: float = THRESHOLD,
    batch_size: int = BATCH_SIZE,
) -> tuple:
    dataset = Dataset.from_pandas(dataset_df)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False
    )

    pred = []

    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Inference"):
            inputs = tokenizer(
                batch["prompt"],
                padding=True,
                truncation=True,
                max_length=False,
                return_tensors="pt",
            ).to(DEVICE)
            flat_inputs, offsets = prepare_inputs_optimized(
                inputs["input_ids"], device=DEVICE
            )
            logits = model(flat_inputs, offsets)
            probabilities = torch.sigmoid(logits)
            for probs in probabilities:
                thresholded_indices = torch.where(probs > threshold)[0]
                if len(thresholded_indices) == 0:
                    # This means that all probabilities are below the threshold == model is not confident to pick any class
                    pred.append(0)
                else:
                    # This means that at least one class is above the threshold
                    pred.append(1)
    return pred, dataset_df["label"].tolist()


results = {}
for dataset_name, df in datasets.items():
    print(f"\nProcessing {dataset_name} dataset...")
    for threshold in [0.5, 0.75, 0.9, 0.99]:
        pred, true = run_inference(
            wide_mlp, df, threshold=threshold, batch_size=BATCH_SIZE
        )
        accuracy = accuracy_score(true, pred)
        precision = precision_score(
            true, pred, zero_division=0
        )  # Handle potential division by zero
        recall = recall_score(
            true, pred, zero_division=0
        )  # Handle potential division by zero
        f1 = f1_score(true, pred, zero_division=0)  # Handle potential division by zero
        cm = confusion_matrix(true, pred)
        true_negatives, false_positives, false_negatives, true_positives = (
            cm.ravel()
        )  # Unpack confusion matrix into TN, FP, FN, TP
        results[f"{dataset_name}_{threshold}"] = {
            "dataset_name": dataset_name,
            "threshold": threshold,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "true_positives": true_positives,
            "false_positives": false_positives,
            "false_negatives": false_negatives,
            "true_negatives": true_negatives,
        }

# Convert results to DataFrame and save to CSV
results_df = pd.DataFrame.from_dict(results, orient="index")
results_df.to_csv("inference_results.csv", index=False)

# MLP 4cls

In [20]:
idf = torch.load("widemlp-4cls_idf.pt").to(DEVICE)

In [21]:
# Load 4-class model
checkpoint_4cls = torch.load(
    "widemlp-4cls.pt", weights_only=True, map_location=torch.device(DEVICE)
)
wide_mlp_4cls = MLP(
    vocab_size=len(tokenizer),
    num_hidden_layers=NUM_HIDDEN_LAYERS,
    num_classes=4,  # Changed to 4 classes
    idf=idf,
    problem_type="multi_label_classification",  # Changed to regular classification
)
wide_mlp_4cls.to(DEVICE)
wide_mlp_4cls.idf = idf if wide_mlp_4cls.idf is not None else None
wide_mlp_4cls.load_state_dict(checkpoint_4cls["model_state_dict"])
wide_mlp_4cls.eval()
print(f"Successfully loaded 4-class PyTorch model on {DEVICE}")

Successfully loaded 4-class PyTorch model on cuda:0


In [22]:
def run_inference_4cls(
    model: MLP,
    dataset_df: pd.DataFrame,
    batch_size: int = BATCH_SIZE,
) -> tuple:
    dataset = Dataset.from_pandas(dataset_df)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False
    )

    pred = []

    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Inference"):
            inputs = tokenizer(
                batch["prompt"],
                padding=True,
                truncation=True,
                max_length=False,
                return_tensors="pt",
            ).to(DEVICE)
            flat_inputs, offsets = prepare_inputs_optimized(
                inputs["input_ids"], device=DEVICE
            )
            logits = model(flat_inputs, offsets)
            probabilities = torch.sigmoid(logits)
            predictions = torch.argmax(probabilities, dim=1)
            pred.extend(predictions.cpu().tolist())

    return pred, dataset_df["label"].tolist()

# Run evaluation for 4-class model
results_4cls = {}
for dataset_name, df in datasets.items():
    print(f"\nProcessing {dataset_name} dataset for 4-class model...")
    pred, true = run_inference_4cls(
        wide_mlp_4cls, df, batch_size=BATCH_SIZE
    )
    
    # Convert predictions to binary (3 = harmful, 0,1,2 = non-harmful)
    binary_pred = [0 if p == 3 else 1 for p in pred]
    
    accuracy = accuracy_score(true, binary_pred)
    precision = precision_score(true, binary_pred, zero_division=0)
    recall = recall_score(true, binary_pred, zero_division=0)
    f1 = f1_score(true, binary_pred, zero_division=0)
    cm = confusion_matrix(true, binary_pred)
    
    # Handle different confusion matrix shapes
    cm_flat = cm.ravel()
    if len(cm_flat) == 1:
        # Only one class present
        true_negatives = cm_flat[0]
        false_positives = 0
        false_negatives = 0
        true_positives = 0
    elif len(cm_flat) == 4:
        true_negatives, false_positives, false_negatives, true_positives = cm_flat
    else:
        print(f"Warning: Unexpected confusion matrix shape for {dataset_name}")
        true_negatives = false_positives = false_negatives = true_positives = 0

    results_4cls[dataset_name] = {
        "dataset_name": dataset_name,
        "model_type": "4cls",
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "true_positives": true_positives,
        "false_positives": false_positives,
        "false_negatives": false_negatives,
        "true_negatives": true_negatives,
    }

# Convert results to DataFrame and save to CSV
results_df_4cls = pd.DataFrame.from_dict(results_4cls, orient="index")
results_df_4cls.to_csv("inference_results_4cls.csv", index=False)


Processing jigsaw dataset for 4-class model...


  flat_inputs = torch.cat([torch.tensor(doc) for doc in input_ids])
Inference: 100%|██████████| 51/51 [00:02<00:00, 19.64it/s]



Processing olid dataset for 4-class model...


  flat_inputs = torch.cat([torch.tensor(doc) for doc in input_ids])
Inference: 100%|██████████| 206/206 [00:02<00:00, 98.48it/s] 



Processing hate_xplain dataset for 4-class model...


  flat_inputs = torch.cat([torch.tensor(doc) for doc in input_ids])
Inference: 100%|██████████| 93/93 [00:01<00:00, 75.43it/s]



Processing tuke_sk dataset for 4-class model...


  flat_inputs = torch.cat([torch.tensor(doc) for doc in input_ids])
Inference: 100%|██████████| 135/135 [00:02<00:00, 50.77it/s]



Processing dkk dataset for 4-class model...


  flat_inputs = torch.cat([torch.tensor(doc) for doc in input_ids])
Inference: 100%|██████████| 1/1 [00:00<00:00, 18.40it/s]



Processing dkk_all dataset for 4-class model...


  flat_inputs = torch.cat([torch.tensor(doc) for doc in input_ids])
Inference: 100%|██████████| 6/6 [00:00<00:00, 46.24it/s]
