In [None]:
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
from accelerate.data_loader import DataLoader
from datasets import ClassLabel, Dataset, load_dataset
from sklearn.metrics import confusion_matrix
from torch import nn as nn
from torch.optim import AdamW
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

In [None]:
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(
    model_output: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
    token_embeddings = model_output[
        0
    ]  # First element of model_output contains all token embeddings
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )

In [None]:
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

In [None]:
model.train()
model = model.cuda() if torch.cuda.is_available() else model

In [None]:
finance = load_dataset("Marina-C/question-answer-Subject-Finance-Instruct")
health = load_dataset("iecjsu/lavita-ChatDoctor-HealthCareMagic-100k")
law = load_dataset("dim/law_stackexchange_prompts")

In [None]:
finance_df = pd.DataFrame(
    {"messages": [row[1]["content"] for row in finance["train"]["messages"]]}
)

health_df = pd.DataFrame({"messages": health["train"]["input"]})

law_df = pd.DataFrame({"messages": law["train"]["prompt"]})

In [None]:
finance_df.head()

In [None]:
health_df.head()

In [None]:
law_df.head()

In [None]:
# Add labels
finance_df["label"] = 0
health_df["label"] = 1
law_df["label"] = 2

In [None]:
combined_df = pd.concat([finance_df, health_df, law_df], ignore_index=True)

In [None]:
combined_df = combined_df.sample(frac=1, random_state=42).reset_index(drop=True)

In [None]:
combined_df.head()

In [None]:
combined_dataset = Dataset.from_pandas(combined_df)

In [None]:
combined_dataset = combined_dataset.cast_column(
    "label", ClassLabel(num_classes=3, names=["finance", "health", "law"])
)

In [None]:
BATCH_SIZE = 128
EARLY_STOP_PATIENCE = 5
FREEZE_TRAIN_LR = 1e-4
UNFROZEN_TRAIN_LR_WARMUP = 1e-6
UNFROZEN_TRAIN_LR = 1e-5
WARMUP_STEPS = 512

In [None]:
tokenizer_func = partial(
    tokenizer, padding=True, truncation=True, return_tensors="pt", max_length=512
)

In [None]:
encoded = tokenizer_func(combined_dataset["messages"][0]).to(model.device)

In [None]:
out = model.forward(**encoded)

In [None]:
out = mean_pooling(out, encoded["attention_mask"])

In [None]:
out = F.normalize(out, p=2, dim=1)

In [None]:
class DyT(nn.Module):
    def __init__(self, hidden_size: int, init_alpha: float = 0.5) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta = nn.Parameter(torch.zeros(hidden_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.gamma * torch.tanh(self.alpha * x) + self.beta

In [None]:
class Head(nn.Module):
    def __init__(self, hidden_size: int, num_classes: int) -> None:
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(hidden_size, 128),
            DyT(128),
            nn.Linear(128, 64),
            DyT(64),
            nn.Linear(64, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.seq(x)

In [None]:
model.head = Head(384, 3).to(model.device)

In [None]:
optim_frozen = AdamW(model.head.parameters(), lr=FREEZE_TRAIN_LR, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = torch.GradScaler()

In [None]:
combined_dataset = combined_dataset.train_test_split(
    test_size=0.1, stratify_by_column="label"
)
train_dataset = combined_dataset["train"]
test_dataset = combined_dataset["test"]
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True
)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
def train_frozen() -> None:
    model.eval()
    for param in model.parameters():
        param.requires_grad = False

    model.head.train()
    for param in model.head.parameters():
        param.requires_grad = True

    last_test_loss = float("inf")
    patience_counter = 0

    while True:
        for batch in tqdm(train_dataloader):
            optim_frozen.zero_grad()
            input_msg = batch["messages"]
            input = tokenizer_func(input_msg).to(model.device)
            target = batch["label"].to(model.device)

            with torch.no_grad():
                out = model.forward(**input)
                out = mean_pooling(out, input["attention_mask"])
                out = F.normalize(out, p=2, dim=1)

            out = model.head(out)
            loss = criterion(out, target)
            loss.backward()
            optim_frozen.step()

        with torch.inference_mode():
            losses = []
            for batch in test_dataloader:
                input_msg = batch["messages"]
                input = tokenizer_func(input_msg).to(model.device)
                target = batch["label"].to(model.device)
                out = model.forward(**input)
                out = mean_pooling(out, input["attention_mask"])
                out = F.normalize(out, p=2, dim=1)
                out = model.head(out)
                loss = criterion(out, target)
                losses.append(loss.item())
            test_loss = sum(losses) / len(losses)
            print(f"Test loss: {test_loss}")
            if test_loss < last_test_loss:
                last_test_loss = test_loss
                patience_counter = 0
            else:
                patience_counter += 1

        if patience_counter >= EARLY_STOP_PATIENCE:
            break

In [None]:
train_frozen()

In [None]:
jigsaw_splits = {'train': 'train_dataset.csv', 'validation': 'val_dataset.csv', 'test': 'test_dataset.csv'}
inference_df = pd.read_csv("hf://datasets/Arsive/toxicity_classification_jigsaw/" + jigsaw_splits["validation"])

olid_splits = {'train': 'train.csv', 'test': 'test.csv'}
olid_df = pd.read_csv("hf://datasets/christophsonntag/OLID/" + olid_splits["train"])

inference_df = inference_df[(inference_df["toxic"] == 1) |
                            (inference_df["severe_toxic"] == 1) |
                            (inference_df["obscene"] == 1) |
                            (inference_df["threat"] == 1) |
                            (inference_df["insult"] == 1) |
                            (inference_df["identity_hate"] == 1)]

olid_df = olid_df.rename(columns={"cleaned_tweet": "prompt"})
olid_df["label"] = 0

inference_df = inference_df.rename(columns={"comment_text": "prompt"})
inference_df["label"] = 0

inference_df = pd.concat([inference_df, olid_df], ignore_index=True)
inference_df = inference_df.sample(frac=1).reset_index(drop=True)

In [None]:
inference_dataset = Dataset.from_pandas(inference_df)
inference_dataloader = DataLoader(
    inference_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
def evaluate_accuracy() -> float:
    model.eval()  # Set model to evaluation mode

    correct = 0
    total = 0

    with torch.inference_mode():
        for batch in tqdm(inference_dataloader, desc="Evaluating"):
            input_msg = batch["messages"]
            input = tokenizer_func(input_msg).to(model.device)
            target = batch["label"].to(model.device)

            # Get embeddings from base model
            out = model.forward(**input)
            out = mean_pooling(out, input["attention_mask"])
            out = F.normalize(out, p=2, dim=1)

            # Forward through head
            logits = model.head(out)

            # Get predictions
            _, predicted = torch.max(logits, 1)

            # Calculate accuracy
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")

    # Print per-class accuracy
    class_correct = [0, 0, 0]
    class_total = [0, 0, 0]

    with torch.inference_mode():
        for batch in tqdm(test_dataloader, desc="Class accuracy"):
            input_msg = batch["messages"]
            input = tokenizer_func(input_msg).to(model.device)
            target = batch["label"].to(model.device)

            out = model.forward(**input)
            out = mean_pooling(out, input["attention_mask"])
            out = F.normalize(out, p=2, dim=1)
            logits = model.head(out)

            _, predicted = torch.max(logits, 1)

            # Calculate per-class accuracy
            for i in range(target.size(0)):
                label = target[i].item()
                class_total[label] += 1
                if predicted[i] == target[i]:
                    class_correct[label] += 1

    for i in range(3):
        class_name = combined_dataset["test"].features["label"].names[i]
        class_acc = 100 * class_correct[i] / class_total[i]
        print(f"Accuracy of {class_name}: {class_acc:.2f}%")

    return accuracy


# Run the evaluation
accuracy = evaluate_accuracy()

In [None]:
accuracy

In [None]:
def plot_confusion_matrix() -> None:
    all_predictions = []
    all_targets = []

    with torch.inference_mode():
        for batch in tqdm(test_dataloader, desc="Collecting predictions"):
            input_msg = batch["messages"]
            input = tokenizer_func(input_msg).to(model.device)
            target = batch["label"].to(model.device)

            out = model.forward(**input)
            out = mean_pooling(out, input["attention_mask"])
            out = F.normalize(out, p=2, dim=1)
            logits = model.head(out)

            _, predicted = torch.max(logits, 1)

            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    # Create confusion matrix
    cm = confusion_matrix(all_targets, all_predictions)

    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=combined_dataset["test"].features["label"].names,
        yticklabels=combined_dataset["test"].features["label"].names,
    )
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()


# Plot the confusion matrix
plot_confusion_matrix()