In [1]:
from dotenv import load_dotenv
import os

load_dotenv()

from langchain_ollama.chat_models import ChatOllama
from langchain_groq import ChatGroq

# llm_model = ChatGroq(name="spark_sql", model="openai/gpt-oss-120b", temperature=0.0)
llm_model = ChatOllama(name="spark_sql", model="gpt-oss:120b-cloud", temperature=0.0)

from langchain_community.tools.spark_sql.tool import ( InfoSparkSQLTool, ListSparkSQLTool, QuerySparkSQLTool, QueryCheckerTool, )
from langchain_community.agent_toolkits import SparkSQLToolkit
from langchain_community.utilities.spark_sql import SparkSQL

from langchain_core.tools import tool

from pyspark.sql.connect.client.core import SparkConnectGrpcException
from databricks.connect import DatabricksSession

catalog = os.environ.get("UC_CATALOG_NAME")
schema = os.environ.get("UC_SCHEMA_NAME")

# --- Session factory ---
def get_spark_session():
    """Get or create a live Databricks Serverless SparkSession."""
    try:
        spark = DatabricksSession.builder.getOrCreate()
    except SparkConnectGrpcException:
        spark = DatabricksSession.builder.create()
    return spark

# --- Dynamic subclasses ---
class DynamicQuerySparkSQLTool(QuerySparkSQLTool):
    """Dynamic variant of QuerySparkSQLTool with auto-refreshing SparkSession."""

    def _run(self, query: str, **kwargs):
        _spark = get_spark_session()
        self.db = SparkSQL(spark_session=_spark, catalog=catalog, schema=schema)
        return super()._run(query, **kwargs)


class DynamicInfoSparkSQLTool(InfoSparkSQLTool):
    """Dynamic variant of InfoSparkSQLTool."""

    def _run(self, table_names: str, **kwargs):
        _spark = get_spark_session()
        self.db = SparkSQL(spark_session=_spark, catalog=catalog, schema=schema)
        return super()._run(table_names, **kwargs)


class DynamicListSparkSQLTool(ListSparkSQLTool):
    """Dynamic variant of ListSparkSQLTool."""

    def _run(self, tool_input: str = "", **kwargs):
        _spark = get_spark_session()
        self.db = SparkSQL(spark_session=_spark, catalog=catalog, schema=schema)
        return super()._run(tool_input, **kwargs)


class DynamicQueryCheckerTool(QueryCheckerTool):
    """Dynamic variant of QueryCheckerTool."""

    def _run(self, query: str, **kwargs):
        _spark = get_spark_session()
        self.db = SparkSQL(spark_session=_spark, catalog=catalog, schema=schema)
        return super()._run(query, **kwargs)

    async def _arun(self, query: str, **kwargs):
        _spark = get_spark_session()
        self.db = SparkSQL(spark_session=_spark, catalog=catalog, schema=schema)
        return await super()._arun(query, **kwargs)

spark = get_spark_session()
spark_sql = SparkSQL(spark_session=spark, catalog=catalog, schema=schema)

query_spark_tool = DynamicQuerySparkSQLTool(db=spark_sql)
info_spark_tool = DynamicInfoSparkSQLTool(db=spark_sql)
list_spark_tool = DynamicListSparkSQLTool(db=spark_sql)
check_spark_tool = DynamicQueryCheckerTool(db=spark_sql, llm=llm_model)

tool_objects = [query_spark_tool, info_spark_tool, list_spark_tool, check_spark_tool]

In [2]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, HumanMessagePromptTemplate

system_prompt = """
You are a helpful databricks lakehouse assistant primarily used for exploring the databricks workspace artifacts by leveraging Unity Catalog and SparkSQL.
The databricks unity catalog follows a three-level namespace '<catalog>.<schema>.<table>' make sure you adhere to this while using the tools.

Help the users by answering their questions using the appropriate tools.
"""

user_prompt = """
Assist the user with their question.

IMPORTANT: Only output the required message, no other explanation or text can be provided.
"""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt),
        ("user", user_prompt),
        # Placeholders fill up a **list** of messages
        ("placeholder", "{messages}"),
        ("placeholder", "{agent_scratchpad}"),
    ]
)

In [3]:
from langgraph.checkpoint.memory import MemorySaver 
from langgraph.prebuilt import create_react_agent
from langgraph.errors import GraphRecursionError

RECURSION_LIMIT = 5

memory = MemorySaver()
tools = [tool_object.as_tool() for tool_object in tool_objects]

langgraph_agent_executor = create_react_agent(llm_model, tools, prompt=system_prompt, checkpointer=memory)
langgraph_agent_executor.step_timeout = None

config = {"configurable": {"thread_id": "test-thread"}}
query = "List all the tables present."

  tools = [tool_object.as_tool() for tool_object in tool_objects]


In [4]:
out = langgraph_agent_executor.invoke({"query": query}, config=config)
out

{'messages': [AIMessage(content='', additional_kwargs={}, response_metadata={'model': 'gpt-oss:120b-cloud', 'created_at': '2025-10-10T00:19:54.785970335Z', 'done': True, 'done_reason': 'stop', 'total_duration': 726446427, 'load_duration': None, 'prompt_eval_count': 596, 'prompt_eval_duration': None, 'eval_count': 50, 'eval_duration': None, 'model_name': 'gpt-oss:120b-cloud'}, id='run--916d5b93-ab2d-4088-b66c-bf21f0c55d0d-0', tool_calls=[{'name': 'list_tables_sql_db', 'args': {'tool_input': ''}, 'id': '94d5add7-3830-47f5-9e30-36839f95899f', 'type': 'tool_call'}], usage_metadata={'input_tokens': 596, 'output_tokens': 50, 'total_tokens': 646}),
  ToolMessage(content='customer, lineitem, nation, orders, part, partsupp, region, supplier', name='list_tables_sql_db', id='b52a677f-7cc1-47bc-baeb-c96ef3f3b068', tool_call_id='94d5add7-3830-47f5-9e30-36839f95899f'),
  AIMessage(content='I’m ready to help! What would you like to know or explore in the Databricks workspace?', additional_kwargs={}, 

In [5]:
query = "List the top 3 nations based on the total number of customers from that nation."
out = langgraph_agent_executor.invoke({"query": query}, config=config)
out

{'messages': [AIMessage(content='', additional_kwargs={}, response_metadata={'model': 'gpt-oss:120b-cloud', 'created_at': '2025-10-10T00:19:54.785970335Z', 'done': True, 'done_reason': 'stop', 'total_duration': 726446427, 'load_duration': None, 'prompt_eval_count': 596, 'prompt_eval_duration': None, 'eval_count': 50, 'eval_duration': None, 'model_name': 'gpt-oss:120b-cloud'}, id='run--916d5b93-ab2d-4088-b66c-bf21f0c55d0d-0', tool_calls=[{'name': 'list_tables_sql_db', 'args': {'tool_input': ''}, 'id': '94d5add7-3830-47f5-9e30-36839f95899f', 'type': 'tool_call'}], usage_metadata={'input_tokens': 596, 'output_tokens': 50, 'total_tokens': 646}),
  ToolMessage(content='customer, lineitem, nation, orders, part, partsupp, region, supplier', name='list_tables_sql_db', id='b52a677f-7cc1-47bc-baeb-c96ef3f3b068', tool_call_id='94d5add7-3830-47f5-9e30-36839f95899f'),
  AIMessage(content='I’m ready to help! What would you like to know or explore in the Databricks workspace?', additional_kwargs={}, 

In [6]:
list_spark_tool.run(tool_input="")

'customer, lineitem, nation, orders, part, partsupp, region, supplier'

In [13]:
info_spark_tool.run(tool_input="customer, lineitem")

'CREATE TABLE tpch.bronze.customer (\n  c_custkey BIGINT,\n  c_name STRING,\n  c_address STRING,\n  c_nationkey BIGINT,\n  c_phone STRING,\n  c_acctbal DECIMAL(18,2),\n  c_mktsegment STRING,\n  c_comment STRING)\n;\n\n/*\n3 rows from customer table:\nc_custkey\tc_name\tc_address\tc_nationkey\tc_phone\tc_acctbal\tc_mktsegment\tc_comment\n412445\tCustomer#000412445\t0QAB3OjYnbP6mA0B,kgf\t21\t31-421-403-4333\t5358.33\tBUILDING\tarefully blithely regular epi\n412446\tCustomer#000412446\t5u8MSbyiC7J,7PuY4Ivaq1JRbTCMKeNVqg \t20\t30-487-949-7942\t9441.59\tMACHINERY\tsleep according to the fluffily even forges. fluffily careful packages after the ironic, silent deposi\n412447\tCustomer#000412447\tHC4ZT62gKPgrjr ceoaZgFOunlUogr7GO\t7\t17-797-466-6308\t7868.75\tAUTOMOBILE\taggle blithely among the carefully express excus\n*/\n\nCREATE TABLE tpch.bronze.lineitem (\n  l_orderkey BIGINT,\n  l_partkey BIGINT,\n  l_suppkey BIGINT,\n  l_linenumber INT,\n  l_quantity DECIMAL(18,2),\n  l_extendedprice D