In [None]:
import os
import random
import time

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
)
from tqdm import tqdm
from transformers import AutoTokenizer

try:
    import wandb
    USE_WANDB = True
except ImportError:
    USE_WANDB = False
    print("wandb not installed, skipping wandb logging")

import dataloader
from widemlp import MLP, inverse_document_frequency, prepare_inputs_optimized

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

DEVICE = (
    torch.device("cuda:0")
    if torch.cuda.is_available()
    else torch.device("cpu")
)

BATCH_SIZE = 128
SEED = 42
MODEL_NAME = ""
EPOCHS = 128
NUM_CLASSES = 3
GRADIENT_ACCUMULATION_STEPS = 1
TRAIN_LOGGING_STEPS = 1
EVAL_LOGGING_STEPS = 100
THRESHOLD = 0.5
DATASET_SIZE = 15_000
LOG_WANDB = True
PATH = "widemlp-3cls-v3-1l.pt"
TEST_SPLIT = 0.2
NUM_HIDDEN_LAYERS = 1
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

In [None]:
# find cuda devices
if torch.cuda.is_available():
    print("CUDA devices:")
    for i in range(torch.cuda.device_count()):
        print(f"  {i}: {torch.cuda.get_device_name(i)}")

In [None]:
if LOG_WANDB:
    wandb.init(project="ood-widemlp")
    wandb.config.update(
        {
            "seed": SEED,
            "batch_size": BATCH_SIZE,
            "epochs": EPOCHS,
            "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
            "train_logging_steps": TRAIN_LOGGING_STEPS,
            "eval_logging_steps": EVAL_LOGGING_STEPS,
            "threshold": THRESHOLD,
            "test_split": TEST_SPLIT,
            "num_hidden_layers": NUM_HIDDEN_LAYERS,
            "num_classes": NUM_CLASSES,
        }
    )

In [None]:
def fix_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def load_data(path: str, eval_size: int = 0.2) -> Dataset:
    df = pd.read_json(path) if path.endswith(".json") else pd.read_csv(path)
    dataset = Dataset.from_pandas(df)
    dataset.shuffle(seed=42)
    split_dataset = dataset.train_test_split(test_size=eval_size)
    return split_dataset

In [None]:
fix_seed(SEED)

# Dataloaders

In [None]:
data = dataloader.get_train_datasets(
    dataset_size=DATASET_SIZE, split=TEST_SPLIT
)
train_loader = torch.utils.data.DataLoader(
    data["train"], batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)
valid_loader = torch.utils.data.DataLoader(
    data["test"], batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)

# Training

In [None]:
docs = [
    tokenizer.encode(raw_doc, padding=False, truncation=True, max_length=None)
    for raw_doc in data["train"]["text"]
]
idf = inverse_document_frequency(docs, len(tokenizer))
model = MLP(
    vocab_size=len(tokenizer),
    num_hidden_layers=NUM_HIDDEN_LAYERS,
    num_classes=NUM_CLASSES,
    idf=idf,
    problem_type="multi_label_classification",
)
model.to(DEVICE)
model.idf = model.idf.to(DEVICE) if model.idf is not None else None
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0)

In [None]:
def calculate_accuracy_one_hot(arr1, arr2):
    arr1 = np.array(arr1)
    arr2 = np.array(arr2)

    if arr1.shape != arr2.shape:
        raise ValueError("Input arrays must have the same shape.")

    num_samples = arr1.shape[0]
    if num_samples == 0:
        return 0.0  # Return 0 if no samples

    correct_predictions = 0
    for i in range(num_samples):
        if np.array_equal(arr1[i], arr2[i]):  # Compare row by row for equality
            correct_predictions += 1

    accuracy = correct_predictions / num_samples
    return accuracy


In [None]:
def calculate_accuracy_one_hot(arr1, arr2):
    arr1 = np.array(arr1)
    arr2 = np.array(arr2)

    if arr1.shape != arr2.shape:
        raise ValueError("Input arrays must have the same shape.")

    num_samples = arr1.shape[0]
    if num_samples == 0:
        return 0.0  # Return 0 if no samples

    correct_predictions = 0
    for i in range(num_samples):
        if np.array_equal(arr1[i], arr2[i]):  # Compare row by row for equality
            correct_predictions += 1

    accuracy = correct_predictions / num_samples
    return accuracy


def evaluate(
    model: torch.nn.Module,
    valid_loader: torch.utils.data.DataLoader,
    tokenizer: AutoTokenizer,
    threshold: float,
) -> dict:
    model.eval()
    all_predictions = []
    all_labels = []
    validation_losses = []
    max_probs = []
    with torch.no_grad():
        for batch in valid_loader:
            inputs = tokenizer(
                batch["text"],
                padding=True,
                truncation=True,
                max_length=False,
                return_tensors="pt",
            ).to(DEVICE)
            flat_inputs, offsets = prepare_inputs_optimized(
                inputs["input_ids"], device=DEVICE
            )
            labels = batch["label"].to(DEVICE, dtype=torch.long)
            one_hot_targets = torch.nn.functional.one_hot(
                labels, num_classes=NUM_CLASSES
            ).to(DEVICE, dtype=torch.float32)

            loss, logits = model(flat_inputs, offsets, one_hot_targets)
            validation_losses.append(loss.item())
            probabilities = torch.sigmoid(logits)  # -> 0-1
            batch_predictions = []
            for i in range(probabilities.size(0)):
                sample_probabilities = probabilities[i]
                max_probs.append(sample_probabilities.max().item())
                thresholded_labels_indices = torch.where(
                    sample_probabilities > threshold
                )[0]
                if len(thresholded_labels_indices) > 1:
                    best_label_index = thresholded_labels_indices[
                        torch.argmax(sample_probabilities[thresholded_labels_indices])
                    ]
                    prediction_vector = torch.zeros(NUM_CLASSES, dtype=torch.int)
                    prediction_vector[best_label_index] = 1
                    batch_predictions.append(prediction_vector.cpu().numpy())
                else:
                    predictions = (sample_probabilities > threshold).int()
                    batch_predictions.append(predictions.cpu().numpy())
            all_predictions.extend(batch_predictions)
            all_labels.extend(one_hot_targets.cpu().numpy())

    avg_validation_loss = np.mean(validation_losses)

    all_predictions_np = np.array(all_predictions)
    all_labels_np = np.array(all_labels)
    scores = {
        "validation_loss": avg_validation_loss,
        "accuracy": calculate_accuracy_one_hot(all_predictions_np, all_labels_np),
        "max_probs": np.mean(max_probs),
    }
    return scores

In [None]:
optimizer.zero_grad()
for _epoch in range(EPOCHS):
    step = 0
    batch_train_loss = []
    for batch in tqdm(train_loader, desc="Training"):
        model.train()
        inputs = tokenizer(
            batch["text"],
            padding=True,
            truncation=True,
            max_length=False,
        )
        flat_inputs, offsets = prepare_inputs_optimized(
            inputs["input_ids"], device=DEVICE
        )
        labels = batch["label"].to(DEVICE, dtype=torch.long)
        one_hot_targets = torch.nn.functional.one_hot(
            labels, num_classes=NUM_CLASSES
        ).to(DEVICE, dtype=torch.float32)
        # inputs["label"] = one_hot_targets
        loss, logits = model(flat_inputs, offsets, one_hot_targets)
        loss.backward()
        step += 1  # noqa: SIM113
        batch_train_loss.append(loss.item())
        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()
        if step % TRAIN_LOGGING_STEPS == 0:
            accuracy = np.mean(batch_train_loss)
            if LOG_WANDB:
                wandb.log({"train_loss": float(accuracy), "epoch": _epoch})
            else:
                print(f"Step {step}, Train Loss: {accuracy:.4f}")
        if step % EVAL_LOGGING_STEPS == 0:
            scores = evaluate(model, valid_loader, tokenizer, threshold=0.1)
            if LOG_WANDB:
                wandb.log(
                    {
                        "validation_loss": float(scores["validation_loss"]),
                        # "test/macro_f1": float(scores["macro_f1"]),
                        # "test/micro_f1": float(scores["micro_f1"]),
                        # "test/macro_recall": float(scores["macro_recall"]),
                        # "test/micro_recall": float(scores["micro_recall"]),
                        # "test/macro_precision": float(scores["macro_precision"]),
                        # "test/micro_precision": float(scores["micro_precision"]),
                        "test/max_probs": float(scores["max_probs"]),
                        "test/accuracy": float(scores["accuracy"]),
                    }
                )
            else:
                print(
                    f"Step {step}, Validation Loss: {scores['validation_loss']:.4f}, Macro F1: {scores['macro_f1']:.4f}, Micro F1: {scores['micro_f1']:.4f}"
                )
if LOG_WANDB:
    wandb.finish()

In [None]:
# save model
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    PATH,
)
torch.save(idf, f"{PATH.replace(".pt", "_idf.pt")}")

# Evaluation

In [None]:
datasets = dataloader.get_eval_datasets()

In [None]:
train = dataloader.get_train_datasets(dataset_size=DATASET_SIZE, split=TEST_SPLIT)

In [None]:
train["train"].to_pandas().to_csv("train_domain.csv", index=False)

# MLP 3cls

In [None]:
def load_model(
    model_path: str, idf_path: str, num_classes: int, num_hidden_layers: int
) -> MLP:
    idf = torch.load(idf_path).to(DEVICE)
    checkpoint = torch.load(
        model_path, weights_only=True, map_location=torch.device(DEVICE)
    )
    # print checkpoint keys
    wide_mlp = MLP(
        vocab_size=len(tokenizer),
        num_hidden_layers=num_hidden_layers,
        num_classes=num_classes,
        idf=idf,
        problem_type="multi_label_classification",
    )
    wide_mlp.to(DEVICE)
    wide_mlp.idf = idf if wide_mlp.idf is not None else None
    wide_mlp.load_state_dict(checkpoint["model_state_dict"])
    wide_mlp.eval()
    print(f"Successfully loaded PyTorch model on {DEVICE}")
    return wide_mlp


def inference(
    model: torch.nn.Module,
    text: str,
    tokenizer: AutoTokenizer,
    threshold: float,
) -> dict:
    model.eval()
    all_predictions = []

    inputs = tokenizer(
        text,
        padding=True,
        truncation=True,
        max_length=False,
        return_tensors="pt",
    ).to(DEVICE)
    flat_inputs, offsets = prepare_inputs_optimized(inputs["input_ids"], device=DEVICE)

    logits = model(flat_inputs, offsets)
    probabilities = torch.sigmoid(logits)  # -> 0-1
    batch_predictions = []
    for i in range(probabilities.size(0)):
        sample_probabilities = probabilities[i]
        thresholded_labels_indices = torch.where(sample_probabilities > threshold)[0]
        if len(thresholded_labels_indices) >= 1:
            best_label_index = thresholded_labels_indices[
                torch.argmax(sample_probabilities[thresholded_labels_indices])
            ]
            prediction_vector = torch.zeros(NUM_CLASSES, dtype=torch.int)
            prediction_vector[best_label_index] = 1
            batch_predictions.append(prediction_vector.cpu().numpy())
        else:
            # predictions = (sample_probabilities > threshold).int()
            batch_predictions.append(np.array([0, 0, 0]))
    all_predictions.extend(batch_predictions)

    all_predictions_np = np.array(all_predictions)
    return all_predictions_np


models = [
    ("widemlp-23-30.pt", 3, "widemlp-3cls-v3_idf.pt"),
    ("widemlp-3cls-v3.pt", 10, "widemlp-3cls-v3_idf.pt"),
    ("widemlp-3cls-v3-1l.pt", 1, "widemlp-3cls-v3-1l_idf.pt"),
    # ("widemlp-3cls-v2.pt",3, "widemlp-3cls-v2_idf.pt"),
    # ("widemlp-3cls-v3-3l.pt",3, "widemlp-3cls-v3-3l_idf.pt")
    # ("widemlp-3cls-v3-64l.pt",64, "widemlp-3cls-v3-64l_idf.pt"),
    # ("widemlp-3cls-v3-128l.pt",128, "widemlp-3cls-v3-128l_idf.pt")
]

# Model: widemlp-23-30.pt - 3, Mean : 92.26000000000002, Scores: [92.30000000000001, 92.30000000000001, 92.25555555555556, 92.23333333333333, 92.21111111111111], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
# Model: widemlp-3cls-v3.pt - 10, Mean : 85.6, Scores: [85.6, 85.6, 85.6, 85.6, 85.6], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
# Model: widemlp-3cls-v2.pt - 3, Mean : 34.63777777777778, Scores: [36.72222222222222, 35.82222222222222, 34.044444444444444, 33.41111111111111, 33.18888888888889], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
# Model: widemlp-3cls-v3-3l.pt - 3, Mean : 71.39333333333335, Scores: [71.54444444444444, 71.52222222222223, 71.43333333333334, 71.33333333333334, 71.13333333333334], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
# Model: widemlp-3cls-v3-64l.pt - 64, Mean : 33.18888888888889, Scores: [33.18888888888889, 33.18888888888889, 33.18888888888889, 33.18888888888889, 33.18888888888889], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
# Model: widemlp-3cls-v3-128l.pt - 128, Mean : 33.18888888888889, Scores: [33.18888888888889, 33.18888888888889, 33.18888888888889, 33.18888888888889, 33.18888888888889], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
for model_name, num_hidden_layers, idf_path in models:
    model_scores = []
    wide_mlp = load_model(
        model_path=model_name,
        idf_path=idf_path,
        num_classes=3,
        num_hidden_layers=num_hidden_layers,
    )
    for threshold in tqdm([0.5, 0.75, 0.9, 0.99]):
        data = []
        ood = pd.read_csv("data/ood_eval.csv")
        for i in ood["prompt"].values:
            results = inference(wide_mlp, i, tokenizer, threshold)
            data.append(results[0])
        ood["pred"] = data
        ood["pred"] = ood["pred"].apply(lambda x: np.argmax(x))
        ood.to_csv(
            f"data/mlp/ood_eval_{model_name}_threshold_{threshold}.csv", index=False
        )
        data = []
        domain = pd.read_csv("data/domain_eval.csv")
        for i in domain["text"].values:
            results = inference(wide_mlp, i, tokenizer, threshold)
            data.append(results[0])
        domain["pred"] = data
        domain["pred"] = domain["pred"].apply(lambda x: np.argmax(x))
        domain.to_csv(
            f"data/mlp/domain_eval_{model_name}_threshold_{threshold}.csv", index=False
        )

In [None]:
def inference_batch(
    model: torch.nn.Module,
    batch_texts: list[str],
    tokenizer: AutoTokenizer,
    threshold: float,
    num_classes: int = NUM_CLASSES,
    device: torch.device = DEVICE,
) -> np.ndarray:
    model.eval()
    all_predictions_list = []

    with torch.no_grad():
        inputs = tokenizer(
            batch_texts,
            padding="max_length",
            truncation=True,
            max_length=tokenizer.model_max_length,
            return_tensors="pt",
        ).to(device)

        input_ids_batch = inputs["input_ids"]
        flat_inputs, offsets = prepare_inputs_optimized(input_ids_batch, device=device)
        logits = model(flat_inputs, offsets=offsets)
        probabilities = torch.sigmoid(logits)
        above_threshold_mask = probabilities > threshold
        has_prediction = torch.any(above_threshold_mask, dim=1)
        probs_for_argmax = torch.where(above_threshold_mask, probabilities, -1.0)
        best_indices = torch.argmax(probs_for_argmax, dim=1)
        batch_predictions = torch.zeros(
            probabilities.size(0), num_classes, dtype=torch.int, device=device
        )
        batch_predictions.scatter_(1, best_indices.unsqueeze(1), 1)
        final_batch_predictions = batch_predictions * has_prediction.unsqueeze(1).int()
        all_predictions_list.append(final_batch_predictions.cpu())
    if not all_predictions_list:
        return np.empty((0, num_classes), dtype=np.int_)
    all_predictions_np = torch.cat(all_predictions_list, dim=0).numpy()

    return all_predictions_np


models = [
    ("widemlp-23-30.pt", 3, "widemlp-3cls-v3_idf.pt"),
    ("widemlp-3cls-v3.pt", 10, "widemlp-3cls-v3_idf.pt"),
    ("widemlp-3cls-v3-1l.pt", 1, "widemlp-3cls-v3-1l_idf.pt"),
]

for model_name, num_hidden_layers, idf_path in models:
    model_scores = []
    wide_mlp = load_model(
        model_path=model_name,
        idf_path=idf_path,
        num_classes=3,
        num_hidden_layers=num_hidden_layers,
    )
    batch_results = []
    for batch_size in tqdm([1, 32, 64, 128, 256]):
        data = pd.read_csv("data/batch_data.csv")

        batches = [
            data["prompt"].values.tolist()[i : i + batch_size]
            for i in range(0, len(data["prompt"].values.tolist()), batch_size)
        ]
        for batch in batches:
            start_time = time.perf_counter()
            results = inference_batch(wide_mlp, batch, tokenizer, 0.85)
            end_time = time.perf_counter()
            elapsed_time = end_time - start_time
            batch_results.append(
                {
                    "batch_size": batch_size,
                    "time_taken": elapsed_time,
                    "results": results,
                    "model_name": model_name,
                }
            )
    pd.DataFrame(batch_results).to_csv(
        f"data/batch/{model_name}-batch.csv", index=False
    )