In [1]:
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.

In [2]:
import os

import pandas as pd
import tiktoken

from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
from graphrag.query.indexer_adapters import (
    read_indexer_covariates,
    read_indexer_entities,
    read_indexer_relationships,
    read_indexer_reports,
    read_indexer_text_units,
)
from graphrag.query.input.loaders.dfs import (
    store_entity_semantic_embeddings,
)
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.question_gen.local_gen import LocalQuestionGen
from graphrag.query.structured_search.local_search.mixed_context import (
    LocalSearchMixedContext,
)
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore


# print(graphrag.query.structured_search.local_search.search.__file__)

## Local Search Example

Local search method generates answers by combining relevant data from the AI-extracted knowledge-graph with text chunks of the raw documents. This method is suitable for questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?).

### Load text units and graph data tables as context for local search

- In this test we first load indexing outputs from parquet files to dataframes, then convert these dataframes into collections of data objects aligning with the knowledge model.

### Load tables to dataframes

In [3]:
INPUT_DIR = ".."
LANCEDB_URI = f"{INPUT_DIR}/lancedb"

COMMUNITY_REPORT_TABLE = "create_final_community_reports"
ENTITY_TABLE = "create_final_nodes"
ENTITY_EMBEDDING_TABLE = "create_final_entities"
RELATIONSHIP_TABLE = "create_final_relationships"
COVARIATE_TABLE = "create_final_covariates"
TEXT_UNIT_TABLE = "create_final_text_units"
COMMUNITY_LEVEL = 2

#### Read entities

In [35]:
# read nodes table to get community and degree data
file_name = '20240724-135713 graphrag'
# file_name = '20240730-161020场景'
file_name = '20240717-143037_stat'
file_name = '20240801-165406_pt1'
file_name = '20240806-175239_multi_hop'


INPUT_DIR = '../output/' + file_name + '/artifacts'
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet")

entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)

# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
    collection_name="entity_description_embeddings",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
entity_description_embeddings = store_entity_semantic_embeddings(
    entities=entities, vectorstore=description_embedding_store
)

print(f"Entity count: {len(entity_df)}")
print(f"Entity embedding count: {len(entity_embedding_df)}")

print(len(entities))

Entity count: 63485
Entity embedding count: 12697
12697


#### Read relationships

In [7]:
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)

print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()

Relationship count: 14310


Unnamed: 0,source,target,weight,description,text_unit_ids,id,human_readable_id,source_degree,target_degree,rank
0,AMAZON,CYBER MONDAY,2.0,"Amazon, a leading global e-commerce platform, ...","[4d47a8c9f4375c9854bf8393b15a2b02, 5d47fbbf131...",83f75faf9ca34d978e0ca01f9675b40c,0,248,4,252
1,AMAZON,BLACK FRIDAY,5.0,"Amazon, a leading global e-commerce platform, ...","[4d47a8c9f4375c9854bf8393b15a2b02, 525df26e51c...",bd1a205e78434f10876a749ac6781f5e,1,248,38,286
2,AMAZON,ECHO SHOW,1.0,Amazon is selling the Echo Show at a discounte...,[5d47fbbf131614ef0068c8ec70807795],dba23b0bdac54aa08662853af68664fd,2,248,3,251
3,AMAZON,FEDERAL TRADE COMMISSION,1.0,Amazon is the subject of an antitrust lawsuit ...,[96f6cdb0a35bdc5aecc23eda5b02764b],80891ad22cc9488b97737d55da1599c0,3,248,1,249
4,AMAZON,17 STATE ATTORNEYS GENERAL,1.0,Amazon is the subject of an antitrust lawsuit ...,[96f6cdb0a35bdc5aecc23eda5b02764b],e6bf4e933bd94c6a9645e48c62f8d3ca,4,248,1,249


In [8]:
# covariate_df = pd.read_parquet(f"{INPUT_DIR}/{COVARIATE_TABLE}.parquet")

# claims = read_indexer_covariates(covariate_df)

# print(f"Claim records: {len(claims)}")
# covariates = {"claims": claims}

#### Read community reports

In [9]:
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)

print(f"Report records: {len(report_df)}")
report_df.head()

Report records: 23


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  entity_df["community"] = entity_df["community"].fillna(-1)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  entity_df["community"] = entity_df["community"].astype(int)


Unnamed: 0,community,full_content,level,rank,title,rank_explanation,summary,findings,full_content_json,id
0,2128,# FTX: The Downfall of a Cryptocurrency Giant\...,4,8.5,FTX: The Downfall of a Cryptocurrency Giant,The impact severity rating is high due to the ...,"The community revolves around FTX, a once-prom...",[{'explanation': 'FTX misused customer funds f...,"{\n ""title"": ""FTX: The Downfall of a Crypto...",dab98538-8321-4c27-bb7b-b621ba0b3222
1,2129,"# Gary Wang, FTX, and the US Government\n\nThe...",4,8.5,"Gary Wang, FTX, and the US Government",The impact severity rating is high due to the ...,"The community is centered around Gary Wang, a ...","[{'explanation': 'Gary Wang, a co-founder of F...","{\n ""title"": ""Gary Wang, FTX, and the US Go...",048b3b83-69cb-4d16-b907-1e60d91b2d03
2,2130,# San Francisco 49ers: A Dominant Force in the...,4,8.5,San Francisco 49ers: A Dominant Force in the NFL,"The San Francisco 49ers' dominance in the NFL,...","The San Francisco 49ers, a professional Americ...",[{'explanation': 'The San Francisco 49ers are ...,"{\n ""title"": ""San Francisco 49ers: A Domina...",823911f9-54c6-4566-8466-caf7aa3fa5cc
3,2131,# Deebo Samuel and Elite NFL Wide Receivers\n\...,4,8.5,Deebo Samuel and Elite NFL Wide Receivers,The impact severity rating is high due to the ...,"The community is centered around Deebo Samuel,...","[{'explanation': 'Deebo Samuel, a wide receive...","{\n ""title"": ""Deebo Samuel and Elite NFL Wi...",67588ab3-2d31-45e8-86c7-a42077bbe5d3
4,2132,# Kansas City Chiefs: A Force in the NFL with ...,4,8.5,Kansas City Chiefs: A Force in the NFL with Ke...,The impact severity rating is high due to the ...,"The Kansas City Chiefs, a professional America...",[{'explanation': 'Taylor Swift's attendance at...,"{\n ""title"": ""Kansas City Chiefs: A Force i...",968d5b29-4376-4c1e-b04b-1bc83741ddaa


#### Read text units

In [10]:
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)

print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()

Text unit records: 2907


Unnamed: 0,id,text,n_tokens,document_ids,entity_ids,relationship_ids
0,5d47fbbf131614ef0068c8ec70807795,Title: 200+ of the best deals from Amazon's Cy...,600,[533008495e988874ea6a375c70494fc0],"[b45241d70f0e43fca764df95b2b81f77, 4119fd06010...","[83f75faf9ca34d978e0ca01f9675b40c, bd1a205e784..."
1,67df2caaa20f7bfb3e8e218e6053f767,ing better speeds and solid power for most eve...,600,[533008495e988874ea6a375c70494fc0],"[254770028d7a4fa9877da4ba0ad5ad21, 4a67211867e...","[920a801038aa44c3bd1299ffa4d96573, bb69abe9c81..."
2,79444a7de5e6387527c61893fbef5837,"Bank’s superannuation business, in its first ...",600,[533008495e988874ea6a375c70494fc0],"[deece7e64b2a4628850d4bb6e394a9c3, 3d0dcbc8971...","[a5249a16b66c4d49aaea10055d50fb7c, f050e94398e..."
3,96f6cdb0a35bdc5aecc23eda5b02764b,4 per cent after the Federal Trade Commission...,600,[533008495e988874ea6a375c70494fc0],"[b45241d70f0e43fca764df95b2b81f77, 17ed1d92075...","[80891ad22cc9488b97737d55da1599c0, e6bf4e933bd..."
4,c25dd421152695caa5cdabec58ee99b1,"retailers outside Amazon,” according to the e...",600,[533008495e988874ea6a375c70494fc0],"[b45241d70f0e43fca764df95b2b81f77, 6fae5ee1a83...","[5d6f3ea76c794a2caf3bc30abb069440, 28180879705..."


In [11]:
print(len(entities))
filtered_entities = list(filter(lambda x: x.description != '', entities))
print(len(filtered_entities))

12697
12271


### Convert .parquet file to .csv

In [12]:
def convert_parquet_to_csv(source_dir, target_dir):
    # 确保目标目录存在
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    # 遍历源目录中的所有文件
    for file_name in os.listdir(source_dir):
        if file_name.endswith('.parquet'):
            # 构建完整的文件路径
            file_path = os.path.join(source_dir, file_name)
            # 读取 Parquet 文件
            df = pd.read_parquet(file_path)
            
            # 构建目标文件路径
            target_file_path = os.path.join(target_dir, file_name.replace('.parquet', '.csv'))
            # 保存为 CSV 文件
            df.to_csv(target_file_path, index=False)
            print(f'Converted {file_path} to {target_file_path}')


source_directory = INPUT_DIR 
target_directory = f'{INPUT_DIR}/convert_csv'
# convert_parquet_to_csv(source_directory, target_directory)

In [13]:
api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = 'qwen2-instruct'
embedding_model ='bge-m3'
model_base_url = 'http://10.4.32.1:9997/v1'
embed_base_url = 'http://10.4.32.2:9997/v1'

llm = ChatOpenAI(
    api_key=api_key,
    model=llm_model,
    api_type=OpenaiApiType.OpenAI,  # OpenaiApiType.OpenAI or OpenaiApiType.AzureOpenAI
    max_retries=20,
    api_base=model_base_url
)

token_encoder = tiktoken.get_encoding("cl100k_base")

text_embedder = OpenAIEmbedding(
    api_key=api_key,
    api_base=embed_base_url,
    api_type=OpenaiApiType.OpenAI,
    model=embedding_model,
    # deployment_name=embedding_model,
    max_retries=20,
)

In [14]:
llm.generate(messages=[{"role": "user", "content": 'hello hihihi'}])

'Hello! How can I assist you today?'

### Create local search context builder

In [16]:
context_builder = LocalSearchMixedContext(
    community_reports=reports,
    text_units=text_units,
    entities=filtered_entities,
    relationships=relationships,
    # covariates=covariates,
    entity_text_embeddings=description_embedding_store,
    embedding_vectorstore_key=EntityVectorStoreKey.ID,  # if the vectorstore uses entity title as ids, set this to EntityVectorStoreKey.TITLE
    text_embedder=text_embedder,
    token_encoder=token_encoder,
)

### Create local search engine

In [33]:
# text_unit_prop: proportion of context window dedicated to related text units
# community_prop: proportion of context window dedicated to community reports.
# The remaining proportion is dedicated to entities and relationships. Sum of text_unit_prop and community_prop should be <= 1
# conversation_history_max_turns: maximum number of turns to include in the conversation history.
# conversation_history_user_turns_only: if True, only include user queries in the conversation history.
# top_k_mapped_entities: number of related entities to retrieve from the entity description embedding store.
# top_k_relationships: control the number of out-of-network relationships to pull into the context window.
# include_entity_rank: if True, include the entity rank in the entity table in the context window. Default entity rank = node degree.
# include_relationship_weight: if True, include the relationship weight in the context window.
# include_community_rank: if True, include the community rank in the context window.
# return_candidate_context: if True, return a set of dataframes containing all candidate entity/relationship/covariate records that
# could be relevant. Note that not all of these records will be included in the context window. The "in_context" column in these
# dataframes indicates whether the record is included in the context window.
# max_tokens: maximum number of tokens to use for the context window.


local_context_params = {
    "text_unit_prop": 0.5,
    "community_prop": 0,
    "conversation_history_max_turns": 5,
    "conversation_history_user_turns_only": True,
    "top_k_mapped_entities": 30,
    "top_k_relationships": 30,
    "include_entity_rank": True,
    "include_relationship_weight": True,
    "include_community_rank": False,
    "return_candidate_context": False,
    "embedding_vectorstore_key": EntityVectorStoreKey.ID,  # set this to EntityVectorStoreKey.TITLE if the vectorstore uses entity title as ids
    "max_tokens": 18_000,  # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 5000)
}

llm_params = {
    "max_tokens": 2_000,  # change this based on the token limit you have on your model (if you are using a model with 8k limit, a good setting could be 1000=1500)
    "temperature": 0.0,
}

search_engine = LocalSearch(
    llm=llm,
    context_builder=context_builder,
    token_encoder=token_encoder,
    llm_params=llm_params,
    context_builder_params=local_context_params,
    # response_type="multiple paragraphs",  # free form text describing the response type and format, can be anything, e.g. prioritized list, single paragraph, multiple paragraphs, multiple-page report
)

### Batch run on 50 questions

In [34]:
from utils import read_questions, create_empty_csv_file
from graph_rag_test import save_answer_data
import time

topic = 'multi_hop'
QUESTION_FILE = '../test/questions/multi_hop_questions.txt'

result_file = f'../test_results/{topic}/GraphRAG_results.csv'
columns = ['Question', 'Graph RAG answer']

questions = read_questions(QUESTION_FILE)
create_empty_csv_file(result_file, columns)


start_time = time.time()

for i, question in enumerate(questions):
    result = await search_engine.asearch(question)
    answer = result.response
    print(f'{i+1}/{len(questions)} ### A: {answer}')
    save_answer_data(question, answer, result_file)

end_time = time.time()
duration = end_time - start_time
avg_time = duration / len(questions)
print(f'total time: {duration}s')
print(f'avg time: {avg_time}s')



1/100 ### A: SAM BANKMAN-FRIED
2/100 ### A: DONALD TRUMP
3/100 ### A: SAM ALTMAN
4/100 ### A: Insufficient information
5/100 ### A: BetMGM Sportsbook
6/100 ### A: SAM BANKMAN-FRIED
7/100 ### A: Yes
8/100 ### A: Insufficient information
9/100 ### A: OpenAI
10/100 ### A: GOOGLE
11/100 ### A: Insufficient information
12/100 ### A: Insufficient information
13/100 ### A: Insufficient information
14/100 ### A: S
15/100 ### A: Yes
16/100 ### A: EU
17/100 ### A: Insufficient information
18/100 ### A: Insufficient information
19/100 ### A: Insufficient information
20/100 ### A: Insufficient information
21/100 ### A: Insufficient information
22/100 ### A: Insufficient information
23/100 ### A: Valve
24/100 ### A: Insufficient information
25/100 ### A: Insufficient information
26/100 ### A: Insufficient information
27/100 ### A: Insufficient information
28/100 ### A: Yes
29/100 ### A: Insufficient information
30/100 ### A: Insufficient information
31/100 ### A: Google
32/100 ### A: Yes

The Verge

### run on one sample questions

In [27]:
question = 'Is apple a company or type of food?'
print("Q: ", question)
result = await search_engine.asearch(question)
print(result.response)

Q:  Is apple a company or type of food?
company


#### Inspecting the context data used to generate the response

In [36]:
print(f'number of records: {len(result.context_data["entities"])}')
result.context_data["entities"][:5]

number of records: 57


Unnamed: 0,id,entity,description,number of relationships,in_context
0,1694,SPORTSBOOKS,Sportsbooks are establishments and businesses ...,8,True
1,131,SPORTING NEWS,Sporting News is a comprehensive media organiz...,31,True
2,1955,FANATICS SPORTSBOOK,"Fanatics Sportsbook, a betting company that op...",5,True
3,1954,KENTUCKY SPORTS BETTING,"Kentucky Sports Betting, a legal activity in t...",10,True
4,1753,NEW INFORMATION,"New information, such as player injuries, subs...",1,True


In [37]:
print(f'number of records: {len(result.context_data["relationships"])}')
result.context_data["relationships"].head(50)

KeyError: 'relationships'

In [30]:
result.context_data["reports"].head()

KeyError: 'reports'

In [32]:
result.context_data["sources"].head(50)

Unnamed: 0,id,text
0,1512,a 10-core GPU that costs extra on the 13-inch...
1,1510,3-10-26T18:17:19+00:00\nCategory: technology\n...
2,1100,and we’ve repeatedly reported that you should...
3,1270,that will be the case. Gurman pointed out th...
4,1269,is holding this event in the evening. It star...
5,1098,striking the same place twice. “It is extrao...


In [None]:
if "claims" in result.context_data:
    print(result.context_data["claims"].head())

### Question Generation

This function takes a list of user queries and generates the next candidate questions.

In [None]:
question_generator = LocalQuestionGen(
    llm=llm,
    context_builder=context_builder,
    token_encoder=token_encoder,
    llm_params=llm_params,
    context_builder_params=local_context_params,
)

In [None]:
question_history = [
    "固晶机有几种运动控制？",
    "有哪些运动场景需要用到音圈电机？",
]
candidate_questions = await question_generator.agenerate(
    question_history=question_history, context_data=None, question_count=50
)
print(*candidate_questions.response, sep='\n')

- 在哪些设备中，直线编码器与直流伺服驱动器共同用于实现高精度的运动控制？
- 手机相机模组组装设备中，哪些组件负责实现多轴运动控制？
- 直流伺服在手机相机模组组装设备中如何控制装配头的运动？
- 控制器在手机相机模组组装设备中扮演什么角色？
- XYZ轴机构在手机相机模组组装设备中如何实现多轴运动控制？
- 模组组装控制在手机相机模组组装设备中涉及哪些关键过程？
- 直线编码器在哪些设备中用于测量直线电机的位移和速度？
- 丝杆传动在点胶机中如何实现Z轴垂直运动？
- 直流伺服驱动器在点胶机中如何控制直线电机和旋转伺服电机？
- 旋转编码器在点胶机中如何检测旋转伺服电机的旋转角度和速度？
- 直线电机在点胶机中如何实现X、Y轴平面运动？
- 旋转伺服电机在点胶机中如何实现Z轴垂直运动？
- 读数头在编码器中起到什么作用？
- 装配头在手机相机模组组装设备中如何精确装配镜头和传感器？
- 直流伺服在哪些设备中用于控制多轴运动？
- 控制器在哪些设备中用于控制整体运行？
- XYZ轴机构在哪些设备中用于实现多轴运动控制？
- 模组组装控制在哪些设备中是关键过程？
- 直线编码器在哪些设备中用于测量位移和速度？
- 丝杆传动在哪些设备中用于实现Z轴垂直运动？
- 直流伺服驱动器在哪些设备中用于控制电机运行？
- 旋转编码器在哪些设备中用于检测旋转角度和速度？
- 直线电机在哪些设备中用于实现平面运动？
- 旋转伺服电机在哪些设备中用于实现垂直运动？
- 读数头在哪些设备中是编码器的组成部分？
- 装配头在哪些设备中用于精确装配组件？
- 直流伺服在哪些设备中用于控制装配头的多轴运动？
- 控制器在哪些设备中用于控制装配头的运行？
- XYZ轴机构在哪些设备中用于实现装配头的多轴运动？
- 模组组装控制在哪些设备中用于精确装配组件？
- 直线编码器在哪些设备中用于测量直线电机的位移和速度？
- 丝杆传动在哪些设备中用于实现Z轴垂直运动？
- 直流伺服驱动器在哪些设备中用于控制点胶机的电机运行？
- 旋转编码器在哪些设备中用于检测点胶机中旋转伺服电机的旋转角度和速度？
- 直线电机在哪些设备中用于实现点胶机的X、Y轴平面运动？
- 旋转伺服电机在哪些设备中用于实现点胶机的Z轴垂直运动？
- 读数头在哪些设备中是编码器的组成部分？
- 装配头在哪些设备中用于精确装配镜头和传