In [None]:
from pathlib import Path

import joblib
import rootutils
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_community.graphs import Neo4jGraph
from langchain_core.documents import Document
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_google_vertexai import VertexAI
from loguru import logger

rootutils.setup_root(".", cwd=True, dotenv=True)

In [None]:
WIKI_QUERY = {
    "query": "光格子時計",
    "lang": "ja",
}
CHUNK_SIZE = 3072
CHUNK_OVERLAP = 24
MAX_DOCUMENTS = 1

LLM = VertexAI(model_name="gemini-1.5-flash-001", temperature=0, max_output_tokens=8192)

In [None]:
def download_wikipedia_documents(
    query: str,
    lang: str,
    cache_dir: str | Path = "data/cache/wikipedia",
    **kwargs,
) -> list[Document]:
    """Downloads Wikipedia documents based on the given query and language"""
    cache_dir = Path(cache_dir)
    cache_dir.mkdir(parents=True, exist_ok=True)

    cache_filepath = cache_dir / f"{query}_{lang}.pkl"
    if cache_filepath.exists():
        logger.info(f"Loading cached Wikipedia documents from {cache_filepath}")
        return joblib.load(cache_filepath)

    logger.info(f"Downloading Wikipedia documents for query '{query}' in language '{lang}'")
    loader = WikipediaLoader(query=query, lang=lang, **kwargs)
    documents = loader.load()
    joblib.dump(documents, cache_filepath)
    return documents

In [None]:
raw_documents = download_wikipedia_documents(**WIKI_QUERY)

# metadata title で query を一番上に
raw_documents = sorted(raw_documents, key=lambda x: x.metadata["title"] != WIKI_QUERY["query"])

text_splitter = TokenTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
)  # どのモデルでも gpt-2 の tokenizer を使用する

documents = text_splitter.split_documents(raw_documents[:MAX_DOCUMENTS])

In [None]:
llm_transformer = LLMGraphTransformer(
    llm=LLM,
    allowed_nodes=[
        "人物",
        "組織",
        "場所",
        "物体",
        "出来事",
        "日付",
        "お金",
        "数量",
        "使用",
        "発明",
        "理論",
        "方法",
        "材料",
        "概念",
        "行動",
        "状態",
        "数値",
        "時間",
    ],
    allowed_relationships=[
        "測定する",
        "使用する",
        "検出する",
        "必要とする",
        "含む",
        "生産する",
        "維持する",
        "校正する",
        "接続する",
        "サポートする",
        "改善する",
        "安定化する",
        "同期する",
        "生成する",
        "転送する",
        "操作する",
        "調整する",
        "定義する",
        "検証する",
        "強化する",
        "規制する",
        "制御する",
        "分析する",
        "最適化する",
        "通信する",
        "内包する",
        "構成する",
        "発明する",
    ],
)
graph_documents = llm_transformer.convert_to_graph_documents(documents)

In [None]:
# Store to neo4j
graph = Neo4jGraph()
graph.add_graph_documents(graph_documents, baseEntityLabel=True, include_source=True)

In [None]:
# Delete all nodes
# graph.query("MATCH (n) DETACH DELETE n")