In [None]:
import time

import statistics
import pandas as pd
from dotenv import load_dotenv
import numpy as np
import pickle as pkl
from sklearn.feature_extraction.text import TfidfVectorizer
from fastembed import TextEmbedding
from xgboost import XGBClassifier
import cupy as cp

load_dotenv()

import dataloader
import util

util.set_seed(22)

In [None]:
datasets = dataloader.get_domain_data()
eval_datasets = dataloader.get_eval_datasets()
batch_data = dataloader.get_batch_data()

batch_sizes = [1, 32, 64, 128, 256]

In [None]:
baai_embedding = TextEmbedding(
    model_name="BAAI/bge-small-en-v1.5", providers=["CPUExecutionProvider"]
)
mini_embedding = TextEmbedding(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    providers=["CPUExecutionProvider"],
)

tfidf_embedding = TfidfVectorizer()

In [None]:
first_dataset = next(iter(datasets.values()))["prompt"]
train_prompts = first_dataset.sample(frac=0.8, random_state=22)

tfidf_embedding.fit(train_prompts)

with open(f"models/tfidf.pkl", "wb") as f:
    pkl.dump(tfidf_embedding, f)

In [None]:
embedding_models = {
    "mini": mini_embedding,
    "tf_idf": tfidf_embedding,
    "baai": baai_embedding,
}

# Train

In [None]:
for domain, dataset in datasets.items():
    train_data = dataset.sample(frac=0.8).reset_index(drop=True)
    test_data = dataset.drop(train_data.index).reset_index(drop=True)

    actuals = []
    predictions = []
    prediction_times = []

    # Skip embedding models if we're waiting for healthcare fastText
    for model_name, embedding_model in embedding_models.items():
        embed_times: float = None

        # Add timing for embedding creation
        if model_name == "tf_idf":
            start_time = time.perf_counter_ns()
            # Convert sparse matrices to dense for consistency
            train_embeds = embedding_model.transform(train_data["prompt"])
            test_embeds = embedding_model.transform(test_data["prompt"])
            end_time = time.perf_counter_ns()
            embed_times = end_time - start_time
        else:
            # Time the embedding process for training data
            start_time = time.perf_counter_ns()
            train_embeds = np.array(list(embedding_model.embed(train_data["prompt"])))
            test_embeds = np.array(list(embedding_model.embed(test_data["prompt"])))
            end_time = time.perf_counter_ns()
            embed_times = end_time - start_time

        mean_embed_time = embed_times / len(train_data + test_data)

        # Train and evaluate XGBoost model
        util.train_and_evaluate_model(
            model_name="XGBoost",
            train_embeds=train_embeds,
            test_embeds=test_embeds,
            train_labels=train_data["label"],
            test_labels=test_data["label"],
            domain=domain,
            embed_model=model_name,
            save_path=f"models/XGBoost_{domain}_{model_name}.json",
            embedding_time=mean_embed_time,
            training=False,
        )


# Eval

In [None]:
xgb_law = XGBClassifier(
    tree_method="hist",
    device="cuda",
)
xgb_finance = XGBClassifier(
    tree_method="hist",
    device="cuda",
)
xgb_healthcare = XGBClassifier(
    tree_method="hist",
    device="cuda",
)

xgb_law.load_model(f"models/XGBoost_law_{embedding_model}.json")
xgb_finance.load_model(f"models/XGBoost_finance_{embedding_model}.json")
xgb_healthcare.load_model(f"models/XGBoost_healthcare_{embedding_model}.json")

In [None]:
# TF-IDF
tfidf_embedding = pkl.load(open("models/tfidf.pkl", "rb"))

embedding_models = {
    "mini": mini_embedding,
    "tf_idf": tfidf_embedding,
    "baai": baai_embedding,
}

In [None]:
for embed_model_name, embedding_model in embedding_models:
    for domain, inference_df in datasets.items():
        # Get actual labels once
        actuals_ml = inference_df["label"].tolist()

        # Get embeddings based on model type
        if embedding_model == "tf_idf":
            test_embeds = embedding_model.transform(inference_df["prompt"])
        else:
            start_time = time.perf_counter_ns()
            if embedding_model == "mini":
                test_embeds = np.array(
                    list(mini_embedding.embed(inference_df["prompt"]))
                )
            else:  # baai
                test_embeds = np.array(
                    list(baai_embedding.embed(inference_df["prompt"]))
                )
            end_time = time.perf_counter_ns()
            embed_times = end_time - start_time
            mean_embed_time = embed_times / len(inference_df)


        predictions_xgb = []
        prediction_times_xgb = []

        # Make predictions
        for test_embed in test_embeds:
            test_embed = test_embed.reshape(1, -1)

            # XGBoost predictions
            start_time = time.perf_counter_ns()
            pred_finance = xgb_finance.predict(test_embed)
            pred_healthcare = xgb_healthcare.predict(test_embed)
            pred_law = xgb_law.predict(test_embed)
            end_time = time.perf_counter_ns()

            prediction_times_xgb.append(end_time - start_time)
            predictions_xgb.append(
                0
                if (
                    pred_finance[0] == 1
                    or pred_healthcare[0] == 1
                    or pred_law[0] == 1
                )
                else 1
            )

        util.evaluate_run(
            predictions=predictions_xgb,
            true_labels=actuals_ml,
            latency=statistics.mean(prediction_times_xgb),
            domain=domain,
            embed_model=embedding_model,
            model_name="XGBoost",
            train_acc=0.0,
            cost=0.0,
            training=False,
        )

# Batch

In [None]:
for embedding_model in ["mini", "baai", "tf_idf"]:    
    xgb_batch_results = []

    for batch_size in batch_sizes:
        xgb_all_results = []
        num_batches = 0

        batches = [
            batch_data[i : i + batch_size] for i in range(0, len(batch_data), batch_size)
        ]
        for batch in batches:
            num_batches += 1

            batch_metrics = {
                "embed_time": 0,
                "xgb_law_time": 0,
                "xgb_finance_time": 0,
                "xgb_health_time": 0
            }

            # Time embeddings
            start_time = time.perf_counter()
            if embedding_model == "tf_idf":
                embeds = tfidf_embedding.transform(batch)
                embeds = embeds.toarray()
            elif embedding_model == "mini":
                embeds = list(mini_embedding.embed(batch))
            else: # baai
                embeds = list(baai_embedding.embed(batch))
            batch_metrics['embed_time'] += time.perf_counter() - start_time

            embeds = cp.array(embeds)

            # XGB predictions
            start_time = time.perf_counter()
            xgb_law_preds = xgb_law.predict(embeds)
            batch_metrics['xgb_law_time'] += time.perf_counter() - start_time

            start_time = time.perf_counter()
            xgb_finance_preds = xgb_finance.predict(embeds)
            batch_metrics['xgb_finance_time'] += time.perf_counter() - start_time

            start_time = time.perf_counter()
            xgb_health_preds = xgb_healthcare.predict(embeds)
            batch_metrics['xgb_health_time'] += time.perf_counter() - start_time

            xgb_batch_preds = [1 if (l or f or h) else 0
                                for l,f,h in zip(xgb_law_preds, xgb_finance_preds, xgb_health_preds)]


            xgb_batch_results.append({
                "batch_size": batch_size,
                "time_taken_embed": batch_metrics['embed_time'],
                "time_taken_law": batch_metrics['xgb_law_time'],
                "time_taken_finance": batch_metrics['xgb_finance_time'],
                "time_taken_healthcare": batch_metrics['xgb_health_time'],
                "results": xgb_batch_preds,
                "model_name": "xgb",
                "embedding_model": embedding_model,
                "embedding": True
            })
    
    pd.DataFrame(xgb_batch_results).to_csv(f"data/results/batch_xgb_{embedding_model}.csv", index=False)
