# Prompt Classification Model Training

This notebook implements the training pipeline for multiple classification models to identify domain-specific prompts. We compare different approaches:

## Models Evaluated
- **GPT-based classifier**: Using LLM for zero/few-shot classification
- **ModernBERT NLI**: Fine-tuned BERT model for Natural Language Inference
- **SVM**: Using different text embeddings (BAAI-BGE, MiniLM, TF-IDF)
- **XGBoost**: Using different text embeddings
- **FastText**: Specialized text classification model

## Domains
- Law
- Healthcare
- Finance

## Evaluation Metrics
- Accuracy
- Latency
- Cost (for LLM-based approaches)

In [None]:
import gc
import os
import pickle
import random
import statistics
import time
from functools import partial

import numpy as np
import onnxruntime as ort
import pandas as pd
import torch
from accelerate.data_loader import DataLoader
from datasets import ClassLabel, Dataset
from dotenv import load_dotenv
from fastembed import TextEmbedding
from sklearn import metrics
from sklearn.feature_extraction.text import TfidfVectorizer
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from transformers import AutoTokenizer

os.chdir('..')
from prompt_classifier.metrics import evaluate_run
from prompt_classifier.modeling.dspy_llm import LlmClassifier
from prompt_classifier.modeling.fasttext import FastTextClassifier
from prompt_classifier.modeling.nli_modernbert import ModernBERTNLI
from prompt_classifier.modeling.widemlp import MLP, prepare_inputs
from prompt_classifier.util import create_domain_dataset, train_and_evaluate_model

load_dotenv()
random.seed(22)

In [None]:
BATCH_SIZE = 32

## Hardware Acceleration

Check available ONNX Runtime providers for hardware acceleration (CPU/CUDA).
This affects the performance of embedding models and ModernBERT.

In [None]:
# Memory tracking
def print_gpu_memory():
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
        print(f"GPU memory cached: {torch.cuda.memory_reserved()/1e9:.2f} GB")


def clear_gpu_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

In [None]:
# Get list of available ONNX Runtime providers (CPU, CUDA etc.)
providers = ort.get_available_providers()

print(providers)

## Dataset Preparation

### Training Data Types
1. **Processed Data**: Clean, filtered dataset used for traditional ML models
2. **Interim Data**: Raw/intermediate data used for LLM experiments

Each domain dataset is balanced with positive samples from target domain and negative samples from other domains.

In [None]:
# Load processed datasets
law_prompts = pd.read_csv("data/processed/law_prompts.csv")
healthcare_prompts = pd.read_csv("data/processed/healthcare_prompts.csv")
finance_prompts = pd.read_csv("data/processed/finance_prompts.csv")

law_dataset = create_domain_dataset(law_prompts, [healthcare_prompts, finance_prompts])

healthcare_dataset = create_domain_dataset(
    healthcare_prompts, [law_prompts, finance_prompts]
)

finance_dataset = create_domain_dataset(
    finance_prompts, [law_prompts, healthcare_prompts]
)

datasets = {
    "law": law_dataset,
    "healthcare": healthcare_dataset,
    "finance": finance_dataset,
}

0(Finance) 1(Law) 2(Health) 3(Uncertain)

In [None]:
finance_prompts["label"] = 0
law_prompts["label"] = 1
healthcare_prompts["label"] = 2

combined_df = pd.concat([finance_prompts, healthcare_prompts, law_prompts], ignore_index=True)
combined_dataset = Dataset.from_pandas(combined_df)
combined_dataset = combined_dataset.cast_column(
    "label", ClassLabel(num_classes=3, names=["finance", "health", "law"])
)
combined_dataset = combined_dataset.train_test_split(
    test_size=0.15, stratify_by_column="label"
)

In [None]:
train_dataset = combined_dataset["train"]
test_dataset = combined_dataset["test"]
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1, pin_memory=True
)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Calculate class weights to handle class imbalance
labels = combined_dataset["train"]["label"]
unique, counts = np.unique(labels, return_counts=True)
total_samples = len(labels)

# Compute weights inversely proportional to class frequencies
class_weights = total_samples / (len(unique) * counts)
class_weights = torch.tensor(class_weights, dtype=torch.float)

print("Class weights:")
print(f"Class 0 (finance): {class_weights[0]:.4f}")
print(f"Class 1 (health): {class_weights[1]:.4f}")
print(f"Class 2 (law): {class_weights[2]:.4f}")


In [None]:
# Load interim datasets
law_prompts_interim = pd.read_csv("data/interim/law_prompts.csv")
healthcare_prompts_interim = pd.read_csv("data/interim/healthcare_prompts.csv")
finance_prompts_interim = pd.read_csv("data/interim/finance_prompts.csv")

law_dataset_interim = create_domain_dataset(
    law_prompts_interim, [healthcare_prompts_interim, finance_prompts_interim]
)

healthcare_dataset_interim = create_domain_dataset(
    healthcare_prompts_interim, [law_prompts_interim, finance_prompts_interim]
)

finance_dataset_interim = create_domain_dataset(
    finance_prompts_interim, [law_prompts_interim, healthcare_prompts_interim]
)

datasets_interim = {
    "law": law_dataset_interim,
    "healthcare": healthcare_dataset_interim,
    "finance": finance_dataset_interim,
}

## Initialize Embedding Models

Set up text embedding models:
- BAAI BGE Small
- MiniLM
- TF-IDF

In [None]:
baai_embedding = TextEmbedding(
    model_name="BAAI/bge-small-en-v1.5", providers=["CUDAExecutionProvider"]
)
mini_embedding = TextEmbedding(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    providers=["CUDAExecutionProvider"],
)

tfidf_embedding = TfidfVectorizer(
    max_features=20_000,
)

In [None]:
print(f"BAAI-BGE available providers: {baai_embedding.model.model.get_providers()}")
print(f"MiniLM available providers: {mini_embedding.model.model.get_providers()}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

torch._dynamo.config.suppress_errors = True  # Suppresses warnings in ModernBERT

## Model Training and Evaluation

### LLM-based Models
First, we evaluate GPT and ModernBERT models using interim data:

In [None]:
for domain, dataset in datasets_interim.items():
    # GPT Classifier
    train_data = dataset.sample(n=800)
    test_data = dataset.drop(train_data.index).sample(n=4000)
    """
    llm_classifier = LlmClassifier(
        api_key=os.getenv("OPENAI_API_KEY"),
        api_base=os.getenv("PROXY_URL"),
        model_name="gpt-4o-mini",
        domain=domain,
        train_data=train_data,
        test_data=test_data,
    )
    """
    llm_classifier = LlmClassifier(
        api_key="",
        api_base="http://localhost:11434",
        model_name="ollama_chat/qwen2.5:14b",
        domain=domain,
        train_data=train_data,
        test_data=test_data,
    )
    try:
        # DSPy optimization
        llm_classifier.optimize_model()

        # Get predictions and metrics for test data
        test_predictions, test_actuals, test_latency = llm_classifier.predict()
        test_latency = statistics.mean(test_latency)

        test_acc = metrics.accuracy_score(test_actuals, test_predictions)

        # evaluate_run and save model
        evaluate_run(
            predictions=test_predictions,
            true_labels=test_actuals,
            domain=domain,
            model_name="qwen2.5:14b",
            embed_model="qwen-base",
            cost=llm_classifier.cost,
            latency=test_latency,
            train_acc=test_acc,
            training=True,
        )

        llm_classifier.save_model(f"models/qwen2.5:14b_{domain}.json")

    except Exception as e:
        print(f"Error running LLM model: {e}")

    # ModernBERT Classifier
    try:
        test_data = dataset.sample(n=30_000)
        bert_classifier = ModernBERTNLI(domain=domain)
        bert_classifier.classifier.model.to("cuda")

        # Test predictions
        test_predictions = []
        test_times = []
        for _, row in tqdm(test_data.iterrows(), total=len(test_data)):
            start_time = time.perf_counter_ns()
            pred = bert_classifier.predict(row["prompt"])
            test_predictions.append(pred)
            test_times.append(time.perf_counter_ns() - start_time)

        print(test_predictions)
        test_acc = metrics.accuracy_score(test_data["label"], test_predictions)
        mean_prediction_time = statistics.mean(test_times)

        # evaluate_run ModernBERT
        evaluate_run(
            predictions=test_predictions,
            true_labels=test_data["label"],
            domain=domain,
            model_name="modernbert",
            embed_model="bert-base",
            latency=mean_prediction_time,
            train_acc=test_acc,
            training=True,
        )
    except Exception as e:
        print(f"Error running ModernBERT model: {e}")

### Traditional ML Models

Evaluate SVM, XGBoost and FastText using processed data with different embedding approaches:
- BAAI-BGE: Dense semantic embeddings
- MiniLM: Lightweight sentence embeddings
- TF-IDF: Sparse word frequency embeddings
- FastText: Custom subword embeddings

Models are trained on 70% of data and evaluated on remaining 30%.

In [None]:
embedding_models = {
    "mini": mini_embedding,
    "tf_idf": tfidf_embedding,
    "baai": baai_embedding,
}

In [None]:
print(f"BAAI-BGE available providers: {baai_embedding.model.model.get_providers()}")
print(f"MiniLM available providers: {mini_embedding.model.model.get_providers()}")

In [None]:
for domain, dataset in datasets.items():
    train_data = dataset.sample(frac=0.7).reset_index(drop=True)
    test_data = dataset.drop(train_data.index).reset_index(drop=True)

    actuals = []
    predictions = []
    prediction_times = []

    for model_name, embedding_model in embedding_models.items():
        embed_times: float = None

        # Add timing for embedding creation
        if model_name == "tf_idf":
            # Fit on training data only
            embedding_model.fit(train_data["prompt"])

            with open(f"models/tfidf_{domain}.pkl", "wb") as f:
                pickle.dump(embedding_model, f)

            start_time = time.perf_counter_ns()
            # Convert sparse matrices to dense for consistency
            train_embeds = embedding_model.transform(train_data["prompt"])
            test_embeds = embedding_model.transform(test_data["prompt"])
            end_time = time.perf_counter_ns()
            embed_times = end_time - start_time
        else:
            # Time the embedding process for training data
            start_time = time.perf_counter_ns()
            train_embeds = np.array(list(embedding_model.embed(train_data["prompt"])))
            test_embeds = np.array(list(embedding_model.embed(test_data["prompt"])))
            end_time = time.perf_counter_ns()
            embed_times = end_time - start_time

        mean_embed_time = embed_times / len(train_data + test_data)

        # Verify shapes
        print(f"Training {model_name} embeddings on {domain} domain")
        print(f"Train shape: {train_embeds.shape}, Test shape: {test_embeds.shape}")
        print(type(train_embeds))

        try:
            # Train and evaluate SVM model
            train_and_evaluate_model(
                model_name="SVM",
                train_embeds=train_embeds,
                test_embeds=test_embeds,
                train_labels=train_data["label"],
                test_labels=test_data["label"],
                domain=domain,
                embed_model=model_name,
                save_path=f"models/SVM_{domain}_{model_name}.pkl",
                embedding_time=mean_embed_time,
                training=True,
            )
        except Exception as e:
            print(f"Error running SVM model: {e}")

        try:
            # Train and evaluate XGBoost model
            train_and_evaluate_model(
                model_name="XGBoost",
                train_embeds=train_embeds,
                test_embeds=test_embeds,
                train_labels=train_data["label"],
                test_labels=test_data["label"],
                domain=domain,
                embed_model=model_name,
                save_path=f"models/XGBoost_{domain}_{model_name}.json",
                embedding_time=mean_embed_time,
                training=True,
            )
        except Exception as e:
            print(f"Error running XGBoost model: {e}")

    # fastText
    try:
        fasttext_classifier = FastTextClassifier(
            train_data=train_data, test_data=test_data
        )
        fasttext_classifier.train()

        train_predictions = []
        for _, row in train_data.iterrows():
            query = str(row["prompt"]).replace("\n", "")
            prediction = fasttext_classifier.model.predict(query)
            train_predictions.append(1 if prediction[0][0] == "__label__1" else 0)

        train_acc = metrics.accuracy_score(train_data["label"], train_predictions)

        for _, row in tqdm(test_data.iterrows(), total=len(test_data)):
            text = str(row["prompt"])
            query = text.replace("\n", "")

            start_time = time.perf_counter_ns()
            prediction = fasttext_classifier.model.predict(query)
            end_time = time.perf_counter_ns()

            prediction_times.append(end_time - start_time)

            if prediction[0][0] == "__label__1":
                predictions.append(1)
            else:
                predictions.append(0)

            actuals.append(row["label"])

        mean_prediction_time = statistics.mean(prediction_times)

        evaluate_run(
            predictions,
            true_labels=actuals,
            domain=domain,
            model_name="fastText",
            embed_model="fastText",
            latency=mean_prediction_time,
            train_acc=train_acc,
            training=True,
        )

        fasttext_classifier.model.save_model(f"models/fastText_{domain}_fasttext.bin")
    except Exception as e:
        print(f"Error running fastText model: {e}")

### Wide MLP Model

In [None]:
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

In [None]:
# Training settings
vocab_size = 20_000
num_epochs = 10
batch_size = 32
learning_rate = 1e-3

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def evaluate_epoch(model, data_loader, device):
    model.eval()
    predictions = []
    actuals = []
    running_loss = 0.0

    with torch.no_grad():
        for batch in data_loader:
            texts = batch['prompt']
            labels = torch.tensor(batch['label'], device=device)

            # Tokenize and prepare inputs
            encoded = tokenizer(texts, padding=True, truncation=True)
            input_ids = torch.tensor(encoded['input_ids'], device=device)
            flat_inputs, offsets = prepare_inputs(input_ids, device)

            loss, outputs = model(flat_inputs, offsets, labels)
            preds = torch.argmax(outputs, dim=1)

            running_loss += loss.item()
            predictions.extend(preds.cpu().tolist())
            actuals.extend(labels.cpu().tolist())

    metrics_dict = {
        'loss': running_loss / len(data_loader),
        'accuracy': metrics.accuracy_score(actuals, predictions),
        'macro_f1': metrics.f1_score(actuals, predictions, average='macro'),
        'weighted_f1': metrics.f1_score(actuals, predictions, average='weighted')
    }

    return metrics_dict

In [None]:
def format_metrics(metrics_dict):
    return ' | '.join([f"{k}: {v:.4f}" for k,v in metrics_dict.items()])

# Initialize model with 3 classes
model = MLP(vocab_size=tokenizer.vocab_size, num_classes=3, problem_type="multi_label_classification")
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)

print(f"Training on device: {device}")
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

for epoch in range(num_epochs):
    model.train()
    epoch_losses = []

    # Training phase
    with tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
        for i, batch in pbar:
            texts = batch['prompt']
            labels = torch.tensor(batch['label'], device=device)

            # Tokenize and prepare inputs
            encoded = tokenizer(texts, padding=True, truncation=True)
            input_ids = torch.tensor(encoded['input_ids'], device=device)
            flat_inputs, offsets = prepare_inputs(input_ids, device)

            optimizer.zero_grad()
            loss, _ = model(flat_inputs, offsets, labels)
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss.item())

            if i % 50 == 0:
                pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    # Evaluation phase
    print(f"\nEpoch {epoch+1} Results:")
    print("-" * 40)
    train_metrics = evaluate_epoch(model, train_dataloader, device)
    print(f"Train | {format_metrics(train_metrics)}")
    val_metrics = evaluate_epoch(model, test_dataloader, device)
    print(f"Test  | {format_metrics(val_metrics)}\n")

    scheduler.step(val_metrics['loss'])
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Learning rate: {current_lr:.2e}\n")

    # Save model if it's the best so far based on validation loss
    if not hasattr(model, 'best_val_loss') or val_metrics['loss'] < model.best_val_loss:
        model.best_val_loss = val_metrics['loss']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_metrics': val_metrics,
        }, 'models_5/mlp_best_model.pt')

print("Training completed!")
print("Best model saved to: models/mlp_best_model.pt")