In this notebook I will show you 
1. How we can query each entity by embedding sec_master.
2. How to get cypher queries from natural language.
3. I show how traversal paths from cypher query results generated by LLM can easily be converted to Snowflake (any warehouse or sql supported DB) queries and we can use the entity row identified from the sec_master vector to easily see if that particular entity (ticker/symbol/company) exist in the traversed path in the SQL DB.

In [232]:
import warnings
warnings.filterwarnings("ignore")
import os
import pandas as pd
import textwrap

from langchain_community.graphs import Neo4jGraph
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate,SystemMessagePromptTemplate, PromptTemplate
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQAWithSourcesChain
from langchain import LLMChain, OpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import GraphCypherQAChain
from langchain_community.vectorstores import Chroma
from langchain.schema import Document

from dotenv import load_dotenv
load_dotenv('.env', override=True)

False

In [233]:
NEO4J_URI = os.getenv('NEO4J_URI')
NEO4J_USERNAME = os.getenv('NEO4J_USERNAME')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD')
NEO4J_DATABASE = os.getenv('NEO4J_DATABASE')

OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
OPENAI_ENDPOINT = os.getenv('OPENAI_BASE_URL') + '/embeddings'

os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_API_KEY"]=os.getenv("LANGCHAIN_API_KEY")

In [234]:
kg = Neo4jGraph(
    url=NEO4J_URI, username=NEO4J_USERNAME, password=NEO4J_PASSWORD, database=NEO4J_DATABASE
)

## Searching for an entity in Secmaster

In [49]:
persist_dir = os.path.join(os.getcwd(), "vector_store")
df = pd.read_excel('sample_data.xlsx',sheet_name='sec_master')

def create_document(row):
    content = "\n".join([f"{col}: {val}" for col, val in row.items()])
    return Document(page_content=content)

documents = df.apply(create_document, axis=1).tolist()

# print(documents[0].page_content)

vector_store = Chroma(persist_directory=persist_dir,embedding_function=OpenAIEmbeddings())

vector_store.add_documents(documents)

['4ddb63c1-1b08-4341-9c30-bc0e4ff538c6',
 'd1b2d981-9fe2-4c25-bc71-96a2b1f9bbdf',
 'db4dc103-746f-4cba-944c-a1822f28771a',
 '1b05e4d7-840a-4770-a9f9-3a1818d9de6c',
 '49fa2a7b-9bdb-47b4-a776-10c560f43bd4',
 'f4459b6f-6b31-433b-8c12-cf106b5d0521',
 '88cf2e6b-4bd1-4599-9416-acaeab3b0484']

In [59]:
# query = "Give me the entity for starbucks"
query = "Show me the ticker AAPL"
result = vector_store.similarity_search_with_score(query,k=1)
print(result[0][0].page_content)

date: 2020-01-01 00:00:00
ticker: AAPL
figi_composite: BBG000B9XRY4
figi_share_class: BBG001S5N8V8
security_name: Apple Inc
entity_name: Apple Inc
gics_sector: Information Technology
gics_sub_sector: Technology Hardware, Storage & Peripherals


In [60]:
query = "Show me the entity Starbucks"
result = vector_store.similarity_search_with_score(query,k=1)
print(result[0][0].page_content)

date: 2020-01-01 00:00:00
ticker: SBUX
figi_composite: BBG000CTQBF3
figi_share_class: BBG001S72KH6
security_name: Starbucks Corp
entity_name: Starbucks Corp
gics_sector: Consumer Discretionary
gics_sub_sector: Restaurants


## Creating Cypher

In [224]:
CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
Instructions:
Use only the provided relationship types and properties in the 
schema. Do not use any other relationship types or properties that 
are not provided. When creating the query make the relationships directionless.
Schema:
{schema}
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than 
for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.
Examples: Here are a few examples of generated Cypher 
statements for particular questions:

# Show me all the datasets that are connected to each other?
MATCH (d1:dataset)-[:HAS_COLUMN]-(:column)-[:RELATED_TO]-(:column)-[:HAS_COLUMN]-(d2:dataset)
RETURN d1, d2;

# Find paths that start from a dataset named 'foot_traffic', 
# go through datasets, and return specific table and column names at each step?
MATCH path = (d:dataset)-[*]-(d2:dataset)
WHERE d.name='foot_trafic'
UNWIND nodes(path) as n
RETURN n.type, n.table, n.name;

The question is:
{question}"""

In [225]:
CYPHER_GENERATION_PROMPT = PromptTemplate(
    input_variables=["schema", "question"], 
    template=CYPHER_GENERATION_TEMPLATE
)

In [227]:
cypherChain = GraphCypherQAChain.from_llm(
    ChatOpenAI(model='gpt-4',temperature=0),
    graph=kg,
    verbose=True,
    cypher_prompt=CYPHER_GENERATION_PROMPT,
)

In [228]:
def prettyCypherChain(question: str) -> str:
    response = cypherChain.run(question)
    print(textwrap.fill(response, 60))

In [25]:
prettyCypherChain("Show me all the datasets available, I only need the names?")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:dataset)
RETURN d.name;[0m
Full Context:
[32;1m[1;3m[{'d.name': 'foot_traffic'}, {'d.name': 'web_traffic'}, {'d.name': 'social_media'}, {'d.name': 'sec_master'}, {'d.name': 'weather'}][0m

[1m> Finished chain.[0m
The available datasets are foot_traffic, web_traffic,
social_media, sec_master, and weather.


In [32]:
prettyCypherChain("Show me all the column names connected to the weather dataset ? also mention the column type as well")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:dataset {name: 'weather'})-[:HAS_COLUMN]->(c:column)
RETURN c.name, c.col_type;[0m
Full Context:
[32;1m[1;3m[{'c.name': 'country', 'c.col_type': 'reference'}, {'c.name': 'zip_code', 'c.col_type': 'reference'}, {'c.name': 'latitude', 'c.col_type': 'feature'}, {'c.name': 'longitude', 'c.col_type': 'feature'}, {'c.name': 'feels_like_max', 'c.col_type': 'feature'}, {'c.name': 'rel_hum_avg', 'c.col_type': 'feature'}, {'c.name': 'snow_depth_min', 'c.col_type': 'feature'}, {'c.name': 'month', 'c.col_type': 'feature'}, {'c.name': 'dma_name', 'c.col_type': 'reference'}, {'c.name': 'state_abvtn', 'c.col_type': 'reference'}][0m

[1m> Finished chain.[0m
The column names connected to the weather dataset along with
their types are as follows: 'country' and 'zip_code' are
references; 'latitude', 'longitude', 'feels_like_max',
'rel_hum_avg', 'snow_depth_min', and 'month' are features;
'dma_name' and 'sta

In [18]:
prettyCypherChain("Show me all the datasets that are connected to each other?, only need the names")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d1:dataset)-[:HAS_COLUMN]-(:column)-[:RELATED_TO]-(:column)-[:HAS_COLUMN]-(d2:dataset)
RETURN d1.name, d2.name;[0m
Full Context:
[32;1m[1;3m[{'d1.name': 'foot_traffic', 'd2.name': 'weather'}, {'d1.name': 'weather', 'd2.name': 'foot_traffic'}, {'d1.name': 'foot_traffic', 'd2.name': 'sec_master'}, {'d1.name': 'sec_master', 'd2.name': 'foot_traffic'}, {'d1.name': 'web_traffic', 'd2.name': 'social_media'}, {'d1.name': 'social_media', 'd2.name': 'web_traffic'}, {'d1.name': 'web_traffic', 'd2.name': 'sec_master'}, {'d1.name': 'sec_master', 'd2.name': 'web_traffic'}][0m

[1m> Finished chain.[0m
foot_traffic, weather, sec_master, social_media, web_traffic


In [29]:
prettyCypherChain("Show me how I can navigate to different datasets from sec_master? show me only the name and type of the entities in the path")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH path = (d:dataset {name: 'sec_master'})-[*]-(d2:dataset)
UNWIND nodes(path) as n
RETURN n.name, n.type;[0m
Full Context:
[32;1m[1;3m[{'n.name': 'sec_master', 'n.type': 'dataset'}, {'n.name': 'ticker', 'n.type': 'column'}, {'n.name': 'symbol', 'n.type': 'column'}, {'n.name': 'foot_traffic', 'n.type': 'dataset'}, {'n.name': 'sec_master', 'n.type': 'dataset'}, {'n.name': 'entity_name', 'n.type': 'column'}, {'n.name': 'website_owner', 'n.type': 'column'}, {'n.name': 'web_traffic', 'n.type': 'dataset'}, {'n.name': 'sec_master', 'n.type': 'dataset'}, {'n.name': 'entity_name', 'n.type': 'column'}][0m

[1m> Finished chain.[0m
From the 'sec_master' dataset, you can navigate to the
'ticker' and 'symbol' columns. You can also reach the
'foot_traffic' and 'web_traffic' datasets. Additionally, you
can access the 'entity_name' and 'website_owner' columns.


In [230]:
prettyCypherChain("""
I need a cypher query to traverse from the sec_master dataset to related datasets? 
In the cypher query return the type, name and underlying table.
""")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH path = (d:dataset {name: 'sec_master'})-[*]-(d2:dataset)
UNWIND nodes(path) as n
RETURN n.type, n.name, n.table;[0m
Full Context:
[32;1m[1;3m[{'n.type': 'dataset', 'n.name': 'sec_master', 'n.table': 'graph_db.public.sec_master'}, {'n.type': 'column', 'n.name': 'ticker', 'n.table': None}, {'n.type': 'column', 'n.name': 'symbol', 'n.table': None}, {'n.type': 'dataset', 'n.name': 'foot_traffic', 'n.table': 'graph_db.public.foot_traffic'}, {'n.type': 'dataset', 'n.name': 'sec_master', 'n.table': 'graph_db.public.sec_master'}, {'n.type': 'column', 'n.name': 'entity_name', 'n.table': None}, {'n.type': 'column', 'n.name': 'website_owner', 'n.table': None}, {'n.type': 'dataset', 'n.name': 'web_traffic', 'n.table': 'graph_db.public.web_traffic'}, {'n.type': 'dataset', 'n.name': 'sec_master', 'n.table': 'graph_db.public.sec_master'}, {'n.type': 'column', 'n.name': 'entity_name', 'n.table': None}][0m

[

## Query

In [235]:
query = """MATCH path = (d:dataset {name: 'sec_master'})-[*]-(d2:dataset)
UNWIND nodes(path) as n
RETURN n.type, n.name, n.table;"""
result = kg.query(query)

In [236]:
df = pd.DataFrame(result)
df.columns = [c.replace("n.","") for c in df.columns]

In [237]:
paths = []
start = 0 
for i,row in df.iterrows():
    if i != 0 and row['name']=='sec_master':
        paths.append(df.iloc[start:i])
        start = i
    elif i == len(df)-1:
        paths.append(df.iloc[start:i+1])
    else:
        pass

In [238]:
paths[0]

Unnamed: 0,type,name,table
0,dataset,sec_master,graph_db.public.sec_master
1,column,ticker,
2,column,symbol,
3,dataset,foot_traffic,graph_db.public.foot_traffic


In [241]:
paths[3]

Unnamed: 0,type,name,table
15,dataset,sec_master,graph_db.public.sec_master
16,column,ticker,
17,column,symbol,
18,dataset,foot_traffic,graph_db.public.foot_traffic
19,column,post_code,
20,column,zip_code,
21,dataset,weather,graph_db.public.weather


Now any of the above paths can be easily converted to Snowflake SQL queries if we want to know if the certain entities exist in the path.
Its very easy to take the above paths and converts it into a SQL - code below.

~~~sql
-- For path1
select count(*)
from graph_db.public.sec_master a 
inner join graph_db.public.foot_traffic b on a.ticker = b.symbol
where a.ticker='SBUX'
~~~

~~~sql
-- For path3
select count(*)
from graph_db.public.sec_master a 
inner join graph_db.public.foot_traffic b on a.ticker = b.symbol
inner join graph_db.public.foot_traffic c on b.post_code = c.zip_code
where a.ticker='SBUX'
~~~

In [218]:
def paths_to_sql(df_path):
    dataset_ls = list(df_path[df_path['type']=='dataset']['table'])
    col_ls = list(df_path[df_path['type']=='column']['name'])

    def gen_fun(lst): # creating generator
        for item in lst:
            yield item

    dataset_gen = gen_fun(dataset_ls)
    col_gen = gen_fun(col_ls)

    i = 0
    query = "select count(*)\nfrom "
    while i < len(dataset_ls):
        col_count = 0 
        if i == 0:
            query += f"{next(dataset_gen)} a "
        else:
            alias = chr(ord('a') + i)
            prev_alias = chr(ord('a') + i - 1)
            query += f"\ninner join {next(dataset_gen)} {alias} on {prev_alias}.{next(col_gen)} = {alias}.{next(col_gen)}"
        i+=1
    return query

In [219]:
print(paths_to_sql(paths[0]))

select count(*)
from graph_db.public.sec_master a 
inner join graph_db.public.foot_traffic b on a.ticker = b.symbol


In [220]:
print(paths_to_sql(paths[3]))

select count(*)
from graph_db.public.sec_master a 
inner join graph_db.public.foot_traffic b on a.ticker = b.symbol
inner join graph_db.public.weather c on b.post_code = c.zip_code
