Date: 8 Nov, 2024

In [None]:
import sys

sys.path.append("..")

In [None]:
from dotenv import load_dotenv

_ = load_dotenv("../.env")

In [None]:
from langchain_core.tools import tool
from pydantic import BaseModel, Field

from src.tools.location.location_matcher import LocationMatcher

### Inits

In [None]:
GADM_CSV_PATH = "../data/gadm.csv"
location_matcher = LocationMatcher(GADM_CSV_PATH)

### Location Tool

In [None]:
class LocationInput(BaseModel):
    """Input schema for location finder tool"""

    query: str = Field(
        description="Name of the location to search for. Can be a city, region, or country name."
    )
    threshold: int = Field(
        default=70,
        description="Minimum similarity score (0-100) to consider a match. Default is 70.",
        ge=0,
        le=100,
    )


@tool("location-tool", args_schema=LocationInput, return_direct=True)
def location_tool(query: str, threshold: int = 70) -> dict:
    """Find locations and their administrative hierarchies given a place name.
      Returns matches at different administrative levels (ADM2, ADM1, ISO) with their IDs and names.

    Args:
        query (str): Location name to search for
        threshold (int, optional): Minimum similarity score. Defaults to 70.

    Returns:
        dict: matching locations
    """
    try:
        matches = location_matcher.find_matches(query, threshold=threshold)
        return matches
    except Exception as e:
        return f"Error finding locations: {str(e)}"

### Test

In [None]:
location_tool.invoke(input={"query": "lisbon portugal"})

In [None]:
from langchain_core.messages import AIMessage
from langchain_ollama import ChatOllama
from langgraph.prebuilt import ToolNode

llm = ChatOllama(model="mistral:instruct", temperature=0)
tools = [location_tool]
tool_node = ToolNode(tools)
llm_with_tools = llm.bind_tools(tools)

In [None]:
result = llm_with_tools.invoke("find datasets near Milan")

In [None]:
result.tool_calls

In [None]:
tool_result = tool_node.invoke(
    {"messages": [AIMessage(content="", tool_calls=result.tool_calls)]}
)

In [None]:
tool_result["messages"][0].content