In [22]:
import logging
import torch
from dotenv import load_dotenv

from aicraft.models import Functionary
from aicraft.types import VisualisationType
from aicraft.tools.executor import SQLExecutor
from aicraft.tools.tools import ToolHandler
import pandas as pd
import json
import numpy as np


load_dotenv()

logging.basicConfig(level=logging.INFO)


logger = logging.getLogger(__name__)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
device

device(type='mps')

## Testing the model

In [16]:
def user_detail(name: str, age: int):
    """
    Creates user details
    """
    return {
        "name": name,
        "age": age
    }

def item_detail(name: str, price: float):
    """
    Creates item details
    """
    return {
        "name": name,
        "price": price
    }


def get_hotness_score_for_counties_in_a_state(state_id: str, yyyymm: int) -> tuple[pd.DataFrame, str, VisualisationType]:
    """Helps in getting the hotness score for counties in a state which can help with understanding which counties to live in. It takes the 2 character state_id (eg: NJ, OH) and the yyyymm (eg: 202208, 202406) as input and returns the dataframe of the top hotness scores for among all  the counties in that state and the visualisation type"""
    query = f"""
        SELECT
            county_name AS "County",
            hotness_score AS "Hotness"
        FROM hobu.county_market_hotness
        WHERE state_id = '{state_id}' AND
              yyyymm = {yyyymm}
        ORDER BY hotness_score DESC;
    """
    df = SQLExecutor.execute(query)
    return df.set_index("County"), f"Hotness Scores per county in {state_id} for {yyyymm}", VisualisationType.BAR

class TestTools(ToolHandler):
    def __init__(self):
        super().__init__(
            {
                "user_detail": user_detail,
                "item_detail": item_detail,
                "get_hotness_score_for_counties_in_a_state": get_hotness_score_for_counties_in_a_state
            }
        )

In [17]:
functionary = Functionary()
tools = TestTools()
tools.get_tools()

[{'type': 'function',
  'function': {'name': 'user_detail',
   'description': 'Creates user details',
   'parameters': {'type': 'object',
    'title': 'user_detail',
    'properties': {'name': {'title': 'Name', 'type': 'string'},
     'age': {'title': 'Age', 'type': 'integer'}},
    'required': ['name', 'age']}}},
 {'type': 'function',
  'function': {'name': 'item_detail',
   'description': 'Creates item details',
   'parameters': {'type': 'object',
    'title': 'item_detail',
    'properties': {'name': {'title': 'Name', 'type': 'string'},
     'price': {'title': 'Price', 'type': 'number'}},
    'required': ['name', 'price']}}},
 {'type': 'function',
  'function': {'name': 'get_hotness_score_for_counties_in_a_state',
   'description': 'Helps in getting the hotness score for counties in a state which can help with understanding which counties to live in. It takes the 2 character state_id (eg: NJ, OH) and the yyyymm (eg: 202208, 202406) as input and returns the dataframe of the top hotne

In [18]:
response = functionary.model.create_chat_completion(
  messages = [
    {
      "role": "system",
      "content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"

    },
    {
      "role": "user",
      "content": "Extract Jason is 25 years old"
    }
  ],
  tools=tools.get_tools(),
  tool_choice="auto"
)

func_name = response["choices"][0]["message"]["tool_calls"][0]["function"]["name"]
kwargs = json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"])
result = tools.execute_tool(func_name, **kwargs)
result



{'name': 'Jason', 'age': 25}

In [19]:
response = functionary.model.create_chat_completion(
  messages = [
    {
      "role": "system",
      "content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"

    },
    {
      "role": "user",
      "content": "Biryani costs $20.13"
    }
  ],
  tools=tools.get_tools(),
  tool_choice="auto"
)

func_name = response["choices"][0]["message"]["tool_calls"][0]["function"]["name"]
kwargs = json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"])
result = tools.execute_tool(func_name, **kwargs)
result

{'name': 'Biryani', 'price': 20.13}

In [20]:
response = functionary.model.create_chat_completion(
  messages = [
    {
      "role": "system",
      "content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"
    },
    {
      "role": "user",
      "content": "New Jersey is a good place to consider living in as it has a lot of great localities. Time period: 202407"
    }
  ],
  tools=tools.get_tools(),
  tool_choice="auto"
)

func_name = response["choices"][0]["message"]["tool_calls"][0]["function"]["name"]
kwargs = json.loads(response["choices"][0]["message"]["tool_calls"][0]["function"]["arguments"])
result = tools.execute_tool(func_name, **kwargs)
result[0]

  df = pd.read_sql_query(sql, connection)


Unnamed: 0,County,Hotness
0,"morris, nj",94.389027
1,"gloucester, nj",93.734414
2,"burlington, nj",92.269327
3,"somerset, nj",89.650873
4,"camden, nj",89.245636
5,"passaic, nj",87.157107
6,"monmouth, nj",84.071072
7,"union, nj",84.008728
8,"middlesex, nj",82.793017
9,"warren, nj",82.107232


In [23]:
chart_data = pd.DataFrame(np.random.randn(20, 3), columns=["a", "b", "c"])
chart_data

Unnamed: 0,a,b,c
0,0.268141,-1.637154,0.917924
1,0.718016,-0.848007,0.569307
2,-1.82198,0.574672,-1.653319
3,-1.586913,1.071389,-2.460579
4,-0.753374,0.089203,-1.048517
5,-0.748945,0.406198,-2.352635
6,-0.213915,-1.427047,-1.049536
7,1.599471,-0.0094,-0.432406
8,-1.003764,-1.359201,-1.201201
9,-0.613224,0.939954,-0.594843
