## Prepare NLP

In [1]:
import google.generativeai as genai


class GeminiNLP:
    def __init__(self, gemini_client: genai):
        self.genai = gemini_client

    def chat(self, model_name, instructions, messages):
        model = self.genai.GenerativeModel(
            model_name=model_name, system_instruction=instructions
        )
        response = model.generate_content(messages)
        return response.text

    def struct_output(self, model_name, instructions, messages, structure):
        model = self.genai.GenerativeModel(
            model_name=model_name, system_instruction=instructions
        )
        response = model.generate_content(
            contents=messages,
            generation_config={
                "response_mime_type": "application/json",
                "response_schema": structure,
            },
        )
        return structure.model_validate_json(response.text)

    def func_call(self, model_name, messages, instructions, func):
        model = self.genai.GenerativeModel(
            model_name=model_name, system_instruction=instructions
        )
        try:
            response = model.generate_content(messages, tools=[func])
            call = response.candidates[0].content.parts[0].function_call

            if call:
                try:
                    result = func(**call.args)
                    return result
                except Exception as e:
                    args_dict = dict(call.args)
                    return f"Error when calling {call.name} with args {args_dict}: {e}"

        except Exception as e:
            return f"Error during model generation: {e}"

        return None

In [2]:
import numpy as np


def normalize_embeddings(vectors: list[list[float]]):
    return [
        (vec / np.linalg.norm(vec)) if np.linalg.norm(vec) != 0 else vec
        for vec in vectors
    ]


class CohereEmbeddings:
    def __init__(self, cohere_client):
        self.cohere_client = cohere_client

    async def embed(
        self,
        list_of_text: list[str],
        model_name="embed-v4.0",
        batch_size=10,
    ) -> list[list[float]]:
        vectors = []
        for i in range(0, len(list_of_text), batch_size):
            batch = list_of_text[i : i + batch_size]
            response = await self.cohere_client.embed(
                texts=batch,
                model=model_name,
                input_type="search_document",
                embedding_types=["float"],
                output_dimension=1024,
            )
            batch_vectors = response.embeddings.float
            normalized = normalize_embeddings(batch_vectors)
            vectors.extend(normalized)
        return vectors

In [3]:
import google.generativeai as genai
from cohere import AsyncClientV2

genai.configure(api_key="AIzaSyChTcVwL9R2wWC6GnYeRI1pE4BDaHoIYLU")
nlp = GeminiNLP(gemini_client=genai)

cohere_client = AsyncClientV2(api_key="q5K7ogyRXviI1pCR4PeTqvvjFyVGLkHvfwlv6EmE")
embedding = CohereEmbeddings(cohere_client=cohere_client)

## Prepare Data

In [4]:
import chromadb


class ChromaProvider:
    def __init__(self, path):
        self.path = path
        self.client = None

    def connect(self):
        self.client = chromadb.PersistentClient(path=self.path)

    def create_collection(self, name):
        self.client.create_collection(
            name=name,
            configuration={"hnsw": {"space": "cosine", "ef_construction": 200}},
        )

    def add_points(self, collection_name, ids, embeddings, metadata):
        collection = self.client.get_collection(name=collection_name)
        collection.add(ids=ids, embeddings=embeddings, metadatas=metadata)

    def semantic_search(self, collection_name, vector, top_k):
        collection = self.client.get_collection(name=collection_name)
        results = collection.query(query_embeddings=vector, n_results=top_k)
        return results["metadatas"]

    def metadata_filter(self, collection_name, key, value):
        collection = self.client.get_collection(name=collection_name)
        results = collection.get(where={key: value})
        return results["metadatas"]

In [5]:
# import json


# with open("data/sales.json", "r") as f:
#     data = json.load(f)

# chunks = data["chunks"]

In [6]:
# ids = [f"id_{i+1}" for i in range(len(chunks))]
# text_chunks = [chunk["description"] for chunk in chunks]

In [7]:
# import asyncio

# embeddings_list = list()
# for i in range(0, len(text_chunks), 20):
#     batch = text_chunks[i : i + 20]
#     batch_embeddings = await embedding.embed(batch)
#     embeddings_list.extend(batch_embeddings)
#     await asyncio.sleep(5)

In [8]:
vectordb = ChromaProvider(path="chromadb")
vectordb.connect()

In [9]:
# vectordb.create_collection("sales")
# vectordb.add_points("sales", ids, embeddings_list, chunks)

## Agentic Workflow

### Prompts

In [10]:
PROMPT_PLAN = """Your responsibility is to act as an expert query planner for an ERP system. Your goal is to deconstruct a user's question into a set of strategic search queries. These queries will be used to search a vector store containing column metadata (table_name, column_name, column_description).

**Crucial Context:** The `column_description` field explains the **business purpose** of each column (e.g., "Unique identifier for a sales order," "Net amount for an invoice line," "Customer master record key"). Your search queries must target this *business meaning*, not just the literal words in the user's question.

**Your Process:**
1.  **Analyze the User's Question:** Break it down into its core components:
    * **Metrics:** What is being measured? (e.g., total amount, count of items, average price).
    * **Dimensions/Filters:** What is the data being grouped or filtered by? (e.g., by customer, in a specific date range, for a product category, by region).
    * **Relationships:** What business concepts need to be linked? (e.g., "sales *for a customer*" implies joining `sales` and `customer` tables).

2.  **Generate Strategic Queries:** Based on your analysis, create a list of queries designed to find the specific columns needed to build the final SQL.
    * **For Metrics:** Generate queries for the *business concept* of the metric.
        * *Example:* If the user asks for "total revenue," search for `"column for total sales amount"` or `"metric for net revenue"`.
    * **For Dimensions/Filters:** Generate queries for the *business purpose* of each filter.
        * *Example:* If the user filters by "customer name," search for `"column for customer name"` or `"customer master name field"`.
        * *Example:* If the user filters by "last month," search for `"column for order date"` or `"transaction timestamp"`.
    * **For Joins (Foreign Keys):** When you identify a relationship, generate queries to find the keys that link the tables.
        * *Example:* To link sales and customers, search for `"foreign key linking sales to customers"` or `"customer ID in sales table"` and `"primary key for customer table"`.

**Instructions:**
* Think step-by-step to identify all necessary pieces of information.
* Generate multiple, specific queries. It's better to have more targeted queries than one vague one.
* Focus on *business terminology* relevant to an ERP (e.g., "general ledger," "invoice line," "bill of materials," "customer master," "sales order header").

user question: {user_message}
"""


PROMPT_SQL = """You atr a SQL agent, You will take user question with relevant columns names with its description and its table name.
You need to use wanted columns only to generate the SQL statement. Make sure that your query follows **Oracle 11g**.

user question: {user_message}

candidate relevant database schema:

{schema}
"""

### Agents

In [11]:
from typing_extensions import TypedDict
from pydantic import BaseModel, Field


class Queries(BaseModel):
    queries: list[str] = Field(
        ..., description="Search queries to get relevant column names"
    )


class SQL(BaseModel):
    sql: str = Field(..., description="Correct Oracle 11g SQL statement")


class State(TypedDict):
    user_message: str
    queries: Queries
    schema: str
    sql_query: SQL

In [12]:
def planner(state: State) -> State:
    user_message = state.get("user_message", "")
    prompt = PROMPT_PLAN.format(user_message=user_message)
    response = nlp.struct_output(
        "gemini-2.5-flash",
        "you are an sql generation planner agent",
        prompt,
        Queries,
    )
    return {"queries": response}

In [13]:
async def search(state: State) -> State:
    queries = state.get("queries").queries
    vectors = await embedding.embed(queries)

    results = vectordb.semantic_search("sales", vectors, 10)
    flatten = [item for sublist in results for item in sublist]

    seen = set()
    unique = []
    for res in flatten:
        key = res["column_name"]
        if key not in seen:
            seen.add(key)
            unique.append(res)

    schema_list = [
        f"Table: {res['table_name']}\nColumn: {res['column_name']}\nDescription: {res['description']}"
        for res in unique
    ]
    schema_text = "\n\n---\n\n".join(schema_list)

    return {"schema": schema_text}

In [14]:
def sql(state: State) -> State:
    user_message = state.get("user_message")
    schema = state.get("schema")
    prompt = PROMPT_SQL.format(user_message=user_message, schema=schema)
    response = nlp.struct_output(
        "gemini-2.5-flash",
        "you are an Oracle 11g sql generation agent",
        prompt,
        SQL,
    )
    return {"sql_query": response}

In [15]:
from langgraph.graph import StateGraph, END

workflow = StateGraph(State)
workflow.add_node("plan", planner)
workflow.add_node("search", search)
workflow.add_node("sql", sql)

workflow.set_entry_point("plan")
workflow.add_edge("plan", "search")
workflow.add_edge("search", "sql")
workflow.add_edge("sql", END)

graph = workflow.compile()

In [16]:
async for event in graph.astream({"user_message": "عايز اعرف الربح اخر شهر"}):
    print(event)

{'plan': {'queries': Queries(queries=['column for net profit amount', 'column for gross profit amount', 'column for sales order date', 'column for transaction date', 'column for revenue amount', 'column for cost of goods sold'])}}
{'search': {'schema': 'Table: PURCHS_BILL_MST_AI_VW\nColumn: PYMNT_CSH\nDescription: Cash payment amount for the purchase invoice.\n\n---\n\nTable: SALES_BILL_MST_AI_VW\nColumn: CSH_AMT\nDescription: Total cash payment amount received for the sales bill.\n\n---\n\nTable: SALES_BILL_MST_AI_VW\nColumn: TAX_AMT\nDescription: Total tax amount calculated for the sales transaction.\n\n---\n\nTable: GLS_PST_AI_VW\nColumn: DR_AMT_F\nDescription: Debit amount in foreign currency.\n\n---\n\nTable: GLS_PST_AI_VW\nColumn: DR_AMT_L\nDescription: Debit amount in local currency.\n\n---\n\nTable: SALES_BILL_MST_AI_VW\nColumn: BNK_AMT\nDescription: Total amount received through bank transfer or cheque payments.\n\n---\n\nTable: SALES_BILL_MST_AI_VW\nColumn: DSCNT_AMT\nDescrip