<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/enhancing_rag_with_graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install --upgrade --quiet  langchain langchain-community langchain-openai langchain-experimental neo4j wikipedia tiktoken yfiles_jupyter_graphs

Note: you may need to restart the kernel to use updated packages.


In [2]:
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    RunnableParallel,
    RunnablePassthrough,
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import Tuple, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
import os
from langchain_community.graphs import Neo4jGraph
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from neo4j import GraphDatabase
from yfiles_jupyter_graphs import GraphWidget
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
from langchain_core.runnables import ConfigurableField, RunnableParallel, RunnablePassthrough

try:
  import google.colab
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass

# Enhancing RAG-based applications accuracy by constructing and leveraging knowledge graphs
## A practical guide to constructing and retrieving information from knowledge graphs in RAG applications with Neo4j and LangChain

Graph retrieval augmented generation (Graph RAG) is gaining momentum and emerging as a powerful addition to traditional vector search retrieval methods. This approach leverages the structured nature of graph databases, which organize data as nodes and relationships, to enhance the depth and contextuality of retrieved information.

Graphs are great at representing and storing heterogeneous and interconnected information in a structured manner, effortlessly capturing complex relationships and attributes across diverse data types. In contrast, vector databases often struggle with such structured information, as their strength lies in handling unstructured data through high-dimensional vectors. In your RAG application, you can combine structured graph data with vector search through unstructured text to achieve the best of both worlds, which is exactly what we will do in this blog post.

Knowledge graphs are great, but how do you create one? Constructing a knowledge graph is typically the most challenging step in leveraging the power of graph-based data representation. It involves gathering and structuring the data, which requires a deep understanding of both the domain and graph modeling. To simplify this process, we have been experimenting with LLMs. LLMs, with their profound understanding of language and context, can automate significant parts of the knowledge graph creation process. By analyzing text data, these models can identify entities, understand the relationships between them, and suggest how they might be best represented in a graph structure. As a result of these experiments, we have added the first version of the graph construction module to LangChain, which we will demonstrate in this blog post.

## Neo4j Environment Setup

You need to set up a Neo4j instance follow along with the examples in this blog post. The easiest way is to start a free instance on [Neo4j Aura](https://neo4j.com/cloud/platform/aura-graph-database/), which offers cloud instances of Neo4j database. Alternatively, you can also set up a local instance of the Neo4j database by downloading the Neo4j Desktop application and creating a local database instance.

In [3]:
from google.colab import userdata

os.environ["OPENAI_API_KEY"] = userdata.get('OXFORD_OPENAI_API_KEY')
os.environ["NEO4J_URI"] = userdata.get('NEO4J_URI')
os.environ["NEO4J_USERNAME"] = "neo4j"
os.environ["NEO4J_PASSWORD"] = userdata.get('NEO4J_PASSWORD')

graph = Neo4jGraph()

# Parse user's input and read records from Neo4J

In [4]:
def print_in_bold(message):
  print('\033[1m' + message + '\033[0m')

In [6]:
def extract_entities(input):
  prompt = PromptTemplate(
      template=
      """
        You are a smart Assistant to extract countries, ProductIds, Categories, and Stores entities from the user input: {question} into a Json Array.
        For example:
        For question: "Forecast the sales of Product 64318 in UK",
        your response should be: ["Country": "UK", "Product ID": 64318, "Store ID": 8036, "Category": "MENS"]
        Ensure that you return a valid and only JSON Response and it should be straightaway parseable by any Json Parser.
        Do not write any other text in the response, do not even write json. It should be a plain direct Json
      """,
      input_variables=["question"]
  )
  str = prompt.format(question=input)
  llm=ChatOpenAI(temperature=0, model_name="gpt-4o")
  return llm.predict(str)

extract_entities("Forecast the sales of Product 13386 in UK for the store 2377 in July")

'[\n    {\n        "Country": "UK",\n        "Product ID": 13386,\n        "Store ID": 2377,\n        "Category": ""\n    }\n]'

In [7]:
import json

def cypher_query(data):
  where_clauses = []
  for key, value in data.items():
      # Handle keys with spaces (e.g., "Product ID")
      cypher_key = f'`{key}`'
      if isinstance(value, type("str")):
          clause = f"k.{cypher_key} = '{value}'"
      else:
          clause = f"k.{cypher_key} = {value}"
      where_clauses.append(clause)

  # Join the WHERE clauses with " AND "
  where_clause = " AND ".join(where_clauses)

  cypher_query = f"""
  MATCH (k:Kellys)-[:CONTAINS]->(target)
  WHERE {where_clause}
  RETURN DISTINCT k
  """

  # Print the Cypher query
  print_in_bold("\nGenerated Cypher query:")
  print(cypher_query)
  return cypher_query

# Fulltext index query
def records_retriever(question: str) -> str:
    """
    MATCH (k:Kellys)-[:CONTAINS]->(target)
    WHERE k.`Country` = 'UK'
    AND k.`Product ID` = 13386
    AND k.`Store ID` = 8036
    RETURN DISTINCT k
    """
    result = ""
    response = extract_entities(question)
    try:
      data = json.loads(response)[0]
    except:
      data = response
    print_in_bold("Extracted entities from user's input:")
    print(data)
    query = cypher_query(data)
    response = graph.query(query)
    return response

question = "Forecast the sales in UK for the store 2377 in July"
result = records_retriever(question)
print_in_bold("\nResults from Neo4J:")
print(result)


[1mExtracted entities from user's input:[0m
{'Country': 'UK', 'Store ID': 2377}
[1m
Generated Cypher query:[0m

  MATCH (k:Kellys)-[:CONTAINS]->(target)
  WHERE k.`Country` = 'UK' AND k.`Store ID` = 2377
  RETURN DISTINCT k
  
[1m
Results from Neo4J:[0m
[{'k': {'Product ID': 64318, 'Inflation Rate': 6.18, 'Units Sold': 4, 'GDP Growth Rate': 5.2, 'Product Category': 'Mens', 'Store ID': 2377, 'Country': 'UK', 'ID': 5, 'Date of Sale': neo4j.time.DateTime(2024, 2, 7, 18, 30, 0, 0, tzinfo=<UTC>), 'Price Sold': 28.7}}, {'k': {'Product ID': 48183, 'Inflation Rate': 5.78, 'Units Sold': 3, 'GDP Growth Rate': 7.26, 'Product Category': 'Mens', 'Store ID': 2377, 'Country': 'UK', 'ID': 328, 'Date of Sale': neo4j.time.DateTime(2024, 1, 11, 18, 30, 0, 0, tzinfo=<UTC>), 'Price Sold': 97.21}}, {'k': {'Product ID': 69246, 'Inflation Rate': -3.39, 'Units Sold': 3, 'GDP Growth Rate': 4.15, 'Product Category': 'Womens', 'Store ID': 2377, 'Country': 'UK', 'ID': 1222, 'Date of Sale': neo4j.time.DateTim

In [8]:
def forecast(question, records):
  forecast_prompt = PromptTemplate(
      template=
      """
        You are a smart 'Business Data Analyst' and your job is to forecast sales based on this user's ask
        {question}
        The following records are past sales data. Understand and process these to forecast future sales.
        {records}

        DO NOT GIVE STEPS TO DO FORECASTING. YOU PREDICT AND GIVE US THE FINAL FIGURES
      """,
      input_variables=["question", "records"]
  )

  str = forecast_prompt.format(question=question, records=records)
  llm=ChatOpenAI(temperature=0, model_name="gpt-4o")
  return llm.predict(str)

forecast(question, result)

'Based on the provided historical sales data for Store 2377 in the UK, here is the forecasted sales for July:\n\n1. **Total Units Sold**: \n   - The average units sold per month can be calculated from the given data. \n   - Summing up the units sold: 4 + 3 + 3 + 3 + 3 + 1 + 1 + 7 = 25 units over 8 records.\n   - Average units sold per record: 25 / 8 = 3.125 units.\n   - Assuming similar sales trends, the forecasted units sold for July would be approximately 3.125 units per record.\n\n2. **Total Revenue**:\n   - The average price sold can be calculated from the given data.\n   - Summing up the prices sold: 28.7 + 97.21 + 65.44 + 54.9 + 98.3 + 43.12 + 48.26 + 19.96 = 455.89.\n   - Average price sold per record: 455.89 / 8 = 56.98625.\n   - Forecasted revenue for July: 3.125 units * 56.98625 = 178.08 (approximately).\n\nTherefore, the forecasted sales for Store 2377 in the UK for July are approximately:\n- **Total Units Sold**: 3.125 units\n- **Total Revenue**: £178.08'