In [None]:
%load_ext autoreload

%autoreload 2

In [None]:
import os
import random

import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from sklearn.metrics import f1_score, precision_score, recall_score
from tqdm import tqdm
from transformers import AutoTokenizer

try:
    import wandb

    USE_WANDB = True
except ImportError:
    USE_WANDB = False
    print("wandb not installed, skipping wandb logging")
from dataloader import get_train_datasets
from widemlp import MLP, inverse_document_frequency, prepare_inputs_optimized

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

DEVICE = (
    torch.device("cuda:0")
    if torch.cuda.is_available()
    else torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
)
BATCH_SIZE = 64
SEED = 42
MODEL_NAME = ""
EPOCHS = 100
NUM_CLASSES = 4
GRADIENT_ACCUMULATION_STEPS = 1
TRAIN_LOGGING_STEPS = 1
EVAL_LOGGING_STEPS = 100
THRESHOLD = 0.5
DATASET_SIZE = 15_000
LOG_WANDB = True
PATH = "widemlp-4cls.pt"
TEST_SPLIT = 0.2
NUM_HIDDEN_LAYERS = 3
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

In [None]:
# find cuda devices
if torch.cuda.is_available():
    print("CUDA devices:")
    for i in range(torch.cuda.device_count()):
        print(f"  {i}: {torch.cuda.get_device_name(i)}")

In [None]:
if LOG_WANDB:
    wandb.init(project="ood-widemlp")
    wandb.config.update(
        {
            "seed": SEED,
            "batch_size": BATCH_SIZE,
            "epochs": EPOCHS,
            "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
            "train_logging_steps": TRAIN_LOGGING_STEPS,
            "eval_logging_steps": EVAL_LOGGING_STEPS,
            "threshold": THRESHOLD,
            "test_split": TEST_SPLIT,
            "num_hidden_layers": NUM_HIDDEN_LAYERS,
            "num_classes": NUM_CLASSES,
        }
    )

In [None]:
def fix_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def load_data(path: str, eval_size: int = 0.2) -> Dataset:
    df = pd.read_json(path) if path.endswith(".json") else pd.read_csv(path)
    dataset = Dataset.from_pandas(df)
    dataset.shuffle(seed=42)
    split_dataset = dataset.train_test_split(test_size=eval_size)
    return split_dataset

In [None]:
fix_seed(SEED)

# Dataloaders

In [None]:
data = get_train_datasets(dataset_size=DATASET_SIZE, seed=SEED, split=TEST_SPLIT)
train_loader = torch.utils.data.DataLoader(
    data["train"], batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)
valid_loader = torch.utils.data.DataLoader(
    data["test"], batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)

# Training

In [None]:
docs = [
    tokenizer.encode(raw_doc, padding=False, truncation=True, max_length=None)
    for raw_doc in data["train"]["text"]
]
idf = inverse_document_frequency(docs, len(tokenizer))
model = MLP(
    vocab_size=len(tokenizer),
    num_hidden_layers=NUM_HIDDEN_LAYERS,
    num_classes=NUM_CLASSES,
    idf=idf,
    problem_type="multi_label_classification",
)
model.to(DEVICE)
model.idf = model.idf.to(DEVICE) if model.idf is not None else None
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0)

In [None]:
def evaluate(
    model: torch.nn.Module,
    valid_loader: torch.utils.data.DataLoader,
    tokenizer: AutoTokenizer,
    threshold: float,
) -> dict:
    model.eval()
    all_predictions = []
    all_labels = []
    validation_losses = []
    max_probs = []
    with torch.no_grad():
        for batch in valid_loader:
            inputs = tokenizer(
                batch["text"],
                padding=True,
                truncation=True,
                max_length=False,
                return_tensors="pt",
            ).to(DEVICE)
            flat_inputs, offsets = prepare_inputs_optimized(inputs["input_ids"], device=DEVICE)
            labels = batch["label"].to(DEVICE, dtype=torch.long)
            one_hot_targets = torch.nn.functional.one_hot(
                labels, num_classes=NUM_CLASSES
            ).to(DEVICE, dtype=torch.float32)

            loss, logits = model(flat_inputs, offsets, one_hot_targets)
            validation_losses.append(loss.item())
            probabilities = torch.sigmoid(logits)  # -> 0-1
            batch_predictions = []
            for i in range(probabilities.size(0)):
                sample_probabilities = probabilities[i]
                max_probs.append(sample_probabilities.max().item())
                thresholded_labels_indices = torch.where(
                    sample_probabilities > threshold
                )[0]
                if len(thresholded_labels_indices) > 1:
                    best_label_index = thresholded_labels_indices[
                        torch.argmax(sample_probabilities[thresholded_labels_indices])
                    ]
                    prediction_vector = torch.zeros(NUM_CLASSES, dtype=torch.int)
                    prediction_vector[best_label_index] = 1
                    batch_predictions.append(prediction_vector.cpu().numpy())
                else:
                    predictions = (sample_probabilities > threshold).int()
                    batch_predictions.append(predictions.cpu().numpy())
            all_predictions.extend(batch_predictions)
            all_labels.extend(one_hot_targets.cpu().numpy())

    avg_validation_loss = np.mean(validation_losses)

    # Convert to numpy arrays for metric calculation
    all_predictions_np = np.array(all_predictions)
    all_labels_np = np.array(all_labels)

    # Calculate metrics
    macro_f1 = f1_score(
        all_labels_np, all_predictions_np, average="macro", zero_division=0
    )  # zero_division=0 to handle cases with no predicted labels
    micro_f1 = f1_score(
        all_labels_np, all_predictions_np, average="micro", zero_division=0
    )
    macro_precision = precision_score(
        all_labels_np, all_predictions_np, average="macro", zero_division=0
    )
    macro_recall = recall_score(
        all_labels_np, all_predictions_np, average="macro", zero_division=0
    )
    micro_precision = precision_score(
        all_labels_np, all_predictions_np, average="micro", zero_division=0
    )
    micro_recall = recall_score(
        all_labels_np, all_predictions_np, average="micro", zero_division=0
    )

    scores = {
        "validation_loss": avg_validation_loss,
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "macro_precision": macro_precision,
        "macro_recall": macro_recall,
        "micro_precision": micro_precision,
        "micro_recall": micro_recall,
        "max_probs": np.mean(max_probs),
    }
    return scores

In [None]:
optimizer.zero_grad()
for _epoch in range(EPOCHS):
    step = 0
    batch_train_loss = []
    for batch in tqdm(train_loader, desc="Training"):
        model.train()
        inputs = tokenizer(
            batch["text"],
            padding=True,
            truncation=True,
            max_length=False,
        )
        flat_inputs, offsets = prepare_inputs_optimized(inputs["input_ids"], device=DEVICE)
        labels = batch["label"].to(DEVICE, dtype=torch.long)
        one_hot_targets = torch.nn.functional.one_hot(
            labels, num_classes=NUM_CLASSES
        ).to(DEVICE, dtype=torch.float32)
        # inputs["label"] = one_hot_targets
        loss, logits = model(flat_inputs, offsets, one_hot_targets)
        loss.backward()
        step += 1  # noqa: SIM113
        batch_train_loss.append(loss.item())
        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad()
        if step % TRAIN_LOGGING_STEPS == 0:
            accuracy = np.mean(batch_train_loss)
            if LOG_WANDB:
                wandb.log({"train_loss": float(accuracy), "epoch" : _epoch})
            else:
                print(f"Step {step}, Train Loss: {accuracy:.4f}")
        if step % EVAL_LOGGING_STEPS == 0:
            scores = evaluate(model, valid_loader, tokenizer, threshold=THRESHOLD)
            if LOG_WANDB:
                wandb.log(
                    {
                        "validation_loss": float(scores["validation_loss"]),
                        "test/macro_f1": float(scores["macro_f1"]),
                        "test/micro_f1": float(scores["micro_f1"]),
                        "test/macro_recall": float(scores["macro_recall"]),
                        "test/micro_recall": float(scores["micro_recall"]),
                        "test/macro_precision": float(scores["macro_precision"]),
                        "test/micro_precision": float(scores["micro_precision"]),
                        "test/max_probs": float(scores["max_probs"]),
                    }
                )
            else:
                print(
                    f"Step {step}, Validation Loss: {scores['validation_loss']:.4f}, Macro F1: {scores['macro_f1']:.4f}, Micro F1: {scores['micro_f1']:.4f}"
                )
if LOG_WANDB:
    wandb.finish()

In [None]:
# save model
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    PATH
)
torch.save(idf, f"{PATH.replace(".pt", "_idf.pt")}")