# モジュール4: GraphRAG検索

このモジュールでは、TiDB Cloud Starterベクトル検索を使って既存のグラフを検索します。

まずベクトル検索で初期ノードを見つけ、グラフを1次展開します。

## 依存関係のインストール

In [None]:
%pip install -q \
    pytidb==0.0.10.dev1 \
    boto3==1.38.23 \
    litellm \
    ipyplot \
    pandas

## TiDBクライアントの作成

In [None]:
import os

from typing import Optional, Any
from pytidb import TiDBClient
from pytidb.embeddings import EmbeddingFunction

client = TiDBClient.connect(
    host=os.getenv("SERVERLESS_CLUSTER_HOST"),
    port=int(os.getenv("SERVERLESS_CLUSTER_PORT")),
    username=os.getenv("SERVERLESS_CLUSTER_USERNAME"),
    password=os.getenv("SERVERLESS_CLUSTER_PASSWORD"),
    database=os.getenv("SERVERLESS_CLUSTER_DATABASE_NAME"),
    enable_ssl=True,
    ensure_db=True,
)

embedding_model = "bedrock/amazon.titan-embed-text-v2:0"

text_embedding_function = EmbeddingFunction(
    embedding_model,
    timeout=60
)

## エンティティとリレーションシップの定義

In [None]:
from pytidb.schema import TableModel, Field
from sqlalchemy import TEXT, Column

class Entities(TableModel):
    __tablename__ = "entities"
    __table_args__ = {"extend_existing": True}
    id: int | None = Field(default=None, primary_key=True)
    name: str = Field(max_length=512)
    description: str = Field(sa_column=Column(TEXT, nullable=False))
    description_vec: Optional[Any] = text_embedding_function.VectorField(
        source_field="description",
    )

class Relationships(TableModel):
    __tablename__ = "relationships"
    __table_args__ = {"extend_existing": True}
    id: int | None = Field(default=None, primary_key=True)
    source_entity_id: int
    target_entity_id: int
    relationship_desc: str = Field(sa_column=Column(TEXT, nullable=False))

entities_table = client.create_table(schema=Entities, if_exists="overwrite")
relationships_table = client.create_table(schema=Relationships, if_exists="overwrite")

## サンプルデータのインポート

In [None]:
import requests

entities_url = "https://gist.github.com/Icemap/6fa6a9088a3c9d2fd9990e2748e39a8a/raw/c42c723a9769dacbd6ac8e8326f0c8f199dd3c59/entities.json"
relationships_url = "https://gist.github.com/Icemap/7354ab8bb6b3ac08bc438f19cfc77a87/raw/65ae9e9770de6fba57d3615abbdcd79effb1545b/relationships.json"

entities = [
    Entities(
        id=item.get('id'),
        name=item.get('name'),
        description=item.get('description')
    ) for item in requests.get(entities_url).json()
]

relationships = [
    Relationships(
        id=item.get('id'),
        source_entity_id=item.get('source_entity_id'),
        target_entity_id=item.get('target_entity_id'),
        relationship_desc=item.get('relationship_desc')
    ) for item in requests.get(relationships_url).json()
]

inserted_entities = entities_table.bulk_insert(entities)
inserted_relationships = relationships_table.bulk_insert(relationships)

f"Inserted {len(inserted_entities)} entities and {len(inserted_relationships)} relationships"

## データの検索

TiDB Cloud Starterにデータをインポートした後、「Elon Muskとは誰ですか？」などの質問で検索できます。

In [None]:
from sqlalchemy import text

question = input("Enter your question:")
embedding = str(text_embedding_function.get_query_embedding(question))

query_sql = """
WITH initial_entity AS (
    SELECT id FROM `entities`
    ORDER BY VEC_Cosine_Distance(description_vec, :embedding) LIMIT 1
), entities_ids AS (
    SELECT source_entity_id i FROM relationships r INNER JOIN initial_entity i ON r.target_entity_id = i.id
    UNION SELECT target_entity_id i FROM relationships r INNER JOIN initial_entity i ON r.source_entity_id = i.id
    UNION SELECT initial_entity.id i FROM initial_entity
) SELECT * FROM `entities` WHERE id IN (SELECT i FROM entities_ids);"""

result = client.query(sql=query_sql, params={"embedding": embedding})
result.to_pandas()

In [None]:
from litellm import completion

llm_model = "bedrock/us.amazon.nova-lite-v1:0"

messages = [
    {"role": "system", "content": f"Please carefully answer the question by {str(result)}"},
    {"role": "user", "content": question}
]

llm_response = completion(
    model=llm_model,
    messages=messages,
)

print(llm_response.choices[0].message.content)