In [7]:
import os
import pickle as pkl
import statistics
import time

import numpy as np
import pandas as pd
from dotenv import load_dotenv
from fastembed import TextEmbedding
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm import tqdm
from xgboost import XGBClassifier

load_dotenv()

import dataloader
import util

util.set_seed(22)

In [None]:
datasets = dataloader.get_domain_data()
eval_datasets = dataloader.get_eval_datasets()
batch_data = dataloader.get_batch_data()

batch_sizes = [1, 32, 64, 128, 256]

In [None]:
baai_embedding = TextEmbedding(
    model_name="BAAI/bge-small-en-v1.5", providers=["CUDAExecutionProvider"]
)
mini_embedding = TextEmbedding(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    providers=["CUDAExecutionProvider"],
)

tfidf_embedding = TfidfVectorizer()

In [None]:
first_dataset = next(iter(datasets.values()))["prompt"]
train_prompts = first_dataset.sample(frac=0.8, random_state=22)

tfidf_embedding.fit(train_prompts)

with open("models/tfidf.pkl", "wb") as f:
    pkl.dump(tfidf_embedding, f)

# Create embedding cache directory
os.makedirs("cache/embeddings", exist_ok=True)

In [None]:
embedding_models = {
    "mini": mini_embedding,
    "tf_idf": tfidf_embedding,
    "baai": baai_embedding,
}

# Cache

In [None]:
def get_cached_embeddings(
    texts, model_name, domain, cache_dir="cache/embeddings", force_recompute=False
):
    """Get embeddings from cache if available, otherwise compute and cache them.

    Args:
        texts: The texts to embed
        model_name: The name of the embedding model to use
        domain: The domain identifier for the cache
        cache_dir: Directory to store/retrieve cached embeddings
        force_recompute: If True, ignore cache and recompute embeddings

    Returns:
        The embeddings matrix
    """
    cache_file = f"{cache_dir}/{domain}_{model_name}_embeddings.pkl"

    # Check if cache exists and we're not forcing recomputation
    if os.path.exists(cache_file) and not force_recompute:
        print(f"Loading cached embeddings for {domain} using {model_name}")
        with open(cache_file, "rb") as f:
            return pkl.load(f)

    # Cache doesn't exist or forced recomputation
    if force_recompute:
        print(f"Force recomputing embeddings for {domain} using {model_name}...")
    else:
        print(f"Computing embeddings for {domain} using {model_name}...")

    if model_name == "tf_idf":
        embeddings = tfidf_embedding.transform(texts)
    else:
        # Get the appropriate embedding model
        embed_model = embedding_models[model_name]

        # Process in batches for better memory efficiency
        batch_size = 1  # Adjust based on available RAM
        all_embeddings = []

        for i in tqdm(range(0, len(texts), batch_size)):
            batch_texts = texts[i : i + batch_size]
            batch_embeddings = list(embed_model.embed(batch_texts))
            all_embeddings.extend(batch_embeddings)

        embeddings = np.array(all_embeddings)

    # Cache the results
    with open(cache_file, "wb") as f:
        pkl.dump(embeddings, f)

    return embeddings

# Train

In [None]:
for domain, dataset in datasets.items():
    train_data = dataset.sample(frac=0.8, random_state=22).reset_index(drop=True)
    test_data = dataset.drop(train_data.index).reset_index(drop=True)

    actuals = []
    predictions = []
    prediction_times = []

    for model_name, embedding_model in embedding_models.items():
        start_time = time.perf_counter_ns()

        # Get cached or compute new embeddings
        train_embeds = get_cached_embeddings(
            train_data["prompt"], model_name, f"{domain}_train"
        )
        test_embeds = get_cached_embeddings(
            test_data["prompt"], model_name, f"{domain}_test"
        )

        end_time = time.perf_counter_ns()
        embed_times = end_time - start_time
        mean_embed_time = embed_times / len(train_data + test_data)

        print(f"Embedding time for {model_name}: {mean_embed_time} ns")

        # Train and evaluate XGBoost model
        util.train_and_evaluate_model(
            model_name="XGBoost",
            train_embeds=train_embeds,
            test_embeds=test_embeds,
            train_labels=train_data["label"],
            test_labels=test_data["label"],
            domain=domain,
            embed_model=model_name,
            save_path=f"models/XGBoost_{domain}_{model_name}.json",
            embedding_time=mean_embed_time,
            training=True,
        )

# Eval

In [None]:
# Load TF-IDF model
with open("models/tfidf.pkl", "rb") as f:
    tfidf_embedding = pkl.load(f)

embedding_models = {
    "mini": mini_embedding,
    "tf_idf": tfidf_embedding,
    "baai": baai_embedding,
}

In [None]:
for embed_model_name, embedding_model in embedding_models.items():
    xgb_law = XGBClassifier(
        tree_method="hist",
        device="cuda",
    )
    xgb_finance = XGBClassifier(
        tree_method="hist",
        device="cuda",
    )
    xgb_healthcare = XGBClassifier(
        tree_method="hist",
        device="cuda",
    )

    xgb_law.load_model(f"models/XGBoost_law_{embed_model_name}.json")
    xgb_finance.load_model(f"models/XGBoost_finance_{embed_model_name}.json")
    xgb_healthcare.load_model(f"models/XGBoost_healthcare_{embed_model_name}.json")

    for domain, inference_df in eval_datasets.items():
        actuals_ml = inference_df["label"].tolist()

        # Use cached embeddings or compute new ones
        embeds = get_cached_embeddings(
            inference_df["prompt"], embed_model_name, f"{domain}_eval"
        )

        predictions_xgb = []
        prediction_times_xgb = []

        # Batch prediction for better performance
        start_time = time.perf_counter_ns()
        pred_finance = xgb_finance.predict(embeds)
        pred_healthcare = xgb_healthcare.predict(embeds)
        pred_law = xgb_law.predict(embeds)
        end_time = time.perf_counter_ns()

        prediction_time = end_time - start_time
        prediction_times_xgb = [prediction_time / embeds.shape[0]] * embeds.shape[0]

        # Combine predictions
        predictions_xgb = [
            0 if (f == 1 or h == 1 or l == 1) else 1
            for f, h, l in zip(pred_finance, pred_healthcare, pred_law, strict=True)
        ]

        util.evaluate_run(
            predictions=predictions_xgb,
            true_labels=actuals_ml,
            latency=statistics.mean(prediction_times_xgb),
            domain=domain,
            embed_model=embed_model_name,
            model_name="XGBoost",
            train_acc=0.0,
            cost=0.0,
            training=False,
        )


# Batch

In [None]:
for embedding_model_name in ["mini", "baai", "tf_idf"]:
    # Load models
    xgb_law = XGBClassifier(tree_method="hist", device="cuda")
    xgb_finance = XGBClassifier(tree_method="hist", device="cuda")
    xgb_healthcare = XGBClassifier(tree_method="hist", device="cuda")

    xgb_law.load_model(f"models/XGBoost_law_{embedding_model_name}.json")
    xgb_finance.load_model(f"models/XGBoost_finance_{embedding_model_name}.json")
    xgb_healthcare.load_model(f"models/XGBoost_healthcare_{embedding_model_name}.json")

    xgb_batch_results = []

    for batch_size in batch_sizes:
        print(
            f"Processing batch size {batch_size} with {embedding_model_name} embeddings"
        )
        batches = [
            batch_data[i : i + batch_size]
            for i in range(0, len(batch_data), batch_size)
        ]
        for batch in tqdm(batches):
            batch_metrics = {
                "embed_time": 0,
                "xgb_law_time": 0,
                "xgb_finance_time": 0,
                "xgb_health_time": 0,
            }

            # Time embeddings
            start_time = time.perf_counter()
            embedding_model = embedding_models[embedding_model_name]
            if embedding_model_name == "tf_idf":
                embeds = embedding_model.transform(batch)
            else:
                embeds = np.array(list(embedding_model.embed(batch)))
            batch_metrics["embed_time"] += time.perf_counter() - start_time

            # XGB predictions
            start_time = time.perf_counter()
            xgb_law_preds = xgb_law.predict(embeds)
            batch_metrics["xgb_law_time"] += time.perf_counter() - start_time

            start_time = time.perf_counter()
            xgb_finance_preds = xgb_finance.predict(embeds)
            batch_metrics["xgb_finance_time"] += time.perf_counter() - start_time

            start_time = time.perf_counter()
            xgb_health_preds = xgb_healthcare.predict(embeds)
            batch_metrics["xgb_health_time"] += time.perf_counter() - start_time

            # Create a list of dictionaries, one for each prompt in the batch
            results = []
            for law_pred, finance_pred, health_pred in zip(
                xgb_law_preds, xgb_finance_preds, xgb_health_preds, strict=True
            ):
                results.append(
                    {
                        "finance": int(finance_pred),
                        "healthcare": int(health_pred),
                        "law": int(law_pred),
                    }
                )

            xgb_batch_results.append(
                {
                    "batch_size": batch_size,
                    "time_taken_embed": batch_metrics["embed_time"],
                    "time_taken_law": batch_metrics["xgb_law_time"],
                    "time_taken_finance": batch_metrics["xgb_finance_time"],
                    "time_taken_healthcare": batch_metrics["xgb_health_time"],
                    "results": results,
                    "model_name": "xgb",
                    "embedding_model": embedding_model_name,
                    "embedding": True,
                }
            )

    pd.DataFrame(xgb_batch_results).to_csv(
        f"data/results/batch_xgb_{embedding_model_name}.csv", index=False
    )