In [1]:
import warnings
warnings.filterwarnings("ignore")
import os
import pandas as pd
import textwrap

from langchain_community.graphs import Neo4jGraph
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQAWithSourcesChain

from dotenv import load_dotenv
load_dotenv('.env', override=True)

False

In [2]:
# Load from environment
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
NEO4J_DATABASE = os.getenv('NEO4J_DATABASE')

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
OPENAI_ENDPOINT = os.getenv('OPENAI_BASE_URL') + '/embeddings'

os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_API_KEY"]=os.getenv("LANGCHAIN_API_KEY")

In [3]:
kg = Neo4jGraph(
    url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD, database=NEO4J_DATABASE
)

In [6]:
# result = kg.query("""
# DROP INDEX embedded_index
#   """
# )

Creating a vector index on all nodes that has the embeddable tag

In [None]:
neo4j_vector_store = Neo4jVector.from_existing_graph(
    embedding=OpenAIEmbeddings(),
    url=NEO4J_URI,
    username=NEO4J_USERNAME,
    password=NEO4J_PASSWORD,
    index_name='embedded_index',
    node_label='embeddable',
    text_node_properties=['name','type','description','values'], 
    embedding_node_property='embedding',
)



In [9]:
question = 'Show me the foot traffic dataset'
response = neo4j_vector_store.similarity_search_with_score(question)
response

[(Document(page_content='\nname: foot_traffic\ntype: dataset\ndescription: This dataset records foot traffic data to physical locations, measured by cell phone traffic, on a specific date. The data includes information on unique places, identified by an ID, name, postal code, and coordinates, along with the number of visits and unique visitors recorded.\nvalues: ', metadata={'source': 'snowflake', 'table': 'graph_db.public.foot_traffic'}),
  0.9328018426895142),
 (Document(page_content='\nname: latitude\ntype: column\ndescription: Latitude of the location where foot traffic was measured.\nvalues: 54.25,53.2', metadata={'source': 'foot_traffic', 'col_type': 'reference'}),
  0.9026437997817993),
 (Document(page_content='\nname: web_traffic\ntype: dataset\ndescription: This dataset tracks web traffic to various websites, capturing the number of visits and unique visitors over time. It includes information on the date, website, visit counts, and details about the websites brand and owner.\

In [50]:
question = 'Show me the zip code columns'
response = neo4j_vector_store.similarity_search(question,k=5)
response

[Document(page_content='\nname: zip_code\ntype: column\ndescription: ZIP code of the location where the data was collected.\nvalues: 10001', metadata={'source': 'weather', 'col_type': 'reference'}),
 Document(page_content='\nname: post_code\ntype: column\ndescription: Postal code where the location is situated.\nvalues: 10001,73070', metadata={'source': 'foot_traffic', 'col_type': 'reference'}),
 Document(page_content='\nname: country\ntype: column\ndescription: Country code where the data was collected.\nvalues: US', metadata={'source': 'weather', 'col_type': 'reference'}),
 Document(page_content='\nname: country\ntype: column\ndescription: Country code where the location is situated.\nvalues: US', metadata={'source': 'foot_traffic', 'col_type': 'reference'}),
 Document(page_content='\nname: weather\ntype: dataset\ndescription: This dataset contains weather-related data for the US, spanning from 2000 to December 2023. It includes information on the date, location (by DMA name, state a

In [51]:
question = 'Show me all the columns of the foot traffic dataset'
response = neo4j_vector_store.similarity_search(question,k=5)
response

[Document(page_content='\nname: foot_traffic\ntype: dataset\ndescription: This dataset records foot traffic data to physical locations, measured by cell phone traffic, on a specific date. The data includes information on unique places, identified by an ID, name, postal code, and coordinates, along with the number of visits and unique visitors recorded.\nvalues: ', metadata={'source': 'snowflake', 'table': 'graph_db.public.foot_traffic'}),
 Document(page_content='\nname: latitude\ntype: column\ndescription: Latitude of the location where foot traffic was measured.\nvalues: 54.25,53.2', metadata={'source': 'foot_traffic', 'col_type': 'reference'}),
 Document(page_content='\nname: longitude\ntype: column\ndescription: Longitude of the location where foot traffic was measured.\nvalues: 10.52,10.5,10.11', metadata={'source': 'foot_traffic', 'col_type': 'reference'}),
 Document(page_content='\nname: web_traffic\ntype: dataset\ndescription: This dataset tracks web traffic to various websites,

In [10]:
chain = RetrievalQAWithSourcesChain.from_chain_type(
    #llm = ChatOpenAI(model='gpt-4',temperature=0), 
    llm = ChatOpenAI(temperature=0), 
    chain_type="stuff", 
    retriever=neo4j_vector_store.as_retriever(search_kwargs={'k': 4})
)
def prettychain(question: str) -> str:
    """Pretty print the chain's response to a question"""
    response = chain({"question": question},
        return_only_outputs=True,)
    print(textwrap.fill(response['answer'], 60))

In [12]:
question = "Show me all the columns with column name similar to zip code"
prettychain(question)

There are two columns with names similar to zip code:
zip_code and post_code.


In [13]:
question = "Show me the dataset that has some info about the climate"
prettychain(question)

The dataset that contains information about the climate is
named "weather" and includes weather-related data for the US
from 2000 to December 2023, with details on various weather
parameters.


In [None]:
question = 'Show me all the columns of the foot traffic dataset'
prettychain(question)

Here we can see that not all columns are being fetched here since we are looking at just the semantic similarity.

### Advanced Retrieval Queries
In the initial example, retrieval relies on the semantic similarity between the nodes and the query, which doesn't require a graph database and we can easily achieve that with a vector database. However, the true strength of graph databases lies in utilizing the relationships between nodes. Therefore, we can create a custom retrieval-augmented generation (RAG) query for our retriever. This query will not only fetch the closest semantically similar nodes based on embeddings but also retrieve the column nodes directly connected to these nodes.

In [14]:
contextualize_query = """
match (node)-[:HAS_COLUMN]->(c:column)
with ('name:'+ node.name +'\n'+'type:'+node.type+'\n'+'description:'+node.description) as self,
reduce(s="", item in collect(c) | s + "\n\n" + 'name:'+item.name +'\n'+ 'description:'+item.description +'\n'+ 'type:'+item.type +'\n'+ 'col_type:'+item.col_type ) as c_name,
score, {source: ' '} as metadata limit 1
return (self +'\n'+ c_name) as text, score, metadata  
"""

In [15]:
neo4j_vector_store = Neo4jVector.from_existing_index(
    embedding=OpenAIEmbeddings(),
    url=NEO4J_URI,
    username=NEO4J_USERNAME,
    password=NEO4J_PASSWORD,
    index_name='embedded_index',
    retrieval_query = contextualize_query
)

In [None]:
chain = RetrievalQAWithSourcesChain.from_chain_type(
    #llm = ChatOpenAI(model='gpt-4',temperature=0), 
    llm = ChatOpenAI(temperature=0), 
    chain_type="stuff", 
    retriever=neo4j_vector_store.as_retriever(search_kwargs={'k': 4})
)
def prettychain(question: str) -> str:
    """Pretty print the chain's response to a question"""
    response = chain({"question": question},
        return_only_outputs=True,)
    print(textwrap.fill(response['answer'], 60))

In [16]:
question = 'Show all the columns of the foot traffic dataset'
response = neo4j_vector_store.similarity_search(question,k=5)
response

[Document(page_content='name:foot_traffic\ntype:dataset\ndescription:This dataset records foot traffic data to physical locations, measured by cell phone traffic, on a specific date. The data includes information on unique places, identified by an ID, name, postal code, and coordinates, along with the number of visits and unique visitors recorded.\n\n\nname:date\ndescription:Date of the foot traffic measurement in month/day/year format.\ntype:column\ncol_type:feature\n\nname:sg_place_id\ndescription:Vendor unique identifier for each location.\ntype:column\ncol_type:reference\n\nname:location_name\ndescription:Name of the physical location where foot traffic was measured.\ntype:column\ncol_type:reference\n\nname:post_code\ndescription:Postal code where the location is situated.\ntype:column\ncol_type:reference\n\nname:country\ndescription:Country code where the location is situated.\ntype:column\ncol_type:reference\n\nname:symbol\ndescription:Stock symbol or identifier for the business.

In [17]:
print(response[0].page_content)

name:foot_traffic
type:dataset
description:This dataset records foot traffic data to physical locations, measured by cell phone traffic, on a specific date. The data includes information on unique places, identified by an ID, name, postal code, and coordinates, along with the number of visits and unique visitors recorded.


name:date
description:Date of the foot traffic measurement in month/day/year format.
type:column
col_type:feature

name:sg_place_id
description:Vendor unique identifier for each location.
type:column
col_type:reference

name:location_name
description:Name of the physical location where foot traffic was measured.
type:column
col_type:reference

name:post_code
description:Postal code where the location is situated.
type:column
col_type:reference

name:country
description:Country code where the location is situated.
type:column
col_type:reference

name:symbol
description:Stock symbol or identifier for the business.
type:column
col_type:reference

name:longitude
descrip

In [18]:
chain = RetrievalQAWithSourcesChain.from_chain_type(
    #llm = ChatOpenAI(model='gpt-4',temperature=0), 
    llm = ChatOpenAI(temperature=0), 
    chain_type="stuff", 
    retriever=neo4j_vector_store.as_retriever(search_kwargs={'k': 4})
)
def prettychain(question: str) -> str:
    """Pretty print the chain's response to a question"""
    response = chain({"question": question},
        return_only_outputs=True,)
    print(textwrap.fill(response['answer'], 60))

In [19]:
question = 'Show me all the columns of the foot traffic dataset'
prettychain(question)

The columns of the foot traffic dataset are: date,
sg_place_id, location_name, post_code, country, symbol,
longitude, latitude, visits, unique_visitors.


Now we can see that all the columns are properly fetched