In [1]:
cd ..

/Users/suriya/projects/dspy


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
%load_ext autoreload
%autoreload 2

## Tool Support Example

[tool-support](https://ollama.com/blog/tool-support)

In [4]:
import ollama

response = ollama.chat(
    model='llama3.1',
    messages=[{'role': 'user', 'content': 
        'What is the weather in Toronto?'}],

		# provide a weather checking tool to the model
    tools=[{
      'type': 'function',
      'function': {
        'name': 'get_current_weather',
        'description': 'Get the current weather for a city',
        'parameters': {
          'type': 'object',
          'properties': {
            'city': {
              'type': 'string',
              'description': 'The name of the city',
            },
          },
          'required': ['city'],
        },
      },
    },
  ],
)

print(response['message']['tool_calls'])

[{'function': {'name': 'get_current_weather', 'arguments': {'city': 'Toronto'}}}]


# A real example

## Get data

In [8]:
import yfinance as yf
TICKER = "AAPL"

def getHistoricalPrice(start: str, end: str):
    df = yf.download(TICKER, start, end)
    return df["Close"]


def computeStatistic(start: str, end: str, statistic: str="mean"):
    data = getHistoricalPrice(start, end)
    return {
        "mean": data.mean(),
        "max": data.max(),
        "min": data.min()
    }[statistic].item()

start, end = "2024-08-01", "2024-08-31"
computeStatistic(start, end, statistic="min")

[*********************100%***********************]  1 of 1 completed


207.22999572753906

## Describe the function

In [18]:
tools = [{
        'type': 'function',
        'function': {
            'name': 'getHistoricalPrice',
            'description': '''
            Get historical stock price given start and end dates.
            Example: Give me last month\'s stock price data.
            '''
            ,
            'parameters': {
            'type': 'object',
            'properties': {
                'start': {
                'type': 'string',
                'description': 'Start Date in YYYY-MM-DD format',
                },
                'end': {
                'type': 'string',
                'description': 'End Date in YYYY-MM-DD format',
                },
            },
            'required': ['start', 'end'],
            },
        },
        },
        {
        'type': 'function',
        'function': {
            'name': 'computeStatistic',
            'description': '''
            Given start and end dates compute a summary statistic represented by statistic argument.
            Example: "What is the maximum stock price this year?".
            '''
            ,
            'parameters': {
            'type': 'object',
            'properties': {
                'start': {
                'type': 'string',
                'description': 'Start Date in YYYY-MM-DD format',
                },
                'end': {
                'type': 'string',
                'description': 'End Date in YYYY-MM-DD format',
                },
                'statistic': {
                'type': 'string',
                'description': 'Summary statistic to compute. Options include "min", "max", and, "mean".',
                },
            },
            'required': ['start', 'end', 'statistic'],
            },
        },
        },
    ]


In [14]:
from datetime import datetime
DATE_FORMAT = "%Y-%m-%d"


def today():
    return datetime.now().strftime(DATE_FORMAT)

In [21]:
import ollama


def generate_response(query):
    response = ollama.chat(
        model='llama3.1',
        messages=[{"role": "system", "content": f"today's date is {today()}"}, {'role': 'user', 'content': query}],
        tools=tools
    )
    content = response["message"]["content"]
    if content:
        return content
    return response["message"]["tool_calls"]

queries = [
    "Give me this year's stock price data.",
    "Lowest price in 2023",
]
for i,q in enumerate(queries):
    print(i+1, ". ", q)
    print(generate_response(q))
    print("---")


1 .  Give me this year's stock price data.
[{'function': {'name': 'getHistoricalPrice', 'arguments': {'end': '2024-09-02', 'start': '2024-01-01'}}}]
---
2 .  Lowest price in 2023
[{'function': {'name': 'computeStatistic', 'arguments': {'end': '2023-12-31', 'start': '2023-01-01', 'statistic': 'min'}}}]
---


That worked out great for us. 
Now, let's put it into action.

## Function calling

In [24]:
import json


def call_function(fn):
    args = fn["function"]["arguments"]
    name = fn["function"]["name"]

    return {
        "getHistoricalPrice": getHistoricalPrice,
        "computeStatistic": computeStatistic
    }[name](**args)
    

### Rewrite `generate_response` to accomodate function calling

In [27]:
def generate_response(query):
    response = ollama.chat(
        model='llama3.1',
        messages=[{"role": "system", "content": f"today's date is {today()}"}, {'role': 'user', 'content': query}],
        tools=tools
    )
    content = response["message"]["content"]
    if content:
        return content
    fn = response["message"]["tool_calls"][0]
    print(fn)
    return call_function(fn)


In [28]:
queries = [
    "Give me last year's stock price data.",
    "Lowest price in 2023",
]
for i,q in enumerate(queries):
    print(i+1, ". ", q)
    print(generate_response(q))
    print("---")

1 .  Give me last year's stock price data.


[*********************100%***********************]  1 of 1 completed

{'function': {'name': 'getHistoricalPrice', 'arguments': {'end': '2023-09-01', 'start': '2022-09-02'}}}
Date
2022-09-02    155.809998
2022-09-06    154.529999
2022-09-07    155.960007
2022-09-08    154.460007
2022-09-09    157.369995
                 ...    
2023-08-25    178.610001
2023-08-28    180.190002
2023-08-29    184.119995
2023-08-30    187.649994
2023-08-31    187.869995
Name: Close, Length: 250, dtype: float64
---
2 .  Lowest price in 2023



[*********************100%***********************]  1 of 1 completed

{'function': {'name': 'computeStatistic', 'arguments': {'end': '2023-12-31', 'start': '2023-01-01', 'statistic': 'min'}}}
125.0199966430664
---



