# Introduction to Databricks LLM

This notebook shows how to query Databricks LLM endpoints. I will show you several different approaches:
- databricks.sdk.WorkspaceClient.serving_endpoints.query (used by Databricks examples in Streamlit)
- openai compatible API (for streaming and tools)
- langchain (many other Databricks examples)

Out of the scope:
- MLFLOW: https://mlflow.org/docs/latest/api_reference/python_api/mlflow.deployments.html#mlflow.deployments.DatabricksDeploymentClient 
- REST API
- reasoning (claude sonnet)
- image recognition (claude sonnet)
- full list: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/score-foundation-models

Besides that in this notebook I show you how to call functions, including the tools created in Unity Catalog as UDF.

Maintainer: pavel_goncharenko@epam.com

## How to get the Databricks:
At least two free options are available for you:
- You can create Databricks Workspace using MSDN subscription in Azure: https://learn.microsoft.com/en-us/azure/databricks/getting-started/free-trial
- You can request an EPAM's Databricks Sandbox here: https://kb.epam.com/display/EPMCBDCC/Databricks+CoE#DatabricksCoE-DATABRICKSSANDBOX

## Authorization
You have several options, I will show two of them:
- Create a service principal (you need admin privileges). Generate a key-pair (`DATABRICKS_CLIENT_ID` and `DATABRICKS_CLIENT_ID`) and use it below.
- You can create your user's PAT in Settings/Developer menu of Databricks Workspace. Place your PAT to `DATABRICKS_TOKEN` below.

## Requred permissions:
- Permissions on the workspace: User
- Permissions on the SQL warehouse: Can use
- Permissions on the `sample` catalog (by default there)

## Please add your credentials below:

In [None]:
import os

In [None]:
# Workspace url: https://adb-327....azuredatabricks.net/
os.environ["DATABRICKS_HOST"] = ...
# a) EITHER Service principal oauth credentials:
os.environ["DATABRICKS_CLIENT_ID"] = ...
os.environ["DATABRICKS_CLIENT_SECRET"] = ...
# b) OR user's personal access token:
os.environ["DATABRICKS_TOKEN"] = ...
# SQL warehouse id for SQL queries:
os.environ["DATABRICKS_WAREHOUSE_ID"] = ...

# Dependencies

In [None]:
# Removing conflicting packages from Colab
%pip uninstall -y pyspark pyspark-connect

In [2]:
# Dependencies for the notebook:
%pip install -qq python-dotenv
%pip install -qq databricks-sdk[openai]==0.50.0
%pip install -qq databricks-langchain==0.4.2
%pip install -qq databricks-sql-connector==4.0.3
%pip install -qq databricks-connect==16.1.3

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


# RESTART Colab session here once:
Runtime > Restart session

In [None]:
import os
import json
import pandas as pd # IF FAILS HERE JUST RESTART Colab: "Runtime > Restart session" and run again.
from pprint import pprint
from IPython.display import display, Markdown, clear_output
from databricks import sql
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole
from databricks.sdk.core import Config
from databricks_langchain import ChatDatabricks, UCFunctionToolkit
from langchain_core.tools import tool
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, BaseMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langchain.callbacks import StdOutCallbackHandler
from langchain.agents import AgentExecutor, create_tool_calling_agent
from unitycatalog.ai.core.databricks import DatabricksFunctionClient
from dotenv import load_dotenv

load_dotenv()

def display_json(obj):
    """Display a JSON object in a formatted way."""
    json_str = json.dumps(obj, indent=2)
    display(Markdown("```json\n" + json_str + "\n```"))

## WorkspaceClient instance
Instantiate a WorkspaceClient instance. It requred to make LLM calls to Databricks API

In [5]:
# from databricks.sdk import WorkspaceClient

client = WorkspaceClient()

## Sample messages we will use:
We will use the message asking how many rows we have in sample table

In [6]:
messages = [
    {"role": "system", "content": "You are helpful assistant, please provide only reliable information based on the context."},
    {"role": "user", "content": "How many records in my table samples.tpch.customer?"}
]

## We will query against this endpoint

The list of available endpoins in Databricks: Machine Learning -> Serving

Doc: https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/foundation-model-overview

This endpoint is pay-per-token and already deployed and available to use by default. You can use any other Chat model to play with.

In [7]:
SERVING_ENDPOINT = "databricks-meta-llama-3-3-70b-instruct"

# 1. Databricks `serving_endpoints` API

- serving_endpoints.query doesn't work with stream=True
- serving_endpoints.query doesn't support tools

**Note**:
In the response below LLM cannot get the real amount of records, because it doesn't have context or tool calling logic

In [8]:
# from databricks.sdk.service.serving import ChatMessage, ChatMessageRole

# Converting dict to ChatMessage:
dbx_messages = [
    ChatMessage(role=ChatMessageRole(msg["role"]), content=msg["content"])
    for msg in messages
]

pprint(dbx_messages)

[ChatMessage(content='You are helpful assistant, please provide only reliable '
                     'information based on the context.',
             role=<ChatMessageRole.SYSTEM: 'system'>),
 ChatMessage(content='How many records in my table samples.tpch.customer?',
             role=<ChatMessageRole.USER: 'user'>)]


In [9]:
chat_completion = client.serving_endpoints.query(
    name=SERVING_ENDPOINT,
    messages=dbx_messages,
    max_tokens=500,
    temperature=0.1,
)

display(Markdown("## LLM raw output:"))
pprint(chat_completion, width=200)

## LLM raw output:

QueryEndpointResponse(choices=[V1ResponseChoiceElement(finish_reason=None,
                                                       index=0,
                                                       logprobs=None,
                                                       message=ChatMessage(content="I'm a large language model, I don't have direct access to your database or tables. However, I can guide you on how "
                                                                                   'to find the number of records in your table.\n'
                                                                                   '\n'
                                                                                   'To get the number of records in your table `samples.tpch.customer`, you can use a SQL query:\n'
                                                                                   '\n'
                                                                                   '```sql\n'
        

In [10]:
display(Markdown("## LLM response:"))
display(Markdown(chat_completion.choices[0].message.content))

## LLM response:

I'm a large language model, I don't have direct access to your database or tables. However, I can guide you on how to find the number of records in your table.

To get the number of records in your table `samples.tpch.customer`, you can use a SQL query:

```sql
SELECT COUNT(*) FROM samples.tpch.customer;
```

This query will return the number of rows in your table. Please execute this query in your database management system or SQL client to get the exact count.

# 2. OpenAI compatible API: `get_open_ai_client`

## 2.1. Chat Completion with OpenAI client

Previous example has limitations. It cannot properly process streams and tools. OpenAI API is better for our goals. Let's use it instead:

In [None]:
openai_client = client.serving_endpoints.get_open_ai_client()

chat_completion = openai_client.chat.completions.create(
    model=SERVING_ENDPOINT,
    messages=messages,
    max_tokens=500,
    temperature=0.1,
)

display(Markdown("### LLM raw output:"))
display_json(chat_completion.model_dump())

# **Note**: In the response below LLM cannot get the real amount of records, because it doesn't have context or tool calling logic

In [None]:
display(Markdown("### LLM response Markdown fomatted:"))
display(Markdown(chat_completion.choices[0].message.content))

### LLM response Markdown fomatted:

I'm a large language model, I don't have direct access to your database or tables. However, I can guide you on how to find the number of records in your table.

To get the number of records in your table `samples.tpch.customer`, you can use a SQL query:

```sql
SELECT COUNT(*) FROM samples.tpch.customer;
```

This query will return the number of rows in your table. Please execute this query in your database management system or SQL client to get the exact count.

## 2.2. Streaming example (openai client)

Here is the example of streaming. Streaming provides better UI experience. Instead of waiting until the LLM provides full answer - user can see the intermediate results immediately token by token.

In [None]:
openai_client = client.serving_endpoints.get_open_ai_client()

chat_completion = openai_client.chat.completions.create(
    model=SERVING_ENDPOINT,
    messages=messages,
    max_tokens=500,
    temperature=0.1,
    stream=True, # Enable streaming
)

content = ""

display(Markdown("### LLM streaming response:"))

for chunk in chat_completion:
    if hasattr(chunk, "choices"):
        if chunk.choices[0].delta.content:
            content += chunk.choices[0].delta.content
            clear_output(wait=True)
            display(Markdown(content))

I'm a large language model, I don't have direct access to your database or tables. However, I can guide you on how to find the number of records in your table.

To get the number of records in your table `samples.tpch.customer`, you can use a SQL query:

```sql
SELECT COUNT(*) FROM samples.tpch.customer;
```

This query will return the number of rows in your table. If you're using a specific database management system like MySQL, PostgreSQL, or BigQuery, you can run this query in the respective query editor or terminal.

If you provide more context or information about your database setup, I can offer more tailored assistance.

# 3. Langchain

Doc: https://python.langchain.com/docs/integrations/chat/databricks/

Many Databricks examples based on langchain syntax. These examples are for you to provide the basic knowledge and understanding. 

## 3.1. Langchain invoke

See https://python.langchain.com/api_reference/community/chat_models/langchain_community.chat_models.databricks.ChatDatabricks.html for other supported parameters

In [None]:
# from databricks_langchain import ChatDatabricks

# If you use .databrickscfg and not a `default` profile, override the `target_uri` like this:
# class MyChatDatabricks(ChatDatabricks):
#     target_uri: str = "databricks://<your_profile_name>"

langchain_chat_model = ChatDatabricks(
    endpoint=SERVING_ENDPOINT,
    temperature=0.1,
    max_tokens=500
)

result = langchain_chat_model.invoke(messages)

display(Markdown("### LLM raw output:"))
display_json(vars(result))

### LLM raw output:

```json
{
  "content": "I'm a large language model, I don't have direct access to your database or tables. However, I can guide you on how to find the number of records in your table.\n\nTo get the number of records in your table `samples.tpch.customer`, you can use a SQL query:\n\n```sql\nSELECT COUNT(*) FROM samples.tpch.customer;\n```\n\nThis query will return the number of rows in your table. If you're using a specific database management system like MySQL, PostgreSQL, or BigQuery, you can run this query in the respective query editor or terminal.\n\nIf you provide more context or information about your database setup, I can offer more tailored guidance.",
  "additional_kwargs": {},
  "response_metadata": {
    "id": "chatcmpl_07f2ce54-cd86-4b96-bf5e-0cc6db792fb9",
    "object": "chat.completion",
    "created": 1745507463,
    "model": "meta-llama-3.3-70b-instruct-121024",
    "usage": {
      "prompt_tokens": 41,
      "completion_tokens": 136,
      "total_tokens": 177
    },
    "model_name": "meta-llama-3.3-70b-instruct-121024"
  },
  "type": "ai",
  "name": null,
  "id": "run-54dc8a85-adf4-4b80-aebf-a270d51c1728-0",
  "example": false,
  "tool_calls": [],
  "invalid_tool_calls": [],
  "usage_metadata": null
}
```

In [None]:
display(Markdown("### LLM response Markdown formatted:"))
display(Markdown(result.content))

### LLM response Markdown formatted:

I'm a large language model, I don't have direct access to your database or tables. However, I can guide you on how to find the number of records in your table.

To get the number of records in your table `samples.tpch.customer`, you can use a SQL query:

```sql
SELECT COUNT(*) FROM samples.tpch.customer;
```

This query will return the number of rows in your table. If you're using a specific database management system like MySQL, PostgreSQL, or BigQuery, you can run this query in the respective query editor or terminal.

If you provide more context or information about your database setup, I can offer more tailored guidance.

## 3.2. Langchain Stream

Stream is also supported in Databricks langchain

In [None]:
langchain_chat_model = ChatDatabricks(
    endpoint=SERVING_ENDPOINT,
    temperature=0.1,
    max_tokens=500,
)

result = langchain_chat_model.stream(messages)

content = "### LLM streaming response Markdown formatted:\n"

for chunk in result:
    content += chunk.content
    clear_output(wait=True)
    display(Markdown(content))

### LLM streaming response Markdown formatted:
I'm a large language model, I don't have direct access to your database or tables. However, I can guide you on how to find the number of records in your table.

To get the number of records in your table `samples.tpch.customer`, you can use a SQL query:

```sql
SELECT COUNT(*) FROM samples.tpch.customer;
```

This query will return the number of rows in your table. Please execute this query in your database management system to get the exact count.

# 4. Function calling

Function calling supported by many LLMs, please check the documentation of your LLM model. The documentation about function calling here: https://platform.openai.com/docs/assistants/tools/function-calling

The basic idea is:

- developer says to LLM: "I have a function `sql_query`, you can execute it with parameter `query`, it executes any SQL query against my database, call it if required".
- user asks LLM: "How many rows in table `sample`?"
- LLM responds to system: "Please execute function `sql_query` with `query`='select count(*) from sample'"
- system designed to call this function and provide the result back to LLM: `200`
- LLM replies to user: "You have 200 records in table `sample`"

In [None]:
# Let's imagine we have a function sql_query, there is a description for LLM call:
sql_query_tool_json = {
    "type": "function",
    "function": {
        "name": "sql_query",
        "description": "Execute any query against Databricks",
        "parameters": {
            "description": "Executing given SQL query and returning a table of results",
            "properties": {
                "query": {
                    "description": "SQL query to be executed",
                    "type": "string"
                }
            },
            "required": ["query"],
            "type": "object"
        }
    }
}

openai_client = client.serving_endpoints.get_open_ai_client()

chat_completion = openai_client.chat.completions.create(
    model=SERVING_ENDPOINT,
    messages=messages,
    max_tokens=500,
    temperature=0.1,
    tools=[sql_query_tool_json]
)

display(Markdown("### LLM raw output:"))
display_json(chat_completion.model_dump())

### LLM raw output:

```json
{
  "id": "chatcmpl_a84847a7-09e7-424f-80c9-2116b7c76674",
  "choices": [
    {
      "finish_reason": "tool_calls",
      "index": 0,
      "logprobs": null,
      "message": {
        "content": null,
        "refusal": null,
        "role": "assistant",
        "annotations": null,
        "audio": null,
        "function_call": null,
        "tool_calls": [
          {
            "id": "call_ebbae73b-2ec6-46ad-bc7c-b0417ac2e214",
            "function": {
              "arguments": "{\"query\": \"SELECT COUNT(*) FROM samples.tpch.customer\"}",
              "name": "sql_query"
            },
            "type": "function"
          }
        ]
      }
    }
  ],
  "created": 1745507470,
  "model": "meta-llama-3.3-70b-instruct-121024",
  "object": "chat.completion",
  "service_tier": null,
  "system_fingerprint": null,
  "usage": {
    "completion_tokens": 21,
    "prompt_tokens": 715,
    "total_tokens": 736,
    "completion_tokens_details": null,
    "prompt_tokens_details": null
  }
}
```

**Output**:

As you can see the output of LLM has:
- finish_reason='tool_calls'
- content=None
- function=Function(...)

LLM asks us to call a function with name `sql_query` and pass the following parameters to it, according our function description:

In [None]:
print(chat_completion.choices[0].message.tool_calls[0].function)

Function(arguments='{"query": "SELECT COUNT(*) FROM samples.tpch.customer"}', name='sql_query')


**Function calling step:**

Let's pretend we executed our function `sql_query` with parameters query="SELECT COUNT(*) FROM samples.tpch.customer" and the function returned the following text:

| count(*) |
| -------- |
|      200 |

Now we should add these two messages (from assistant and from our function) to the `messages` list and pass them back to the LLM:

In [None]:
new_messages = messages.copy()

# Adding a message from assistant to our messages array, which asks for tool calling:
new_messages.append({
    "role": "assistant",
    "tool_calls": [{
        "type": "function",
        "id": chat_completion.choices[0].message.tool_calls[0].id,
        "function": {
            "name": chat_completion.choices[0].message.tool_calls[0].function.name,
            "arguments": chat_completion.choices[0].message.tool_calls[0].function.arguments},
    }]
})
# Adding the function result:
new_messages.append({
    "role": "tool",
    "tool_call_id": chat_completion.choices[0].message.tool_calls[0].id,
    "content": "200"
})

# New messages array looks like:
display(Markdown("### New messages array:"))
display_json(new_messages)

### New messages array:

```json
[
  {
    "role": "system",
    "content": "You are helpful assistant, please provide only reliable information based on the context."
  },
  {
    "role": "user",
    "content": "How many records in my table samples.tpch.customer?"
  },
  {
    "role": "assistant",
    "tool_calls": [
      {
        "type": "function",
        "id": "call_ebbae73b-2ec6-46ad-bc7c-b0417ac2e214",
        "function": {
          "name": "sql_query",
          "arguments": "{\"query\": \"SELECT COUNT(*) FROM samples.tpch.customer\"}"
        }
      }
    ]
  },
  {
    "role": "tool",
    "tool_call_id": "call_ebbae73b-2ec6-46ad-bc7c-b0417ac2e214",
    "content": "200"
  }
]
```

In [None]:
chat_completion = openai_client.chat.completions.create(
    model=SERVING_ENDPOINT,
    messages=new_messages,
    max_tokens=500,
    tools=[sql_query_tool_json]
)

display(Markdown("### LLM raw output:"))
display_json(chat_completion.model_dump())

### LLM raw output:

```json
{
  "id": "chatcmpl_d6970271-2722-42f0-a08e-2132ff54f123",
  "choices": [
    {
      "finish_reason": "stop",
      "index": 0,
      "logprobs": null,
      "message": {
        "content": "There are 200 records in your table samples.tpch.customer.",
        "refusal": null,
        "role": "assistant",
        "annotations": null,
        "audio": null,
        "function_call": null,
        "tool_calls": null
      }
    }
  ],
  "created": 1745507471,
  "model": "meta-llama-3.3-70b-instruct-121024",
  "object": "chat.completion",
  "service_tier": null,
  "system_fingerprint": null,
  "usage": {
    "completion_tokens": 14,
    "prompt_tokens": 755,
    "total_tokens": 769,
    "completion_tokens_details": null,
    "prompt_tokens_details": null
  }
}
```

In [None]:
display(Markdown("### LLM output Markdown formatted:"))
display(Markdown(chat_completion.choices[0].message.content))

### LLM output Markdown formatted:

There are 200 records in your table samples.tpch.customer.

### Real function example:

In [None]:
# SQL query will be performed against this SQL Wartehouse:
DATABRICKS_WAREHOUSE_ID = os.getenv("DATABRICKS_WAREHOUSE_ID")

def sql_query_example(query: str) -> dict:
    cfg = Config()
    with sql.connect(
        server_hostname=cfg.host,
        http_path=f"/sql/1.0/warehouses/{DATABRICKS_WAREHOUSE_ID}",
        credentials_provider=lambda: cfg.authenticate
    ) as connection:
        with connection.cursor() as cursor:
            cursor.execute(query)
            df = cursor.fetchall_arrow().to_pandas()
            if len(df) > 1000:
                raise ValueError("Query result exceeds 1000 rows. Please refine your query.")
            return df.to_dict('index')
        
tool_output = sql_query_example("SELECT COUNT(*) FROM samples.tpch.customer")

display(Markdown("### Tool real output:"))
display_json(tool_output)


### Tool real output:

```json
{
  "0": {
    "count(1)": 750000
  }
}
```

# 5. Langchain tools

Here we will use the real function. Using `tool` decorator to convert our python function to langchain `StructuredTool`. It will help also with json structure creation.

In [None]:
# from langchain_core.tools import tool

# Google Style function docstrings:

@tool(parse_docstring=True)
def sql_query(query: str) -> str:
    """Executing given SQL query and returning a table of results

    Args:
        query: SQL query to be executed
    """
    cfg = Config(profile="demous")
    with sql.connect(
        server_hostname=cfg.host,
        http_path=f"/sql/1.0/warehouses/{DATABRICKS_WAREHOUSE_ID}",
        credentials_provider=lambda: cfg.authenticate
    ) as connection:
        with connection.cursor() as cursor:
            cursor.execute(query)
            df = cursor.fetchall_arrow().to_pandas()
            if len(df) > 1000:
                raise ValueError("Query result exceeds 1000 rows. Please refine your query.")

            return json.dumps(df.to_dict('index'), default=str)

# Langchain generates the same openai function-calling json description:
args_schema = sql_query.args_schema.model_json_schema()
display_json(args_schema)

```json
{
  "description": "Executing given SQL query and returning a table of results",
  "properties": {
    "query": {
      "description": "SQL query to be executed",
      "title": "Query",
      "type": "string"
    }
  },
  "required": [
    "query"
  ],
  "title": "sql_query",
  "type": "object"
}
```

## Invoking langchain functions

We can invoke a Tool calling `invoke`:

In [None]:
tool_output = sql_query.invoke({"query": "SELECT COUNT(*) FROM samples.tpch.customer"})

display(Markdown("### Tool real output:"))
display_json(json.loads(tool_output))

### Tool real output:

```json
{
  "0": {
    "count(1)": 750000
  }
}
```

## Adding this Tool to the model:

https://python.langchain.com/v0.2/docs/concepts/#tools

In [None]:
# from langchain_core.messages import AIMessage

langchain_chat_model = ChatDatabricks(
    endpoint=SERVING_ENDPOINT,
    temperature=0.1,
    max_tokens=500,
)
model_with_tools = langchain_chat_model.bind_tools([sql_query])

tool_call_result: AIMessage = model_with_tools.invoke(messages)

display(Markdown("### LLM raw output in JSON:"))
display_json(tool_call_result.model_dump())

### LLM raw output in JSON:

```json
{
  "content": "",
  "additional_kwargs": {
    "tool_calls": [
      {
        "id": "call_39952a6f-fc8b-4427-896d-2cf630757889",
        "type": "function",
        "function": {
          "name": "sql_query",
          "arguments": "{\"query\": \"SELECT COUNT(*) FROM samples.tpch.customer\"}"
        }
      }
    ]
  },
  "response_metadata": {
    "id": "chatcmpl_11e174b9-1561-4d1a-bbb7-79de48f71105",
    "object": "chat.completion",
    "created": 1745507488,
    "model": "meta-llama-3.3-70b-instruct-121024",
    "usage": {
      "prompt_tokens": 702,
      "completion_tokens": 21,
      "total_tokens": 723
    },
    "model_name": "meta-llama-3.3-70b-instruct-121024"
  },
  "type": "ai",
  "name": null,
  "id": "run-2b7a765b-3214-49c8-8f6c-87214575be15-0",
  "example": false,
  "tool_calls": [
    {
      "name": "sql_query",
      "args": {
        "query": "SELECT COUNT(*) FROM samples.tpch.customer"
      },
      "id": "call_39952a6f-fc8b-4427-896d-2cf630757889",
      "type": "tool_call"
    }
  ],
  "invalid_tool_calls": [],
  "usage_metadata": null
}
```

In [None]:
tool_call = tool_call_result.tool_calls[0]

# In our case sql_query returns a JSON string, so we can show it as JSON back:
tool_output = sql_query.invoke(tool_call["args"])
display_json(json.loads(tool_output))

```json
{
  "0": {
    "count(1)": 750000
  }
}
```

## Generating new messages array appending Assistant and Tool responses:

In [None]:
# from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage, BaseMessage

new_messages: list[BaseMessage] = []

new_messages.append(SystemMessage(content=messages[0]["content"]))
new_messages.append(HumanMessage(content=messages[1]["content"]))
new_messages.append(tool_call_result)
new_messages.append(ToolMessage(content=tool_output, tool_call_id=tool_call["id"], name=tool_call["name"]))

display(Markdown("### New messages array (objects):"))
pprint(new_messages)
display(Markdown("### New messages array (JSON):"))
display_json([x.model_dump() for x in new_messages])

### New messages array (objects):

[SystemMessage(content='You are helpful assistant, please provide only reliable information based on the context.', additional_kwargs={}, response_metadata={}),
 HumanMessage(content='You are helpful assistant, please provide only reliable information based on the context.', additional_kwargs={}, response_metadata={}),
 AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_39952a6f-fc8b-4427-896d-2cf630757889', 'type': 'function', 'function': {'name': 'sql_query', 'arguments': '{"query": "SELECT COUNT(*) FROM samples.tpch.customer"}'}}]}, response_metadata={'id': 'chatcmpl_11e174b9-1561-4d1a-bbb7-79de48f71105', 'object': 'chat.completion', 'created': 1745507488, 'model': 'meta-llama-3.3-70b-instruct-121024', 'usage': {'prompt_tokens': 702, 'completion_tokens': 21, 'total_tokens': 723}, 'model_name': 'meta-llama-3.3-70b-instruct-121024'}, id='run-2b7a765b-3214-49c8-8f6c-87214575be15-0', tool_calls=[{'name': 'sql_query', 'args': {'query': 'SELECT COUNT(*) FROM samples.tpch

### New messages array (JSON):

```json
[
  {
    "content": "You are helpful assistant, please provide only reliable information based on the context.",
    "additional_kwargs": {},
    "response_metadata": {},
    "type": "system",
    "name": null,
    "id": null
  },
  {
    "content": "You are helpful assistant, please provide only reliable information based on the context.",
    "additional_kwargs": {},
    "response_metadata": {},
    "type": "human",
    "name": null,
    "id": null,
    "example": false
  },
  {
    "content": "",
    "additional_kwargs": {
      "tool_calls": [
        {
          "id": "call_39952a6f-fc8b-4427-896d-2cf630757889",
          "type": "function",
          "function": {
            "name": "sql_query",
            "arguments": "{\"query\": \"SELECT COUNT(*) FROM samples.tpch.customer\"}"
          }
        }
      ]
    },
    "response_metadata": {
      "id": "chatcmpl_11e174b9-1561-4d1a-bbb7-79de48f71105",
      "object": "chat.completion",
      "created": 1745507488,
      "model": "meta-llama-3.3-70b-instruct-121024",
      "usage": {
        "prompt_tokens": 702,
        "completion_tokens": 21,
        "total_tokens": 723
      },
      "model_name": "meta-llama-3.3-70b-instruct-121024"
    },
    "type": "ai",
    "name": null,
    "id": "run-2b7a765b-3214-49c8-8f6c-87214575be15-0",
    "example": false,
    "tool_calls": [
      {
        "name": "sql_query",
        "args": {
          "query": "SELECT COUNT(*) FROM samples.tpch.customer"
        },
        "id": "call_39952a6f-fc8b-4427-896d-2cf630757889",
        "type": "tool_call"
      }
    ],
    "invalid_tool_calls": [],
    "usage_metadata": null
  },
  {
    "content": "{\"0\": {\"count(1)\": 750000}}",
    "additional_kwargs": {},
    "response_metadata": {},
    "type": "tool",
    "name": "sql_query",
    "id": null,
    "tool_call_id": "call_39952a6f-fc8b-4427-896d-2cf630757889",
    "artifact": null,
    "status": "success"
  }
]
```

**Note**: You can use objects `ToolMessage`, `AIMessage`, etc. directly as elements of the `messages` array as well (without model_dump usage). But I have first two in dict format, and it is bad practice of mixing `dict` and langchain classes in the `messages` array. Langchain can work with both though even in mixed.

Now, let's execute LLM with new messages array:

In [None]:
result = langchain_chat_model.invoke(new_messages)

display(Markdown("### LLM output in Markdown:"))
display(Markdown(result.content))

### LLM output in Markdown:

I'm happy to help, but it seems like there was a misunderstanding. The conversation just started, and I didn't receive any context or question to provide reliable information about. Could you please provide more context or ask a question so I can assist you better?

# 6. Databricks UC functions as tools

You can create UDF in Databricks and use them as Tools. Databricks already has one: `system.ai.python_exec`, it can execute any python code and return the output

In [None]:
# from databricks_langchain import UCFunctionToolkit
# from unitycatalog.ai.core.databricks import DatabricksFunctionClient

uc_client = DatabricksFunctionClient(client=client)
uc_toolkit = UCFunctionToolkit(function_names=["system.ai.python_exec"], client=uc_client)

python_exec_tool = uc_toolkit.tools_dict["system.ai.python_exec"]

display_json(python_exec_tool.args_schema.model_json_schema())

```json
{
  "additionalProperties": false,
  "properties": {
    "code": {
      "anyOf": [
        {
          "type": "string"
        },
        {
          "type": "null"
        }
      ],
      "default": null,
      "description": "Python code to execute. Ensure that all variables are initialized within the code, and import any necessary standard libraries. The code must print the final result to stdout. Do not attempt to access files or external systems.",
      "title": "Code"
    }
  },
  "title": "system__ai__python_exec__params",
  "type": "object"
}
```

**Sample**:

Let's call it:

In [None]:
tool_result = python_exec_tool.invoke({"code": """
import math
a = 200 + 25
print(math.sqrt(a))
"""})
display_json(json.loads(tool_result))

```json
{
  "format": "SCALAR",
  "value": "15.0\n"
}
```

## Adding UC tool for LLM example:

Now let's say to our LLM that we have two tools avaiable: `python_exec_tool` and `sql_query` tool. And let LLM decide which tool to call based on th user's request.

In [None]:
langchain_chat_model = ChatDatabricks(
    endpoint=SERVING_ENDPOINT,
    temperature=0.1,
    max_tokens=500,
)
model_with_tools = langchain_chat_model.bind_tools([python_exec_tool, sql_query])

### Let's ask LLM to execute something using PYTHON:

In [None]:
tool_call_request = model_with_tools.invoke([
    {"role": "user", "content": "using python calculate the pi/2"}
])

tool_call = tool_call_request.tool_calls[0]

display(Markdown("### LLM raw output:"))
display_json(vars(tool_call_request))

### LLM raw output:

```json
{
  "content": "",
  "additional_kwargs": {
    "tool_calls": [
      {
        "id": "call_dae46b4b-be9f-4672-b95d-f4a7cb2954ea",
        "type": "function",
        "function": {
          "name": "system__ai__python_exec",
          "arguments": "{\"code\": \"import math\\nprint(math.pi / 2)\"}"
        }
      }
    ]
  },
  "response_metadata": {
    "id": "chatcmpl_b5507d7c-8b9e-4fd8-8230-578776091c6f",
    "object": "chat.completion",
    "created": 1745507519,
    "model": "meta-llama-3.3-70b-instruct-121024",
    "usage": {
      "prompt_tokens": 875,
      "completion_tokens": 26,
      "total_tokens": 901
    },
    "model_name": "meta-llama-3.3-70b-instruct-121024"
  },
  "type": "ai",
  "name": null,
  "id": "run-87815d1e-619a-4bb0-b689-8b673f3cf6ca-0",
  "example": false,
  "tool_calls": [
    {
      "name": "system__ai__python_exec",
      "args": {
        "code": "import math\nprint(math.pi / 2)"
      },
      "id": "call_dae46b4b-be9f-4672-b95d-f4a7cb2954ea",
      "type": "tool_call"
    }
  ],
  "invalid_tool_calls": [],
  "usage_metadata": null
}
```

**Output**:

In the output LLM asks us to execute function `system.ai.python_exec` (which we named as system__ai__python_exec for LLM) and pass the parameter `{'code': 'import math; print(math.pi / 2)'}` to answer user's question.

Our program logic could look like this:

In [None]:
if tool_call["name"] == "system__ai__python_exec":
    display(Markdown("### PYTHON function executed"))
    func = python_exec_tool
elif tool_call["name"] == "sql_query":
    display(Markdown("### SQL function executed"))
    func = sql_query

tool_output = func.invoke(tool_call["args"])
display_json(json.loads(tool_output))

### PYTHON function executed

```json
{
  "format": "SCALAR",
  "value": "1.5707963267948966\n"
}
```

### Let's ask the SAME LLM to execute the same as previous, but using SQL now:

In [None]:
tool_call_request = model_with_tools.invoke([
    {"role": "user", "content": "using SQL calculate the pi/2"}
])

tool_call = tool_call_request.tool_calls[0]

display(Markdown("### LLM raw output:"))
display_json(vars(tool_call_request))

### LLM raw output:

```json
{
  "content": "",
  "additional_kwargs": {
    "tool_calls": [
      {
        "id": "call_0aa60cd7-9787-4f98-9e6c-37d8fbfbfc64",
        "type": "function",
        "function": {
          "name": "sql_query",
          "arguments": "{\"query\": \"SELECT 4 * ATAN(1)\"}"
        }
      }
    ]
  },
  "response_metadata": {
    "id": "chatcmpl_2c2193a5-0ded-4feb-812b-f0f2d1008ad3",
    "object": "chat.completion",
    "created": 1745507522,
    "model": "meta-llama-3.3-70b-instruct-121024",
    "usage": {
      "prompt_tokens": 875,
      "completion_tokens": 21,
      "total_tokens": 896
    },
    "model_name": "meta-llama-3.3-70b-instruct-121024"
  },
  "type": "ai",
  "name": null,
  "id": "run-60912fcd-dc87-475a-849f-2ab0cee702c9-0",
  "example": false,
  "tool_calls": [
    {
      "name": "sql_query",
      "args": {
        "query": "SELECT 4 * ATAN(1)"
      },
      "id": "call_0aa60cd7-9787-4f98-9e6c-37d8fbfbfc64",
      "type": "tool_call"
    }
  ],
  "invalid_tool_calls": [],
  "usage_metadata": null
}
```

**Output**:

In the output LLM asks us to execute function `sql_query` and pass the parameter `query` with SELECT statement to answer user's question.

Our program logic is the same:

In [None]:
if tool_call["name"] == "system__ai__python_exec":
    display(Markdown("### PYTHON function executed"))
    func = python_exec_tool
elif tool_call["name"] == "sql_query":
    display(Markdown("### SQL function executed"))
    func = sql_query

tool_output = func.invoke(tool_call["args"])
display_json(json.loads(tool_output))

### SQL function executed

```json
{
  "0": {
    "(4 * ATAN(1))": 3.141592653589793
  }
}
```

# Performing SQL queries using Databricks and LangChain agents

- https://python.langchain.com/v0.2/docs/tutorials/sql_qa/
- Obsolete: https://python.langchain.com/v0.2/api_reference/langchain/chains/langchain.chains.sql_database.query.create_sql_query_chain.html
- https://docs.unitycatalog.io/ai/integrations/langchain/

Before, we manually executed functions based on the LLM request. Langchain has helpers to do it automatically.

Agent will ask LLM to generate SQL query and will trigger our functions with provided arguments on behalf of us:

## Adding Callbacks for transparency

In [None]:
# from langchain.callbacks import StdOutCallbackHandler
# from langchain_core.prompts import ChatPromptTemplate
# from langchain.agents import AgentExecutor, create_tool_calling_agent


# Not important for demonstration, but makes logging more readable:
class CustomLoggingCallback(StdOutCallbackHandler):

    def on_tool_start(self, serialized, input_str, **kwargs):
        print(f"🛠️  Tool called: {serialized['name']} with input: {input_str}")

    def on_tool_end(self, output, **kwargs):
        print(f"✅ Tool output: {output=}")

    def on_llm_start(self, serialized, prompts, **kwargs):
        print(f"🤖 LLM Prompt: {prompts=}")

    def on_llm_end(self, response, **kwargs):
        print(f"📜 LLM Response: {response=}")

    def on_chain_start(self, serialized, inputs, run_id, **kwargs):
        print(f"⛓️ Chain start: {inputs=}")

    def on_chain_end(self, outputs, **kwargs):
        print(f"⛓️ Chain end: {outputs=}")

    def on_agent_action(self, action, **kwargs):
        print(f"🛠️  Agent action: {action=} with {kwargs=}")


langchain_chat_model = ChatDatabricks(
    endpoint="databricks-meta-llama-3-1-405b-instruct",
    temperature=0.1,
    max_tokens=500,
    callbacks=[CustomLoggingCallback()]
)


prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful assistant. Always use the available tools for calculations. "
            "Ensure Python code is formatted with line breaks."
        ),
        ("placeholder", "{chat_history}"),
        ("human", "{input}"),
        ("placeholder", "{agent_scratchpad}"),
    ]
)


agent = create_tool_calling_agent(langchain_chat_model, [sql_query, python_exec_tool], prompt)

agent_executor = AgentExecutor(
    agent=agent,
    tools=[sql_query, python_exec_tool],
    verbose=True,
    return_intermediate_steps=True,
    callbacks=[CustomLoggingCallback()],
    max_iterations=8,
    handle_parsing_errors=True,
    handle_tool_errors=True,
)

In [None]:
res = agent_executor.invoke({
    "input": 
"""
Count the amount of records in table samples.nyctaxi.trips where trip distance is less than 5.
Find a square root of the result using python_exec tool. 
"""
})

res

⛓️ Chain start: inputs={'input': '\nCount the amount of records in table samples.nyctaxi.trips where trip distance is less than 5.\nFind a square root of the result using python_exec tool. \n'}
🤖 LLM Prompt: prompts=['System: You are a helpful assistant. Always use the available tools for calculations. Ensure Python code is formatted with line breaks.\nHuman: \nCount the amount of records in table samples.nyctaxi.trips where trip distance is less than 5.\nFind a square root of the result using python_exec tool. \n']
📜 LLM Response: response=LLMResult(generations=[[ChatGenerationChunk(generation_info={'finish_reason': 'tool_calls'}, message=AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'id': 'call_126170bc-d64a-4f24-8238-9418c8819bd9', 'index': 0, 'type': 'function', 'function': {'name': 'sql_query', 'arguments': '{"query": "SELECT COUNT(*) FROM samples.nyctaxi.trips WHERE trip_distance < 5"}'}}]}, response_metadata={'finish_reason': 'tool_calls'}, id='run-df8e6dd7-32e7-

{'input': '\nCount the amount of records in table samples.nyctaxi.trips where trip distance is less than 5.\nFind a square root of the result using python_exec tool. \n',
 'output': 'The square root of the count of records in the table samples.nyctaxi.trips where trip distance is less than 5 is 137.58270240113762.',
 'intermediate_steps': [(ToolAgentAction(tool='sql_query', tool_input={'query': 'SELECT COUNT(*) FROM samples.nyctaxi.trips WHERE trip_distance < 5'}, log="\nInvoking: `sql_query` with `{'query': 'SELECT COUNT(*) FROM samples.nyctaxi.trips WHERE trip_distance < 5'}`\n\n\n", message_log=[AIMessageChunk(content='', additional_kwargs={'tool_calls': [{'id': 'call_126170bc-d64a-4f24-8238-9418c8819bd9', 'index': 0, 'type': 'function', 'function': {'name': 'sql_query', 'arguments': '{"query": "SELECT COUNT(*) FROM samples.nyctaxi.trips WHERE trip_distance < 5"}'}}]}, response_metadata={'finish_reason': 'tool_calls'}, id='run-df8e6dd7-32e7-487f-9b97-721157ff0e65', tool_calls=[{'nam

In [None]:
display(Markdown("### Agent Markdown output:"))
display(Markdown(res["output"]))

### Agent Markdown output:

The square root of the count of records in the table samples.nyctaxi.trips where trip distance is less than 5 is 137.58270240113762.

# Limitations / Bugs

Many LLM models cannot return content and function at the same time, but sometimes they want (they want to comment or explain their actions before function call). For example LLAMA in this case returns content like this one: 

**LLM's content field:**
To calculate this we need that... Now, we will use this value of B in our SQL query: <function=sql_query>{"query": "SELECT COUNT(*) as count FROM samples.nyctaxi.trips WHERE trip_distance < 0.67"}</function>

So, actually it adds a tag `function` inside the content and believes that asks for a function call. OpenAI models, for instance, can provide both content and function, so they do this properly. "Standard" modules from langchain cannot process such behavior (which is quite common).

**Example:**


In [None]:
@tool(parse_docstring=True)
def exec_code(code: str) -> str:
    """Executing given python code and returning the output

    Args:
        code: python code to call
    """
    # Fake result
    return "2310"

llm2 = langchain_chat_model.bind_tools([exec_code])
res = llm2.invoke([
    {"role":"system", "content":"You are a helpful assistant. Always use tools results and do not rely on your knowledge."},
    {"role":"user", "content":"Using python execute 12321342*23423423/22234. Discuss your decision before the function call."}
])

🤖 LLM Prompt: prompts=['System: You are a helpful assistant. Always use tools results and do not rely on your knowledge.\nHuman: Using python execute 12321342*23423423/22234. Discuss your decision before the function call.']
📜 LLM Response: response=LLMResult(generations=[[ChatGeneration(text='To answer the user\'s question, we need to execute a mathematical expression in Python. The expression is 12321342*23423423/22234. Since this involves executing Python code, the "exec_code" function is relevant here.\n\n<function=exec_code>{"code": "print(12321342*23423423/22234)"}</function>', generation_info={}, message=AIMessage(content='To answer the user\'s question, we need to execute a mathematical expression in Python. The expression is 12321342*23423423/22234. Since this involves executing Python code, the "exec_code" function is relevant here.\n\n<function=exec_code>{"code": "print(12321342*23423423/22234)"}</function>', additional_kwargs={}, response_metadata={'id': 'chatcmpl_8fc6684a-

In [None]:
display_json(vars(res))

```json
{
  "content": "To answer the user's question, we need to execute a mathematical expression in Python. The expression is 12321342*23423423/22234. Since this involves executing Python code, the \"exec_code\" function is relevant here.\n\n<function=exec_code>{\"code\": \"print(12321342*23423423/22234)\"}</function>",
  "additional_kwargs": {},
  "response_metadata": {
    "id": "chatcmpl_8fc6684a-479e-410d-8a32-7775de28289c",
    "object": "chat.completion",
    "created": 1745507539,
    "model": "meta-llama-3.1-405b-instruct-081924",
    "usage": {
      "prompt_tokens": 715,
      "completion_tokens": 75,
      "total_tokens": 790
    },
    "model_name": "meta-llama-3.1-405b-instruct-081924"
  },
  "type": "ai",
  "name": null,
  "id": "run-89305471-b2e7-4448-9091-05ff9d15ad7d-0",
  "example": false,
  "tool_calls": [],
  "invalid_tool_calls": [],
  "usage_metadata": null
}
```

In [None]:
display(Markdown("### RAW LLM content:"))
pprint(res.content, width=200)
display(Markdown("### What user see and should not see like `function` tag:"))
display(Markdown(res.content))
display(Markdown("### LLM function call is empty (which is wrong):"))
pprint(f"{res.tool_calls=}")

### RAW LLM content:

('To answer the user\'s question, we need to execute a mathematical expression in Python. The expression is 12321342*23423423/22234. Since this involves executing Python code, the "exec_code" '
 'function is relevant here.\n'
 '\n'
 '<function=exec_code>{"code": "print(12321342*23423423/22234)"}</function>')


### What user see and should not see:

To answer the user's question, we need to execute a mathematical expression in Python. The expression is 12321342*23423423/22234. Since this involves executing Python code, the "exec_code" function is relevant here.

<function=exec_code>{"code": "print(12321342*23423423/22234)"}</function>

### LLM function call is empty:

'res.tool_calls=[]'
