In [0]:
# Step 1: Install dependencies (if not already installed)
%pip install networkx sentence-transformers scikit-learn mlflow faker

In [0]:
# Step 2: Import required libraries
import networkx as nx
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, DateType

In [0]:
import random

item_names = [
    "Choco Bliss", "Nutty Crunch", "Caramel Dream", "Minty Fresh", "Berry Burst",
    "Peanut Delight", "Crispy Joy", "Fudge Fantasy", "Toffee Twist", "Coconut Charm",
    "Almond Supreme", "Hazel Heaven", "Marshmallow Magic", "Cookie Craze", "Sugar Rush",
    "Golden Nugget", "Jelly Gem", "Cocoa Swirl", "Vanilla Velvet", "Maple Munch",
    "Cherry Chew", "Orange Zest", "Lemon Drop", "Gummy Glow", "Rainbow Ribbons",
    "S'mores Sensation", "Mocha Melt", "Cinnamon Swirl", "Pecan Pleasure", "Honey Hug",
    "Strawberry Sizzle", "Banana Bonanza", "Apple Aroma", "Grape Gala", "Bubble Bliss",
    "Espresso Edge", "Salted Caramel", "Truffle Treat", "Pumpkin Pop", "Raspberry Ripple",
    "Mango Magic", "Pistachio Punch", "Cranberry Crunch", "Apricot Adventure", "Plum Passion",
    "Tropical Tango", "Lime Lush", "Blueberry Bash", "Butterscotch Burst", "Walnut Whirl"
]

descriptions = [
    "Rich chocolate with creamy filling", "Crunchy nuts in smooth chocolate", "Soft caramel center",
    "Refreshing mint flavor", "Burst of berry goodness", "Peanut butter blend", "Crispy wafer layers",
    "Decadent fudge treat", "Classic toffee twist", "Sweet coconut flakes", "Premium almonds inside",
    "Hazelnut cream filling", "Fluffy marshmallow center", "Cookie bits in chocolate", "Extra sweet sensation",
    "Golden caramel nuggets", "Jelly fruit center", "Swirled cocoa delight", "Smooth vanilla cream",
    "Maple infused chocolate", "Chewy cherry pieces", "Zesty orange flavor", "Tangy lemon drop",
    "Glowing gummy candies", "Colorful candy ribbons", "S'mores inspired bar", "Mocha coffee blend",
    "Cinnamon spice swirl", "Pecan nut crunch", "Honey sweetened chocolate", "Strawberry infused bar",
    "Banana flavored treat", "Apple flavored candy", "Grape jelly center", "Bubblegum flavored chocolate",
    "Espresso infused bar", "Salted caramel layer", "Rich truffle filling", "Pumpkin spice pop",
    "Raspberry ripple center", "Mango flavored delight", "Pistachio nut blend", "Cranberry crunch bar",
    "Apricot jam filling", "Plum flavored chocolate", "Tropical fruit mix", "Lime zest infusion",
    "Blueberry burst center", "Butterscotch cream", "Walnut swirl filling"
]

data = [
    (
        "p-" + str(i+1),
        item_names[i],
        descriptions[i]
    )
    for i in range(50)
]

item_details_df = spark.createDataFrame(
    data,
    schema=["item_id", "item_name", "item_description"]
)

display(item_details_df)

In [0]:
import random
from pyspark.sql.types import StructType, StructField, StringType

us_cities = [
    "New York", "Los Angeles", "Chicago", "Houston", "Phoenix", "Philadelphia", "San Antonio", "San Diego", "Dallas", "San Jose",
    "Austin", "Jacksonville", "Fort Worth", "Columbus", "Charlotte", "San Francisco", "Indianapolis", "Seattle", "Denver", "Washington",
    "Boston", "El Paso", "Nashville", "Detroit", "Oklahoma City", "Portland", "Las Vegas", "Memphis", "Louisville", "Baltimore",
    "Milwaukee", "Albuquerque", "Tucson", "Fresno", "Sacramento", "Kansas City", "Mesa", "Atlanta", "Omaha", "Colorado Springs",
    "Raleigh", "Miami", "Long Beach", "Virginia Beach", "Oakland", "Minneapolis", "Tulsa", "Arlington", "Tampa", "New Orleans"
]

schema = StructType([
    StructField("location_id", StringType(), False),
    StructField("location_name", StringType(), True),
    StructField("location_description", StringType(), True)
])

random.shuffle(us_cities)
data = [
    ("l-" + str(i+1), us_cities[i], f"Store located in {us_cities[i]}, USA") for i in range(50)
]

store_location_df = spark.createDataFrame(
    data,
    schema=schema
)
display(store_location_df)

In [0]:
from faker import Faker
fake = Faker()
Faker.seed(42)

customer_data = []
for i in range(200):
    customer_id = f"c-{i+1}"
    name = fake.unique.name()
    email = fake.unique.email()
    zip_code = fake.zipcode_in_state(state_abbr='NY')  # US zip code, can randomize state if desired
    customer_data.append((customer_id, name, email, zip_code))

customer_schema = StructType([
    StructField("customer_id", StringType(), False),
    StructField("customer_name", StringType(), True),
    StructField("customer_email", StringType(), True),
    StructField("customer_zip_code", StringType(), True)
])

customer_df = spark.createDataFrame(customer_data, schema=customer_schema)
display(customer_df)

In [0]:
sales_data = []
for _ in range(5000):
    item_id = random.choice(item_ids)
    location_id = random.choice(location_ids)
    customer_id = random.choice(customer_ids)
    sale_date = random_date()
    units_sold = random.randint(1, 20)
    unit_price = random.uniform(1.0, 10.0)
    total_sales_value = round(units_sold * unit_price, 2)
    sales_data.append(
        (item_id, location_id, customer_id, sale_date, units_sold, total_sales_value)
    )

items_sales_df = spark.createDataFrame(
    sales_data,
    schema=[
        "item_id",
        "location_id",
        "customer_id",
        "sale_date",
        "units_sold",
        "total_sales_value"
    ]
)

display(items_sales_df)

In [0]:
# Step 4: Save to Unity Catalog
catalog = "accenture"
schema = "sales_analysis"

In [0]:
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}")

item_details_df.write.mode("overwrite").saveAsTable(f"{catalog}.{schema}.item_details")
store_location_df.write.mode("overwrite").saveAsTable(f"{catalog}.{schema}.store_location")
customer_df.write.mode("overwrite").saveAsTable(f"{catalog}.{schema}.customer_details")
items_sales_df.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable(f"{catalog}.{schema}.items_sales")

print("‚úì Tables saved to Unity Catalog")

In [0]:
# Step 5: Configure Graph RAG
from dataclasses import dataclass
from typing import List, Optional, Dict

@dataclass
class GraphRAGConfig:
    catalog: str
    schema: str
    fact_table: str
    dimension_tables: List[str]
    embedding_model: str = "all-MiniLM-L6-v2"
    top_k_nodes: int = 5
    max_hops: int = 2
    fk_mappings: Optional[Dict] = None

config = GraphRAGConfig(
    catalog=catalog,
    schema=schema,
    fact_table="items_sales",
    dimension_tables=["item_details", "store_location"],
    embedding_model="all-MiniLM-L6-v2",
    top_k_nodes=5,
    max_hops=2,
    fk_mappings={
        "items_sales": {
            "item_id": "item_details",
            "location_id": "store_location",
            "customer_id": "customer_details",
        }
    }
)

print("‚úì Configuration created")

In [0]:
# Step 6: Build Knowledge Graph
class KnowledgeGraphBuilder:
    def __init__(self, config: GraphRAGConfig):
        self.config = config
        self.graph = nx.MultiDiGraph()
        
    def build_from_tables(self):
        """Build graph from dimension and fact tables"""
        
        # Load dimension tables as nodes
        for dim_table in self.config.dimension_tables:
            table_name = f"{self.config.catalog}.{self.config.schema}.{dim_table}"
            df = spark.table(table_name).toPandas()
            
            # First column is assumed to be the ID
            id_col = df.columns[0]
            
            for _, row in df.iterrows():
                node_id = f"{dim_table}_{row[id_col]}"
                attributes = {col: str(row[col]) for col in df.columns}
                attributes['_table'] = dim_table
                self.graph.add_node(node_id, **attributes)
        
        print(f"‚úì Added {self.graph.number_of_nodes()} nodes from dimension tables")
        
        # Load fact table as edges
        fact_table_name = f"{self.config.catalog}.{self.config.schema}.{self.config.fact_table}"
        fact_df = spark.table(fact_table_name).toPandas()
        
        # Get FK mappings
        fk_mappings = self.config.fk_mappings.get(self.config.fact_table, {})
        
        edges_added = 0
        for _, row in fact_df.iterrows():
            edge_attrs = {col: row[col] for col in fact_df.columns}
            
            # Create edges between dimension entities
            source_fks = []
            for fk_col, dim_table in fk_mappings.items():
                if fk_col in row:
                    source_fks.append((fk_col, dim_table, row[fk_col]))
            
            # Create edges between all pairs
            for i, (fk1, dim1, val1) in enumerate(source_fks):
                for fk2, dim2, val2 in source_fks[i+1:]:
                    node1 = f"{dim1}_{val1}"
                    node2 = f"{dim2}_{val2}"
                    
                    if self.graph.has_node(node1) and self.graph.has_node(node2):
                        self.graph.add_edge(node1, node2, **edge_attrs)
                        edges_added += 1
        
        print(f"‚úì Added {edges_added} edges from fact table")
        return self.graph

# Build the graph
builder = KnowledgeGraphBuilder(config)
graph = builder.build_from_tables()

print(f"\nüìä Graph Statistics:")
print(f"   Total Nodes: {graph.number_of_nodes()}")
print(f"   Total Edges: {graph.number_of_edges()}")

In [0]:
# Step 7: Generate embeddings
print("\nüîÑ Generating embeddings...")
model = SentenceTransformer(config.embedding_model)

node_texts = {}
node_embeddings = {}

for node_id, attrs in graph.nodes(data=True):
    # Create text representation of node
    text_parts = [f"{k}: {v}" for k, v in attrs.items() if not k.startswith('_')]
    node_text = ", ".join(text_parts)
    node_texts[node_id] = node_text
    
# Generate embeddings in batch
all_texts = list(node_texts.values())
all_node_ids = list(node_texts.keys())
embeddings = model.encode(all_texts, show_progress_bar=True)

for node_id, embedding in zip(all_node_ids, embeddings):
    node_embeddings[node_id] = embedding

print(f"‚úì Generated embeddings for {len(node_embeddings)} nodes")

In [0]:
# Step 8: Test query
def query_graph(question: str, top_k: int = 5, max_hops: int = 2):
    """Query the graph with natural language"""
    
    # Generate question embedding
    question_embedding = model.encode([question])[0]
    
    # Find similar nodes
    similarities = {}
    for node_id, node_emb in node_embeddings.items():
        sim = cosine_similarity([question_embedding], [node_emb])[0][0]
        similarities[node_id] = sim
    
    # Get top K nodes
    top_nodes = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:top_k]
    
    print(f"\nüîç Top {top_k} matching nodes:")
    for node_id, score in top_nodes:
        print(f"   {node_id}: {score:.3f} - {node_texts[node_id][:80]}...")
    
    # Traverse graph
    subgraph_nodes = set()
    for node_id, _ in top_nodes:
        subgraph_nodes.add(node_id)
        
        # Multi-hop traversal
        current_nodes = {node_id}
        for hop in range(max_hops):
            next_nodes = set()
            for node in current_nodes:
                neighbors = set(graph.neighbors(node))
                predecessors = set(graph.predecessors(node))
                next_nodes.update(neighbors)
                next_nodes.update(predecessors)
            subgraph_nodes.update(next_nodes)
            current_nodes = next_nodes
    
    # Extract subgraph
    subgraph = graph.subgraph(subgraph_nodes)
    
    print(f"\nüìä Subgraph: {subgraph.number_of_nodes()} nodes, {subgraph.number_of_edges()} edges")
    
    # Generate answer
    answer_parts = []
    for node in subgraph.nodes():
        answer_parts.append(f"- {node}: {node_texts[node]}")
    
    return "\n".join(answer_parts[:10])  # Limit to top 10 for readability

In [0]:
# Test queries
print("\n" + "="*80)
print("TESTING QUERIES")
print("="*80)

# Query 1
print("\n" + "-"*80)
print("\nüìù Query: Show top 3 sales items by total sales value")
result = query_graph("Show high sales items")
print(f"\nüí° Answer:\n{result}")

print("\n" + "="*80)
print("‚úÖ Graph RAG implementation complete!")

In [0]:
import re
from typing import Dict, List, Tuple

class EnhancedGraphRAG:
    def __init__(self, graph, node_embeddings, node_texts, config, spark):
        self.graph = graph
        self.node_embeddings = node_embeddings
        self.node_texts = node_texts
        self.config = config
        self.spark = spark
        self.model = SentenceTransformer(config.embedding_model)
        
    def classify_query_intent(self, question: str) -> Dict:
        """Classify what the user is asking for"""
        question_lower = question.lower()
        
        # Aggregation keywords
        agg_keywords = ['top', 'highest', 'most', 'best', 'total', 'sum', 
                        'average', 'max', 'min', 'count', 'bottom', 'worst']
        
        # Measure keywords
        measure_keywords = {
            'sales': 'total_sales_value',
            'revenue': 'total_sales_value',
            'units': 'units_sold',
            'quantity': 'units_sold',
            'volume': 'units_sold'
        }
        
        # Check for aggregation intent
        has_aggregation = any(keyword in question_lower for keyword in agg_keywords)
        
        # Detect which measure
        measure_col = None
        for keyword, col in measure_keywords.items():
            if keyword in question_lower:
                measure_col = col
                break
        
        # Detect entity type
        entity_type = None
        if any(word in question_lower for word in ['item', 'product', 'goods']):
            entity_type = 'item_details'
        elif any(word in question_lower for word in ['location', 'store', 'place']):
            entity_type = 'store_location'
        
        # Detect number requested
        limit = 5  # default
        number_match = re.search(r'top (\d+)|(\d+) (top|best|highest)', question_lower)
        if number_match:
            limit = int(number_match.group(1) or number_match.group(2))
        
        return {
            'is_aggregation': has_aggregation,
            'measure': measure_col,
            'entity_type': entity_type,
            'limit': limit,
            'intent': 'aggregation' if has_aggregation else 'semantic_search'
        }
    
    def execute_aggregation_query(self, intent: Dict) -> str:
        """Execute aggregation-based query using Spark SQL"""
        
        measure = intent['measure'] or 'total_sales_value'
        entity_type = intent['entity_type'] or 'item_details'
        limit = intent['limit']
        
        # Map entity type to table and join column
        if entity_type == 'item_details':
            entity_col = 'item_id'
            entity_table = f"{self.config.catalog}.{self.config.schema}.item_details"
            name_col = 'item_name'
        else:
            entity_col = 'location_id'
            entity_table = f"{self.config.catalog}.{self.config.schema}.store_location"
            name_col = 'location_name'
        
        # Build SQL query
        query = f"""
        SELECT 
            e.{entity_col},
            e.{name_col},
            SUM(f.{measure}) as total_measure,
            COUNT(*) as transaction_count
        FROM {self.config.catalog}.{self.config.schema}.{self.config.fact_table} f
        JOIN {entity_table} e ON f.{entity_col} = e.{entity_col}
        GROUP BY e.{entity_col}, e.{name_col}
        ORDER BY total_measure DESC
        LIMIT {limit}
        """
        
        print(f"\nüîç Executing aggregation query...")
        print(f"   Measure: {measure}")
        print(f"   Entity: {entity_type}")
        print(f"   Limit: {limit}")
        
        # Execute query
        result_df = self.spark.sql(query)
        results = result_df.collect()
        
        # Format answer
        answer_parts = [f"\nTop {limit} {entity_type.replace('_', ' ')} by {measure}:\n"]
        
        for i, row in enumerate(results, 1):
            name = row[name_col]
            value = row['total_measure']
            count = row['transaction_count']
            answer_parts.append(
                f"{i}. {name} - ${value:,.2f} ({count} transactions)"
            )
        
        return "\n".join(answer_parts)
    
    def execute_semantic_search(self, question: str, top_k: int = 5, max_hops: int = 2) -> str:
        """Execute semantic similarity search (original method)"""
        
        # Generate question embedding
        question_embedding = self.model.encode([question])[0]
        
        # Find similar nodes
        similarities = {}
        for node_id, node_emb in self.node_embeddings.items():
            sim = cosine_similarity([question_embedding], [node_emb])[0][0]
            similarities[node_id] = sim
        
        # Get top K nodes
        top_nodes = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:top_k]
        
        print(f"\nüîç Top {top_k} semantically similar nodes:")
        for node_id, score in top_nodes:
            print(f"   {node_id}: {score:.3f}")
        
        # Build answer from top nodes
        answer_parts = []
        for node_id, score in top_nodes:
            node_text = self.node_texts[node_id]
            answer_parts.append(f"- {node_text} (similarity: {score:.3f})")
        
        return "\n".join(answer_parts)
    
    def query(self, question: str) -> str:
        """Main query method with intent classification"""
        
        print(f"\n{'='*80}")
        print(f"üìù Query: {question}")
        print(f"{'='*80}")
        
        # Classify intent
        intent = self.classify_query_intent(question)
        print(f"\nüß† Query Intent: {intent['intent']}")
        print(f"   Details: {intent}")
        
        # Route to appropriate handler
        if intent['is_aggregation'] and intent['measure']:
            result = self.execute_aggregation_query(intent)
        else:
            result = self.execute_semantic_search(question, top_k=5, max_hops=2)
        
        print(f"\nüí° Answer:")
        print(result)
        print(f"\n{'='*80}\n")
        
        return result


# Initialize the enhanced system
enhanced_rag = EnhancedGraphRAG(
    graph=graph,
    node_embeddings=node_embeddings,
    node_texts=node_texts,
    config=config,
    spark=spark
)

In [0]:
# Query 1: Aggregation query
enhanced_rag.query("Show top 5 sales items by total sales value")

In [0]:
query = """
SELECT s.item_id, SUM(s.total_sales_value) AS total_sales, i.item_name, i.item_description
FROM accenture.sales_analysis.items_sales s
JOIN accenture.sales_analysis.item_details i ON s.item_id = i.item_id
GROUP BY s.item_id, i.item_name, i.item_description
ORDER BY total_sales DESC
LIMIT 5
"""
top_items_df = spark.sql(query)
display(top_items_df)

In [0]:
# Query 2: Another aggregation
enhanced_rag.query("What are the top 5 items by revenue?")

In [0]:
query = """
SELECT s.item_id, SUM(s.total_sales_value) AS total_sales, i.item_name, i.item_description
FROM accenture.sales_analysis.items_sales s
JOIN accenture.sales_analysis.item_details i ON s.item_id = i.item_id
GROUP BY s.item_id, i.item_name, i.item_description
ORDER BY total_sales DESC
LIMIT 5
"""
top_items_df = spark.sql(query)
display(top_items_df)

In [0]:
# Query 3: Location-based aggregation
enhanced_rag.query("List top 5 store location_name that have highest sales?")

In [0]:
query = """
SELECT 
  l.location_name, 
  SUM(s.total_sales_value) AS total_sales_value, 
  COUNT(*) AS transaction_count
FROM accenture.sales_analysis.items_sales s
JOIN accenture.sales_analysis.store_location l 
  ON s.location_id = l.location_id
GROUP BY l.location_name
ORDER BY total_sales_value DESC
LIMIT 5
"""
top_locations_df = spark.sql(query)
display(top_locations_df)

In [0]:
# Query 4: Time-based aggregation
enhanced_rag.query("List the top 5 items by units_sold in December 2025?")

In [0]:
query = """
SELECT 
  i.item_id,
  i.item_name,
  SUM(s.units_sold) AS total_units_sold,
  COUNT(*) AS transaction_count
FROM accenture.sales_analysis.items_sales s
JOIN accenture.sales_analysis.item_details i ON s.item_id = i.item_id
WHERE YEAR(s.sale_date) = 2025 AND MONTH(s.sale_date) = 12
GROUP BY i.item_id, i.item_name
ORDER BY total_units_sold DESC
LIMIT 5
"""
top_items_units_sold_df = spark.sql(query)
display(top_items_units_sold_df)

In [0]:
import re
from datetime import datetime
from typing import Dict, List, Tuple, Optional

class EnhancedGraphRAG:
    def __init__(self, graph, node_embeddings, node_texts, config, spark):
        self.graph = graph
        self.node_embeddings = node_embeddings
        self.node_texts = node_texts
        self.config = config
        self.spark = spark
        self.model = SentenceTransformer(config.embedding_model)
        
    def extract_date_filters(self, question: str) -> Dict:
        """Extract date/time filters from natural language query"""
        question_lower = question.lower()
        
        filters = {
            'year': None,
            'month': None,
            'quarter': None,
            'date_range': None
        }
        
        # Extract year (e.g., "2025", "in 2024")
        year_match = re.search(r'\b(20\d{2})\b', question)
        if year_match:
            filters['year'] = int(year_match.group(1))
        
        # Extract month by name or number
        month_names = {
            'january': 1, 'jan': 1,
            'february': 2, 'feb': 2,
            'march': 3, 'mar': 3,
            'april': 4, 'apr': 4,
            'may': 5,
            'june': 6, 'jun': 6,
            'july': 7, 'jul': 7,
            'august': 8, 'aug': 8,
            'september': 9, 'sep': 9, 'sept': 9,
            'october': 10, 'oct': 10,
            'november': 11, 'nov': 11,
            'december': 12, 'dec': 12
        }
        
        for month_name, month_num in month_names.items():
            if month_name in question_lower:
                filters['month'] = month_num
                break
        
        # Extract month number (e.g., "month 12", "12/2025")
        if not filters['month']:
            month_match = re.search(r'\b(month\s+)?(\d{1,2})[/\-]', question_lower)
            if month_match:
                month = int(month_match.group(2))
                if 1 <= month <= 12:
                    filters['month'] = month
        
        # Extract quarter (Q1, Q2, Q3, Q4)
        quarter_match = re.search(r'\bq([1-4])\b', question_lower)
        if quarter_match:
            filters['quarter'] = int(quarter_match.group(1))
        
        # Extract relative dates
        if 'last month' in question_lower:
            current = datetime.now()
            filters['month'] = current.month - 1 if current.month > 1 else 12
            filters['year'] = current.year if current.month > 1 else current.year - 1
        elif 'this month' in question_lower:
            current = datetime.now()
            filters['month'] = current.month
            filters['year'] = current.year
        elif 'last year' in question_lower:
            filters['year'] = datetime.now().year - 1
        elif 'this year' in question_lower:
            filters['year'] = datetime.now().year
        
        return filters
    
    def classify_query_intent(self, question: str) -> Dict:
        """Classify what the user is asking for"""
        question_lower = question.lower()
        
        # Aggregation keywords
        agg_keywords = ['top', 'highest', 'most', 'best', 'total', 'sum', 
                        'average', 'max', 'min', 'count', 'bottom', 'worst', 'list']
        
        # Measure keywords
        measure_keywords = {
            'sales': 'total_sales_value',
            'revenue': 'total_sales_value',
            'units': 'units_sold',
            'quantity': 'units_sold',
            'volume': 'units_sold',
            'transactions': 'COUNT(*)'
        }
        
        # Check for aggregation intent
        has_aggregation = any(keyword in question_lower for keyword in agg_keywords)
        
        # Detect which measure
        measure_col = None
        for keyword, col in measure_keywords.items():
            if keyword in question_lower:
                measure_col = col
                break
        
        # Default to sales value if aggregation but no measure specified
        if has_aggregation and not measure_col:
            measure_col = 'total_sales_value'
        
        # Detect entity type
        entity_type = None
        if any(word in question_lower for word in ['item', 'product', 'goods']):
            entity_type = 'item_details'
        elif any(word in question_lower for word in ['location', 'store', 'place', 'shop']):
            entity_type = 'store_location'
        
        # Default to items if not specified
        if has_aggregation and not entity_type:
            entity_type = 'item_details'
        
        # Detect number requested
        limit = 5  # default
        number_match = re.search(r'top (\d+)|(\d+) (top|best|highest)', question_lower)
        if number_match:
            limit = int(number_match.group(1) or number_match.group(2))
        
        # Extract date filters
        date_filters = self.extract_date_filters(question)
        
        return {
            'is_aggregation': has_aggregation,
            'measure': measure_col,
            'entity_type': entity_type,
            'limit': limit,
            'date_filters': date_filters,
            'intent': 'aggregation' if has_aggregation else 'semantic_search'
        }
    
    def build_date_where_clause(self, date_filters: Dict, table_alias: str = 'f') -> str:
        """Build SQL WHERE clause for date filters"""
        conditions = []
        
        if date_filters['year']:
            conditions.append(f"YEAR({table_alias}.sale_date) = {date_filters['year']}")
        
        if date_filters['month']:
            conditions.append(f"MONTH({table_alias}.sale_date) = {date_filters['month']}")
        
        if date_filters['quarter']:
            conditions.append(f"QUARTER({table_alias}.sale_date) = {date_filters['quarter']}")
        
        if date_filters['date_range']:
            start, end = date_filters['date_range']
            conditions.append(f"{table_alias}.sale_date BETWEEN '{start}' AND '{end}'")
        
        return " AND ".join(conditions) if conditions else "1=1"
    
    def execute_aggregation_query(self, intent: Dict) -> str:
        """Execute aggregation-based query using Spark SQL"""
        
        measure = intent['measure']
        entity_type = intent['entity_type']
        limit = intent['limit']
        date_filters = intent['date_filters']
        
        # Map entity type to table and join column
        if entity_type == 'item_details':
            entity_col = 'item_id'
            entity_table = f"{self.config.catalog}.{self.config.schema}.item_details"
            name_col = 'item_name'
        else:
            entity_col = 'location_id'
            entity_table = f"{self.config.catalog}.{self.config.schema}.store_location"
            name_col = 'location_name'
        
        # Build WHERE clause for date filters
        where_clause = self.build_date_where_clause(date_filters, 'f')
        
        # Build aggregation expression
        if measure == 'COUNT(*)':
            agg_expr = 'COUNT(*)'
            measure_label = 'transaction_count'
        else:
            agg_expr = f'SUM(f.{measure})'
            measure_label = 'total_measure'
        
        # Build SQL query
        query = f"""
        SELECT 
            e.{entity_col},
            e.{name_col},
            {agg_expr} as {measure_label},
            COUNT(*) as transaction_count
        FROM {self.config.catalog}.{self.config.schema}.{self.config.fact_table} f
        JOIN {entity_table} e ON f.{entity_col} = e.{entity_col}
        WHERE {where_clause}
        GROUP BY e.{entity_col}, e.{name_col}
        ORDER BY {measure_label} DESC
        LIMIT {limit}
        """
        
        print(f"\nüîç Executing aggregation query...")
        print(f"   Measure: {measure}")
        print(f"   Entity: {entity_type}")
        print(f"   Limit: {limit}")
        if any(date_filters.values()):
            print(f"   Date Filters: {date_filters}")
        print(f"\nüìã SQL Query:")
        print(query)
        
        # Execute query
        result_df = self.spark.sql(query)
        results = result_df.collect()
        
        # Format answer
        date_desc = ""
        if date_filters['month'] and date_filters['year']:
            month_names = ['', 'January', 'February', 'March', 'April', 'May', 'June',
                          'July', 'August', 'September', 'October', 'November', 'December']
            date_desc = f" in {month_names[date_filters['month']]} {date_filters['year']}"
        elif date_filters['year']:
            date_desc = f" in {date_filters['year']}"
        elif date_filters['quarter'] and date_filters['year']:
            date_desc = f" in Q{date_filters['quarter']} {date_filters['year']}"
        
        answer_parts = [f"\nTop {limit} {entity_type.replace('_', ' ')}{date_desc} by {measure}:\n"]
        
        if not results:
            return f"\nNo data found for the specified filters{date_desc}."
        
        for i, row in enumerate(results, 1):
            name = row[name_col]
            value = row[measure_label]
            count = row['transaction_count']
            
            # Format value based on measure type
            if measure == 'total_sales_value':
                value_str = f"${value:,.2f}"
            elif measure == 'COUNT(*)':
                value_str = f"{int(value):,} transactions"
            else:
                value_str = f"{value:,.0f} units"
            
            answer_parts.append(
                f"{i}. {name} - {value_str} ({count} transactions)"
            )
        
        return "\n".join(answer_parts)
    
    def execute_semantic_search(self, question: str, top_k: int = 5, max_hops: int = 2) -> str:
        """Execute semantic similarity search (original method)"""
        
        # Generate question embedding
        question_embedding = self.model.encode([question])[0]
        
        # Find similar nodes
        similarities = {}
        for node_id, node_emb in self.node_embeddings.items():
            sim = cosine_similarity([question_embedding], [node_emb])[0][0]
            similarities[node_id] = sim
        
        # Get top K nodes
        top_nodes = sorted(similarities.items(), key=lambda x: x[1], reverse=True)[:top_k]
        
        print(f"\nüîç Top {top_k} semantically similar nodes:")
        for node_id, score in top_nodes:
            print(f"   {node_id}: {score:.3f}")
        
        # Build answer from top nodes
        answer_parts = []
        for node_id, score in top_nodes:
            node_text = self.node_texts[node_id]
            answer_parts.append(f"- {node_text} (similarity: {score:.3f})")
        
        return "\n".join(answer_parts)
    
    def query(self, question: str) -> str:
        """Main query method with intent classification"""
        
        print(f"\n{'='*80}")
        print(f"üìù Query: {question}")
        print(f"{'='*80}")
        
        # Classify intent
        intent = self.classify_query_intent(question)
        print(f"\nüß† Query Intent: {intent['intent']}")
        print(f"   Details: {intent}")
        
        # Route to appropriate handler
        if intent['is_aggregation'] and intent['measure']:
            result = self.execute_aggregation_query(intent)
        else:
            result = self.execute_semantic_search(question, top_k=5, max_hops=2)
        
        print(f"\nüí° Answer:")
        print(result)
        print(f"\n{'='*80}\n")
        
        return result


# Re-initialize with enhanced version
enhanced_rag = EnhancedGraphRAG(
    graph=graph,
    node_embeddings=node_embeddings,
    node_texts=node_texts,
    config=config,
    spark=spark
)

# Test all your queries
print("\nüß™ TESTING TIME-BASED QUERIES\n")

# Test 1: Month + Year filter
enhanced_rag.query("List the top 5 items by units_sold in December 2025?")

# Test 2: Just year filter
enhanced_rag.query("What are the top 3 items by revenue in 2025?")

# Test 3: Quarter filter
enhanced_rag.query("Show top 5 locations by sales in Q4 2025")

# Test 4: Month name
enhanced_rag.query("Top selling items in November 2025")

# Test 5: No date filter (all time)
enhanced_rag.query("What are the top 5 items by revenue?")

# Test 6: Relative date
enhanced_rag.query("Show top items this year")

In [0]:
query = """
SELECT 
  i.item_id,
  i.item_name,
  SUM(s.units_sold) AS total_units_sold,
  COUNT(*) AS transaction_count
FROM accenture.sales_analysis.items_sales s
JOIN accenture.sales_analysis.item_details i ON s.item_id = i.item_id
WHERE YEAR(s.sale_date) = 2025 AND MONTH(s.sale_date) = 12
GROUP BY i.item_id, i.item_name
ORDER BY total_units_sold DESC
LIMIT 5
"""
top_items_units_sold_df = spark.sql(query)
display(top_items_units_sold_df)