In [1]:
%cd ..
%pwd

/tmp/working


'/tmp/working'

In [13]:
from __future__ import annotations

import ctypes
import gc
import os
import re
import sys
import time
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path

import blingfire as bf
import faiss
import hydra
import numpy as np
import pandas as pd
import torch
from faiss import read_index, write_index
from hydra import compose, initialize, initialize_config_dir, initialize_config_module
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, OmegaConf
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm

libc = ctypes.CDLL("libc.so.6")

In [14]:
with initialize(version_base=None, config_path="../yamls"):
    c = compose(config_name="config.yaml", overrides=["preprocess=331/000"])

    OmegaConf.resolve(c)  # debugやseedを解決
    cfg = c.preprocess

# debugやseedを解決
cfg = c.preprocess

exp_name = f"331/000"
preprocessed_path = Path(f"./preprocessed/{exp_name}")

print(cfg)
print("preprocessed_path:", preprocessed_path)

{'data_paths': ['preprocessed/901_concat/data2.csv', 'preprocessed/901_concat/data1.csv', 'preprocessed/901_concat/data0_0.csv', 'preprocessed/901_concat/data0_10000.csv', 'preprocessed/901_concat/data0_20000.csv', 'preprocessed/901_concat/data0_30000.csv', 'preprocessed/901_concat/data0_40000.csv', 'preprocessed/901_concat/data0_50000.csv', 'preprocessed/901_concat/data0_60000.csv', 'preprocessed/901_concat/data0_70000.csv', 'preprocessed/901_concat/data0_80000.csv', 'preprocessed/901_concat/data0_90000.csv'], 'wiki_dir': 'input/llm-science-wikipedia-data-b', 'wiki_index_path': 'preprocessed/320_doc_index/001/all.parquet', 'index_path': 'preprocessed/320_doc_index/001/ivfpq_index.faiss', 'sim_model': 'BAAI/bge-small-en', 'num_sentences_include': 20, 'max_length': 384, 'batch_size': 32, 'doc_top_k': 3, 'window_size': 5, 'sliding_size': 4, 'debug': False, 'seed': 7}
preprocessed_path: preprocessed/331/000


In [15]:
def extract_chunk_by_sliding_window(text_list: list[str], window_size: int, sliding_size: int) -> list[str]:
    """
    text のリストをsliding windowで結合する。window_size個のtextが含まれるまで結合し、sliding_size個ずつずらして結合する。
    """
    chunks = []
    for i in range(0, len(text_list), sliding_size):
        chunk = " ".join(text_list[i : i + window_size])
        chunks.append(chunk)
    return chunks


def extract_sections(text: str) -> list[tuple[str, str]]:
    pattern = re.compile(r"#{2,}\s?(.*?)\s?#{2,}")
    sections = []

    matches = list(pattern.finditer(text))
    start_idx = 0

    if len(matches) == 0:
        sections.append(("", text))
        return sections

    for i, match in enumerate(matches):
        if i == 0:
            end_idx = match.start()
            sections.append(("", text[start_idx:end_idx].strip()))

        start_idx = match.end()
        end_idx = matches[i + 1].start() if i + 1 < len(matches) else len(text)
        section = (match.group(1).strip(), text[start_idx:end_idx].strip())
        if section[0] not in ["See also", "References", "Further reading", "External links"]:
            sections.append(section)

        start_idx = end_idx

    # 空のtextの場合は飛ばす
    sections = [section for section in sections if len(section[1].split(" ")) >= 3]
    return sections


def sentencize(
    titles: Iterable[str],
    documents: Iterable[str],
    document_ids: Iterable,
    window_size: int = 3,
    sliding_size: int = 2,
    filter_len: int = 5,
    filter_len_max: int = 500,
    disable_progress_bar: bool = False,
) -> pd.DataFrame:
    """
    Split a document into sentences. Can be used with `sectionize_documents`
    to further split documents into more manageable pieces. Takes in offsets
    to ensure that after splitting, the sentences can be matched to the
    location in the original documents.

    :param documents: Iterable containing documents which are strings
    :param document_ids: Iterable containing document unique identifiers
    :param filter_len: Minimum character length of a sentence (otherwise filter out)
    :return: Pandas DataFrame containing the columns `document_id`, `text`, `section`, `offset`
    """

    document_sentences = []
    for title, document, document_id in tqdm(
        zip(titles, documents, document_ids), total=len(documents), disable=disable_progress_bar
    ):
        try:
            # chunk にまとめる
            ## 念のため改行をスペースに変換
            document = document.replace("\n", " ")
            _, sentence_offsets = bf.text_to_sentences_and_offsets(document)
            section_sentences = []
            for o in sentence_offsets:
                if filter_len < o[1] - o[0] and o[1] - o[0] < filter_len_max:
                    section_sentences.append(document[o[0] : o[1]])
            chunks = extract_chunk_by_sliding_window(section_sentences, window_size, sliding_size)

            for chunk in chunks:
                row = {}
                row["document_id"] = document_id
                row["text"] = f"{title} > {chunk}"
                row["offset"] = (0, 0)
                document_sentences.append(row)
        except:
            continue
    return pd.DataFrame(document_sentences)


def sectionize_documents(
    titles: Iterable[str],
    documents: Iterable[str],
    document_ids: Iterable,
    disable_progress_bar: bool = False,
) -> pd.DataFrame:
    """
    Obtains the sections of the imaging reports and returns only the
    selected sections (defaults to FINDINGS, IMPRESSION, and ADDENDUM).

    :param documents: Iterable containing documents which are strings
    :param document_ids: Iterable containing document unique identifiers
    :param disable_progress_bar: Flag to disable tqdm progress bar
    :return: Pandas DataFrame containing the columns `document_id`, `text`, `offset`
    """
    processed_documents = []
    for title, document_id, document in tqdm(
        zip(titles, document_ids, documents), total=len(documents), disable=disable_progress_bar
    ):
        row = {}
        text, start, end = (document, 0, len(document))
        row["document_id"] = document_id
        row["text"] = text
        row["offset"] = (start, end)

        processed_documents.append(row)

    _df = pd.DataFrame(processed_documents)
    if _df.shape[0] > 0:
        return _df.sort_values(["document_id", "offset"]).reset_index(drop=True)
    else:
        return _df


def relevant_title_retrieval(
    df: pd.DataFrame,
    index_path: str,
    model: SentenceTransformer,
    top_k: int = 3,
    batch_size: int = 32,
) -> pd.DataFrame:
    sentence_index = read_index(index_path)  # index 読み込み
    res = faiss.StandardGpuResources()  # use a single GPU
    co = faiss.GpuClonerOptions()
    co.useFloat16 = True
    sentence_index = faiss.index_cpu_to_gpu(res, 0, sentence_index, co)
    sentence_index.nprobe = 10
    prompt_embeddings = model.encode(
        df.prompt_answer_stem.values,
        batch_size=batch_size,
        device="cuda",
        show_progress_bar=True,
        # convert_to_tensor=True,
        normalize_embeddings=True,
    )
    prompt_embeddings = prompt_embeddings.astype(np.float32)
    # prompt_embeddings = prompt_embeddings.detach().cpu().numpy()
    search_score, search_index = sentence_index.search(prompt_embeddings, top_k)
    res.noTempMemory()
    del res
    del sentence_index
    del prompt_embeddings
    _ = gc.collect()
    libc.malloc_trim(0)
    return search_score, search_index


def get_wikipedia_file_data(
    search_score: np.ndarray,
    search_index: np.ndarray,
    wiki_index_path: str,
) -> pd.DataFrame:
    wiki_index_df = pd.read_parquet(wiki_index_path, columns=["id", "file"])
    wikipedia_file_data = []
    for i, (scr, idx) in tqdm(enumerate(zip(search_score, search_index)), total=len(search_score)):
        _df = wiki_index_df.loc[idx].copy()
        _df["prompt_id"] = i
        wikipedia_file_data.append(_df)
    wikipedia_file_data = pd.concat(wikipedia_file_data).reset_index(drop=True)
    wikipedia_file_data = (
        wikipedia_file_data[["id", "prompt_id", "file"]]
        .drop_duplicates()
        .sort_values(["file", "id"])
        .reset_index(drop=True)
    )
    ## Save memory - delete df since it is no longer necessary
    del wiki_index_df
    _ = gc.collect()
    libc.malloc_trim(0)
    return wikipedia_file_data


def get_full_text_data(
    wikipedia_file_data: pd.DataFrame,
    wiki_dir: str,
):
    ## Get the full text data
    wiki_text_data = []
    for file in tqdm(wikipedia_file_data.file.unique(), total=len(wikipedia_file_data.file.unique())):
        _id = [str(i) for i in wikipedia_file_data[wikipedia_file_data["file"] == file]["id"].tolist()]
        _df = pd.read_parquet(f"{wiki_dir}/{file}", columns=["id", "title", "text"])
        _df_temp = _df[_df["id"].isin(_id)].copy()
        del _df
        _ = gc.collect()
        libc.malloc_trim(0)
        wiki_text_data.append(_df_temp)
    wiki_text_data = pd.concat(wiki_text_data).drop_duplicates().reset_index(drop=True)
    _ = gc.collect()
    libc.malloc_trim(0)
    return wiki_text_data


def extract_contexts_from_matching_pairs(
    df: pd.DataFrame,
    processed_wiki_text_data: pd.DataFrame,
    wikipedia_file_data: pd.DataFrame,
    wiki_data_embeddings: np.ndarray,
    question_embeddings: np.ndarray,
    num_sentences_include: int = 5,
):
    results = {"contexts": [], "sim_min": [], "sim_max": [], "sim_mean": [], "sim_std": [], "sim_num": []}
    for r in tqdm(df.itertuples(), total=len(df)):
        prompt_id = r.Index
        prompt_indices = processed_wiki_text_data[
            processed_wiki_text_data["document_id"].isin(
                wikipedia_file_data[wikipedia_file_data["prompt_id"] == prompt_id]["id"].values
            )
        ].index.values
        assert prompt_indices.shape[0] > 0
        prompt_index = faiss.index_factory(wiki_data_embeddings.shape[1], "Flat")
        prompt_index.add(wiki_data_embeddings[prompt_indices])
        ## Get the top matches
        ss, ii = prompt_index.search(question_embeddings[np.newaxis, prompt_id], num_sentences_include)
        context = ""
        total_len = 0
        num = 0
        for _s, _i in zip(ss[0], ii[0]):
            if total_len > 1000 or _s >= 1.0:
                break
            text = processed_wiki_text_data.loc[prompt_indices]["text"].iloc[_i]
            context += text + " "
            total_len += len(text.split(" "))
            num += 1
        results["contexts"].append(context)
        results["sim_max"].append(ss[0][:num].max())
        results["sim_min"].append(ss[0][:num].min())
        results["sim_mean"].append(ss[0][:num].mean())
        results["sim_std"].append(ss[0][:num].std())
        results["sim_num"].append(num)

    return results

In [16]:
# モデル読み込み
model = SentenceTransformer(cfg.sim_model, device="cuda")
model.max_seq_length = cfg.max_length
model = model.half()

In [25]:
path = "preprocessed/901_concat/data2.csv"

# データ読み込み
df = pd.read_csv(path)
df[["A", "B", "C", "D", "E"]] = df[["A", "B", "C", "D", "E"]].fillna("")

df.reset_index(inplace=True, drop=True)
if cfg.debug:
    df = df.head(15)

df["answer_all"] = df.apply(lambda x: " ".join([x["A"], x["B"], x["C"], x["D"], x["E"]]), axis=1)
df["prompt_answer_stem"] = df["prompt"] + " " + df["answer_all"]
df["prompt_answer_stem"] = df["prompt_answer_stem"].str.replace('"', "")
df["prompt_answer_stem"] = df["prompt_answer_stem"].str.replace("“", "")
df["prompt_answer_stem"] = df["prompt_answer_stem"].str.replace("”", "")

# title 検索
print("【title 検索】")
search_score, search_index = relevant_title_retrieval(
    df,
    cfg.index_path,
    model,
    top_k=cfg.doc_top_k,
    batch_size=cfg.batch_size,
)

【title 検索】


Batches:   0%|          | 0/38 [00:00<?, ?it/s]

In [27]:
search_score[:5]

array([[0.10069704, 0.12945175, 0.1297493 ],
       [0.1351223 , 0.15178347, 0.18356371],
       [0.12990236, 0.13123703, 0.14057088],
       [0.15044713, 0.16522694, 0.17006326],
       [0.14510965, 0.1493355 , 0.16133547]], dtype=float32)