In [1]:
from pathlib import Path

import joblib
import rootutils
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_community.graphs import Neo4jGraph
from langchain_core.documents import Document
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_google_vertexai import VertexAI
from loguru import logger

rootutils.setup_root(".", cwd=True, dotenv=True)

PosixPath('/workspace')

In [10]:
WIKI_QUERY = {
    "query": "光格子時計",
    "lang": "ja",
}
CHUNK_SIZE = 1024
CHUNK_OVERLAP = 24
MAX_DOCUMENTS = 1

LLM = VertexAI(model_name="gemini-1.5-flash-001", temperature=0, max_output_tokens=8192)

In [11]:
def download_wikipedia_documents(
    query: str,
    lang: str,
    cache_dir: str | Path = "data/cache/wikipedia",
    **kwargs,
) -> list[Document]:
    """Downloads Wikipedia documents based on the given query and language"""
    cache_dir = Path(cache_dir)
    cache_dir.mkdir(parents=True, exist_ok=True)

    cache_filepath = cache_dir / f"{query}_{lang}.pkl"
    if cache_filepath.exists():
        logger.info(f"Loading cached Wikipedia documents from {cache_filepath}")
        return joblib.load(cache_filepath)

    logger.info(f"Downloading Wikipedia documents for query '{query}' in language '{lang}'")
    loader = WikipediaLoader(query=query, lang=lang, **kwargs)
    documents = loader.load()
    joblib.dump(documents, cache_filepath)
    return documents

In [12]:
raw_documents = download_wikipedia_documents(**WIKI_QUERY)
text_splitter = TokenTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=CHUNK_OVERLAP,
)  # どのモデルでも gpt-2 の tokenizer を使用する

documents = text_splitter.split_documents(raw_documents[:MAX_DOCUMENTS])

[32m2024-07-15 19:46:13.444[0m | [1mINFO    [0m | [36m__main__[0m:[36mdownload_wikipedia_documents[0m:[36m13[0m - [1mLoading cached Wikipedia documents from data/cache/wikipedia/光格子時計_ja.pkl[0m


In [13]:
llm_transformer = LLMGraphTransformer(
    llm=LLM,
    allowed_nodes=[
        "PERSON",
        "ORGANIZATION",
        "LOCATION",
        "OBJECT",
        "EVENT",
        "DATE",
        "MONEY",
        "QUANTITY",
        "USAGE",
        "INVENTION",
        "THEORY",
    ],
    allowed_relationships=None,  # TODO: ここに関係を指定する
)
graph_documents = llm_transformer.convert_to_graph_documents(documents)

In [14]:
# Store to neo4j
graph = Neo4jGraph()
graph.add_graph_documents(graph_documents, baseEntityLabel=True, include_source=True)

In [16]:
# Delete all nodes
# graph.query("MATCH (n) DETACH DELETE n")

[]