In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import TypedDict, Annotated, Literal
from langgraph.graph import StateGraph, START, END, MessagesState
from langgraph.types import Command, Send
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode
from langchain_core.tools import tool, InjectedToolCallId
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage
import operator
import json

In [3]:
class State(TypedDict):
    messages: Annotated[list, operator.add]
    pending_tool_calls: list[dict]
    tool_results: Annotated[list[dict], operator.add]
    needs_human_review: bool
    final_answer: str

## Tool calls with `InjectedToolCallId`

In [15]:
import yfinance as yf
from datetime import datetime
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@tool
def search_for_stock_price(
    ticker: str,
    tool_call_id: Annotated[str, InjectedToolCallId],
    start_date: Annotated[str | None, "Date time in YYYY-MM-DD format"] = None,
    end_date: str | None = None,
) -> Command:
    """Searches Yahoo Finance for stock information

    Args:
        ticker (str): Ticker of interest. For example "META"
        tool_call_id (Annotated[str, InjectedToolCallId]): _description_

    Returns:
        Command: A node of the parent graph to route to
    """
    date_format = "%Y-%m-%d"
    try:
        stock_ticker = yf.Ticker(ticker)
        if start_date and end_date:
            try:
                start_date_formatted = datetime.strptime(start_date, date_format)
                end_date_formatted = datetime.strptime(end_date, date_format)
                data = stock_ticker.history(start=start_date_formatted, end=end_date_formatted).to_json()
            
            except Exception as e:
                # If we fail to parse the datetime format, we just get the max information.
                logger.error(f"Faced the error: {e}. Getting data from max period")
                data = stock_ticker.history(period="max").to_json()
        else:                
            data = stock_ticker.history(period="max").to_json()
        
        tool_msg = ToolMessage(
            content=data,
            tool_call_id=tool_call_id,
            name="search_for_stock_prices"
        )
        return Command(
            update={
                "messages": [tool_msg],
                "tool_results":[{
                    'tool': 'search_for_stock_prices',
                    'status':'success',
                    'results': data
                }]
            },
            goto="analyze_results"
        )
        
    except Exception as e:
        msg = f"Unable to find the info from {ticker} due to {e}. It is likely that the ticker is incorrect"
        logger.error(msg)
        tool_msg = ToolMessage(
            content=msg,
            tool_call_id=tool_call_id,
            name="search_for_stock_prices"
        )
        return Command(
            update={'messages': [tool_msg]},
            goto="handle_no_results"
        )
    

In [None]:
def handle_no_results(state: State) -> Literal[Command("aggregate_results")]:
    """Handles cases where tools found no results."""
    return Command(
        update={"messages"}
    )