In [None]:
import sys
sys.path.append("../src/")


from kaggle_llm.core import (
    ROOT_PATH,
    count_words,
)
import pandas as pd
import numpy as np
from tqdm import tqdm
from typing import *
from sentence_transformers import SentenceTransformer
from pathlib import Path
import blingfire as bf
import argparse
import torch
import faiss
import json
import yaml


too_long_prompt_wc = 250
context_cutoff_len = 150


def split_text_into_chunks(text: str, chunk_size: int) -> List[str]:
    tokens = [i for i in text.split() if len(i) > 0]
    start = 0
    end = chunk_size
    chunks = []
    while end < len(tokens):
        chunks.append(" ".join(tokens[start: end]))
        start += chunk_size
        end += chunk_size
    return chunks


@torch.no_grad()
def get_sentence_embeddings(
        wiki_df_path: Union[str, Path],
        model: SentenceTransformer
):
    wiki_df = pd.read_csv(wiki_df_path, index_col=0)

    wc_per_page = wiki_df.groupby("page")[["word_count"]].sum().sort_values("word_count", ascending=False)
    black_list = list(wc_per_page.loc[
        (wc_per_page["word_count"] > 10000)
        | (wc_per_page.index.map(lambda x: "list of equations" in x.lower()))
    ].index)
    print(json.dumps(black_list, indent=4))

    filtered_wiki_df = wiki_df.loc[~wiki_df["page"].isin(black_list), :].copy()
    print(len(wiki_df), len(filtered_wiki_df))

    batch_size = 16
    sentences_df = []

    print("extracting sentences:")
    for _, row in tqdm(filtered_wiki_df.iterrows(), total=len(filtered_wiki_df)):
        _, sentence_offsets = bf.text_to_sentences_and_offsets(row["text"])
        for start_idx, end_idx in sentence_offsets:
            is_long_enough = (end_idx - start_idx) > 3
            is_math = "\\" in row["text"][start_idx: end_idx]  # leads to excessive tokens
            if is_long_enough and (not is_math):
                sentences_df.append({
                    "page": row["page"],
                    "i_sentence": len(sentences_df),
                    "text": row["text"][start_idx: end_idx],
                    "topic": row["page"],
                })

    sentences_df = pd.DataFrame.from_records(sentences_df)
    print(f"extracted: {len(sentences_df)} sentences")

    print(f"dropping too long sentences")
    pass_indices = sentences_df.loc[sentences_df["text"].apply(count_words) < context_cutoff_len, "text"].index
    print(f"keeping {len(pass_indices) / len(sentences_df) * 100} % at cutoff {context_cutoff_len}")
    sentences_df = sentences_df.loc[pass_indices, :].reset_index().copy()

    print("computing wiki embeddings:")
    sentence_embeddings = model.encode(
        sentences_df["text"].values,
        batch_size=batch_size,
        device="cuda",
        show_progress_bar=True,
        convert_to_tensor=True,
        normalize_embeddings=True,
    ).half()
    sentence_embeddings = sentence_embeddings.detach().cpu().numpy()

    sentence_index = faiss.IndexFlatIP(sentence_embeddings.shape[1])
    sentence_index.add(sentence_embeddings)
    print(f"{sentence_index.ntotal = }")
    return sentences_df, sentence_index

In [None]:
train_df_path = "/home/clay/research/kaggle/kaggle_llm/data/data_dumps/context_df/more_questions_raw_questions.csv"
wiki_df_path = "/home/clay/research/kaggle/kaggle_llm/data/physics_pages_list/physics_pages_formatted.csv"
sentence_model = "/home/clay/research/kaggle/kaggle_llm/data/sentence_transformer_model"
k = 3

In [None]:
model = SentenceTransformer(sentence_model, device="cuda")
model.max_seq_length = 384
model = model.half()

In [None]:
wiki_df_path = Path(wiki_df_path)
print(f"computing wiki embeddings")
sentences_df, sentence_index = get_sentence_embeddings(wiki_df_path, model)
batch_size = 16
print(f"computed wiki embeddings")

In [None]:
batch_size = 16
train_df = pd.read_csv(train_df_path)
if "id" in train_df.columns:
    train_df = train_df.drop("id", axis=1)

train_df["prompt_and_answer"] = (
        train_df["prompt"]
        + " " + train_df["A"]
        + " " + train_df["B"]
        + " " + train_df["C"]
        + " " + train_df["D"]
        + " " + train_df["E"]
)
question_embeddings = model.encode(
    train_df["prompt_and_answer"].values,
    batch_size=batch_size,
    device="cuda",
    show_progress_bar=True,
    convert_to_tensor=True,
    normalize_embeddings=True,
).half()
question_embeddings = question_embeddings.detach().cpu().numpy()

distance, indices = sentence_index.search(question_embeddings, k)

for i in range(k):
    train_df[f"context_{i}_idx"] = indices[:, i]

for i in range(k):
    # this is different, we just join the topic here
    train_df[f"context_topic_{i}"] = train_df.join(
        sentences_df["topic"],
        on=f"context_{i}_idx",
        how="left",
        rsuffix=f"_context_{i}"
    )[f"topic_context_{i}"]
    train_df[f"context_topic_{i}"] = train_df[f"context_topic_{i}"].apply(
        lambda x: x.lower().replace(" ", "_")\
            .replace("/", "_")\
            .replace("'", "_")\
            .replace("(", "_")\
            .replace(")", "_")\
            .replace(":", "_")
    )

assert not train_df["prompt"].isna().any(), f"{train_df_path} contains {train_df['prompt'].isna().sum()} dumbass prompts"

In [None]:
train_df.head()

In [None]:
accs = []
for i in range(k):
    accs.append((train_df["topic"] == train_df[f"context_topic_{i}"]).mean())

In [None]:
accs

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay


ConfusionMatrixDisplay(
    confusion_matrix(
        train_df["topic"],
        train_df["context_topic_0"]
    )
).plot()