In [1]:
%cd /tmp/working
%pwd

/tmp/working


'/tmp/working'

In [5]:
from __future__ import annotations

import ctypes
import gc
import os
import sys
import time
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path

import blingfire as bf
import faiss
import hydra
import numpy as np
import pandas as pd
from faiss import read_index, write_index
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm

libc = ctypes.CDLL("libc.so.6")
sys.path.append(os.pardir)

import utils

In [6]:
df = pd.read_csv("input/kaggle-llm-science-exam/train.csv")
print(f"{df.shape}")


# モデル読み込み
model = SentenceTransformer("all-MiniLM-L6-v2", device="cuda")
model.max_seq_length = 384
model = model.half()

(200, 8)


In [7]:
def relevant_title_retrieval(
    df: pd.DataFrame,
    index_path: str,
    model: SentenceTransformer,
    top_k: int = 3,
    batch_size: int = 32,
) -> pd.DataFrame:
    print("read_index")
    sentence_index = read_index(index_path)  # index 読み込み
    prompt_embeddings = model.encode(
        df.prompt.values,
        batch_size=batch_size,
        device="cuda",
        show_progress_bar=True,
        convert_to_tensor=True,
        normalize_embeddings=True,
    )
    prompt_embeddings = prompt_embeddings.detach().cpu().numpy()
    search_score, search_index = sentence_index.search(prompt_embeddings, top_k)
    del sentence_index
    del prompt_embeddings
    _ = gc.collect()
    libc.malloc_trim(0)
    return search_score, search_index


# title 検索
print("【title 検索】")
search_score, search_index = relevant_title_retrieval(
    df,
    "input/wikipedia-2023-07-faiss-index/wikipedia_202307.index",
    model,
    top_k=3,
    batch_size=16,
)

【title 検索】
read_index


Batches:   0%|          | 0/13 [00:00<?, ?it/s]

In [8]:
def get_wikipedia_file_data(
    search_score: np.ndarray,
    search_index: np.ndarray,
    wiki_index_path: str,
) -> pd.DataFrame:
    wiki_index_df = pd.read_parquet(wiki_index_path, columns=["id", "file"])
    wikipedia_file_data = []
    for i, (scr, idx) in tqdm(enumerate(zip(search_score, search_index)), total=len(search_score)):
        _df = wiki_index_df.loc[idx].copy()
        _df["prompt_id"] = i
        wikipedia_file_data.append(_df)
    wikipedia_file_data = pd.concat(wikipedia_file_data).reset_index(drop=True)
    wikipedia_file_data = (
        wikipedia_file_data[["id", "prompt_id", "file"]]
        .drop_duplicates()
        .sort_values(["file", "id"])
        .reset_index(drop=True)
    )
    ## Save memory - delete df since it is no longer necessary
    del wiki_index_df
    _ = gc.collect()
    libc.malloc_trim(0)
    return wikipedia_file_data


print("【wikipedia file data 取得】")
wikipedia_file_data = get_wikipedia_file_data(
    search_score,
    search_index,
    "input/wikipedia-20230701/wiki_2023_index.parquet",
)
print(wikipedia_file_data.head())

【wikipedia file data 取得】


  0%|          | 0/200 [00:00<?, ?it/s]

         id  prompt_id       file
0      1141        151  a.parquet
1  11963992        185  a.parquet
2      1200         63  a.parquet
3      1234        130  a.parquet
4      1317         89  a.parquet


In [None]:
def get_full_text_data(
    wikipedia_file_data: pd.DataFrame,
    wiki_dir: str,
):
    ## Get the full text data
    wiki_text_data = []
    for file in tqdm(wikipedia_file_data.file.unique(), total=len(wikipedia_file_data.file.unique())):
        _id = [str(i) for i in wikipedia_file_data[wikipedia_file_data["file"] == file]["id"].tolist()]
        _df = pd.read_parquet(f"{wiki_dir}/{file}", columns=["id", "text"])
        _df_temp = _df[_df["id"].isin(_id)].copy()
        del _df
        _ = gc.collect()
        libc.malloc_trim(0)
        wiki_text_data.append(_df_temp)
    wiki_text_data = pd.concat(wiki_text_data).drop_duplicates().reset_index(drop=True)
    _ = gc.collect()
    libc.malloc_trim(0)
    return wiki_text_data


# wikipedia text data 取得 ("id", "text")
print("【wikipedia text data 取得】")
wiki_text_data = get_full_text_data(
    wikipedia_file_data,
    "input/wikipedia-20230701",
)
print(wiki_text_data.head())

【wikipedia text data 取得】


  0%|          | 0/28 [00:00<?, ?it/s]

In [44]:
def sentencize(
    documents: Iterable[str],
    document_ids: Iterable,
    offsets: Iterable[tuple[int, int]],
    filter_len: int = 3,
    disable_progress_bar: bool = False,
) -> pd.DataFrame:
    """
    Split a document into sentences. Can be used with `sectionize_documents`
    to further split documents into more manageable pieces. Takes in offsets
    to ensure that after splitting, the sentences can be matched to the
    location in the original documents.

    :param documents: Iterable containing documents which are strings
    :param document_ids: Iterable containing document unique identifiers
    :param offsets: Iterable tuple of the start and end indices
    :param filter_len: Minimum character length of a sentence (otherwise filter out)
    :return: Pandas DataFrame containing the columns `document_id`, `text`, `section`, `offset`
    """

    document_sentences = []
    for document, document_id, offset in tqdm(
        zip(documents, document_ids, offsets), total=len(documents), disable=disable_progress_bar
    ):
        # 不要と思われる部分は削除する
        document = document.split("==See also==")[0]
        document = document.split("== See also ==")[0]
        document = document.split("==References==")[0]
        document = document.split("== References ==")[0]
        document = document.split("==Further reading==")[0]
        document = document.split("== Further reading ==")[0]
        document = document.split("==External links==")[0]
        document = document.split("== External links ==")[0]
        try:
            _, sentence_offsets = bf.text_to_sentences_and_offsets(document)
            for o in sentence_offsets:
                if o[1] - o[0] > filter_len:
                    sentence = document[o[0] : o[1]]
                    abs_offsets = (o[0] + offset[0], o[1] + offset[0])
                    row = {}
                    row["document_id"] = document_id
                    row["text"] = sentence
                    row["offset"] = abs_offsets
                    document_sentences.append(row)
        except:
            continue
    return pd.DataFrame(document_sentences)


def sectionize_documents(
    documents: Iterable[str], document_ids: Iterable, disable_progress_bar: bool = False
) -> pd.DataFrame:
    """
    Obtains the sections of the imaging reports and returns only the
    selected sections (defaults to FINDINGS, IMPRESSION, and ADDENDUM).

    :param documents: Iterable containing documents which are strings
    :param document_ids: Iterable containing document unique identifiers
    :param disable_progress_bar: Flag to disable tqdm progress bar
    :return: Pandas DataFrame containing the columns `document_id`, `text`, `offset`
    """
    processed_documents = []
    for document_id, document in tqdm(
        zip(document_ids, documents), total=len(documents), disable=disable_progress_bar
    ):
        row = {}
        text, start, end = (document, 0, len(document))
        row["document_id"] = document_id
        row["text"] = text
        row["offset"] = (start, end)

        processed_documents.append(row)

    _df = pd.DataFrame(processed_documents)
    if _df.shape[0] > 0:
        return _df.sort_values(["document_id", "offset"]).reset_index(drop=True)
    else:
        return _df


def process_documents(
    documents: Iterable[str],
    document_ids: Iterable,
    split_sentences: bool = True,
    filter_len: int = 3,
    disable_progress_bar: bool = False,
) -> pd.DataFrame:
    """
    Main helper function to process documents from the EMR.

    :param documents: Iterable containing documents which are strings
    :param document_ids: Iterable containing document unique identifiers
    :param document_type: String denoting the document type to be processed
    :param document_sections: List of sections for a given document type to process
    :param split_sentences: Flag to determine whether to further split sections into sentences
    :param filter_len: Minimum character length of a sentence (otherwise filter out)
    :param disable_progress_bar: Flag to disable tqdm progress bar
    :return: Pandas DataFrame containing the columns `document_id`, `text`, `section`, `offset`
    """

    df = sectionize_documents(documents, document_ids, disable_progress_bar)

    if split_sentences:
        df = sentencize(df.text.values, df.document_id.values, df.offset.values, filter_len, disable_progress_bar)
    return df


processed_wiki_text_data = process_documents(wiki_text_data.text.values, wiki_text_data.id.values)
processed_wiki_text_data

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

Unnamed: 0,document_id,text,offset
0,103529,* Tri- is a numerical prefix meaning three.,"(0, 43)"
1,103529,Tri or TRI may also refer to: ==Places== * Tri...,"(44, 946)"
2,12571,Galaxy groups and clusters are the largest kno...,"(0, 148)"
3,12571,They form the densest part of the large-scale ...,"(149, 221)"
4,12571,In models for the gravitational formation of s...,"(222, 405)"
...,...,...,...
697,8603,When looking at a cross section of a beam of l...,"(32636, 32776)"
698,8603,"In the case of Young's double-slit experiment,...","(32777, 33018)"
699,8603,"In the case of particles like electrons, neutr...","(33019, 33388)"
700,8603,These femtosecond-duration pulses will allow f...,"(33389, 33498)"


In [47]:
processed_wiki_text_data.text

0            * Tri- is a numerical prefix meaning three.
1      Tri or TRI may also refer to: ==Places== * Tri...
2      Galaxy groups and clusters are the largest kno...
3      They form the densest part of the large-scale ...
4      In models for the gravitational formation of s...
                             ...                        
697    When looking at a cross section of a beam of l...
698    In the case of Young's double-slit experiment,...
699    In the case of particles like electrons, neutr...
700    These femtosecond-duration pulses will allow f...
701    Due to these short pulses, radiation damage ca...
Name: text, Length: 702, dtype: object

In [49]:
## Get embeddings of the wiki text data
print("【Get embeddings of the wiki text data】")
wiki_data_embeddings = model.encode(
    processed_wiki_text_data.text,
    batch_size=16,
    device="cuda",
    show_progress_bar=True,
    convert_to_tensor=True,
    normalize_embeddings=True,
)  # .half()
wiki_data_embeddings = wiki_data_embeddings.detach().cpu().numpy()
_ = gc.collect()

【Get embeddings of the wiki text data】


Batches:   0%|          | 0/44 [00:00<?, ?it/s]

In [52]:
## Combine all answers
print("【Combine all answers】")
df["answer_all"] = df.apply(lambda x: " ".join([x["A"], x["B"], x["C"], x["D"], x["E"]]), axis=1)
df["prompt_answer_stem"] = df["prompt"] + " " + df["answer_all"]

option_embeddings = []
for letter in ["prompt_answer_stem", "A", "B", "C", "D", "E"]:
    embeddings = model.encode(
        df[letter].values,
        batch_size=16,
        device="cuda",
        show_progress_bar=True,
        convert_to_tensor=True,
        normalize_embeddings=True,
    )
    embeddings = embeddings.detach().cpu().numpy()
    option_embeddings.append(embeddings)
df.drop(["answer_all", "prompt_answer_stem"], axis=1, inplace=True)

【Combine all answers】


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [133]:
def extract_contexts_from_matching_pairs(
    df: pd.DataFrame,
    processed_wiki_text_data: pd.DataFrame,
    wikipedia_file_data: pd.DataFrame,
    wiki_data_embeddings: np.ndarray,
    option_embeddings: list[np.ndarray],
    num_sentences_include: int = 5,
):
    contexts = []
    for r in tqdm(df.itertuples(), total=len(df)):
        prompt_id = r.Index
        prompt_indices = processed_wiki_text_data[
            processed_wiki_text_data["document_id"].isin(
                wikipedia_file_data[wikipedia_file_data["prompt_id"] == prompt_id]["id"].values
            )
        ].index.values
        """
        if prompt_indices.shape[0] > 0:
            prompt_index = faiss.index_factory(wiki_data_embeddings.shape[1], "Flat")
            prompt_index.add(wiki_data_embeddings[prompt_indices])
            # option embeddings のそれぞれで検索した結果を結合し、全体で上位5つを取得する
            search_results_dict = {}
            for embeddings in option_embeddings:
                ss, ii = prompt_index.search(embeddings, num_sentences_include)
                for _s, _i in zip(ss[prompt_id], ii[prompt_id]):
                    if _i in search_results_dict:
                        search_results_dict[_i] = min(_s, search_results_dict[_i])
                    else:
                        search_results_dict[_i] = _s
            search_results = sorted(search_results_dict.items(), key=lambda x: x[1])[:num_sentences_include]
            context = ""
            for _i, _s in search_results:
                context += processed_wiki_text_data.loc[prompt_indices]["text"].iloc[_i] + " "
        contexts.append(context)
        """
        if prompt_indices.shape[0] > 0:
            prompt_index = faiss.index_factory(wiki_data_embeddings.shape[1], "Flat")
            prompt_index.add(wiki_data_embeddings[prompt_indices])
            # option embeddings のそれぞれで検索した結果を結合し、全体で上位5つを取得する
            search_results_dict = {}
            for embeddings in option_embeddings:
                ss, ii = prompt_index.search(embeddings[np.newaxis, prompt_id], num_sentences_include)
                for _s, _i in zip(ss[0], ii[0]):
                    if _i in search_results_dict:
                        search_results_dict[_i] = min(_s, search_results_dict[_i])
                    else:
                        search_results_dict[_i] = _s
            search_results = sorted(search_results_dict.items(), key=lambda x: x[1])[:num_sentences_include]
            context = ""
            for _i, _s in search_results:
                context += processed_wiki_text_data.loc[prompt_indices]["text"].iloc[_i] + " "
        contexts.append(context)
    return contexts


## Extract contexts from matching pairs
print("【Extract contexts from matching pairs】")
contexts = extract_contexts_from_matching_pairs(
    df,
    processed_wiki_text_data,
    wikipedia_file_data,
    wiki_data_embeddings,
    option_embeddings,
    num_sentences_include=10,
)
df["context"] = contexts

【Extract contexts from matching pairs】


  0%|          | 0/5 [00:00<?, ?it/s]

In [128]:
option_embeddings[0][:, 1, np.newaxis].shape

(5, 1)

In [129]:
def print_df(df, index):
    for col, row in zip(df.columns, df.iloc[index]):
        print(f"【{col}】:", row)


old_df = pd.read_csv("preprocessed/000_base/000/train.csv")

In [132]:
index = 2
print_df(df, index)
print()
print_df(old_df, index)

【id】: 2
【prompt】: Which of the following statements accurately describes the origin and significance of the triskeles symbol?
【A】: The triskeles symbol was reconstructed as a feminine divine triad by the rulers of Syracuse, and later adopted as an emblem. Its usage may also be related to the Greek name of Sicily, Trinacria, which means "having three headlands." The head of Medusa at the center of the Sicilian triskeles represents the three headlands.
【B】: The triskeles symbol is a representation of three interlinked spirals, which was adopted as an emblem by the rulers of Syracuse. Its usage in modern flags of Sicily has its origins in the ancient Greek name for the island, Trinacria, which means "Sicily with three corners." The head of Medusa at the center is a representation of the island's rich cultural heritage.
【C】: The triskeles symbol is a representation of a triple goddess, reconstructed by the rulers of Syracuse, who adopted it as an emblem. Its significance lies in the fact t