In [None]:
from dotenv import load_dotenv
import os

load_dotenv()

from langchain_ollama.chat_models import ChatOllama
from langchain_groq import ChatGroq

# llm_model = ChatGroq(name="spark_sql", model="openai/gpt-oss-120b", temperature=0.0)
llm_model = ChatOllama(name="spark_sql", model="gpt-oss:120b-cloud", temperature=0.0)

from langchain_community.tools.spark_sql.tool import ( InfoSparkSQLTool, ListSparkSQLTool, QuerySparkSQLTool, QueryCheckerTool, )
from langchain_community.agent_toolkits import SparkSQLToolkit
from langchain_community.utilities.spark_sql import SparkSQL
from langchain_core.tools import tool

from pydantic import BaseModel, Field

from pyspark.sql.connect.client.core import SparkConnectGrpcException
from databricks.connect import DatabricksSession

catalog = os.environ.get("UC_CATALOG_NAME")
schema = os.environ.get("UC_SCHEMA_NAME")

# --- Session factory ---
def get_spark_session():
    """Get or create a live Databricks Serverless SparkSession."""
    try:
        spark = DatabricksSession.builder.getOrCreate()
    except SparkConnectGrpcException:
        spark = DatabricksSession.builder.create()
    return spark

class SparkSQLResponse(BaseModel):
    """Always use this tool to structure your response to the user."""
    question: str = Field(..., description="The user question.")
    query: str = Field(..., description="The query you have executed using 'query_sql_db' tool.")
    response: str = Field(..., description="The final response after the query is executed.")

# --- Dynamic subclasses ---
class DynamicQuerySparkSQLTool(QuerySparkSQLTool):
    """Dynamic variant of QuerySparkSQLTool with auto-refreshing SparkSession."""

    def _run(self, query: str, **kwargs):
        _spark = get_spark_session()
        self.db = SparkSQL(spark_session=_spark, catalog=catalog, schema=schema)
        return super()._run(query, **kwargs)


class DynamicInfoSparkSQLTool(InfoSparkSQLTool):
    """Dynamic variant of InfoSparkSQLTool."""

    def _run(self, table_names: str, **kwargs):
        _spark = get_spark_session()
        self.db = SparkSQL(spark_session=_spark, catalog=catalog, schema=schema)
        return super()._run(table_names, **kwargs)


class DynamicListSparkSQLTool(ListSparkSQLTool):
    """Dynamic variant of ListSparkSQLTool."""

    def _run(self, tool_input: str = "", **kwargs):
        _spark = get_spark_session()
        self.db = SparkSQL(spark_session=_spark, catalog=catalog, schema=schema)
        return super()._run(tool_input, **kwargs)


class DynamicQueryCheckerTool(QueryCheckerTool):
    """Dynamic variant of QueryCheckerTool."""

    def _run(self, query: str, **kwargs):
        _spark = get_spark_session()
        self.db = SparkSQL(spark_session=_spark, catalog=catalog, schema=schema)
        return super()._run(query, **kwargs)

    async def _arun(self, query: str, **kwargs):
        _spark = get_spark_session()
        self.db = SparkSQL(spark_session=_spark, catalog=catalog, schema=schema)
        return await super()._arun(query, **kwargs)

spark = get_spark_session()
spark_sql = SparkSQL(spark_session=spark, catalog=catalog, schema=schema)

query_spark_tool = DynamicQuerySparkSQLTool(db=spark_sql)
info_spark_tool = DynamicInfoSparkSQLTool(db=spark_sql)
list_spark_tool = DynamicListSparkSQLTool(db=spark_sql)
check_spark_tool = DynamicQueryCheckerTool(db=spark_sql, llm=llm_model)

tool_objects = [query_spark_tool.as_tool(), info_spark_tool.as_tool(), list_spark_tool.as_tool(), check_spark_tool.as_tool(), SparkSQLResponse]

  tool_objects = [query_spark_tool.as_tool(), info_spark_tool.as_tool(), list_spark_tool.as_tool(), check_spark_tool.as_tool(), SpqrkSQLResponse]


In [2]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, HumanMessagePromptTemplate

SYSTEM_PROMPT = """
## Role

You are a **Spark SQL tool-using agent** responsible for answering data-related questions by generating and executing SQL queries on a Databricks Lakehouse environment.  
You interact with Spark through a live Spark session and must always use the **Spark SQL dialect** when writing queries.

You MUST use the provided tools to obtain all information — never rely on your own assumptions or memory.  
Do **not** respond conversationally or with confirmations like "Got it".  
Every single response must either:
1. Call one or more tools to gather information or execute queries, OR
2. Return a final output containing both the executed SQL query and its markdown-formatted results.

---

## Environment Context

- The Spark session is connected to the Databricks Lakehouse using **Databricks Connect**.
- Use the following catalog and schema names:
  - **Catalog:** `{CATALOG_NAME}`
  - **Schema:** `{SCHEMA_NAME}`
- All queries must explicitly reference this context in the form:

```
SELECT * FROM <catalog>.<schema>.<table_name>
```

- Always use Spark SQL dialect conventions (Databricks SQL), including functions, syntax, and operators supported by Spark 3.x+.

## Tool Usage Policy

For **every user query**:
1. Start by calling `list_tables_sql_db` to see what tables exist. If that is all the user asked then return these results.
2. Then call `schema_sql_db` for any relevant tables to understand structure and columns.  
3. Use that schema information to construct a **fully qualified** Spark SQL query referencing the correct catalog and schema.  
4. Validate the query using `query_checker_sql_db`.  
5. Execute it via `query_sql_db`.  

If any tool (especially `query_sql_db`) returns an error, summarize it clearly.  
Do **not** include full stack traces — only the main error message and a concise explanation.

## Critical Reminders
- Always use **fully qualified table names**: `<catalog>.<schema>.<table>`.
- Use **Spark SQL syntax only** (no T-SQL, MySQL, or Postgres syntax).
- Do not invent column names, table names, or joins.
- Only base your queries on information gathered from the tools.
- Return concise, structured, markdown-formatted outputs.
"""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", SYSTEM_PROMPT.format(**{"CATALOG_NAME": catalog, "SCHEMA_NAME": schema})),
        # ("user", user_prompt),
        # Placeholders fill up a **list** of messages
        ("placeholder", "{messages}"),
        ("placeholder", "{agent_scratchpad}"),
    ]
)

In [3]:
from langgraph.checkpoint.memory import MemorySaver 
from langgraph.prebuilt import create_react_agent
from langgraph.errors import GraphRecursionError

RECURSION_LIMIT = 5

memory = MemorySaver()
tools = [tool_object for tool_object in tool_objects]

langgraph_agent_executor = create_react_agent(llm_model, tools, prompt=prompt, checkpointer=memory)
langgraph_agent_executor.step_timeout = None

config = {"configurable": {"thread_id": "test-thread"}}
query = "List the top 3 nations based on the total number of customers from that nation."

In [4]:
out = langgraph_agent_executor.invoke({"messages": [("user", query)]}, config=config)
out

{'messages': [HumanMessage(content='List the top 3 nations based on the total number of customers from that nation.', additional_kwargs={}, response_metadata={}, id='b543a395-6e67-4909-a153-75ecc9443b70'),
  AIMessage(content='', additional_kwargs={}, response_metadata={'model': 'gpt-oss:120b-cloud', 'created_at': '2025-10-23T18:07:19.092992708Z', 'done': True, 'done_reason': 'stop', 'total_duration': 588072360, 'load_duration': None, 'prompt_eval_count': 1041, 'prompt_eval_duration': None, 'eval_count': 58, 'eval_duration': None, 'model_name': 'gpt-oss:120b-cloud'}, id='run--6aede515-b1f1-4389-b38a-7fdfe8756a47-0', tool_calls=[{'name': 'list_tables_sql_db', 'args': {'tool_input': ''}, 'id': 'b92b0ee0-9141-413e-879e-69117708067b', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1041, 'output_tokens': 58, 'total_tokens': 1099}),
  ToolMessage(content='customer, lineitem, nation, orders, part, partsupp, region, supplier', name='list_tables_sql_db', id='20cda84e-d7b5-4ccd-ba1b-82e9

In [5]:
list_spark_tool.run(tool_input="")

'customer, lineitem, nation, orders, part, partsupp, region, supplier'

In [6]:
info_spark_tool.run(tool_input="customer, lineitem")

'CREATE TABLE tpch.bronze.customer (\n  c_custkey BIGINT,\n  c_name STRING,\n  c_address STRING,\n  c_nationkey BIGINT,\n  c_phone STRING,\n  c_acctbal DECIMAL(18,2),\n  c_mktsegment STRING,\n  c_comment STRING)\n;\n\n/*\n3 rows from customer table:\nc_custkey\tc_name\tc_address\tc_nationkey\tc_phone\tc_acctbal\tc_mktsegment\tc_comment\n412445\tCustomer#000412445\t0QAB3OjYnbP6mA0B,kgf\t21\t31-421-403-4333\t5358.33\tBUILDING\tarefully blithely regular epi\n412446\tCustomer#000412446\t5u8MSbyiC7J,7PuY4Ivaq1JRbTCMKeNVqg \t20\t30-487-949-7942\t9441.59\tMACHINERY\tsleep according to the fluffily even forges. fluffily careful packages after the ironic, silent deposi\n412447\tCustomer#000412447\tHC4ZT62gKPgrjr ceoaZgFOunlUogr7GO\t7\t17-797-466-6308\t7868.75\tAUTOMOBILE\taggle blithely among the carefully express excus\n*/\n\nCREATE TABLE tpch.bronze.lineitem (\n  l_orderkey BIGINT,\n  l_partkey BIGINT,\n  l_suppkey BIGINT,\n  l_linenumber INT,\n  l_quantity DECIMAL(18,2),\n  l_extendedprice D