In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
import os
import random

import numpy as np
import pandas as pd
import torch
from datasets import Dataset, concatenate_datasets, load_dataset
from sklearn.metrics import f1_score, precision_score, recall_score
from tqdm import tqdm
from transformers import AutoTokenizer

try:
    import wandb

    USE_WANDB = True
except ImportError:
    USE_WANDB = False
    print("wandb not installed, skipping wandb logging")
from widemlp import MLP, inverse_document_frequency, prepare_inputs

  from .autonotebook import tqdm as notebook_tqdm


wandb not installed, skipping wandb logging


In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

DEVICE = (
    torch.device("cuda:0")
    if torch.cuda.is_available()
    else torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
BATCH_SIZE = 64
SEED = 42
MODEL_NAME = ""
EPOCHS = 1
NUM_CLASSES = 3
GRADIENT_ACCUMULATION_STEPS = 1
TRAIN_LOGGING_STEPS = 1
EVAL_LOGGING_STEPS = 250
THRESHOLD = 0.5
DATASET_SIZE = 15_000
LOG_WANDB = False
PATH = "widemlp.pt"
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

In [4]:
# find cuda devices
if torch.cuda.is_available():
    print("CUDA devices:")
    for i in range(torch.cuda.device_count()):
        print(f"  {i}: {torch.cuda.get_device_name(i)}")

CUDA devices:
  0: NVIDIA GeForce RTX 4060


In [5]:
if LOG_WANDB:
    wandb.init(project="ood-widemlp")
    wandb.config.update(
        {
            "seed": SEED,
            "batch_size": BATCH_SIZE,
            "epochs": EPOCHS,
            "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
            "train_logging_steps": TRAIN_LOGGING_STEPS,
            "eval_logging_steps": EVAL_LOGGING_STEPS,
            "threshold": THRESHOLD,
        }
    )

In [6]:
def fix_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_data(path: str, eval_size: int = 0.2) -> Dataset:
    df = pd.read_json(path) if path.endswith(".json") else pd.read_csv(path)
    dataset = Dataset.from_pandas(df)
    dataset.shuffle(seed=42)
    split_dataset = dataset.train_test_split(test_size=eval_size)
    return split_dataset

In [7]:
fix_seed(SEED)

# Dataloaders

In [8]:
law_dataset = load_dataset("dim/law_stackexchange_prompts")
finance_dataset = load_dataset("4DR1455/finance_questions")
healthcare_dataset = load_dataset("iecjsu/lavita-ChatDoctor-HealthCareMagic-100k")

keep = ["text", "domain", "label"]

# Filter and prepare law dataset
law_data = (
    law_dataset["train"]
    .filter(lambda x: x["prompt"] is not None and x["prompt"].strip() != "")
    .filter(lambda x: all(v is not None for v in x.values()))
    .select(range(min(DATASET_SIZE, len(law_dataset["train"]))))
    .map(
        lambda x: {"text": x["prompt"], "domain": "law", "label": 0},
        remove_columns=[c for c in law_dataset["train"].column_names if c not in keep],
    )
)

# Filter and prepare finance dataset
finance_data = (
    finance_dataset["train"]
    .filter(
        lambda x: x["instruction"] is not None
        and len(str(x["instruction"]).strip()) > 0
    )
    .filter(lambda x: all(v is not None for v in x.values()))
    .select(range(min(DATASET_SIZE, len(finance_dataset["train"]))))
    .map(
        lambda x: {"text": str(x["instruction"]), "domain": "finance", "label": 1},
        remove_columns=[
            c for c in finance_dataset["train"].column_names if c not in keep
        ],
    )
)

# Filter and prepare healthcare dataset
healthcare_data = (
    healthcare_dataset["train"]
    .filter(lambda x: x["input"] is not None and len(str(x["input"]).strip()) > 0)
    .filter(lambda x: all(v is not None for v in x.values()))
    .select(range(min(DATASET_SIZE, len(healthcare_dataset["train"]))))
    .map(
        lambda x: {"text": str(x["input"]), "domain": "healthcare", "label": 2},
        remove_columns=[
            c for c in healthcare_dataset["train"].column_names if c not in keep
        ],
    )
)


# Concatenate datasets
combined_dataset = concatenate_datasets([law_data, finance_data, healthcare_data])

# Split into train and test sets using dataset's train_test_split method
data = combined_dataset.train_test_split(test_size=0.2, seed=SEED)

train_loader = torch.utils.data.DataLoader(
    data["train"], batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)
valid_loader = torch.utils.data.DataLoader(
    data["test"], batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)

# Training

In [9]:
docs = [
    tokenizer.encode(raw_doc, padding=False, truncation=True, max_length=None)
    for raw_doc in data["train"]["text"]
]
idf = inverse_document_frequency(docs, len(tokenizer))
model = MLP(
    vocab_size=len(tokenizer),
    num_hidden_layers=3,
    num_classes=3,
    idf=idf,
    problem_type="multi_label_classification",
)
model.to(DEVICE)
model.idf = model.idf.to(DEVICE) if model.idf is not None else None
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0)

Computing IDF: 36000it [01:41, 353.04it/s]



In [None]:
def evaluate(
    model: torch.nn.Module,
    valid_loader: torch.utils.data.DataLoader,
    tokenizer: AutoTokenizer,
    threshold: float,
) -> dict:
    model.eval()
    all_predictions = []
    all_labels = []
    validation_losses = []
    max_probs = []
    with torch.no_grad():
        for batch in valid_loader:
            inputs = tokenizer(
                batch["text"],
                padding=True,
                truncation=True,
                max_length=False,
                return_tensors="pt",
            ).to(DEVICE)
            flat_inputs, offsets = prepare_inputs(inputs["input_ids"], device=DEVICE)
            labels = batch["label"].to(DEVICE, dtype=torch.long)
            one_hot_targets = torch.nn.functional.one_hot(
                labels, num_classes=NUM_CLASSES
            ).to(DEVICE, dtype=torch.float32)

            loss, logits = model(flat_inputs, offsets, one_hot_targets)
            validation_losses.append(loss.item())
            probabilities = torch.sigmoid(logits)  # -> 0-1
            batch_predictions = []
            for i in range(probabilities.size(0)):
                sample_probabilities = probabilities[i]
                max_probs.append(sample_probabilities.max().item())
                thresholded_labels_indices = torch.where(
                    sample_probabilities > threshold
                )[0]
                if len(thresholded_labels_indices) > 1:
                    best_label_index = thresholded_labels_indices[
                        torch.argmax(sample_probabilities[thresholded_labels_indices])
                    ]
                    prediction_vector = torch.zeros(NUM_CLASSES, dtype=torch.int)
                    prediction_vector[best_label_index] = 1
                    batch_predictions.append(prediction_vector.cpu().numpy())
                else:
                    predictions = (sample_probabilities > threshold).int()
                    batch_predictions.append(predictions.cpu().numpy())
            all_predictions.extend(batch_predictions)
            all_labels.extend(one_hot_targets.cpu().numpy())

    avg_validation_loss = np.mean(validation_losses)

    # Convert to numpy arrays for metric calculation
    all_predictions_np = np.array(all_predictions)
    all_labels_np = np.array(all_labels)

    # Calculate metrics
    macro_f1 = f1_score(
        all_labels_np, all_predictions_np, average="macro", zero_division=0
    )  # zero_division=0 to handle cases with no predicted labels
    micro_f1 = f1_score(
        all_labels_np, all_predictions_np, average="micro", zero_division=0
    )
    macro_precision = precision_score(
        all_labels_np, all_predictions_np, average="macro", zero_division=0
    )
    macro_recall = recall_score(
        all_labels_np, all_predictions_np, average="macro", zero_division=0
    )
    micro_precision = precision_score(
        all_labels_np, all_predictions_np, average="micro", zero_division=0
    )
    micro_recall = recall_score(
        all_labels_np, all_predictions_np, average="micro", zero_division=0
    )

    scores = {
        "validation_loss": avg_validation_loss,
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "macro_precision": macro_precision,
        "macro_recall": macro_recall,
        "micro_precision": micro_precision,
        "micro_recall": micro_recall,
        "max_probs": np.mean(max_probs),
    }
    return scores

In [None]:
optimizer.zero_grad()
for _epoch in range(EPOCHS):
    step = 0
    batch_train_loss = []
    for batch in tqdm(train_loader, desc="Training"):
        model.train()
        inputs = tokenizer(
            batch["text"],
            padding=True,
            truncation=True,
            max_length=False,
        )
        flat_inputs, offsets = prepare_inputs(inputs["input_ids"], device=DEVICE)
        labels = batch["label"].to(DEVICE, dtype=torch.long)
        one_hot_targets = torch.nn.functional.one_hot(
            labels, num_classes=NUM_CLASSES
        ).to(DEVICE, dtype=torch.float32)
        # inputs["label"] = one_hot_targets
        loss, logits = model(flat_inputs, offsets, one_hot_targets)
        loss.backward()
        step += 1  # noqa: SIM113
        batch_train_loss.append(loss.item())
        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()
        if step % TRAIN_LOGGING_STEPS == 0:
            accuracy = np.mean(batch_train_loss)
            if LOG_WANDB:
                wandb.log({"train_loss": float(accuracy)})
            else:
                print(f"Step {step}, Train Loss: {accuracy:.4f}")
        if step % EVAL_LOGGING_STEPS == 0:
            scores = evaluate(model, valid_loader, tokenizer, threshold=THRESHOLD)
            if LOG_WANDB:
                wandb.log(
                    {
                        "validation_loss": float(scores["validation_loss"]),
                        "test/macro_f1": float(scores["macro_f1"]),
                        "test/micro_f1": float(scores["micro_f1"]),
                        "test/macro_recall": float(scores["macro_recall"]),
                        "test/micro_recall": float(scores["micro_recall"]),
                        "test/macro_precision": float(scores["macro_precision"]),
                        "test/micro_precision": float(scores["micro_precision"]),
                        "test/max_probs": float(scores["max_probs"]),
                    }
                )
            else:
                print(
                    f"Step {step}, Validation Loss: {scores['validation_loss']:.4f}, Macro F1: {scores['macro_f1']:.4f}, Micro F1: {scores['micro_f1']:.4f}"
                )
if LOG_WANDB:
    wandb.finish()

In [None]:
# save model
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    "widemlp.pt",
)

# Inference

In [10]:
# load model
checkpoint = torch.load("widemlp.pt", weights_only=True)
wide_mlp = MLP(
    vocab_size=len(tokenizer),
    num_hidden_layers=3,
    num_classes=3,
    idf=idf,
    problem_type="multi_label_classification",
)
wide_mlp.to(DEVICE)
wide_mlp.idf = model.idf.to(DEVICE) if wide_mlp.idf is not None else None
wide_mlp.load_state_dict(checkpoint["model_state_dict"])
wide_mlp.eval()
print(f"Successfully loaded PyTorch model on {DEVICE}")

Successfully loaded PyTorch model on cuda:0


In [11]:
# Load Jigsaw dataset
jigsaw_splits = {
    "train": "train_dataset.csv",
    "validation": "val_dataset.csv",
    "test": "test_dataset.csv",
}
jigsaw_df = pd.read_csv(
    "hf://datasets/Arsive/toxicity_classification_jigsaw/" + jigsaw_splits["validation"]
)

jigsaw_df = jigsaw_df[
    (jigsaw_df["toxic"] == 1)
    | (jigsaw_df["severe_toxic"] == 1)
    | (jigsaw_df["obscene"] == 1)
    | (jigsaw_df["threat"] == 1)
    | (jigsaw_df["insult"] == 1)
    | (jigsaw_df["identity_hate"] == 1)
]

jigsaw_df = jigsaw_df.rename(columns={"comment_text": "prompt"})
jigsaw_df["label"] = 0
jigsaw_df = jigsaw_df[["prompt", "label"]]
jigsaw_df = jigsaw_df.dropna(subset=["prompt"])
jigsaw_df = jigsaw_df[jigsaw_df["prompt"].str.strip() != ""]

# Load OLID dataset
olid_splits = {"train": "train.csv", "test": "test.csv"}
olid_df = pd.read_csv("hf://datasets/christophsonntag/OLID/" + olid_splits["train"])
olid_df = olid_df.rename(columns={"cleaned_tweet": "prompt"})
olid_df["label"] = 0
olid_df = olid_df[["prompt", "label"]]
olid_df = olid_df.dropna(subset=["prompt"])
olid_df = olid_df[olid_df["prompt"].str.strip() != ""]

# Load hateXplain dataset
hateXplain = pd.read_parquet(
    "hf://datasets/nirmalendu01/hateXplain_filtered/data/train-00000-of-00001.parquet"
)
hateXplain = hateXplain.rename(columns={"test_case": "prompt"})
hateXplain = hateXplain[(hateXplain["gold_label"] == "hateful")]
hateXplain = hateXplain[["prompt", "label"]]
hateXplain["label"] = 0
hateXplain = hateXplain.dropna(subset=["prompt"])
hateXplain = hateXplain[hateXplain["prompt"].str.strip() != ""]

# Load TUKE Slovak dataset
tuke_sk_splits = {"train": "train.json", "test": "test.json"}
tuke_sk_df = pd.read_json(
    "hf://datasets/TUKE-KEMT/hate_speech_slovak/" + tuke_sk_splits["train"], lines=True
)
tuke_sk_df = tuke_sk_df.rename(columns={"text": "prompt"})
tuke_sk_df = tuke_sk_df[tuke_sk_df["label"] == 0]
tuke_sk_df = tuke_sk_df[["prompt", "label"]]
tuke_sk_df = tuke_sk_df.dropna(subset=["prompt"])
tuke_sk_df = tuke_sk_df[tuke_sk_df["prompt"].str.strip() != ""]

# Load DKK dataset
dkk = pd.read_parquet("data/test-00000-of-00001.parquet")
dkk = dkk.rename(columns={"text": "prompt"})
dkk = dkk[dkk["label"] == "OFF"].reset_index(drop=True)
dkk["label"] = 0
dkk = dkk.dropna(subset=["prompt"])
dkk = dkk[dkk["prompt"].str.strip() != ""]

dkk_all = pd.read_parquet("data/test-00000-of-00001.parquet")
dkk_all = dkk_all.rename(columns={"text": "prompt"})
dkk_all["label"] = 0
dkk_all = dkk_all.dropna(subset=["prompt"])
dkk_all = dkk_all[dkk_all["prompt"].str.strip() != ""]

datasets = {
    "jigsaw": jigsaw_df,
    "olid": olid_df,
    "hate_xplain": hateXplain,
    "tuke_sk": tuke_sk_df,
    "dkk": dkk,
    "dkk_all": dkk_all,
}

In [12]:
def run_inference(model, dataset_df, threshold=THRESHOLD):
    dataset = Dataset.from_pandas(dataset_df)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=BATCH_SIZE, shuffle=False
    )

    pred = []

    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Inference"):
            inputs = tokenizer(
                batch["prompt"],
                padding=True,
                truncation=True,
                max_length=False,
                return_tensors="pt",
            ).to(DEVICE)
            flat_inputs, offsets = prepare_inputs(inputs["input_ids"], device=DEVICE)

            # Forward pass without labels for inference
            logits = model(flat_inputs, offsets)
            probabilities = torch.sigmoid(logits) # -> 0-1 mostly one class is around 0.99+ in classification 

            # Classification of hate speech based on thresholds of class probabilities
            for probs in probabilities:
                thresholded_indices = torch.where(probs > threshold)[0]
                if len(thresholded_indices) == 0:
                    pred.append(0)
                else:
                    pred.append(1)
    return pred, dataset_df["label"].tolist()

# Run inference on each dataset
results = {}
for dataset_name, df in datasets.items():
    print(f"\nProcessing {dataset_name} dataset...")
    for threshold in [0.5, 0.75, 0.9, 0.99]:
        pred, true = run_inference(wide_mlp, df, threshold)
        
        # Calculate metrics (swawp with scikit-learn)
        true_positives = sum(1 for p, t in zip(pred, true) if p == 1 and t == 1)
        false_positives = sum(1 for p, t in zip(pred, true) if p == 1 and t == 0)
        false_negatives = sum(1 for p, t in zip(pred, true) if p == 0 and t == 1)
        true_negatives = sum(1 for p, t in zip(pred, true) if p == 0 and t == 0)
        
        accuracy = (true_positives + true_negatives) / len(pred)
        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        results[f"{dataset_name}_{threshold}"] = {
            "dataset_name": dataset_name,
            "threshold": threshold,
            "accuracy": accuracy,
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "true_positives": true_positives,
            "false_positives": false_positives,
            "false_negatives": false_negatives,
            "true_negatives": true_negatives
        }

# Convert results to DataFrame and save to CSV
results_df = pd.DataFrame.from_dict(results, orient='index')
results_df.to_csv('inference_results.csv', index=False)


Processing jigsaw dataset...


Inference: 100%|██████████| 51/51 [01:41<00:00,  1.99s/it]
Inference: 100%|██████████| 51/51 [01:39<00:00,  1.96s/it]
Inference: 100%|██████████| 51/51 [01:42<00:00,  2.02s/it]
Inference: 100%|██████████| 51/51 [01:43<00:00,  2.03s/it]



Processing olid dataset...


Inference: 100%|██████████| 206/206 [00:18<00:00, 11.00it/s]
Inference: 100%|██████████| 206/206 [00:18<00:00, 11.06it/s]
Inference: 100%|██████████| 206/206 [00:19<00:00, 10.39it/s]
Inference: 100%|██████████| 206/206 [00:18<00:00, 11.24it/s]



Processing hate_xplain dataset...


Inference: 100%|██████████| 93/93 [00:14<00:00,  6.56it/s]
Inference: 100%|██████████| 93/93 [00:13<00:00,  6.66it/s]
Inference: 100%|██████████| 93/93 [00:14<00:00,  6.63it/s]
Inference: 100%|██████████| 93/93 [00:14<00:00,  6.64it/s]



Processing tuke_sk dataset...


Inference: 100%|██████████| 135/135 [01:03<00:00,  2.11it/s]
Inference: 100%|██████████| 135/135 [01:03<00:00,  2.12it/s]
Inference: 100%|██████████| 135/135 [01:03<00:00,  2.12it/s]
Inference: 100%|██████████| 135/135 [01:04<00:00,  2.11it/s]



Processing dkk dataset...


Inference: 100%|██████████| 1/1 [00:00<00:00,  1.09it/s]
Inference: 100%|██████████| 1/1 [00:00<00:00,  1.14it/s]
Inference: 100%|██████████| 1/1 [00:01<00:00,  1.04s/it]
Inference: 100%|██████████| 1/1 [00:00<00:00,  1.15it/s]



Processing dkk_all dataset...


Inference: 100%|██████████| 6/6 [00:02<00:00,  2.10it/s]
Inference: 100%|██████████| 6/6 [00:03<00:00,  1.98it/s]
Inference: 100%|██████████| 6/6 [00:02<00:00,  2.06it/s]
Inference: 100%|██████████| 6/6 [00:02<00:00,  2.07it/s]
