In [1]:
import os
import pickle as pkl
import random
import time

import fasttext
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from dotenv import load_dotenv
from fastembed import TextEmbedding
from sklearn import metrics
from tqdm import tqdm
from xgboost import XGBClassifier

from prompt_classifier.modeling.dspy_llm import LlmClassifier
from prompt_classifier.modeling.fasttext import FastTextClassifier
from prompt_classifier.modeling.nli_modernbert import ModernBERTNLI

load_dotenv()
random.seed(1)

In [2]:
jigsaw_splits = {'train': 'train_dataset.csv', 'validation': 'val_dataset.csv', 'test': 'test_dataset.csv'}
inference_df = pd.read_csv("hf://datasets/Arsive/toxicity_classification_jigsaw/" + jigsaw_splits["validation"])

olid_splits = {'train': 'train.csv', 'test': 'test.csv'}
olid_df = pd.read_csv("hf://datasets/christophsonntag/OLID/" + olid_splits["train"])

inference_df = inference_df[(inference_df["toxic"] == 1) | 
                            (inference_df["severe_toxic"] == 1) | 
                            (inference_df["obscene"] == 1) | 
                            (inference_df["threat"] == 1) | 
                            (inference_df["insult"] == 1) | 
                            (inference_df["identity_hate"] == 1)]

olid_df = olid_df.rename(columns={"cleaned_tweet": "prompt"})
olid_df["label"] = 0

inference_df = inference_df.rename(columns={"comment_text": "prompt"})
inference_df["label"] = 0

inference_df = pd.concat([inference_df, olid_df], ignore_index=True)
inference_df = inference_df.sample(frac=1).reset_index(drop=True)

In [3]:
inference_df = inference_df[["prompt", "label"]]
inference_df = inference_df.dropna(subset=['prompt'])

In [4]:
baai_embedding = TextEmbedding(
    model_name="BAAI/bge-small-en-v1.5",
    providers=["CUDAExecutionProvider"]
)

[0;93m2025-03-31 14:26:10.285378521 [W:onnxruntime:, session_state.cc:1168 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2025-03-31 14:26:10.285426626 [W:onnxruntime:, session_state.cc:1170 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m


In [None]:
# GPT Classifier
llm_classifier = LlmClassifier(
    api_key=os.getenv("OPENAI_API_KEY"),
    api_base=os.getenv("PROXY_URL"),
    model_name="gpt-4o-mini",
    domain="finance or healthcare or law",
    train_data=inference_df,
    test_data=inference_df,
)

try:
    llm_classifier.load_model("models/gpt-4o-mini-finance.json")

    test_predictions, test_actuals, test_latency = llm_classifier.predict()

except Exception as e:
    print(f"Error running GPT model: {e}")

In [None]:
bert_classifier_finance = ModernBERTNLI(domain="finance")
bert_classifier_healthcare = ModernBERTNLI(domain="healthcare") 
bert_classifier_law = ModernBERTNLI(domain="law")

predictions_bert = []
prediction_times_bert = []

try:
    # Move models to GPU
    bert_classifier_finance.classifier.model.to("cuda")
    bert_classifier_healthcare.classifier.model.to("cuda")
    bert_classifier_law.classifier.model.to("cuda")

    # Get predictions for each prompt
    for _, row in tqdm(inference_df.iterrows(), total=len(inference_df)):
        start_time = time.perf_counter_ns()
        
        # Get predictions from all models
        pred_finance = bert_classifier_finance.predict(row["prompt"])
        pred_healthcare = bert_classifier_healthcare.predict(row["prompt"])
        pred_law = bert_classifier_law.predict(row["prompt"])
        
        end_time = time.perf_counter_ns()
        prediction_times_bert.append(end_time - start_time)
        
        # If any model predicts 1, final prediction is 0
        predictions_bert.append(0 if (pred_finance == 1 or pred_healthcare == 1 or pred_law == 1) else 1)

except Exception as e:
    print(f"Error running ModernBERT models: {e}")

In [5]:
actuals_ft = []
predictions_ft = []
prediction_times_ft = []

# fastText
fasttext_classifier_finance = FastTextClassifier(train_data=inference_df, test_data=inference_df)
fasttext_classifier_finance.model = fasttext.load_model("models/fastText_finance_fasttext.bin")

fasttext_classifier_healthcare = FastTextClassifier(train_data=inference_df, test_data=inference_df)
fasttext_classifier_healthcare.model = fasttext.load_model("models/fastText_healthcare_fasttext.bin")

fasttext_classifier_law = FastTextClassifier(train_data=inference_df, test_data=inference_df)
fasttext_classifier_law.model = fasttext.load_model("models/fastText_law_fasttext.bin")

for _, row in tqdm(inference_df.iterrows(), total=len(inference_df)):
    text = str(row["prompt"])
    query = text.replace("\n", "")

    start_time = time.perf_counter_ns()
    
    # Predictions from all three classifiers
    prediction_finance = fasttext_classifier_finance.model.predict(query)
    prediction_healthcare = fasttext_classifier_healthcare.model.predict(query)
    prediction_law = fasttext_classifier_law.model.predict(query)
    
    end_time = time.perf_counter_ns()
    prediction_times_ft.append(end_time - start_time)
    
    predictions_ft.append(0 if (prediction_finance[0][0] == "__label__1" or prediction_healthcare[0][0] == "__label__1" or prediction_law[0][0] == "__label__1") else 1)

    actuals_ft.append(row["label"])

100%|██████████| 16384/16384 [00:01<00:00, 12847.39it/s]


In [6]:
# Embedding test data
start_time = time.perf_counter_ns()
test_embeds = np.array(list(baai_embedding.embed(inference_df["prompt"])))
end_time = time.perf_counter_ns()
embed_times = end_time - start_time

mean_embed_time = embed_times / len(inference_df)

with open("models/SVM_finance_baai.pkl", "rb") as svm_file:
    svm_classifier_finance = pkl.load(svm_file)

with open("models/SVM_healthcare_baai.pkl", "rb") as svm_file:
    svm_classifier_healthcare = pkl.load(svm_file)

with open("models/SVM_law_baai.pkl", "rb") as svm_file:
    svm_classifier_law = pkl.load(svm_file)

xgb_classifier_finance = XGBClassifier()
xgb_classifier_healthcare = XGBClassifier()
xgb_classifier_law = XGBClassifier()

xgb_classifier_finance.load_model("models/XGBoost_finance_baai.json")
xgb_classifier_healthcare.load_model("models/XGBoost_healthcare_baai.json")
xgb_classifier_law.load_model("models/XGBoost_law_baai.json")

predictions_xgb = []
predictions_svm = []

prediction_times_xgb = []
prediction_times_svm = []

for test_embed in test_embeds:
    # SVM predictions
    start_time = time.perf_counter_ns()
    pred_finance = svm_classifier_finance.predict(test_embed.reshape(1, -1))
    pred_healthcare = svm_classifier_healthcare.predict(test_embed.reshape(1, -1))
    pred_law = svm_classifier_law.predict(test_embed.reshape(1, -1))
    end_time = time.perf_counter_ns()
    
    prediction_times_svm.append(end_time - start_time)
    # If any model predicts 1, final prediction is 0
    predictions_svm.append(0 if (pred_finance[0] == 1 or pred_healthcare[0] == 1 or pred_law[0] == 1) else 1)

    # XGBoost predictions
    start_time = time.perf_counter_ns()
    pred_finance = xgb_classifier_finance.predict(test_embed.reshape(1, -1))
    pred_healthcare = xgb_classifier_healthcare.predict(test_embed.reshape(1, -1))
    pred_law = xgb_classifier_law.predict(test_embed.reshape(1, -1))
    end_time = time.perf_counter_ns()
    
    prediction_times_xgb.append(end_time - start_time)
    # If any model predicts 1, final prediction is 0
    predictions_xgb.append(0 if (pred_finance[0] == 1 or pred_healthcare[0] == 1 or pred_law[0] == 1) else 1)

In [8]:
times_ms = {
    'fastText': np.array(prediction_times_ft) / 1_000_000,
    'XGBoost': np.array(prediction_times_xgb) / 1_000_000,
    'SVM': np.array(prediction_times_svm) / 1_000_000,
}

# Create DataFrame for plotting
plot_data = pd.DataFrame({
    'Model': [k for k,v in times_ms.items() for _ in v],
    'Latency (ms)': [x for v in times_ms.values() for x in v]
})

In [9]:
accuracy_ft = metrics.accuracy_score(actuals_ft, predictions_ft)
accuracy_svm = metrics.accuracy_score(actuals_ft, predictions_svm)
accuracy_xgb = metrics.accuracy_score(actuals_ft, predictions_xgb)

# Print the accuracies
print(f"fastText Accuracy: {accuracy_ft}")
print(f"SVM Accuracy: {accuracy_svm}")
print(f"XGBoost Accuracy: {accuracy_xgb}")

fastText Accuracy: 0.15362548828125
SVM Accuracy: 0.778564453125
XGBoost Accuracy: 0.725341796875
