<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/enhancing_rag_with_graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install --upgrade --quiet  langchain langchain-community langchain-openai langchain-experimental neo4j wikipedia tiktoken yfiles_jupyter_graphs

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
try:
  from google.colab import userdata
except:
  pass

os.environ["OPENAI_API_KEY"] = userdata.get('OXFORD_OPENAI_API_KEY')
os.environ["NEO4J_URI"] = userdata.get('NEO4J_URI')
os.environ["NEO4J_USERNAME"] = "neo4j"
os.environ["NEO4J_PASSWORD"] = userdata.get('NEO4J_PASSWORD')

In [11]:
# pip install --upgrade --quiet  langchain
# pip install --upgrade --quiet  langchain-openai

from langchain.chains import GraphCypherQAChain
from langchain_community.graphs import Neo4jGraph
from langchain_openai import ChatOpenAI

graph = Neo4jGraph(
    url=os.environ["NEO4J_URI"],
    username=os.environ["NEO4J_USERNAME"],
    password=os.environ["NEO4J_PASSWORD"]
)

def drop_nodes(graph):
  cypher_query = "MATCH (n) DETACH DELETE n"
  graph.query(cypher_query)
  print("Deleted all the existing records")


def load_data(graph):
  # Load Kellys Dataset
  graph.query(
  """
    LOAD CSV WITH HEADERS FROM 'https://docs.google.com/spreadsheets/d/1gxEEnp0NklCS7ywsU-IgM5LAY2X74aVoDH0R4-PdNyI/export?format=csv' AS row
    WITH row,
      apoc.date.parse(row.`Date of Sale`, 'ms', 'MM/dd/yyyy') AS parsedDate
    MERGE (c:Country {name: row.Country})
    MERGE (s:Store {id: toInteger(row.`Store ID`)})
    ON CREATE SET s.country = row.Country
    MERGE (pc:ProductCategory {name: row.`Product Category`})
    MERGE (p:Product {id: toInteger(row.`Product ID`)})
    ON CREATE SET p.category = row.`Product Category`
    MERGE (sale:Sale {id: toInteger(row.ID), unitsSold: toInteger(row.`Units Sold`), date: date(datetime({epochMillis: parsedDate})), price: toFloat(row.`Price Sold`), gdpGrowth: toFloat(row.`GDP Growth Rate`), inflation: toFloat(row.`Inflation Rate`)})
    MERGE (s)-[:LOCATED_IN]->(c)
    MERGE (s)-[:SOLD]->(sale)
    MERGE (sale)-[:OF_PRODUCT]->(p)
    MERGE (p)-[:BELONGS_TO]->(pc)
  """
  )
  print("Loaded Kellys dataset to Neo4J")

drop_nodes(graph)
load_data(graph)

print("created schema:")
print(graph.schema)


Deleted all the existing records
Loaded Kellys dataset to Neo4J
created schema:
Node properties:
Country {name: STRING}
Store {id: INTEGER, country: STRING}
ProductCategory {name: STRING}
Product {id: INTEGER, category: STRING}
Sale {id: INTEGER, date: DATE, gdpGrowth: FLOAT, price: FLOAT, unitsSold: INTEGER, inflation: FLOAT}
Relationship properties:

The relationships:
(:Store)-[:LOCATED_IN]->(:Country)
(:Store)-[:SOLD]->(:Sale)
(:Product)-[:BELONGS_TO]->(:ProductCategory)
(:Sale)-[:OF_PRODUCT]->(:Product)


In [12]:
from langchain.chains.base import Chain
from langchain.chains import GraphCypherQAChain
from langchain.llms.base import LLM
from langchain_community.graphs.graph_store import GraphStore
from langchain.chains.llm import LLMChain

from langchain_core.callbacks import CallbackManagerForChainRun

from typing import Any, Dict, Optional

class CustomGraphCypherQAChain(GraphCypherQAChain):
    """Customized version of GraphCypherQAChain with overridden _call method."""

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        response = super()._call(inputs, run_manager)

        self.log(run_manager, "Db response\n" + str(response))

        context = f"""
        {response}

        Note: Do not include any explanations or apologies in your responses. Do not say I don't know the answer.
        Parse the query and the graph db response, and return a human readable response.
        """

        result = self.qa_chain(
            {"question": inputs['query'], "context": context}
        )
        return {'result': result[self.qa_chain.output_key]}

    def log(self, run_manager, text):
        _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
        _run_manager.on_text(
            text, color="green", end="\n", verbose=self.verbose
        )

chain = CustomGraphCypherQAChain.from_llm(
    ChatOpenAI(model="gpt-4o", temperature=0.5),
    graph=graph,
    verbose=True,
    return_direct=True
)

def run(question):
  hydrated_question = f"""
  Generate 
  {question}

  Instructions:
  The cypher query RETURN statement should contain all the relevant nodes and relationships in the result like store ID, country, categorgy, product and other.
  """
  graph_result = chain.run(hydrated_question)
  print(graph_result)


In [5]:
chain.run("Get me all the countries available")

  warn_deprecated(




[1m> Entering new CustomGraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (c:Country)
RETURN c.name
[0m
[32;1m[1;3mDb response
{'result': [{'c.name': 'Spain'}, {'c.name': 'Germany'}, {'c.name': 'UK'}, {'c.name': 'Italy'}, {'c.name': 'France'}]}[0m


  warn_deprecated(



[1m> Finished chain.[0m


'Spain, Germany, UK, Italy, France'

In [13]:
chain.run("Get me the highest number of units sold in UK country for store 2377")



[1m> Entering new CustomGraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (s:Store {id: 2377, country: "UK"})-[:SOLD]->(sale:Sale)
RETURN MAX(sale.unitsSold) AS highestUnitsSold
[0m
[32;1m[1;3mDb response
{'result': [{'highestUnitsSold': 10}]}[0m

[1m> Finished chain.[0m


'The highest number of units sold is 10.'

In [17]:
chain.run("Get me all the store ids in Germany")



[1m> Entering new CustomGraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (s:Store)-[:LOCATED_IN]->(c:Country {name: 'Germany'})
RETURN s.id
[0m
[32;1m[1;3mDb response
{'result': [{'s.id': 8036}, {'s.id': 3062}, {'s.id': 2912}, {'s.id': 2037}, {'s.id': 2377}, {'s.id': 7423}, {'s.id': 4091}, {'s.id': 5275}, {'s.id': 8942}, {'s.id': 4137}]}[0m

[1m> Finished chain.[0m


'8036, 3062, 2912, 2037, 2377, 7423, 4091, 5275, 8942, 4137'