Date: 8 Nov, 2024

In [1]:
import sys
sys.path.append("..")

In [2]:
from dotenv import load_dotenv
_ = load_dotenv("../.env")

In [3]:
from langchain_core.tools import tool
from pydantic import BaseModel, Field

from src.tools.location.location_matcher import LocationMatcher

### Inits

In [4]:
GADM_CSV_PATH = "../data/gadm.csv"
location_matcher = LocationMatcher(GADM_CSV_PATH)

### Location Tool

In [5]:
class LocationInput(BaseModel):
    """Input schema for location finder tool"""
    query: str = Field(
        description="Name of the location to search for. Can be a city, region, or country name."
    )
    threshold: int = Field(
        default=70,
        description="Minimum similarity score (0-100) to consider a match. Default is 70.",
        ge=0,
        le=100
    )


@tool("location-tool", args_schema=LocationInput, return_direct=True)
def location_tool(query: str, threshold: int = 70) -> dict:
    """Find locations and their administrative hierarchies given a place name.
      Returns matches at different administrative levels (ADM2, ADM1, ISO) with their IDs and names.

    Args:
        query (str): Location name to search for
        threshold (int, optional): Minimum similarity score. Defaults to 70.

    Returns:
        dict: matching locations
    """
    try:
        matches = location_matcher.find_matches(query, threshold=threshold)
        return matches
    except Exception as e:
        return f"Error finding locations: {str(e)}"

### Test

In [6]:
from langchain_ollama import ChatOllama
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode

llm = ChatOllama(model="mistral:instruct", temperature=0)
tools = [location_tool]
tool_node = ToolNode(tools)
llm_with_tools = llm.bind_tools(tools)

In [16]:
result = llm_with_tools.invoke("find datasets near Milan")

In [17]:
result.tool_calls

[{'name': 'location-tool',
  'args': {'query': 'Milan', 'threshold': 70},
  'id': 'f8681f2f-af82-4286-88f9-71f586d33bf1',
  'type': 'tool_call'}]

In [18]:
tool_result = tool_node.invoke({
    "messages": [
            AIMessage(
                content="", 
                tool_calls=result.tool_calls)
        ]
    })

In [23]:
tool_result["messages"][0].content

'[{"iso": "ITA", "adm1": 10.0, "adm2": 8.0, "names": {"iso": "Italy", "adm1": "Lombardia", "adm2": "Milano"}, "score": 91, "match_type": "adm2"}, {"iso": "YEM", "adm1": 8.0, "adm2": 8.0, "names": {"iso": "Yemen", "adm1": "Al Mahwit", "adm2": "Milhan"}, "score": 91, "match_type": "adm2"}, {"iso": "DZA", "adm1": 29.0, "adm2": 16.0, "names": {"iso": "Algeria", "adm1": "Mila", "adm2": "Mila"}, "score": 89, "match_type": "adm2"}]'