In [1]:
from dotenv import load_dotenv
import os
load_dotenv()

True

In [2]:
from databricks.sdk import WorkspaceClient
w = WorkspaceClient()

from langchain_groq import ChatGroq
llm = ChatGroq(model="gemma2-9b-it", temperature=0)

In [3]:
from langchain_core.tools import tool

from typing import Union
import json

@tool
def get_all_schemas_info(catalog_name: str) -> str:
    """Use this tool to find all available schemas present in the given catalog."""
    schemas = []
    workspace_client = w
    for schema in workspace_client.schemas.list(catalog_name=catalog_name):
        schema_info = json.dumps({
            "name": schema.name,
            "full_name": schema.full_name,
            "description": schema.comment
        })
        schemas.append(schema_info) if schema.name != "information_schema" else None
    schemas.append("\n")
    return "\n".join(schemas)

@tool()
def get_all_tables_info(catalog_name: str, schema_name: Union[str, list] | None = None) -> str:
    """Use this tool to find the information of all available tables (including description and columns information) present in the given catalog and schema"""
    tables = []
    workspace_client = w

    if schema_name is None:
        return f"The user has not specified the 'schema_name'. Collect all available schemas and recall this tool with the list of schemas as input"
    
    if isinstance(schema_name, str):
        for table in workspace_client.tables.list(catalog_name=catalog_name, schema_name=schema_name):
            table_constraints = w.tables.get(full_name=table.full_name).table_constraints #BUG table.table_constraints has some bug and is returning None
            table_info = json.dumps({
                "name": table.name,
                "full_name": table.full_name,
                "type": table.table_type.name,
                "description": table.comment,
                "columns": [table.columns[0].as_dict()] if table.columns else None,
                "constraints": table_constraints[0].as_dict() if table_constraints else None,
                "view_definition": table.view_definition,
                "view_dependencies": table.view_dependencies.as_dict() if table.view_dependencies else None
            })
            tables.append(table_info)
    else:
        for schema in schema_name:
            for table in workspace_client.tables.list(catalog_name=catalog_name, schema_name=schema):
                table_constraints = w.tables.get(full_name=table.full_name).table_constraints #BUG table.table_constraints has some bug and is returning None
                table_info = json.dumps({
                    "name": table.name,
                    "full_name": table.full_name,
                    "type": table.table_type.name,
                    "description": table.comment,
                    "columns": [table.columns[0].as_dict()] if table.columns else None,
                    "constraints": table_constraints[0].as_dict() if table_constraints else None,
                    "view_definition": table.view_definition,
                    "view_dependencies": table.view_dependencies.as_dict() if table.view_dependencies else None
                })
                tables.append(table_info)
    tables.append("\n")
    return "\n".join(tables)

In [4]:
from langchain_community.agent_toolkits import SparkSQLToolkit, create_spark_sql_agent
from langchain_community.utilities.spark_sql import SparkSQL

from pyspark.sql.connect.client.core import SparkConnectGrpcException
from databricks.connect import DatabricksSession

try:
    spark = DatabricksSession.builder.getOrCreate()
except SparkConnectGrpcException:
    spark = DatabricksSession.builder.create()

spark_sql = SparkSQL(spark_session=spark)
toolkit = SparkSQLToolkit(db=spark_sql, llm=llm)

In [None]:
from langchain_community.tools.spark_sql.tool import QueryCheckerTool, QuerySparkSQLTool

checker_prompt = """"""
query_prompt = """"""

query_checker = QueryCheckerTool(llm=llm, db=spark_sql)
query_executor = QuerySparkSQLTool(db=spark_sql)

In [6]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, HumanMessagePromptTemplate

system_promt = SystemMessagePromptTemplate.from_template(template="""
You are a helpful databricks lakehouse assistant primarily used for exploring the databricks workspace artifacts by leveraging Unity Catalog and SparkSQL.
The databricks unity catalog follows a three-level namespace '<catalog>.<schema>.<table>' make sure you adhere to this while using the tools.

You are only allowed to access '{catalog}' catalog from the workspace. Use this catalog by default if the user has not mentioned anything. 
Return 'Unauthorized to access, contact the Admin' if the user is explicitly trying/asking to access any other catalog.

Help the users by answering/executing their queries using the appropriate tools.
""",
input_variables=["catalog"]
)

human_prompt = HumanMessagePromptTemplate.from_template(template="""
Assist the user with the following query:
---
{query}
---
Only output the required message, no other explanation or text can be provided.
""",
input_variables=["query"]
)

prompt = ChatPromptTemplate.from_messages([
    system_promt,
    MessagesPlaceholder(variable_name="chat_history"),
    human_prompt,
    ("placeholder", "{agent_scratchpad}"),
])

In [7]:
from langchain.memory import ConversationBufferMemory

memory = ConversationBufferMemory(
    input_key="query",
    memory_key="chat_history",  # must align with MessagesPlaceholder variable_name
    output_key="output",
    return_messages=True  # to return Message objects
)

  memory = ConversationBufferMemory(


In [8]:
from langchain.agents import create_tool_calling_agent
from langchain.agents import AgentExecutor

tools = [get_all_schemas_info, get_all_tables_info, query_checker, query_executor]

agent = create_tool_calling_agent(llm=llm, tools=tools, prompt=prompt)

agent_executor = AgentExecutor(
    agent=agent,
    tools=tools,
    memory=memory,
    verbose=True
)

agent_executor.invoke({
    "catalog": "tpch",
    "query": "List all the table names that are available",
    "chat_history": memory.chat_memory.messages
})



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `get_all_tables_info` with `{'catalog_name': 'tpch'}`


[0m[33;1m[1;3mThe user has not specified the 'schema_name'. Collect all available schemas and recall this tool with the list of schemas as input[0m[32;1m[1;3m
Invoking: `get_all_schemas_info` with `{'catalog_name': 'tpch'}`


[0m[36;1m[1;3m{"name": "bronze", "full_name": "tpch.bronze", "description": ""}
{"name": "default", "full_name": "tpch.default", "description": "Default schema (auto-created)"}
{"name": "silver", "full_name": "tpch.silver", "description": ""}

[0m[32;1m[1;3m
Invoking: `get_all_tables_info` with `{'catalog_name': 'tpch', 'schema_name': ['bronze', 'default', 'silver']}`


[0m[33;1m[1;3m{"name": "customer", "full_name": "tpch.bronze.customer", "type": "MANAGED", "description": null, "columns": [{"name": "c_custkey", "nullable": true, "position": 0, "type_json": "{\"name\":\"c_custkey\",\"type\":\"long\",\"nullable\":true,\"met

{'catalog': 'tpch',
 'query': 'List all the table names that are available',
 'chat_history': [HumanMessage(content='List all the table names that are available', additional_kwargs={}, response_metadata={}),
  AIMessage(content='```json\n[\n  "customer",\n  "lineitem",\n  "nation",\n  "orders",\n  "part",\n  "partsupp",\n  "region",\n  "supplier",\n  "customer_features",\n  "orders_features",\n  "part_supplier_features"\n]\n```', additional_kwargs={}, response_metadata={})],
 'output': '```json\n[\n  "customer",\n  "lineitem",\n  "nation",\n  "orders",\n  "part",\n  "partsupp",\n  "region",\n  "supplier",\n  "customer_features",\n  "orders_features",\n  "part_supplier_features"\n]\n```'}

In [9]:
agent_executor.invoke({
    "catalog": "tpch",
    "query": "Describe the customer_features table from silver schema",
    "chat_history": memory.chat_memory.messages
})



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `get_all_tables_info` with `{'catalog_name': 'tpch', 'schema_name': 'silver'}`
responded: ```json


[0m[33;1m[1;3m{"name": "customer_features", "full_name": "tpch.silver.customer_features", "type": "MANAGED", "description": "Customer feature table with account balance, market segment, nation, and region features", "columns": [{"name": "customer_id", "nullable": false, "position": 0, "type_json": "{\"name\":\"customer_id\",\"type\":\"long\",\"nullable\":false,\"metadata\":{}}", "type_name": "LONG", "type_precision": 0, "type_scale": 0, "type_text": "bigint"}], "constraints": {"primary_key_constraint": {"child_columns": ["customer_id"], "name": "customer_features_pk"}}, "view_definition": null, "view_dependencies": null}
{"name": "orders_features", "full_name": "tpch.silver.orders_features", "type": "MANAGED", "description": "Orders feature table with order status, total price, date, and priority features", "colum

{'catalog': 'tpch',
 'query': 'Describe the customer_features table from silver schema',
 'chat_history': [HumanMessage(content='List all the table names that are available', additional_kwargs={}, response_metadata={}),
  AIMessage(content='```json\n[\n  "customer",\n  "lineitem",\n  "nation",\n  "orders",\n  "part",\n  "partsupp",\n  "region",\n  "supplier",\n  "customer_features",\n  "orders_features",\n  "part_supplier_features"\n]\n```', additional_kwargs={}, response_metadata={}),
  HumanMessage(content='Describe the customer_features table from silver schema', additional_kwargs={}, response_metadata={}),
  AIMessage(content="```json\n[\n  ('customer_id', 'bigint', 'None'),\n  ('c_account_balance', 'decimal(18,2)', 'None'),\n  ('c_market_segment', 'string', 'None'),\n  ('c_nation', 'string', 'None'),\n  ('c_region', 'string', 'None')\n]\n```", additional_kwargs={}, response_metadata={})],
 'output': "```json\n[\n  ('customer_id', 'bigint', 'None'),\n  ('c_account_balance', 'decima

In [12]:
agent_executor.invoke({
    "catalog": "tpch",
    "query": "Generate th extented description on the orders table from bronze schema",
    "chat_history": memory.chat_memory.messages
})



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m```json
{
	"tool_call": {
		"id": "pending",
		"type": "function",
		"function": {
			"name": "get_all_tables_info"
		},
		"parameters": {
			"catalog_name": "tpch"
		}
	}
}
```[0m

[1m> Finished chain.[0m


{'catalog': 'tpch',
 'query': 'Generate th extented description on the orders table from bronze schema',
 'chat_history': [HumanMessage(content='List all the table names that are available', additional_kwargs={}, response_metadata={}),
  AIMessage(content='```json\n[\n  "customer",\n  "lineitem",\n  "nation",\n  "orders",\n  "part",\n  "partsupp",\n  "region",\n  "supplier",\n  "customer_features",\n  "orders_features",\n  "part_supplier_features"\n]\n```', additional_kwargs={}, response_metadata={}),
  HumanMessage(content='Describe the customer_features table from silver schema', additional_kwargs={}, response_metadata={}),
  AIMessage(content="```json\n[\n  ('customer_id', 'bigint', 'None'),\n  ('c_account_balance', 'decimal(18,2)', 'None'),\n  ('c_market_segment', 'string', 'None'),\n  ('c_nation', 'string', 'None'),\n  ('c_region', 'string', 'None')\n]\n```", additional_kwargs={}, response_metadata={}),
  HumanMessage(content='Summarize the orders table', additional_kwargs={}, res