In [None]:
import os
import random

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    precision_score,
    recall_score,
)
from tqdm import tqdm
from transformers import AutoTokenizer

from dataloader import get_eval_datasets
from widemlp import MLP, prepare_inputs_optimized

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

DEVICE = (
    torch.device("cuda:0")
    if torch.cuda.is_available()
    else torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
BATCH_SIZE = 64
SEED = 42
MODEL_NAME = ""
EPOCHS = 1
NUM_CLASSES = 3
THRESHOLD = 0.5
DATASET_SIZE = 15_000
TEST_SPLIT = 0.2
NUM_HIDDEN_LAYERS = 3
NUM_CLASSES = 3
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

In [None]:
def fix_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


fix_seed(SEED)

In [None]:
datasets = get_eval_datasets()

# MLP 3cls

In [None]:
def load_model(model_path :str, idf_path : str, num_classes : int, num_hidden_layers : int) -> MLP:
    idf = torch.load(idf_path).to(DEVICE)
    checkpoint = torch.load(
        model_path, weights_only=True, map_location=torch.device(DEVICE)
    )
    # print checkpoint keys
    wide_mlp = MLP(
        vocab_size=len(tokenizer),
        num_hidden_layers=num_hidden_layers,
        num_classes=num_classes,
        idf=idf,
        problem_type="multi_label_classification",
    )
    wide_mlp.to(DEVICE)
    wide_mlp.idf = idf if wide_mlp.idf is not None else None
    wide_mlp.load_state_dict(checkpoint["model_state_dict"])
    wide_mlp.eval()
    print(f"Successfully loaded PyTorch model on {DEVICE}")
    return wide_mlp

def inference(
    model: torch.nn.Module,
    text: str,
    tokenizer: AutoTokenizer,
    threshold: float,
) -> dict:
    model.eval()
    all_predictions = []

        
    inputs = tokenizer(
        text,
        padding=True,
        truncation=True,
        max_length=False,
        return_tensors="pt",
    ).to(DEVICE)
    flat_inputs, offsets = prepare_inputs_optimized(inputs["input_ids"], device=DEVICE)

    logits = model(flat_inputs, offsets)
    probabilities = torch.sigmoid(logits)  # -> 0-1
    batch_predictions = []
    for i in range(probabilities.size(0)):
        sample_probabilities = probabilities[i]
        thresholded_labels_indices = torch.where(
            sample_probabilities > threshold
        )[0]
        if len(thresholded_labels_indices) >= 1:
            best_label_index = thresholded_labels_indices[
                torch.argmax(sample_probabilities[thresholded_labels_indices])
            ]
            prediction_vector = torch.zeros(NUM_CLASSES, dtype=torch.int)
            prediction_vector[best_label_index] = 1
            batch_predictions.append(prediction_vector.cpu().numpy())
        else:
            # predictions = (sample_probabilities > threshold).int()
            batch_predictions.append(np.array([0,0,0]))
    all_predictions.extend(batch_predictions)

    all_predictions_np = np.array(all_predictions)
    return all_predictions_np

models = [
    ("widemlp-23-30.pt",3, "widemlp-3cls-v3_idf.pt"),
    ("widemlp-3cls-v3.pt", 10, "widemlp-3cls-v3_idf.pt"),
# ("widemlp-3cls-v2.pt",3, "widemlp-3cls-v2_idf.pt"),
# ("widemlp-3cls-v3-3l.pt",3, "widemlp-3cls-v3-3l_idf.pt"),
# ("widemlp-3cls-v3-64l.pt",64, "widemlp-3cls-v3-64l_idf.pt"),
# ("widemlp-3cls-v3-128l.pt",128, "widemlp-3cls-v3-128l_idf.pt")
]

# Model: widemlp-23-30.pt - 3, Mean : 92.26000000000002, Scores: [92.30000000000001, 92.30000000000001, 92.25555555555556, 92.23333333333333, 92.21111111111111], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
# Model: widemlp-3cls-v3.pt - 10, Mean : 85.6, Scores: [85.6, 85.6, 85.6, 85.6, 85.6], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
# Model: widemlp-3cls-v2.pt - 3, Mean : 34.63777777777778, Scores: [36.72222222222222, 35.82222222222222, 34.044444444444444, 33.41111111111111, 33.18888888888889], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
# Model: widemlp-3cls-v3-3l.pt - 3, Mean : 71.39333333333335, Scores: [71.54444444444444, 71.52222222222223, 71.43333333333334, 71.33333333333334, 71.13333333333334], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
# Model: widemlp-3cls-v3-64l.pt - 64, Mean : 33.18888888888889, Scores: [33.18888888888889, 33.18888888888889, 33.18888888888889, 33.18888888888889, 33.18888888888889], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
# Model: widemlp-3cls-v3-128l.pt - 128, Mean : 33.18888888888889, Scores: [33.18888888888889, 33.18888888888889, 33.18888888888889, 33.18888888888889, 33.18888888888889], Thresholds: [0.4, 0.5, 0.75, 0.9, 0.99]
for model_name, num_hidden_layers, idf_path in models:
    model_scores = []
    wide_mlp = load_model(
        model_path=model_name,
        idf_path=idf_path,
        num_classes=3,
        num_hidden_layers=num_hidden_layers,
    )
    for threshold in tqdm([0.5, 0.75, 0.9, 0.99]):
        data = []
        ood = pd.read_csv("data/ood_eval.csv")
        for i in ood['prompt'].values:
            results = inference(wide_mlp,i, tokenizer, threshold)
            data.append(results[0])
        ood['pred'] = data
        ood['pred'] = ood['pred'].apply(lambda x: np.argmax(x))
        ood.to_csv(f"data/mlp/ood_eval_{model_name}_threshold_{threshold}.csv", index=False)
        data = []
        domain = pd.read_csv("data/domain_eval.csv")
        for i in domain['text'].values:
            results = inference(wide_mlp,i, tokenizer, threshold)
            data.append(results[0])
        domain['pred'] = data
        domain['pred'] = domain['pred'].apply(lambda x: np.argmax(x))
        domain.to_csv(f"data/mlp/domain_eval_{model_name}_threshold_{threshold}.csv", index=False)