<a href="https://colab.research.google.com/github/thoppae/Build_Launch_Agents/blob/main/A2_SC_Multi_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Step 1- Install required libraries from the requirements file

In [None]:
!pip install -r requirements.txt

# Generate a new API key and use it through google secrets

In [None]:
from google.colab import userdata

GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
#print(f"GOOGLE_API_KEY: {GOOGLE_API_KEY}")

import google.generativeai as genai

genai.configure(api_key=GOOGLE_API_KEY)

model = genai.GenerativeModel('gemini-1.5-flash')

# sq lite required to execute the query

In [None]:
import sqlite3

db_file = 'database.db'
conn = sqlite3.connect(db_file)
cursor = conn.cursor()


Step 2 - Build and read the simulated database using the sql tools

In [None]:
#sql tools

# Create table

def create_table(conn):
    """
    Retrieve a list of tables in the database.
    Args:
        cursor (sqlite3.Cursor): The cursor object for executing SQL queries.
    Returns:
        list: A list of table names in the database.
    """
    print(f' - DB CALL: create_table()')

    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    return [table[0] for table in tables]

# Describe table

def describe_table(conn, table_name):
    """
    Look up the schema of a table in the database.
    Args:
        cursor (sqlite3.Cursor): The cursor object for executing SQL queries.
        table_name (str): The name of the table to describe.

    Returns:
        list: A list of tuples representing the schema of the table.

    """
    print(f' - DB CALL: describe_table({table_name})')

    cursor = conn.cursor()
    cursor.execute(f"PRAGMA table_info({table_name});")
    schema = cursor.fetchall()
    return [(col[1], col[2]) for col in schema]

# Execute query

def execute_query(conn, sql: str) -> list[list[str]]:
    """
    Execute an SQL query and return the results.
    Args:
        cursor (sqlite3.Cursor): The cursor object for executing SQL queries.
        sql (str): The SQL query to execute.

        Returns:
        list: A list of tuples representing the query results.
    """
    print(f' - DB CALL: execute_query({sql})')

    cursor = conn.cursor()
    cursor.execute(sql)
    return cursor.fetchall()



describe_table(conn, "InventoryOH")


#Connection to the sqlite database

In [None]:
def connect_to_database(db_file):
    """
    Connect to an SQLite database.
    Args:
        db_file (str): The path to the SQLite database file.
        Returns:
        sqlite3.Connection: The database connection object.
    """
    print(f' - DB CALL: connect_to_database({db_file})')
    conn = sqlite3.connect(db_file)
    return conn



# langchain tools to list/describe/execute

In [None]:
from langchain.tools import tool

#list table

@tool
def list_tables(conn) -> list[str]:
    """
    Retrieve a list of tables in the database.
    Args:
        cursor (sqlite3.Cursor): The cursor object for executing SQL queries.

        Returns:
        list: A list of table names in the database.
    """
    print(f' - DB CALL: list_tables()')

    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    return [table[0] for table in tables]

#Describe table

@tool
def describe_table(conn, table_name: str) -> list[tuple[str, str]]:
    """
    Look up the schema of a table in the database.
    Args:
        cursor (sqlite3.Cursor): The cursor object for executing SQL queries.
        table_name (str): The name of the table to describe.
        Returns:
        list: A list of tuples representing the schema of the table.
    """
    print(f' - DB CALL: describe_table({table_name})')
    cursor = conn.cursor()
    cursor.execute(f"PRAGMA table_info({table_name});")
    schema = cursor.fetchall()
    return [(col[1], col[2]) for col in schema]


#Execute query table

@tool
def execute_query(conn, sql: str) -> list[list[str]]:
    """
    Execute an SQL query and return the results.
    Args:
        cursor (sqlite3.Cursor): The cursor object for executing SQL queries.
        sql (str): The SQL query to execute.
        Returns:
        list: A list of tuples representing the query results.
        """
    print(f' - DB CALL: execute_query({sql})')
    cursor = conn.cursor()
    cursor.execute(sql)
    return cursor.fetchall()




# pre-compute foriegn keys

In [None]:
from pprint import pprint

def precompute_foreign_keys(conn):
    """
    Precompute foreign key relationships between tables.
    Args:
        cursor (sqlite3.Cursor): The cursor object for executing SQL queries.

        Returns:
        dict: A dictionary mapping table names to a list of foreign key constraints.
    """
    print(f' - DB CALL: precompute_foreign_keys()')
    foreign_keys = {}
    cursor = conn.cursor()

    # step 1: get all the table names
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    # step 2: build map of table --> it's *_id columns
    for table in tables:
        table_name = table[0]
        cursor.execute(f"PRAGMA foreign_key_list({table_name});")
        foreign_key_list = cursor.fetchall()

        foreign_keys[table_name] = [fk[2] for fk in foreign_key_list]

    return foreign_keys

    #step 3: Infer relationships based on shared column names
    fk_map = {}
    for table_name, columns in foreign_keys.items():
        for other_table_name, other_columns in foreign_keys.items():
            if table_name != other_table_name:
                shared_columns = set(columns) & set(other_columns)
                if shared_columns:
                    if table_name not in fk_map:
                        fk_map[table_name] = []
                    fk_map[table_name].append((other_table_name, list(shared_columns)))

    return fk_map
    conn = get_connection()
    pprint(precompute_foreign_keys(conn))
    def get_connection():
        return connect_to_database(db_file)

    def get_precomputed_relationships():
        return precompute_foreign_keys(get_connection())





#list the tools

In [None]:
tools = [
    list_tables,
    describe_table,
    execute_query,
    precompute_foreign_keys
]
#

#Execute query tool

In [None]:
execute_query.run({"conn": conn, "sql": "select i.EODQty from InventoryOH i where i.Store_id = 5 and i.Item_id = 1"})


Step 3: Define helper functions for agents and langgraph to get JSON based output

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.runnables import RunnableConfig

class SchemaExplorerAgent:
  def __init__(self) -> None:
      #Initialize the LLM agent(Gemini model)
      self.LLM = ChatGoogleGenerativeAI(
          model="gemini-1.5-flash",
          # GOOGLE_API_KEY=GOOGLE_API_KEY,
          temperature=0.2
      )
      self.tools = [list_tables, describe_table]
      self.state = {}

      #Provide instructions on how gemini should handle the schema exploration
      self.instruction = """
      You are an expert schema exploration agent for a retail supply chain database.

        Your task is to explore the database and gather schema information:
        - List all tables in the database.
        - For each table, describe the columns (name and type).
        - Infer relationships between tables based on common column names (e.g., *_id columns).

        The relationships should be inferred based on shared column names, assuming:
        - If two tables have columns with the same name, they are likely related (e.g., store_id in InventoryOH and Store).

        Return the schema in a structured format that includes:
        - Table name
        - Column names
        - Column types
        - Inferred relationships
        """

      # Config to provide LLM instructions
      self.config = RunnableConfig(
          # llm_config=RunnableConfig.llm_config(
          #     system_prompt=self.instruction
          # )
      )

  def run(self, query: str) -> str:
       # Get the database connection using the appropriate method
       conn = connect_to_database(db_file)

       #Get the list of all the tables from the database
       tables = list_tables.run({"conn": conn})

       #list to store schema information

       schema = []
       relationships = []

       #Describe each table and get columns
       for table in tables:
            columns = describe_table.run({"conn": conn, "table_name": table})
            schema.append({
                "table_name": table,
                "columns": columns
            })

            # Extract column information
            table_columns = [col[0] for col in columns]

            schema_map = {col[0]: col[1] for col in columns}


            # Check for foreign key relationships
            for other_table in tables:
                if table != other_table:
                    other_columns = describe_table.run({"conn": conn, "table_name": other_table})

       #Prepare the schema output


       schema_output = {
             "schema": schema,
             "relationships": relationships
         }
       for table in schema:
             table_name = table["table_name"]
             for other_table in schema:
                 other_table_name = other_table["table_name"]
                 if table_name != other_table_name:
                     shared_columns = set(table["columns"]) & set(other_table["columns"])
                     if shared_columns:
                         relationships.append({
                             "table_name": table_name,
                             "other_table_name": other_table_name,
                             "shared_columns": list(shared_columns)
                         })

       # Store the schema and relationships in the state
       self.state["schema"] = schema
       self.state["relationships"] = relationships
       return schema_output

  def get_precomputed_relationships(self):
        return self.state.get("relationships", [])
  def get_schema(self):
         return self.state.get("schema", [])

In [None]:
# Schema explorer agent class already installed
schema_explorer_agent = SchemaExplorerAgent()

# assuming the state is a dictionary that holds the output data
state = schema_explorer_agent.state


#Now run the SchemaExplorerAgent and capture the output in the state dictionary
result = schema_explorer_agent.run("What is the schema of the database?")

#Print the output from the state dictionary
#print("Schema:", state["schema"])
#print("Relationships:", state["relationships"])
#

#chatbot libraries and state difinition and an example how chatbot welcome node

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
from typing import TypedDict
from langgraph.graph import StateGraph, END
from langgraph.graph.graph import CompiledGraph

llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    temperature=0.2
)

    # State definition
class SupplyChainState(TypedDict):
    store_id: int
    item_id: int
    quantity: int
    manager_response: str
    generated_queries: list
    executed_results: dict

    """
    Safely parse the response into a JSON dictionary.
    Args:
        response (str): The response to parse.
    Returns:
        dict: The parsed JSON dictionary.
    """



    # Example chatbot welcome node

def welcome_node(state: SupplyChainState) -> SupplyChainState:
  """
  Welcomes the user to the supply chain chatbot.
  Args:
      state (SupplyChainState): The current state of the supply chain chatbot.
  Returns:
      SupplyChainState: The updated state of the supply chain chatbot.
  """
  print(f' - Chatbot CALL: welcome_node()')
  return state




Step 4 - Define Agents/Nodes Operation






# Role: Agent - 01: Manager Agent¶
Acts as the coordinator for user interactions.

Responsibilities:

Validates user input by checking for missing or ambiguous entities (e.g., item, store, DC).
Determines whether the query can be executed or needs clarification.
Routes the task:
To sql_generation_agent if all required parameters are present.
To clarification_node if additional information is needed.
This agent ensures every query is handled efficiently and passed to the right downstream node.


In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
from pydantic import BaseModel
import logging, json, re

# Pydantic schema for Gemini extraction #
class GeminiExtraction(BaseModel):
    store_id: int
    item_id: int
    quantity: int
logging.basicConfig(level=logging.INFO)

class ManagerAgent:
    def __init__(self, llm: ChatGoogleGenerativeAI, clarification_node) -> None:
        self.llm = llm
        """

        You're a retail supply chain assistant. Extra store reference (by number or name) and item reference (by number or name) from user questions.

        Extract RAW:
        - store_id
        - item_id
        - quantity

      Respond ONLY in this JSON:
        {
            "store_id": int,
            "item_id": int,
            "quantity": int
        }
        """
        self.llm = ChatGoogleGenerativeAI(
            model="gemini-1.5-flash",
            temperature=0.2
        ).with_structured_output(GeminiExtraction)

        self.sql_generation_agent = SQLGenerationAgent(llm)
        self.clarification_node = clarification_node

    def extract_store_ids(self, raw_stores):
        # This is a placeholder for the actual store ID extraction logic
        return [1, 2, 3], []

    def extract_item_ids(self, raw_items):
        # This is a placeholder for the actual item ID extraction logic
        return [101, 102], []

    def extract_quantities(self, raw_quantities):
        # This is a placeholder for the actual quantity extraction logic
        return [10, 20]

    def resolve_store_name(self, store_name):
        # This is a placeholder for the actual store name resolution logic
        return 1

    def resolve_item_name(self, item_name):
        # This is a placeholder for the actual item name resolution logic
        return 101

    def run(self, state: SupplyChainState) -> SupplyChainState:
        """
        Run the ManagerAgent.
        Args:
            state (SupplyChainState): The current state of the supply chain chatbot.
        Returns:
            SupplyChainState: The updated state of the supply chain chatbot.
        """
        print('Manager Agent: Validating user input using Gemini...')

        query = state.get('user_input', '')
        print(f'Manager Agent: User input: {query}')

        if not query:
            print('Manager Agent: No user input provided.')
            return state

        try:

          # 1 - Raw extraction from Gemini
          extracted = self.llm.invoke(query)
          raw_stores = extracted.store_id
          raw_items = extracted.item_id
          raw_quantities = extracted.quantity

          print(f'Manager Agent: Extracted raw stores: {raw_stores}')
          print(f'Manager Agent: Extracted raw items: {raw_items}')
          print(f'Manager Agent: Extracted raw quantities: {raw_quantities}')

          # 2 - Store and Item Extraction
          store_ids, unresolved_store_names = self.extract_store_ids(raw_stores)
          item_ids, unresolved_item_names = self.extract_item_ids(raw_items)

          for store_name in unresolved_store_names:
              print(f'Manager Agent: Unresolved store name: {store_name}')

          for item_name in unresolved_item_names:
              print(f'Manager Agent: Unresolved item name: {item_name}')

          # 3 - Quantity Extraction
          quantities = self.extract_quantities(raw_quantities)

          for quantity in quantities:
              print(f'Manager Agent: Extracted quantity: {quantity}')

          # 4 - Converty city names into store_ids dynamically

          for name in unresolved_store_names:
              sql = f"SELECT Store_id FROM Store WHERE Store_name = '{name}'"
              rows = execute_query.run({"conn": conn, "sql": sql})
              for row in rows:
                  store_ids.append(row[0])

          # 5 Handle Item ID's
          item_ids, unresolved_item_names = self.extract_item_ids(raw_items)
          for item_name in unresolved_item_names:
              item_id = self.resolve_item_name(item_name)
              if item_id is not None:
                  item_ids.append(item_id)
                  print(f'Manager Agent: Resolved item name: {item_name} to item_id: {item_id}')

          # 6 Convert unresolved items to item_ids dynamically

          for name in unresolved_item_names:
              item_id = self.resolve_item_name(name)
              if item_id is not None:
                  item_ids.append(item_id)

          item_ids = sorted(set(item_ids))
          print(f'Manager Agent: Extracted item_ids: {item_ids}')

          # 7 - Handle Quantity
          quantities = self.extract_quantities(raw_quantities)
          for quantity in quantities:
              print(f'Manager Agent: Extracted quantity: {quantity}')

          # 8 - Handle missing store and item scenarios dynamically
          if not store_ids and not raw_stores:
              print('Manager Agent: No valid stores found.')
              return state

          if not item_ids and not raw_items:
              print('Manager Agent: No valid items found.')
              return state

          print(f'Manager Agent: Valid stores: {store_ids}')
          print(f'Manager Agent: Valid items: {item_ids}')

        except Exception as e:
            print(f'Manager Agent: Error during extraction: {e}')
            return state

# SQL Generation Agent

In [None]:
class SQLGenerationAgent:
    def __init__(self, llm: ChatGoogleGenerativeAI) -> None:
        self.llm = llm

    def run(self, state: SupplyChainState) -> SupplyChainState:
        """
        Generates SQL queries based on the user's request.
        """
        print("SQL Generation Agent: Generating SQL queries...")
        # This is a placeholder for the actual SQL generation logic
        state["generated_queries"] = ["SELECT * FROM Inventory;"]
        return state

# Test to see how the Manager Agent works

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
from typing import TypedDict
from pydantic import BaseModel
import logging, json, re

llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    temperature=0.2
)

class SupplyChainState(TypedDict):
    store_id: int
    item_id: int
    quantity: int
    manager_response: str
    generated_queries: list
    executed_results: dict

# Pydantic schema for Gemini extraction #
class GeminiExtraction(BaseModel):
    store_id: int
    item_id: int
    quantity: int
logging.basicConfig(level=logging.INFO)

class SQLGenerationAgent:
    def __init__(self, llm: ChatGoogleGenerativeAI) -> None:
        self.llm = llm

    def run(self, state: SupplyChainState) -> SupplyChainState:
        """
        Generates SQL queries based on the user's request.
        """
        print("SQL Generation Agent: Generating SQL queries...")
        # This is a placeholder for the actual SQL generation logic
        state["generated_queries"] = ["SELECT * FROM Inventory;"]
        return state

class ManagerAgent:
    def __init__(self, llm: ChatGoogleGenerativeAI, clarification_node) -> None:
        self.llm = llm
        """

        You're a retail supply chain assistant. Extra store reference (by number or name) and item reference (by number or name) from user questions.

        Extract RAW:
        - store_id
        - item_id
        - quantity

      Respond ONLY in this JSON:
        {
            "store_id": int,
            "item_id": int,
            "quantity": int
        }
        """
        self.llm = ChatGoogleGenerativeAI(
            model="gemini-1.5-flash",
            temperature=0.2
        ).with_structured_output(GeminiExtraction)

        self.sql_generation_agent = SQLGenerationAgent(llm)
        self.clarification_node = clarification_node

    def extract_store_ids(self, raw_stores):
        # This is a placeholder for the actual store ID extraction logic
        return [1, 2, 3], []

    def extract_item_ids(self, raw_items):
        # This is a placeholder for the actual item ID extraction logic
        return [101, 102], []

    def extract_quantities(self, raw_quantities):
        # This is a placeholder for the actual quantity extraction logic
        return [10, 20]

    def resolve_store_name(self, store_name):
        # This is a placeholder for the actual store name resolution logic
        return 1

    def resolve_item_name(self, item_name):
        # This is a placeholder for the actual item name resolution logic
        return 101

    def run(self, state: SupplyChainState) -> SupplyChainState:
        """
        Run the ManagerAgent.
        Args:
            state (SupplyChainState): The current state of the supply chain chatbot.
        Returns:
            SupplyChainState: The updated state of the supply chain chatbot.
        """
        print('Manager Agent: Validating user input using Gemini...')

        query = state.get('user_input', '')
        print(f'Manager Agent: User input: {query}')

        if not query:
            print('Manager Agent: No user input provided.')
            return state

        try:

          # 1 - Raw extraction from Gemini
          extracted = self.llm.invoke(query)
          raw_stores = extracted.store_id
          raw_items = extracted.item_id
          raw_quantities = extracted.quantity

          print(f'Manager Agent: Extracted raw stores: {raw_stores}')
          print(f'Manager Agent: Extracted raw items: {raw_items}')
          print(f'Manager Agent: Extracted raw quantities: {raw_quantities}')

          # 2 - Store and Item Extraction
          store_ids, unresolved_store_names = self.extract_store_ids(raw_stores)
          item_ids, unresolved_item_names = self.extract_item_ids(raw_items)

          for store_name in unresolved_store_names:
              print(f'Manager Agent: Unresolved store name: {store_name}')

          for item_name in unresolved_item_names:
              print(f'Manager Agent: Unresolved item name: {item_name}')

          # 3 - Quantity Extraction
          quantities = self.extract_quantities(raw_quantities)

          for quantity in quantities:
              print(f'Manager Agent: Extracted quantity: {quantity}')

          # 4 - Converty city names into store_ids dynamically

          for name in unresolved_store_names:
              sql = f"SELECT Store_id FROM Store WHERE Store_name = '{name}'"
              rows = execute_query.run({"conn": conn, "sql": sql})
              for row in rows:
                  store_ids.append(row[0])

          # 5 Handle Item ID's
          item_ids, unresolved_item_names = self.extract_item_ids(raw_items)
          for item_name in unresolved_item_names:
              item_id = self.resolve_item_name(item_name)
              if item_id is not None:
                  item_ids.append(item_id)
                  print(f'Manager Agent: Resolved item name: {item_name} to item_id: {item_id}')

          # 6 Convert unresolved items to item_ids dynamically

          for name in unresolved_item_names:
              item_id = self.resolve_item_name(name)
              if item_id is not None:
                  item_ids.append(item_id)

          item_ids = sorted(set(item_ids))
          print(f'Manager Agent: Extracted item_ids: {item_ids}')

          # 7 - Handle Quantity
          quantities = self.extract_quantities(raw_quantities)
          for quantity in quantities:
              print(f'Manager Agent: Extracted quantity: {quantity}')

          # 8 - Handle missing store and item scenarios dynamically
          if not store_ids and not raw_stores:
              print('Manager Agent: No valid stores found.')
              return state

          if not item_ids and not raw_items:
              print('Manager Agent: No valid items found.')
              return state

          print(f'Manager Agent: Valid stores: {store_ids}')
          print(f'Manager Agent: Valid items: {item_ids}')

        except Exception as e:
            print(f'Manager Agent: Error during extraction: {e}')
            return state

# Create an instance of the ClarificationNode
class ClarificationNode:
    def __init__(self, llm: ChatGoogleGenerativeAI) -> None:
        self.llm = llm
        self.sql_generation_agent = SQLGenerationAgent(llm)

    def run(self, state: SupplyChainState) -> SupplyChainState:
      return state
clarification_node = ClarificationNode(llm)


# Create an instance of the ManagerAgent
manager_agent = ManagerAgent(llm, clarification_node)

# Create a sample SupplyChainState
state = {
    "user_input": "I need to order 10 items of item 101 for store 1."
}

# Run the ManagerAgent
new_state = manager_agent.run(state)

# Print the new state
print(new_state)

# Role: Agent 02 - SQL Generation Agent

Role:
Translates validated user queries into executable, schema-aware SQL code.

Responsibilities:



1.  Uses dynamic prompting and metadata to understand the query structure
2.  Leverages schema context from the Digital Twin to generate accurate SQL
3. Ensures queries align with simulation data and current inventory state.
4. Returns a well-formed SQL query to fetch the requested insights.

This agent acts as the bridge between natural language and database execution.

In [None]:
import json
import re
import google.generativeai as genai
from pydantic import BaseModel
from langchain_core.messages import HumanMessage

def sql_generation_agent(state: SupplyChainState) -> SupplyChainState:
  """
  Generates SQL queries based on the user's request.
  Args:
      state (SupplyChainState): The current state of the supply chain chatbot.
  Returns:
      SupplyChainState: The updated state of the supply chain chatbot.
  """
  print(f' - Chatbot CALL: sql_generation_agent()')

  prompt = f"""
  You are an expert SQL Gen Agent for a retail Supply Chain database.
  Schema:Only use the tables & columns listed below.
  {json.dumps(state.get('schema', []))}

----Example----

1) In: "What is the current inventory of item_id=1 at store 5?"
Out:
[
  "SELECT EODQty FROM InventoryOH WHERE store_id=5 AND item_id=1 ORDER BY InventoryDt DESC LIMIT 1;"
]

2) In: "Show me the last 7 days of sales for item 2 at store 3."
Out:
[
  "SELECT Sales_dt, Sales FROM StoreSales WHERE store_id=3 AND item_id=2 ORDER BY Sales_dt DESC LIMIT 7;"
]

3) In: "Compare forecast vs actual for item 4 at store 2 for this month."
Out:
[
  "SELECT Forecast_sales_dt, Forecast FROM StoreFC WHERE store_id=2 AND item_id=4 AND Forecast_sales_dt BETWEEN DATE('now','start of month') AND DATE('now');",
  "SELECT Sales_dt, Sales FROM StoreSales WHERE store_id=2 AND item_id=4 AND Sales_dt BETWEEN DATE('now','start of month') AND DATE('now');"
]

--- END EXAMPLES ---

User Question: {state.get('user_input', '')}
Store ID: {state.get('store_id')}
Item ID: {state.get('item_id')}
Quantity: {state.get('quantity')}

Return **Only** a JSON array of SQL queries.
"""
  print("Prompt sent to model:")
  print(prompt)
  model = genai.GenerativeModel('gemini-1.5-flash')
  response = model.generate_content(prompt)
  print("Response from model:")
  print(response)
  response_text = response.text
  print("Response text:")
  print(response_text)

  # Extract the JSON from the response
  match = re.search(r"```json\n(.*)\n```", response_text, re.DOTALL)
  if match:
      json_text = match.group(1)
  else:
      json_text = response_text

  # try JSON parse
  try:
      queries = json.loads(json_text)
      # Basic validation to ensure we don't get INSERT statements for "what is" questions
      if "what is" in state.get('user_input','').lower():
          validated_queries = [q for q in queries if q.strip().upper().startswith("SELECT")]
          if not validated_queries:
              # Fallback to a safe SELECT if the model hallucinates an INSERT
              validated_queries = [f"SELECT EODQty FROM InventoryOH WHERE store_id={state.get('store_id')} AND item_id={state.get('item_id')} ORDER BY InventoryDt DESC LIMIT 1;"]
          state["generated_queries"] = validated_queries
      else:
          state["generated_queries"] = queries

  except json.JSONDecodeError:

    #fallback: extract SELECT..; snippets
    queries = re.findall(r"SELECT.*?;", response_text, re.DOTALL)
    state["generated_queries"] = queries

  print(f'SQL Generation Agent: Generated queries: {state["generated_queries"]}')
  return state

In [None]:
# Test input state (example)

state = {
    "store_id": 5,
    "item_id": 1,
    "quantity": 10,
    "schema": "What is the current inventory of milk at Arizona?"
}

# Assuming the sql_generation_agent function is already defined

# Run the agent with the test input
result = sql_generation_agent(state)

# Print the result
print(result)

In [None]:
# Get the generated SQL query from the result
sql_query = result["generated_queries"][0]

# Execute the query
query_result = execute_query.run({"conn": conn, "sql": sql_query})

# Print the result
print(query_result)

Agent 03 - SQL Agent Execution

Role:
Executes the SQL queries and retrieves results from the Digital Twin database.

Responsibilities:

Connects to the simulated supply chain database.
Runs the SQL query generated by the sql_generation_agent.
Fetches and formats the output for easy interpretation.
Passes results downstream for explanation or visualization.

This agent ensures accurate data retrieval and powers real-time decision insights.

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableConfig

class SQLAgentExecution:
    def __init__(self, llm: ChatGoogleGenerativeAI) -> None:
        self.llm = ChatGoogleGenerativeAI(
            model="gemini-1.5-flash",
            temperature=0.2
        )
        self.sql_generation_agent = SQLGenerationAgent(llm)

    def run(self, state: SupplyChainState) -> SupplyChainState:
        """
        Executes the SQL queries generated by the SQL Generation Agent.
        Args:
            state (SupplyChainState): The current state of the supply chain chatbot.
        Returns:
            SupplyChainState: The updated state of the supply chain chatbot.
        """
        print(f' - Chatbot CALL: sql_agent_execution()')

        queries = state.get("generated_queries", [])
        query_results = []
        if not queries:
            print("SQL Agent Execution: No generated queries found.")
        else:
            for sql_query in queries:
                print(f"SQL Agent Execution: Executing query: {sql_query}")
                query_result = execute_query.run({"conn": conn, "sql": sql_query})
                query_results.append(query_result)

        state["query_results"] = query_results
        return state

In [None]:
# Create an instance of the SQLAgentExecution class
sql_agent_execution = SQLAgentExecution(llm)

# Create a sample SupplyChainState with a generated SQL query
state = {
    "generated_queries": ["SELECT * FROM InventoryOH LIMIT 5;"]
}

# Run the SQLAgentExecution with the sample state
new_state = sql_agent_execution.run(state)

# Print the result
print(new_state)

Agent 04 - Root Cause Analyzer

Role:
Interprets the data retrieved from the simulation to explain why a supply chain issue occurred.

Responsibilities:

Analyzes historical trends, lead times, calendar constraints, and replenishment parameters.
Detects anomalies and flags potential root causes (e.g., delayed shipments, inaccurate forecasts).
Leverages domain knowledge and prompt chaining to explain issues in human-readable form.
Supports users in understanding complex inventory behaviors and system bottlenecks.

This agent provides transparency into the "why" behind disruptions, enabling smarter corrective actions.

In [None]:
import json
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage

class RootCauseAnalyzerAgent:
  def __init__(self, llm: ChatGoogleGenerativeAI, google_api_key: str) -> None:
    """
    Role:
    Interprets the data retrieved from the simulation to explain why a supply chain issue occurred.

    Responsibilities:
    Analyzes historical trends, lead times, calendar constraints, and replenishment parameters.
    Detects anomalies and flags potential root causes (e.g., delayed shipments, inaccurate forecasts).
    Leverages domain knowledge and prompt chaining to explain issues in human-readable form.
    Supports users in understanding complex inventory behaviors and system bottlenecks.

    This agent provides transparency into the "why" behind disruptions, enabling smarter corrective actions.
    """


    self.llm = ChatGoogleGenerativeAI(
        model="gemini-1.5-flash",
        temperature=0.2,
        google_api_key=google_api_key
    )

  def run(self, state: SupplyChainState) -> SupplyChainState:
    """
    Executes the Root Cause Analysis Agent.
    Args:
        state (SupplyChainState): The current state of the supply chain chatbot.
    Returns:
        SupplyChainState: The updated state of the supply chain chatbot.
    """
    print(f' - Chatbot CALL: root_cause_analyzer_agent()')

    user_q = state.get('user_input', '')
    queries = state.get('generated_queries', [])
    results = state.get('query_results', [])
    schema = state.get('schema', [])


    # Pair up queries and results into a block for the LLM

    qr_block = []
    for q, res in zip(queries, results):
      qr_block.append(f"Query: {q}\nResult: {res}")
    qr_block = "\n".join(qr_block) if qr_block else "No queries or results to analyze."
    print(f'Root Cause Analyzer Agent: QR Block: {qr_block}')

    prompt = f"""
    You are an expert Root Cause Analyzer for a retail Supply Chain database.
    Schema:Only use the tables & columns listed below.
    {json.dumps(state.get('schema', []))}

    ----Example----

    1) In: "What is the current inventory of item_id=1 at store 5?"
    Out:
    [
      "SELECT EODQty FROM InventoryOH WHERE store_id=5 AND item_id=1 ORDER BY InventoryDt DESC LIMIT 1;"
    ]
    2) In: "Show me the last 7 days of sales for item 2 at store 3."
    Out:
    [
      "SELECT Sales_dt, Sales FROM StoreSales WHERE store

      Respond in the plain language, without SQL code or other explanations.
    """

    response = self.llm.invoke([HumanMessage(content=prompt)])
    response_text = response.content
    print(f'Root Cause Analyzer Agent: Response: {response_text}')
    state['root_cause'] = response_text
    return state

# Root cause analyzer

In [None]:
# Step 1 - Grab you're API key

from google.colab import userdata

GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
#print(f"GOOGLE_API_KEY: {GOOGLE_API_KEY}")

test_state = {
      "user_input": "What is the current inventory of eggs in Arizona?",
      "generated_queries": [
              "SELECT EODQty FROM InventoryOH WHERE Store_id=5 AND Item_id = 1 ORDER By InventoryDt DESC LIMIT 1;",

      ], "query_results":[[
          (0,)]

      ]

}

# Step 3 - Instantiate the analyzer with the API key

analyzer = RootCauseAnalyzerAgent(llm, GOOGLE_API_KEY)

# Step 4 - Run the analyzer and print the output

output_state = analyzer.run(test_state)
print("🔍 Root Cause Summary:\n", output_state["root_cause"])

Step 5 - Initialize the LangGraph

This step sets up the multi-agent workflow using LangGraph. Each agent is represented as a node in the graph, with clearly defined transitions based on the outcome of their tasks.

The graph architecture includes:

Manager Agent
SQL Generation Agent
SQL Execution Agent
Root Cause Analyzer
Clarification Node (for incomplete inputs)

We'll define the nodes, their logic, and the edges connecting them to ensure a seamless flow of tasks—from user input to final response.


In [None]:
from langgraph.graph import StateGraph, END
import logging

logging.basicConfig(level=logging.INFO)

# Create a StateGraph instance
graph = StateGraph(SupplyChainState)

class SupplyChainGraph:
    def __init__(self) -> None:
        self.welcome = welcome_node
        self.clarifier = ClarificationNode(llm)
        self.manager = ManagerAgent(llm, self.clarifier)
        self.sql_generation = sql_generation_agent
        self.sql_execution = SQLAgentExecution(llm)
        self.root_cause_analyzer = RootCauseAnalyzerAgent(llm, GOOGLE_API_KEY)
        self.schema_explorer = SchemaExplorerAgent()

    def run(self):
        graph.add_node("welcome", self.welcome)
        graph.add_node("clarifier", self.clarifier.run)
        graph.add_node("manager", self.manager.run)
        graph.add_node("sql_generation", self.sql_generation)
        graph.add_node("sql_execution", self.sql_execution.run)
        graph.add_node("root_cause_analyzer", self.root_cause_analyzer.run)
        graph.add_node("schema_explorer", self.schema_explorer.run)

        graph.set_entry_point("welcome")

        # ---Edges/Routing---

        # Welcome -> manager
        graph.add_edge("welcome", "manager")

        # Clarifier -> SQL Generation
        graph.add_edge("clarifier", "sql_generation")

        #after manager -> clarification or schema exploration

        def after_manager(state):
          if state.get("clarify", False):
            return "clarifier"
          else:
            return "schema_explorer"

        graph.add_conditional_edges("manager", after_manager, {
            "clarify": "clarifier",
            "schema_explorer": "schema_explorer",
            "end": END,
        })

        # fixed linear flow from schema -> SQL gen -> SQL exec -> root cause -> END

        graph.add_edge("schema_explorer", "sql_generation")
        graph.add_edge("sql_generation", "sql_execution")
        graph.add_edge("sql_execution", "root_cause_analyzer")
        graph.add_edge("root_cause_analyzer", END)

        # Run the graph
        return graph.compile()

# Instantiate and run the graph
app = SupplyChainGraph().run()

In [None]:
# Run the graph with a sample user query
for event in app.stream({
    "user_input": "What is the current inventory of item 1 at store 5?"
}):
    print(event)

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableConfig
from langgraph.graph import StateGraph, END

class SQLAgentExecution:
    def __init__(self, llm: ChatGoogleGenerativeAI) -> None:
        self.llm = ChatGoogleGenerativeAI(
            model="gemini-1.5-flash",
            temperature=0.2
        )
        self.sql_generation_agent = SQLGenerationAgent(llm)

    def run(self, state: SupplyChainState) -> SupplyChainState:
        """
        Executes the SQL queries generated by the SQL Generation Agent.
        Args:
            state (SupplyChainState): The current state of the supply chain chatbot.
        Returns:
            SupplyChainState: The updated state of the supply chain chatbot.
        """
        print(f' - Chatbot CALL: sql_agent_execution()')

        queries = state.get("generated_queries", [])
        query_results = []
        if not queries:
            print("SQL Agent Execution: No generated queries found.")
        else:
            for sql_query in queries:
                print(f"SQL Agent Execution: Executing query: {sql_query}")
                query_result = execute_query.run({"conn": conn, "sql": sql_query})
                query_results.append(query_result)

        state["query_results"] = query_results
        return state

# Create a new instance of the SQLAgentExecution class
sql_agent_execution = SQLAgentExecution(llm)

# Create a new graph
workflow = StateGraph(SupplyChainState)

# Add the nodes to the graph
workflow.add_node("welcome", welcome_node)
workflow.add_node("manager_agent", manager_agent.run)
workflow.add_node("sql_generation_agent", sql_generation_agent)
workflow.add_node("sql_agent_execution", sql_agent_execution.run)
workflow.add_node("clarification_node", clarification_node.run)

# Set the entry point of the graph
workflow.set_entry_point("welcome")

# Add the edges to the graph
workflow.add_edge("welcome", "manager_agent")
workflow.add_edge("manager_agent", "sql_generation_agent")
workflow.add_edge("sql_generation_agent", "sql_agent_execution")
workflow.add_edge("sql_agent_execution", END)
workflow.add_edge("clarification_node", END)

# Compile the graph
app = workflow.compile()

In [None]:
# Welcome node or nodes.py

class WelcomeNode:
  def __init__(self, llm: ChatGoogleGenerativeAI) -> None:
    self.llm = llm
    print(f' - Chatbot CALL: welcome_node()')
    print("Ask me why a product is out of stock at a store, and I'll investigate")
    state["clarify"] = True
    return state



# Testing Stockout

In [None]:
def run(state: SupplyChainState) -> SupplyChainState:
  graph = StateGraph(SupplyChainState)

  # Add nodes and edges to the graph
  welcome = welcome_node
  clarifier = ClarificationNode(llm)
  manager = ManagerAgent(llm, clarifier)
  sql_generation = sql_generation_agent
  sql_execution = SQLAgentExecution(llm)
  root_cause_analyzer = RootCauseAnalyzerAgent(llm, GOOGLE_API_KEY)
  schema_explorer = SchemaExplorerAgent()

  graph.add_node("welcome", welcome)
  graph.add_node("clarifier", clarifier.run)
  graph.add_node("manager", manager.run)
  graph.add_node("sql_generation", sql_generation)
  graph.add_node("sql_execution", sql_execution.run)
  graph.add_node("root_cause_analyzer", root_cause_analyzer.run)
  graph.add_node("schema_explorer", schema_explorer.run)

  graph.set_entry_point("welcome")

  # ---Edges/Routing---

  # Welcome -> manager
  graph.add_edge("welcome", "manager")

  # Clarifier -> SQL Generation
  graph.add_edge("clarifier", "sql_generation")

  #after manager -> clarification or schema exploration

  def after_manager(state):
    if state.get("clarify", False):
      return "clarifier"
    else:
      return "schema_explorer"

  graph.add_conditional_edges("manager", after_manager, {
      "clarify": "clarifier",
      "schema_explorer": "schema_explorer",
      "end": END,
  })

  # fixed linear flow from schema -> SQL gen -> SQL exec -> root cause -> END

  graph.add_edge("schema_explorer", "sql_generation")
  graph.add_edge("sql_generation", "sql_execution")
  graph.add_edge("sql_execution", "root_cause_analyzer")
  graph.add_edge("root_cause_analyzer", END)

  # Compile the graph before streaming
  app = graph.compile()

  # Get the actual schema to provide to the agent
  schema_agent = SchemaExplorerAgent()
  db_schema = schema_agent.run("")


  state = {"user_input": "What is the current inventory of item 2 at store 5?",
           "store_id": 5,
           "item_id": 2,
           "quantity": None,
           "schema": db_schema, # Pass the correct schema
           "clarify": False, # Should be False to follow the main path
           "executed_results":'',
           "generated_queries": [],
           "query_results": [],
           "root_cause": ""

           }

  for event in app.stream(state):
    print(event)

if __name__ == "__main__":
  # The initial state for the run function call needs to be defined
  # or passed in a way that the function can access it.
  # For this fix, we will define a minimal state.
  initial_state = {
    "user_input": "What is the current inventory of item 2 at store 5?",
    "store_id": 5,
    "item_id": 2,
    "quantity": None,
    "clarify": False,
    "executed_results":'',
    "generated_queries": [],
    "query_results": [],
    "root_cause": ""
  }
  run(initial_state)

In [None]:
!pip install streamlit

#We will use streamlit as an UI to query the inventory and give the agent name as SCEPTRE

In [None]:
%%writefile app.py

import streamlit as st
import sqlite3
import json
import os
from typing import TypedDict
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from langchain.tools import tool
from langgraph.graph import StateGraph, END
from pydantic import BaseModel

# --- Database Setup and Tools ---

# This assumes 'database.db' is in the same directory.
# In a real app, you'd manage this connection more robustly.
db_file = 'database.db'
conn = sqlite3.connect(db_file, check_same_thread=False)

@tool
def list_tables() -> list[str]:
    """Retrieves a list of tables in the database."""
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    return [table[0] for table in tables]

@tool
def describe_table(table_name: str) -> list[tuple[str, str]]:
    """Describes the schema of a specific table."""
    cursor = conn.cursor()
    cursor.execute(f"PRAGMA table_info({table_name});")
    schema = cursor.fetchall()
    return [(col[1], col[2]) for col in schema]

@tool
def execute_query(sql: str) -> list[list[str]]:
    """Executes an SQL query and returns the results."""
    cursor = conn.cursor()
    cursor.execute(sql)
    return cursor.fetchall()

# --- State and Agent Definitions ---

class SupplyChainState(TypedDict):
    user_input: str
    store_id: int | None
    item_id: int | None
    quantity: int | None
    clarify: bool
    schema: dict
    generated_queries: list
    query_results: list
    root_cause: str

# Use the API key set in the environment
api_key = os.getenv('GOOGLE_API_KEY')
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.2, google_api_key=api_key)

class GeminiExtraction(BaseModel):
    store_id: int
    item_id: int
    quantity: int

def welcome_node(state: SupplyChainState) -> SupplyChainState:
    print(f' - Chatbot CALL: welcome_node()')
    return state

class ManagerAgent:
    def __init__(self, llm_instance):
        self.llm = llm_instance.with_structured_output(GeminiExtraction)

    def run(self, state: SupplyChainState) -> SupplyChainState:
        print('Manager Agent: Validating user input...')
        query = state.get('user_input', '')
        if not query:
            return state

        try:
            extracted = self.llm.invoke(query)
            state['store_id'] = extracted.store_id
            state['item_id'] = extracted.item_id
            state['quantity'] = extracted.quantity
            print(f"Manager Agent: Extracted: store_id={extracted.store_id}, item_id={extracted.item_id}, quantity={extracted.quantity}")
            state['clarify'] = False
        except Exception as e:
            print(f"Manager Agent: Error during extraction: {e}")
            state['clarify'] = True # If extraction fails, ask for clarification
        return state

class SchemaExplorerAgent:
    def run(self, state: SupplyChainState) -> SupplyChainState:
        print("Schema Explorer Agent: Exploring schema...")
        tables = list_tables.run()
        schema_info = {}
        for table in tables:
            if table == 'sqlite_sequence': continue
            schema_info[table] = describe_table.run(table_name=table)
        state['schema'] = schema_info
        return state

def sql_generation_agent(state: SupplyChainState) -> SupplyChainState:
    print(f' - Chatbot CALL: sql_generation_agent()')
    prompt = f"""
    You are an expert SQL Gen Agent. Based on the user question and the database schema, generate the correct SQL query.
    Schema: {json.dumps(state.get('schema', {}), indent=2)}
    User Question: {state.get('user_input', '')}
    Store ID: {state.get('store_id')}
    Item ID: {state.get('item_id')}
    Return **Only** a JSON array of SQL queries. For example: ["SELECT * FROM Store;"]
    """
    response = llm.invoke(prompt)
    try:
        # Extract JSON from the response, which might be in a markdown block
        match = re.search(r"```json\n(.*)\n```", response.content, re.DOTALL)
        json_text = match.group(1) if match else response.content
        queries = json.loads(json_text)
        state["generated_queries"] = queries
    except (json.JSONDecodeError, AttributeError) as e:
        print(f"SQL Generation Error: {e}")
        state["generated_queries"] = [f"SELECT EODQty FROM InventoryOH WHERE store_id={state.get('store_id')} AND item_id={state.get('item_id')} ORDER BY InventoryDt DESC LIMIT 1;"]
    print(f"SQL Generation Agent: Generated queries: {state['generated_queries']}")
    return state

class SQLAgentExecution:
    def run(self, state: SupplyChainState) -> SupplyChainState:
        print(f' - Chatbot CALL: sql_agent_execution()')
        queries = state.get("generated_queries", [])
        results = []
        for sql in queries:
            try:
                results.append(execute_query.run(sql=sql))
            except Exception as e:
                results.append(f"Error executing query '{sql}': {e}")
        state["query_results"] = results
        return state

class RootCauseAnalyzerAgent:
    def run(self, state: SupplyChainState) -> SupplyChainState:
        print(f' - Chatbot CALL: root_cause_analyzer_agent()')
        prompt = f"""
        You are an expert Root Cause Analyzer. Based on the user's question and the results of the SQL queries, provide a plain language explanation.
        User Question: {state.get('user_input', '')}
        SQL Queries and Results: {json.dumps(dict(zip(state.get('generated_queries', []), state.get('query_results', []))), indent=2)}
        Respond in plain language.
        """
        response = llm.invoke(prompt)
        state['root_cause'] = response.content
        return state

def clarification_node(state: SupplyChainState) -> SupplyChainState:
    print("Clarification Node: Asking user for more information.")
    state['root_cause'] = "I'm sorry, I couldn't understand your request. Please provide a clear store ID, item ID, and quantity."
    return state

# --- Graph Definition ---
def build_graph():
    graph = StateGraph(SupplyChainState)
    manager_agent = ManagerAgent(llm)
    schema_explorer = SchemaExplorerAgent()
    sql_executor = SQLAgentExecution()
    root_cause_analyzer = RootCauseAnalyzerAgent()

    graph.add_node("welcome", welcome_node)
    graph.add_node("manager", manager_agent.run)
    graph.add_node("clarifier", clarification_node)
    graph.add_node("schema_explorer", schema_explorer.run)
    graph.add_node("sql_generation", sql_generation_agent)
    graph.add_node("sql_execution", sql_executor.run)
    graph.add_node("root_cause_analyzer", root_cause_analyzer.run)

    graph.set_entry_point("welcome")
    graph.add_edge("welcome", "manager")

    def after_manager(state):
        return "clarifier" if state.get("clarify") else "schema_explorer"
    graph.add_conditional_edges("manager", after_manager, {"clarifier": "clarifier", "schema_explorer": "schema_explorer"})

    graph.add_edge("schema_explorer", "sql_generation")
    graph.add_edge("sql_generation", "sql_execution")
    graph.add_edge("sql_execution", "root_cause_analyzer")
    graph.add_edge("root_cause_analyzer", END)
    graph.add_edge("clarifier", END)

    return graph.compile()

# --- Streamlit UI ---
st.title("Sceptre AI Supply Chain Assistant")
st.write("Ask me to investigate stockouts or inventory levels (e.g., 'Why is item 1 out of stock at store 5?')")

user_input = st.text_input("Your question:")

if user_input:
    with st.spinner("Sceptre is investigating..."):
        # Build and compile the graph
        app = build_graph()

        # Initial state
        initial_state = {
            "user_input": user_input,
            "clarify": False,
        }

        # Stream the events
        final_state = None
        for event in app.stream(initial_state):
            st.write(event)
            # The final state is in the last event
            if END in event:
                final_state = event[END]


        st.divider()
        st.header("Investigation Summary")
        if final_state and final_state.get('root_cause'):
            st.write(final_state['root_cause'])
        else:
            st.write("Sorry, I couldn't determine the root cause.")

# Task
Explain the error in the following code. If possible, fix the error and incorporate the changes into the existing code. Otherwise, try to diagnose the error.

```python
%%sql
SELECT
    S.StoreState,
    S.StoreZip,

    SUM(I.SalesDollar) AS SumSalesDollar,
    SUM(I.SalesQuantity) AS SumSalesQuantity,
    SUM(I.InventoryOnHand) AS SumInventoryOnHand,
    SUM(I.onOrder) AS SumOnOrder,
    SUM(I.InventoryOnHand) + SUM(I.onOrder) AS TotalInventory
FROM
    InventoryOH I
JOIN
    Store S ON I.Store = S.Store
GROUP BY
    S.StoreState, S.StoreZip
ORDER BY
    S.StoreState, S.StoreZip;
```