In [None]:
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from typing import *
from sentence_transformers import SentenceTransformer
from faiss import write_index, read_index
import blingfire as bf
import torch
import faiss
import json
import re

In [None]:
train_df = pd.read_csv("/home/clay/research/kaggle/kaggle_llm/data/kaggle-llm-science-exam/train.csv", index_col=0)
wiki_df = pd.read_csv("/home/clay/research/kaggle/kaggle_llm/data/physics_pages_list/physics_pages_formatted.csv", index_col=0)

In [None]:
wc_per_page = wiki_df.groupby("page")[["word_count"]].sum().sort_values("word_count", ascending=False)
black_list = list(wc_per_page.loc[
    (wc_per_page["word_count"] > 10000)
    | (wc_per_page.index.map(lambda x: "list of equations" in x.lower()))
].index)
print(json.dumps(black_list, indent=4))


filtered_wiki_df = wiki_df.loc[~wiki_df["page"].isin(black_list), :].copy()
print(len(wiki_df), len(filtered_wiki_df))

In [None]:
filtered_wiki_df.head()

In [None]:
sentence_model = "/home/clay/research/kaggle/kaggle_llm/data/sentence_transformer_model"
max_length = 384
batch_size = 16


model = SentenceTransformer(sentence_model, device="cuda")
model.max_seq_length = max_length
model = model.half()

In [None]:
sentences_df = []


for _, row in tqdm(filtered_wiki_df.iterrows(), total=len(filtered_wiki_df)):
    _, sentence_offsets = bf.text_to_sentences_and_offsets(row["text"])
    for i, (start_idx, end_idx) in enumerate(sentence_offsets):
        if (end_idx - start_idx) > 3:
            sentences_df.append({
                "page": row["page"],
                "i_sentence": i,
                "text": row["text"][start_idx: end_idx],
            })

            
sentences_df = pd.DataFrame.from_records(sentences_df)
print(f"extracted: {len(sentences_df)} sentences")

In [None]:
sentences_df.head()

In [None]:
import matplotlib.pyplot as plt


def count_words(text):
    return sum([1 for i in text.split() if len(i) > 0])


_ = plt.hist(sentences_df["text"].apply(count_words), bins=50)

In [None]:
# sentences_df["text"].apply(count_words).describe()
too_long_sentences = sentences_df.loc[sentences_df["text"].apply(count_words) > 150, "text"]
print(len(too_long_sentences) / len(sentences_df))

In [None]:
train_df["prompt_and_answer"] = (
    train_df["prompt"]
    + " " + train_df["A"]
    + " " + train_df["B"]
    + " " + train_df["C"]
    + " " + train_df["D"]
    + " " + train_df["E"]
)


choice_embeddings = []


with torch.no_grad():
    question_embeddings = model.encode(
        train_df["prompt_and_answer"].values, 
        batch_size=batch_size, 
        device=0, 
        show_progress_bar=True, 
        convert_to_tensor=True, 
        normalize_embeddings=True,
    ).half()
    question_embeddings = question_embeddings.detach().cpu().numpy()
    
    for choice in ["A", "B", "C", "D", "E"]:
        embeddings = model.encode(
            train_df[choice].values, 
            batch_size=batch_size, 
            device=0, 
            show_progress_bar=True, 
            convert_to_tensor=True, 
            normalize_embeddings=True,
        ).half()
        choice_embeddings.append(embeddings.detach().cpu().numpy())
    
    sentence_embeddings = model.encode(
        sentences_df["text"].values, 
        batch_size=batch_size,
        device=0, 
        show_progress_bar=True, 
        convert_to_tensor=True, 
        normalize_embeddings=True,
    ).half()
    sentence_embeddings = sentence_embeddings.detach().cpu().numpy()

In [None]:
sentence_index = faiss.IndexFlatL2(sentence_embeddings.shape[1])
sentence_index.add(sentence_embeddings)
print(f"{sentence_index.ntotal = }")

In [None]:
k = 3
distance, indices = sentence_index.search(question_embeddings, k)

for i in range(k):
    train_df[f"context_{i}_idx"] = indices[:, i]
    
for i in range(k):
    train_df[f"context_{i}"] = train_df.join(
        sentences_df["text"],
        on=f"context_{i}_idx",
        how="left",
    )["text"]


choice_k = 1
for emb, choice in zip(choice_embeddings, ["A", "B", "C", "D", "E"]):
    choice_distance, choice_indices = sentence_index.search(emb, choice_k)
    for i in range(choice_k):
        train_df[f"context_{choice}_idx"] = choice_indices[:, i]
        train_df[f"context_{choice}"] = train_df.join(
            sentences_df["text"],
            on=f"context_{choice}_idx",
            how="left",
        )["text"]

    
train_df = train_df.drop([f"context_{i}_idx" for i in range(k)], axis=1)
train_df = train_df.drop([f"context_{i}_idx" for i in ["A", "B", "C", "D", "E"]], axis=1)
train_df.head()

In [None]:
def count_words(text):
    return sum([1 for i in text.split() if len(i) > 0])


train_df["total_wc"] = (
    train_df["prompt"] 
    + " " + train_df["A"]
    + " " + train_df["B"]
    + " " + train_df["C"]
    + " " + train_df["D"]
    + " " + train_df["E"]
    + " " + train_df["context_0"]
    + " " + train_df["context_1"]
    + " " + train_df["context_2"]
    + " " + train_df["context_A"]
    + " " + train_df["context_B"]
    + " " + train_df["context_C"]
    + " " + train_df["context_D"]
    + " " + train_df["context_E"]
).apply(count_words)


In [None]:
import matplotlib.pyplot as plt


_ = plt.hist(train_df["total_wc"], bins=50)

In [None]:
row = train_df.iloc[1]
print(f"question: {row['prompt']}\n")
print(f"context_0: {row['context_0']}\n")
print(f"context_1: {row['context_1']}\n")
print(f"context_2: {row['context_2']}\n")
print(f"A: {row['A']}\n")
print(f"context_A: {row['context_A']}\n")
print(f"B: {row['B']}\n")
print(f"context_B: {row['context_B']}\n")
print(f"C: {row['C']}\n")
print(f"context_C: {row['context_C']}\n")
print(f"D: {row['D']}\n")
print(f"context_D: {row['context_D']}\n")
print(f"E: {row['E']}\n")
print(f"context_E: {row['context_E']}\n")
print(f"answer: {row['answer']}\n")