# Retail Example
Exact Schema and Expert Tools

## Setup

In [1]:
import os

from snowflake.snowpark.functions import column

parent_dir = os.getcwd()
data_dir = os.path.join(parent_dir, "data")
data_model_dir = os.path.join(parent_dir, "data-models")

print("Parent directory:", parent_dir)
print("Data directory:", data_dir)
print("Data model directory:", data_model_dir)


Parent directory: /Users/zachblumenfeld/demo/graphrag-nd/examples/retail
Data directory: /Users/zachblumenfeld/demo/graphrag-nd/examples/retail/data
Data model directory: /Users/zachblumenfeld/demo/graphrag-nd/examples/retail/data-models


In [2]:
from dotenv import load_dotenv


load_dotenv('.env', override=True)

uri = os.getenv('NEO4J_URI_MED')
username = os.getenv('NEO4J_USERNAME_MED')
password = os.getenv('NEO4J_PASSWORD_MED')

## Drafting Graph Schemas
Creating graph schemas for production is an iterative process requiring reviews, version controls, and some trial and error.
graph-nd is designed to support this process.
To get started you can create an initial graph schema from any JSON-like file. For example, you can start with other data modeling tools, such as the Neo4j Data Importer, and export the resulting schema to a file. GraphRAG can then map this to an initial graph schema, which experts can refine further as needed.


In [3]:
from graph_nd import GraphRAG
from langchain_openai import ChatOpenAI


#file names
json_file = os.path.join(data_model_dir, "neo4j-importer-draft.json")
graph_schema_v1 = os.path.join(data_model_dir, "graph-schema-v1.json")

# LLM
llm=ChatOpenAI(model="gpt-4o", temperature=0.0)

# draft v1 graph-schema from neo4j importer model
(GraphRAG(llm=llm).schema
 .from_json_like_file(json_file)
 .export(graph_schema_v1))

[Schema] Successfully Crafted schema


## Tracking & Loading Schemas
You can iterate, track, and re-load these graph schema files. Allowing you to have __precise, version controlled, expert crafted schemas__. Below is how you load a graph schema for use

In [4]:
from graph_nd import GraphRAG
from neo4j import GraphDatabase
from langchain_openai import OpenAIEmbeddings, ChatOpenAI

db_client = GraphDatabase.driver(uri, auth=(username, password))
embedding_model = OpenAIEmbeddings(model='text-embedding-ada-002')
llm = ChatOpenAI(model="gpt-4o", temperature=0.0)
schema_file = os.path.join(data_model_dir, "graph-schema-finalized.json")

# instantiate graphrag
graphrag = GraphRAG(db_client, llm, embedding_model)

# load schema
graphrag.schema.load(schema_file)

[Schema] Schema successfully loaded from /Users/zachblumenfeld/demo/graphrag-nd/examples/retail/data-models/graph-schema-finalized.json


## Map Tabular Data
We can map data with our own custom logic for precision before merging nodes and relationship records

In [5]:
import pandas as pd

ingest_id = "my-tabular-data-ingest"
graphrag.data.nuke()

In [6]:
customer_df = pd.read_csv(os.path.join(data_dir, "customers.csv"))
records = customer_df[[ "customerId", "postalCode", "age", "fashionNewsFrequency", "clubMemberStatus"]].to_dict(orient="records")

graphrag.data.merge_nodes("Customer", records, source_metadata={"ingest_id": ingest_id})

In [8]:
product_df = pd.read_csv(os.path.join(data_dir, "products.csv"))
product_df['text'] = ("##Product \n"
"Name: " + product_df['prodName'].fillna('') + "\n"
"Type: " + product_df['productTypeName'].fillna('') + "\n"
"Category: " + product_df['productGroupName'].fillna('') + "\n"
"Description: " + product_df['detailDesc'].fillna('')
)
product_df['url']= "https://xyzbrands/product/" + product_df['productCode'].astype(str)
product_df.rename(columns={"prodName": "name", "detailDesc":"description"}, inplace=True)


prod_records = product_df[[ "productCode", "name", "description", "url", "text"]].to_dict(orient="records")

print("This will take a minute or so because it is embedding the 'text' field....")
graphrag.data.merge_nodes("Product", prod_records, source_metadata={"ingest_id": ingest_id})

This will take a minute or so because it is embedding the 'text' field....


In [9]:
prod_cat_records = (product_df[['productCode', 'productGroupName']]
                    .rename(columns={'productCode':'start_node_id', 'productGroupName':'end_node_id'})
                    .to_dict(orient="records"))

graphrag.data.merge_relationships(rel_type='PART_OF',
                                  start_node_label='Product',
                                  end_node_label='ProductCategory',
                                  records=prod_cat_records,
                                  source_metadata={"ingest_id": ingest_id})

In [10]:
prod_type_records = (product_df[['productCode', 'productTypeName']]
                    .rename(columns={'productCode':'start_node_id', 'productTypeName':'end_node_id'})
                    .to_dict(orient="records"))

graphrag.data.merge_relationships(rel_type='PART_OF',
                                  start_node_label='Product',
                                  end_node_label='ProductType',
                                  records=prod_type_records,
                                  source_metadata={"ingest_id": ingest_id})

In [11]:
article_df = pd.read_csv(os.path.join(data_dir, "articles.csv"))

article_records = article_df[["articleId", "colourGroupCode", "colourGroupName", "graphicalAppearanceName", "graphicalAppearanceNo"]].to_dict(orient="records")

graphrag.data.merge_nodes("Article", article_records, source_metadata={"ingest_id": ingest_id})


In [12]:
variant_records = (article_df[['articleId', 'productCode']]
                    .rename(columns={'articleId':'start_node_id', 'productCode':'end_node_id'})
                    .to_dict(orient="records"))

graphrag.data.merge_relationships(rel_type='VARIANT_OF',
                                  start_node_label='Article',
                                  end_node_label='Product',
                                  records=variant_records,
                                  source_metadata={"ingest_id": ingest_id})

In [13]:
supplied_by_records = (article_df[['articleId', 'supplierId']]
                    .rename(columns={'articleId':'start_node_id', 'supplierId':'end_node_id'})
                    .to_dict(orient="records"))
graphrag.data.merge_relationships(rel_type='SUPPLIED_BY',
                                  start_node_label='Article',
                                  end_node_label='Supplier',
                                  records=supplied_by_records,
                                  source_metadata={"ingest_id": ingest_id})

In [14]:
supplier_df = pd.read_csv(os.path.join(data_dir, "suppliers.csv"))
supplier_records = supplier_df.rename(columns={"supplierName": "name", "supplierAddress": "address"}).to_dict(orient="records")
graphrag.data.merge_nodes("Supplier", supplier_records, source_metadata={"ingest_id": ingest_id})

In [15]:
order_df = pd.read_csv(os.path.join(data_dir, "order-details.csv"))
order_records = order_df[['orderId', 'tDat']].drop_duplicates().rename(columns={'tDat':'date'}).to_dict(orient="records")
graphrag.data.merge_nodes("Order", order_records, source_metadata={"ingest_id": ingest_id})

In [16]:
ordered_records = order_df[['customerId', 'orderId']].drop_duplicates().rename(columns={'customerId':'start_node_id', 'orderId':'end_node_id'}).to_dict(orient="records")
graphrag.data.merge_relationships(rel_type='ORDERED',
                                  start_node_label='Customer',
                                  end_node_label='Order',
                                  records=ordered_records,
                                  source_metadata={"ingest_id": ingest_id})

In [17]:
contains_records = (order_df[['orderId', 'articleId', 'txId', 'price']]).rename(columns={'orderId':'start_node_id', 'articleId':'end_node_id'}).to_dict(orient="records")
graphrag.data.merge_relationships(rel_type='CONTAINS',
                                  start_node_label='Order',
                                  end_node_label='Article',
                                  records=contains_records,
                                  source_metadata={"ingest_id": ingest_id})

#contains_records

### LLM powered Tabular Mappings Are Still Decent BTW
You can experiment by turning this off and comparing they are in fact the same.

## Text Extraction From PDF
We can Add Schema Subsets here for precision. This would also let us pass custom directions for target schema

In [18]:
from graph_nd import SubSchema

for i in range(2):
    graphrag.data.merge_pdf(os.path.join(data_dir, 'credit-notes.pdf'),
                            nodes_only=False,
                            sub_schema=SubSchema(
                                patterns=[('CreditNote','REFUND_FOR_ORDER', 'Order'), ('CreditNote',"REFUND_OF_ARTICLE", 'Article')]
                            ))

[Data] Merging data from document: /Users/zachblumenfeld/demo/graphrag-nd/examples/retail/data/credit-notes.pdf


Extracting entities from text: 100%|██████████| 31/31 [00:47<00:00,  1.54s/it]


Consolidating results...


Merging Nodes by Label: 100%|██████████| 3/3 [00:04<00:00,  1.62s/node]
Merging Relationships by Type & Pattern: 100%|██████████| 2/2 [00:01<00:00,  1.88rel/s]


[Data] Merging data from document: /Users/zachblumenfeld/demo/graphrag-nd/examples/retail/data/credit-notes.pdf


Extracting entities from text: 100%|██████████| 31/31 [00:53<00:00,  1.71s/it]


Consolidating results...


Merging Nodes by Label: 100%|██████████| 3/3 [00:03<00:00,  1.32s/node]
Merging Relationships by Type & Pattern: 100%|██████████| 2/2 [00:01<00:00,  1.86rel/s]


## Test an Agent

In [20]:
graphrag.agent("Which suppliers where responsible for the most refunds")


Which suppliers where responsible for the most refunds
Tool Calls:
  aggregate (call_JmZHV39l9jAhyihEaTnMjBRK)
 Call ID: call_JmZHV39l9jAhyihEaTnMjBRK
  Args:
    agg_instructions: Aggregate the number of refunds for each supplier by counting the number of CreditNote nodes connected to Article nodes, which are in turn connected to Supplier nodes. Return the suppliers with the highest number of refunds.
Running Query:
MATCH (cn:CreditNote)-[:REFUND_OF_ARTICLE]->(a:Article)-[:SUPPLIED_BY]->(s:Supplier)
RETURN s.name AS supplierName, COUNT(cn) AS numberOfRefunds
ORDER BY numberOfRefunds DESC
Name: aggregate

[
    {
        "supplierName": "1616 - Textile & Apparel Manufacturing",
        "numberOfRefunds": 45
    },
    {
        "supplierName": "1779 - Denim Textiles",
        "numberOfRefunds": 42
    },
    {
        "supplierName": "3708 - Textile & Apparel Manufacturing",
        "numberOfRefunds": 40
    },
    {
        "supplierName": "1643 - Textile & Apparel Manufacturing",
  

## Create Expert Tools For Retrieval
For Our use case there may be some specific query templates and retrieval methodologies

In [46]:
from typing import List, Dict


def get_product_recommendations(product_codes_or_article_ids: List[int]) -> List[Dict]:
    """
    Retrieve product recommendations given a list of product codes or articles ids.
    Please re-order or filter further based on additional context from user.
    """
    res = db_client.execute_query("""
    //recommend from product codes
    MATCH (customer:Customer)-[:ORDERED]->()-[:CONTAINS]->()-[:VARIANT_OF]->
    (interestedInProducts:Product)<-[:VARIANT_OF]-(interestedInArticles:Article)<-[:CONTAINS]-()<-[:ORDERED]
    -(:Customer)-[:ORDERED]->()-[:CONTAINS]->(recArticle:Article)-[:VARIANT_OF]->(product:Product)
    WHERE (interestedInArticles.articleId IN $itemIds)
        OR (interestedInProducts.productCode IN $itemIds)
    WITH count(recArticle) AS recommendationScore, product
    RETURN product.productCode AS productCode,
        product.text AS text,
        product.url AS url
    ORDER BY recommendationScore DESC LIMIT 20
    """, itemIds=product_codes_or_article_ids, result_transformer_ = lambda r: r.data())
    return res


def get_product_order_supplier_info(product_codes: List[int]) -> List[Dict]:
    """
    Given a list of product codes, gets statistics for total orders and refunds as well as by supplier for each product.
    """
    res = db_client.execute_query("""
    MATCH(p:Product)<-[:VARIANT_OF]-(a:Article)-[:SUPPLIED_BY]->(s)
    WHERE p.productCode IN $productCodes
    WITH *,
      COUNT {MATCH (:Order)-[:CONTAINS]->(a)} AS numberOfOrders,
      COUNT {MATCH (:CreditNote)-[:REFUND_OF_ARTICLE]-(a)} AS numberOfRefunds
    RETURN p.productCode AS productCode,
      sum(numberOfOrders) AS totalOrders,
      sum(numberOfRefunds) AS totalReturns,
      collect({supplierId:s.supplierId, name:s.name, numberOfOrders:numberOfOrders, numberOfRefunds:numberOfRefunds}) AS supplierInfos
    """, productCodes=product_codes, result_transformer_ = lambda r: r.data())
    return res

def get_supplier_order_product_info(supplier_ids: List[int]) -> List[Dict]:
    """
    Given a list of supplier ids, gets statistics for the total orders and refunds as well by product delivered for each supplier.
    """
    res = db_client.execute_query("""
    MATCH(p:Product)<-[:VARIANT_OF]-(:Article)-[:SUPPLIED_BY]->(s)
    WHERE s.supplierId IN $supplierIds
    WITH DISTINCT p, s,
      COUNT {MATCH (:Order)-[:CONTAINS]->()-[:VARIANT_OF]->(p)} AS numberOfOrders,
      COUNT {MATCH (:CreditNote)-[:REFUND_OF_ARTICLE]-()-[:VARIANT_OF]->(p)} AS numberOfRefunds
    RETURN s.supplierId AS supplierId,
      sum(numberOfOrders) AS totalOrders,
      sum(numberOfRefunds) AS totalReturns,
      collect({productCode:p.productCode, name:s.name, numberOfOrders:numberOfOrders, numberOfRefunds:numberOfRefunds}) AS supplierInfos
    """, supplierIds=supplier_ids, result_transformer_ = lambda r: r.data())
    return res

In [47]:
agent = graphrag.create_react_agent(tools=[get_product_recommendations,
                                           get_product_order_supplier_info,
                                           get_supplier_order_product_info])

In [48]:
from langchain_core.messages import HumanMessage

# use just like any other langgraph agent...we are going to make a wrapper function for convenience
config = {"configurable": {"thread_id": "thread-1"}}

def agent_stream(question, history=None):
    if history is None:
        history = list()
    for step in agent.stream(
        {"messages": history + [HumanMessage(content=question)]},
        stream_mode="values", config=config
    ):
        history.append(step["messages"][-1])
        step["messages"][-1].pretty_print()
    return history


In [49]:
history = agent_stream("What are some good sweaters for spring? Nothing too warm please!")


What are some good sweaters for spring? Nothing too warm please!
Tool Calls:
  node_search (call_atrxhaPp151Z30QMAAifvLmv)
 Call ID: call_atrxhaPp151Z30QMAAifvLmv
  Args:
    search_config: {'search_type': 'SEMANTIC', 'node_label': 'Product', 'search_prop': 'text'}
    search_query: light spring sweater
Name: node_search

[
    {
        "productCode": 838787,
        "text": "##Product \nName: Spring\nType: Dress\nCategory: Garment Full body\nDescription: Calf-length dress in an airy viscose weave with a collar, concealed buttons at the top and long raglan sleeves with buttoned cuffs. Relaxed fit with a gathered seam at the hips and hem. Unlined.",
        "description": "Calf-length dress in an airy viscose weave with a collar, concealed buttons at the top and long raglan sleeves with buttoned cuffs. Relaxed fit with a gathered seam at the hips and hem. Unlined.",
        "name": "Spring",
        "url": "https://xyzbrands/product/838787",
        "search_score": 0.9259033203125
   

In [50]:
history = agent_stream("What else can you recommend to go with that?", history)


What else can you recommend to go with that?
Tool Calls:
  get_product_recommendations (call_6ITxq8CGYgJ40s2xeI7aUBbq)
 Call ID: call_6ITxq8CGYgJ40s2xeI7aUBbq
  Args:
    product_codes_or_article_ids: [358483, 674250, 531615, 687335, 244267]
Name: get_product_recommendations

[{"productCode": 687016, "text": "##Product \nName: DORIS CREW\nType: Sweater\nCategory: Garment Upper body\nDescription: Top in sweatshirt fabric with a motif on the front and ribbing around the neckline, cuffs and hem. Soft brushed inside.", "url": "https://xyzbrands/product/687016"}, {"productCode": 108775, "text": "##Product \nName: Strap top\nType: Vest top\nCategory: Garment Upper body\nDescription: Jersey top with narrow shoulder straps.", "url": "https://xyzbrands/product/108775"}, {"productCode": 781833, "text": "##Product \nName: Chicago dress\nType: Dress\nCategory: Garment Full body\nDescription: Short dress in a crêpe weave with a round neckline and an opening with a button at the back of the neck. C

In [51]:
history2 = agent_stream("Which suppliers have the highest number of returns (i.,e, credit notes)?")


Which suppliers have the highest number of returns (i.,e, credit notes)?
Tool Calls:
  aggregate (call_Nc3oZ0ihOsoUQgcphxzJ4KAA)
 Call ID: call_Nc3oZ0ihOsoUQgcphxzJ4KAA
  Args:
    agg_instructions: Aggregate the number of credit notes (returns) for each supplier and return the suppliers with the highest number of returns.
Running Query:
MATCH (cn:CreditNote)-[:REFUND_OF_ARTICLE]->(a:Article)-[:SUPPLIED_BY]->(s:Supplier)
RETURN s.name AS supplierName, COUNT(cn) AS numberOfReturns
ORDER BY numberOfReturns DESC
Name: aggregate

[
    {
        "supplierName": "1616 - Textile & Apparel Manufacturing",
        "numberOfReturns": 45
    },
    {
        "supplierName": "1779 - Denim Textiles",
        "numberOfReturns": 42
    },
    {
        "supplierName": "3708 - Textile & Apparel Manufacturing",
        "numberOfReturns": 40
    },
    {
        "supplierName": "1643 - Textile & Apparel Manufacturing",
        "numberOfReturns": 39
    },
    {
        "supplierName": "5832 - Jersey M

In [53]:
history3 = agent_stream("What are the top 3 most returned products for supplier 1616? Get those product codes and find other suppliers who have less returns for each product I can use instead.")


What are the top 3 most returned products for supplier 1616? Get those product codes and find other suppliers who have less returns for each product I can use instead.
Tool Calls:
  aggregate (call_bflujBnG7pSAYo5W0NNu1gbh)
 Call ID: call_bflujBnG7pSAYo5W0NNu1gbh
  Args:
    agg_instructions: Find the top 3 most returned products for supplier 1616 by aggregating the number of refunds for each product supplied by this supplier. Return the product codes of these top 3 products.
Running Query:
MATCH (s:Supplier {supplierId: 1616})<-[:SUPPLIED_BY]-(a:Article)<-[:REFUND_OF_ARTICLE]-(c:CreditNote)
MATCH (a)-[:VARIANT_OF]->(p:Product)
RETURN p.productCode AS productCode, COUNT(c) AS refundCount
ORDER BY refundCount DESC
LIMIT 3
Name: aggregate

[
    {
        "productCode": 673677,
        "refundCount": 30
    },
    {
        "productCode": 748269,
        "refundCount": 11
    },
    {
        "productCode": 802023,
        "refundCount": 4
    }
]
Tool Calls:
  get_product_order_supplie

## MCP Integration
Of course, we can also use MCP to connect tools