<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/ms_graphrag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install --quiet langchain-community langchain-experimental langchain-openai neo4j graphdatascience

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.5/199.5 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m203.0/203.0 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m973.5/973.5 kB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m310.2/310.2 kB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass()

··········


In [3]:
from langchain_community.graphs import Neo4jGraph

os.environ["NEO4J_URI"] = "bolt://44.202.206.163:7687"
os.environ["NEO4J_USERNAME"] = "neo4j"
os.environ["NEO4J_PASSWORD"] = "capture-debit-blanket"

graph = Neo4jGraph()

In [4]:
import pandas as pd

news = pd.read_csv("https://raw.githubusercontent.com/tomasonjo/blog-datasets/main/news_articles.csv")
news.head()

Unnamed: 0,title,date,text
0,Chevron: Best Of Breed,2031-04-06T01:36:32.000000000+00:00,JHVEPhoto Like many companies in the O&G secto...
1,FirstEnergy (NYSE:FE) Posts Earnings Results,2030-04-29T06:55:28.000000000+00:00,FirstEnergy (NYSE:FE – Get Rating) posted its ...
2,Dáil almost suspended after Sinn Féin TD put p...,2023-06-15T14:32:11.000000000+00:00,The Dáil was almost suspended on Thursday afte...
3,Epic’s latest tool can animate hyperrealistic ...,2023-06-15T14:00:00.000000000+00:00,"Today, Epic is releasing a new tool designed t..."
4,"EU to Ban Huawei, ZTE from Internal Commission...",2023-06-15T13:50:00.000000000+00:00,The European Commission is planning to ban equ...


# Entity extraction

In [5]:
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(temperature=0, model_name="gpt-4o")

llm_transformer = LLMGraphTransformer(llm=llm, node_properties=["description"])

In [6]:
from typing import List
from langchain_community.graphs.graph_document import GraphDocument
from langchain_core.documents import Document

def process_text(text: str) -> List[GraphDocument]:
    doc = Document(page_content=text)
    return llm_transformer.convert_to_graph_documents([doc])

In [7]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm  # Import tqdm for progress tracking

MAX_WORKERS = 10
NUM_ARTICLES = 25
graph_documents = []

with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    # Submitting all tasks and creating a list of future objects
    futures = [
        executor.submit(process_text, f"{row['title']} {row['text']}")
        for i, row in news.head(NUM_ARTICLES).iterrows()
    ]

    for future in tqdm(
        as_completed(futures), total=len(futures), desc="Processing documents"
    ):
        graph_document = future.result()
        graph_documents.extend(graph_document)

graph.add_graph_documents(
    graph_documents,
    baseEntityLabel=True,
    include_source=True
)

Processing documents: 100%|██████████| 25/25 [00:36<00:00,  1.44s/it]


In [8]:
graph.query("""
MATCH (n:`__Entity__`)
RETURN count(*) AS entity_count,
       count(n.description) AS non_null_descriptions
""")

[{'entity_count': 244, 'non_null_descriptions': 92}]

In [9]:
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings

vector = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(),
    node_label='__Entity__',
    text_node_properties=['id', 'description'],
    embedding_node_property='embedding'
)

# Entity resolution

In [10]:
from graphdatascience import GraphDataScience
# project graph

gds = GraphDataScience(
    os.environ["NEO4J_URI"],
    auth=(os.environ["NEO4J_USERNAME"], os.environ["NEO4J_PASSWORD"])
)

In [12]:
G, result = gds.graph.project(
    "entities",                   #  Graph name
    "__Entity__",                 #  Node projection
    "*",                          #  Relationship projection
    nodeProperties=["embedding"]  #  Configuration parameters
)

In [13]:
gds.knn.mutate(
  G,
  nodeProperties=['embedding'],
  mutateRelationshipType= 'SIMILAR',
  mutateProperty= 'score',
  similarityCutoff=0.97
)

ranIterations                                                             6
nodePairsConsidered                                                   82718
didConverge                                                            True
preProcessingMillis                                                       0
computeMillis                                                           826
mutateMillis                                                             14
postProcessingMillis                                                      0
nodesCompared                                                           244
relationshipsWritten                                                     48
similarityDistribution    {'min': 0.9708518981933594, 'p5': 0.9713249206...
configuration             {'mutateProperty': 'score', 'jobId': 'f961b687...
Name: 0, dtype: object

In [14]:
gds.wcc.write(
    G,
    writeProperty="wcc",
    relationshipTypes=["SIMILAR"]
)

writeMillis                                                            176
nodePropertiesWritten                                                  244
componentCount                                                         227
componentDistribution    {'min': 1, 'p5': 1, 'max': 6, 'p999': 6, 'p99'...
postProcessingMillis                                                     2
preProcessingMillis                                                      0
computeMillis                                                            4
configuration            {'writeProperty': 'wcc', 'jobId': '6e3f1124-76...
Name: 0, dtype: object

In [15]:
graph.query(
    """MATCH (e:`__Entity__`)
    WITH e.wcc AS community, collect(e) AS nodes, count(*) AS count
    WHERE count > 1
    RETURN [n IN nodes | n.id] AS duplicates
    """)

[{'duplicates': ['Beijing', 'China']},
 {'duplicates': ['European Commission', 'European Union']},
 {'duplicates': ['Man United', 'Manchester United']},
 {'duplicates': ['Chingona Ventures', 'Samara Hernandez']},
 {'duplicates': ['S&P 500', 'S&P Global Inc.']},
 {'duplicates': ['Metahuman', 'Metahuman Hub']},
 {'duplicates': ['Square Enix', 'Star Ocean: First Departure R']},
 {'duplicates': ['Carmax Auto Funding Llc', 'Carmax Business Services, Llc']},
 {'duplicates': ['Vivo X90', 'Vivo X90 Pro']},
 {'duplicates': ['Mediatek Dimensity 9200+ Soc',
   'Mediatek Dimensity 9200 Soc']},
 {'duplicates': ['Wi-Fi 7', 'Wi-Fi 6']},
 {'duplicates': ['Spotify',
   'Gaana',
   'Jiosaavn',
   'Google Podcasts',
   'Apple Podcasts',
   'Amazon Music']},
 {'duplicates': ['Fastag', 'Fastag System']}]

In [16]:
graph.query(
    """MATCH (e:`__Entity__`)
    WITH e.wcc AS community, collect(e) AS nodes, count(*) AS count
    WHERE count > 1
    CALL apoc.refactor.mergeNodes(nodes)
    YIELD node
    RETURN count(*)
    """)

[{'count(*)': 13}]

In [17]:
G.drop()

graphName                                                         entities
database                                                             neo4j
databaseLocation                                                     local
memoryUsage                                                               
sizeInBytes                                                             -1
nodeCount                                                              244
relationshipCount                                                      259
configuration            {'relationshipProjection': {'__ALL__': {'aggre...
density                                                           0.004368
creationTime                           2024-06-03T20:16:15.639848763+00:00
modificationTime                       2024-06-03T20:16:21.285915603+00:00
schema                   {'graphProperties': {}, 'nodes': {'__Entity__'...
schemaWithOrientation    {'graphProperties': {}, 'nodes': {'__Entity__'...
Name: 0, dtype: object

# Calculating communities of entities

In [18]:
G, result = gds.graph.project(
    "entities",  #  Graph name
    "__Entity__",  #  Node projection
    {"_ALL_": {"type": "*", "orientation": "UNDIRECTED"}},
)


In [19]:
gds.leiden.write(G, writeProperty="communities", includeIntermediateCommunities=True)

writeMillis                                                            184
nodePropertiesWritten                                                  227
ranLevels                                                                3
didConverge                                                           True
nodeCount                                                              227
communityCount                                                          44
communityDistribution    {'min': 1, 'p5': 1, 'max': 17, 'p999': 17, 'p9...
modularity                                                        0.958559
modularities             [0.8402214685204736, 0.9548415354551786, 0.958...
postProcessingMillis                                                     3
preProcessingMillis                                                      0
computeMillis                                                         1048
configuration            {'writeProperty': 'communities', 'theta': 0.01...
Name: 0, dtype: object

# Build community summaries

In [20]:
# Create first level community nodes
graph.query("""
MATCH (e:`__Entity__`)
WITH e, '1-' + e.communities[0] AS communityId // first smaller one community
MERGE (c:Community {id:communityId})
MERGE (e)-[:PART_OF]->(c)
""")

[]

In [22]:
community_info = graph.query("""
MATCH (e:`__Entity__`)
WITH '1-' + e.communities[0] AS communityId, collect(e) AS nodes
WHERE size(nodes) > 1
CALL apoc.path.subgraphAll(nodes[0], {
	whitelistNodes:nodes
})
YIELD relationships
RETURN communityId, [r in relationships | {start: startNode(r).id, type: type(r), end: endNode(r).id}] AS rels
""")

In [27]:
community_info[5]

{'communityId': '1-18',
 'rels': [{'start': 'Ryanair', 'type': 'TERMINATED', 'end': 'Aidan Murray'},
  {'start': 'Ryanair', 'type': 'EMPLOYED', 'end': 'Aidan Murray'},
  {'start': 'Aidan Murray', 'type': 'REPORTED_BY', 'end': 'The Independent'},
  {'start': 'Aidan Murray', 'type': 'HARASSED', 'end': 'Ryanair'},
  {'start': 'Aidan Murray', 'type': 'TERMINATED_BY', 'end': 'Darrell Hughes'},
  {'start': 'Aidan Murray',
   'type': 'REPORTED_BY',
   'end': 'The Financial Times'},
  {'start': 'Aidan Murray', 'type': 'CONTACTED_BY', 'end': 'Bbc'},
  {'start': 'Royal Aeronautical Society',
   'type': 'REPORTED',
   'end': 'Ryanair'}]}

In [28]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

community_template = """Based on the provided triples that belong to the same graph community,
generate a natural language summary of the provided information:
{community_info}

Summary:"""  # noqa: E501

community_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Given an input triples, generate the information summary. No pre-amble.",
        ),
        ("human", community_template),
    ]
)

community_chain = community_prompt | llm | StrOutputParser()

In [29]:
def process_community(community):
    stringify_info = "\n".join([f"{el['start']}-{el['type']}->{el['end']}" for el in community['rels']])
    summary = community_chain.invoke({'community_info': stringify_info})
    return {"community": community['communityId'], "summary": summary}

summaries = []
with ThreadPoolExecutor() as executor:
    futures = {executor.submit(process_community, community): community for community in community_info}

    for future in tqdm(as_completed(futures), total=len(futures), desc="Processing communities"):
        summaries.append(future.result())

Processing communities: 100%|██████████| 63/63 [00:10<00:00,  6.00it/s]


In [30]:
# Store summaries
graph.query("""
UNWIND $data AS row
MERGE (c:Community {id:row.community})
SET c.summary = row.summary
""", params={"data": summaries})

[]

In [32]:
# Do higher level communities...
graph.query("""
MATCH (e:`__Entity__`)
WITH e, '1-' + e.communities[0] AS communityId1, '2-' + e.communities[-1] AS communityId2
WITH distinct communityId1, communityId2
MATCH (c:Community {id: communityId1})
MERGE (c1:Community {id: communityId2})
MERGE (c)-[:PART_OF]->(c1)
""")

[]

In [33]:
community_info = graph.query("""
MATCH (e:`__Entity__`)
WITH '2-' + e.communities[-1] AS communityId, collect(e) AS nodes
WHERE size(nodes) > 1
CALL apoc.path.subgraphAll(nodes[0], {
	whitelistNodes:nodes
})
YIELD relationships
RETURN communityId, [r in relationships | {start: startNode(r).id, type: type(r), end: endNode(r).id}] AS rels
""")

In [34]:
summaries = []
with ThreadPoolExecutor() as executor:
    futures = {executor.submit(process_community, community): community for community in community_info}

    for future in tqdm(as_completed(futures), total=len(futures), desc="Processing communities"):
        summaries.append(future.result())

Processing communities: 100%|██████████| 41/41 [00:11<00:00,  3.54it/s]


In [35]:
# Store summaries
graph.query("""
UNWIND $data AS row
MERGE (c:Community {id:row.community})
SET c.summary = row.summary
""", params={"data": summaries})

[]

In [40]:
# Calculate embeddings
community_vector = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(),
    index_name="summariesIndex",
    node_label='Community',
    text_node_properties=['summary'],
    embedding_node_property='embedding'
)

In [41]:
community_vector.similarity_search("What's the deal with EU and Huawei?", k=1)

[Document(page_content='\nsummary: The European Union considers Huawei Technologies Co. and Zte Corp. to be high-risk entities. Additionally, the European Union has imposed bans on Tiktok Inc., Zte Corp., and Huawei Technologies Co. Alberto Nardelli and Thomas Seal have both provided assistance to the European Union.')]

* repeat for larger community
* maybe connect community summaries to sth
* put a QA chain on top of