In [None]:
import sys
sys.path.append("../src/")


from kaggle_llm.core import (
    ROOT_PATH,
    count_words,
)
import pandas as pd
import numpy as np
from tqdm import tqdm
from typing import *
from sentence_transformers import SentenceTransformer
from pathlib import Path
import blingfire as bf
import argparse
import torch
import faiss
import json
import yaml


too_long_prompt_wc = 250
context_cutoff_len = 150


def split_text_into_chunks(text: str, chunk_size: int) -> List[str]:
    tokens = [i for i in text.split() if len(i) > 0]
    start = 0
    end = chunk_size
    chunks = []
    while end < len(tokens):
        chunks.append(" ".join(tokens[start: end]))
        start += chunk_size
        end += chunk_size
    return chunks


@torch.no_grad()
def get_sentence_df(
        wiki_df_path: Union[str, Path],
):
    wiki_df = pd.read_csv(wiki_df_path, index_col=0)

    wc_per_page = wiki_df.groupby("page")[["word_count"]].sum().sort_values("word_count", ascending=False)
    black_list = list(wc_per_page.loc[
        (wc_per_page["word_count"] > 10000)
        | (wc_per_page.index.map(lambda x: "list of equations" in x.lower()))
    ].index)
    print(json.dumps(black_list, indent=4))

    filtered_wiki_df = wiki_df.loc[~wiki_df["page"].isin(black_list), :].copy()
    print(len(wiki_df), len(filtered_wiki_df))

    batch_size = 16
    sentences_df = []

    print("extracting sentences:")
    for _, row in tqdm(filtered_wiki_df.iterrows(), total=len(filtered_wiki_df)):
        _, sentence_offsets = bf.text_to_sentences_and_offsets(row["text"])
        for start_idx, end_idx in sentence_offsets:
            is_long_enough = (end_idx - start_idx) > 3
            is_math = "\\" in row["text"][start_idx: end_idx]  # leads to excessive tokens
            if is_long_enough and (not is_math):
                sentences_df.append({
                    "page": row["page"],
                    "i_sentence": len(sentences_df),
                    "text": row["text"][start_idx: end_idx],
                })

    sentences_df = pd.DataFrame.from_records(sentences_df)
    print(f"extracted: {len(sentences_df)} sentences")

    print(f"dropping too long sentences")
    pass_indices = sentences_df.loc[sentences_df["text"].apply(count_words) < context_cutoff_len, "text"].index
    print(f"keeping {len(pass_indices) / len(sentences_df) * 100} % at cutoff {context_cutoff_len}")
    sentences_df = sentences_df.loc[pass_indices, :].reset_index().copy()

    return sentences_df

In [None]:
sentence_model = "/home/clay/research/kaggle/kaggle_llm/data/sentence_transformer_model"
wiki_df_path = "/home/clay/research/kaggle/kaggle_llm/data/physics_pages_list/physics_pages_formatted.csv"

In [None]:
sentence_df = get_sentence_df(wiki_df_path)

In [None]:
print(len(sentence_df))
sentence_df.head()

In [None]:
train_labels = sorted(list(sentence_df["page"].unique()))
train_rv_idx = {v: i for i, v in enumerate(train_labels)}

In [None]:
model_id = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(model_id)

In [None]:
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_examples = [
    InputExample(texts=['Sentence from class 0'], label=0), 
    InputExample(texts=['Another sentence from class 0'], label=0),
    InputExample(texts=['Sentence from class 1'], label=1), 
    InputExample(texts=['Sentence from class 2'], label=2)]


In [None]:
eval_sentences1 = []
eval_sentences2 = []
eval_labels = []


# same group
for name, group in sentence_df.groupby("page"):
    sampled_sentences = group.sample(min(5, len(group)))
    for i in range(len(sampled_sentences)):
        for j in range(i, len(sampled_sentences)):
            eval_sentences1.append(sampled_sentences.iloc[i]["text"])
            eval_sentences2.append(sampled_sentences.iloc[j]["text"])
            eval_labels.append(1.0)


# different group
n_samples = len(eval_sentences1)
sampled1 = sentence_df.sample(n_samples).reset_index()["text"].values
sampled2 = sentence_df.sample(n_samples).reset_index()["text"].values
sampled_labels = (sampled1 == sampled2).astype(float)

eval_sentences1 += list(sampled1)
eval_sentences2 += list(sampled2)
eval_labels += list(sampled_labels)

print(f"{len(eval_sentences1) = }: ratio: {sum(eval_labels) / len(eval_labels)}")

In [None]:
from sentence_transformers import InputExample


train_examples = []
for _, row in sentence_df.iterrows():
    train_examples.append(InputExample(
        texts=[row["text"]],
        label=train_rv_idx[row["page"]],
    ))

In [None]:
from torch.utils.data import DataLoader


train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

In [None]:
from sentence_transformers import losses, evaluation


# train_loss = losses.TripletLoss(model=model)
# train_loss = losses.BatchHardSoftMarginTripletLoss(model=model)
train_loss = losses.BatchAllTripletLoss(model=model)
evaluator = evaluation.EmbeddingSimilarityEvaluator(
    eval_sentences1, 
    eval_sentences2, 
    eval_labels
)

In [None]:
def callback(_score, _epoch, _step):
    print(f"score={_score}, epoch={_epoch}, step={_step}")


model.fit(
    train_objectives=[(train_dataloader, train_loss)], 
    epochs=10,
    output_path="/home/clay/research/kaggle/kaggle_llm/data/data_dumps/sentence_embedder_modules/",
    evaluator=evaluator,
    evaluation_steps=len(train_dataloader),
    callback=callback,
)

In [None]:
model.max_seq_length

In [None]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        num_params = param.numel()
        # if using DS Zero 3 and the weights are initialized empty
        if num_params == 0 and hasattr(param, "ds_numel"):
            num_params = param.ds_numel

        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params
    print(
        f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param}"
    )

In [None]:
print_trainable_parameters(model)