In [21]:
import os 
import json

In [None]:
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent

from langchain_core.messages import HumanMessage, BaseMessage, SystemMessage
from langgraph.graph.message import add_messages

from typing import TypedDict, Annotated, List, Dict
from typing_extensions import TypedDict

In [23]:
from src.utils_trans import search_flight_details

In [24]:
llm = ChatDeepSeek(model="deepseek-chat",
                   api_key=os.getenv("DEEPSEEK_API_KEY"))

In [None]:
class State(TypedDict):
    messages: Annotated[List, add_messages]
    trip_start_date: str
    trip_end_date: str
    traveler_preferences: Dict
    start_flight: str
    return_flight: str
    

    flight_info: Dict
    total_schedule: Dict 
    report_markdown: str

In [26]:
def get_system_prompt(suffix: str) -> str:
    return (
        "You are a helpful AI assistant, collaborating with other assistants."
        " Use the provided tools to progress towards answering the question."
        " Always format your final answer as a JSON string."
        f"\n{suffix}"
    )

In [27]:
@tool
def get_flight_data(start_flight: str, return_flight: str) -> str:
    """
    Retrieve flight details for given flight numbers.
    """
    result = search_flight_details(start_flight, return_flight)
    json.dumps(result)
    return json.dumps(result)

In [28]:
flight_data_agent = create_react_agent(
    llm,
    tools=[get_flight_data],
    prompt=get_system_prompt(
        "You are a travel planning assistant specializing in retrieving round-trip flight information. "
        "Use the get_flight_data tool with provided flight numbers and return its output directly as a JSON string. "
        "Do NOT include any explanatory text, comments, or non-JSON content. "
        "Return an empty JSON object {} if no flight data is available or if the tool fails."
    ),
)

In [29]:
def flight_data_node(state: State) -> State:
    start_flight = state['start_flight']
    return_flight = state['return_flight']
    
    # Run agent
    agent_input = {"messages": state['messages'] + [HumanMessage(content=f"Get data for {start_flight} and {return_flight}")]}
    agent_output = flight_data_agent.invoke(agent_input)

    # Parse safely

    final_content = agent_output['messages'][-1].content
    flight_info = json.loads(final_content)

    
    # Update state with extracted dates (adjust keys based on search_flight_details)
    state['flight_info'] = flight_info
    state['messages'].append(SystemMessage(content=f"Flight info updated: {flight_info}"))
    return state

In [30]:
# Live test calling the real agent + tool (no stubs)
import os, json

# Ensure API key is present; skip gracefully if missing
if not os.getenv("DEEPSEEK_API_KEY"):
    print("Skipping live test: DEEPSEEK_API_KEY is not set.")
else:
    # Re-create the real agent in case a previous cell overrode it with a stub
    flight_data_agent = create_react_agent(
        llm,
        tools=[get_flight_data],
        prompt=get_system_prompt(
            "You are a travel planning assistant specializing in retrieving round-trip flight information. "
            "Use the get_flight_data tool with provided flight numbers and return its output directly as a JSON string. "
            "Do NOT include any explanatory text, comments, or non-JSON content. "
            "Return an empty JSON object {} if no flight data is available or if the tool fails."
        ),
    )

    # You can override these via environment variables if needed
    start_flight = os.getenv("TEST_START_FLIGHT", "UO870")
    return_flight = os.getenv("TEST_RETURN_FLIGHT", "UO871")

    # Minimal state for the node
    live_state = {
        "start_flight": start_flight,
        "return_flight": return_flight,
        "messages": [],
    }

    try:
        live_updated = flight_data_node(live_state)
        print("flight_info:")
        print(json.dumps(live_updated.get("flight_info", {}), indent=2))

        # Minimal sanity checks
        assert isinstance(live_updated.get("flight_info"), dict)
        # If tool succeeded, expect keys produced by search_flight_details
        if live_updated["flight_info"]:
            assert "start_flight" in live_updated["flight_info"], "Missing start_flight in flight_info"
            assert "return_flight" in live_updated["flight_info"], "Missing return_flight in flight_info"
            for leg_key in ("start_flight", "return_flight"):
                leg = live_updated["flight_info"][leg_key]
                for k in ("departure_airport", "arrival_airport", "departure_time", "estimated_arrival_time"):
                    assert k in leg, f"Missing {k} in {leg_key}"
        print("OK: live flight_data_node test passed")
    except Exception as e:
        print("Live test failed:", repr(e))
        raise


flight_info:
{
  "start_flight": {
    "departure_airport": "Hong Kong International Airport",
    "arrival_airport": "Tokyo Narita International Airport",
    "departure_region": "Hong Kong",
    "arrival_region": "Tokyo",
    "departure_time": "10:46",
    "estimated_arrival_time": "14:55"
  },
  "return_flight": {
    "departure_airport": "Tokyo Narita International Airport",
    "arrival_airport": "Hong Kong International Airport",
    "departure_region": "Tokyo",
    "arrival_region": "Hong Kong",
    "departure_time": "16:49",
    "estimated_arrival_time": "21:03"
  }
}
OK: live flight_data_node test passed
