In [189]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [190]:
import os
import json
import dotenv

In [191]:
dotenv.load_dotenv()

True

In [192]:
from typing import TypedDict, Annotated, List, Dict, Literal
from typing_extensions import TypedDict
from IPython.display import display, Markdown, Image
import operator
import grandalf
import pygraphviz

In [193]:
from langchain_deepseek import ChatDeepSeek
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent

In [194]:
from langgraph.graph import START, END, StateGraph
from langgraph.types import Command, Send
from langchain_core.messages import HumanMessage, BaseMessage, SystemMessage
from langchain_experimental.utilities import PythonREPL
from langgraph.graph.message import add_messages

In [195]:
#from src.utils_email import send_report_via_email
from src.utils_trans import search_flight_details

In [196]:
def get_system_prompt(suffix: str) -> str:
    return (
        "You are a helpful AI assistant, collaborating with other assistants."
        " Use the provided tools to progress towards answering the question."
        " Always format your final answer as a JSON string."
        f"\n{suffix}"
    )

In [197]:
llm = ChatDeepSeek(model="deepseek-chat",
                   api_key=os.getenv("DEEPSEEK_API_KEY"))

In [198]:
class State(TypedDict):
    messages: Annotated[List, add_messages]
    trip_start_date: str
    trip_end_date: str
    traveler_preferences: Dict

    # Output of day_to_split_node
    split_decision: Dict  # {"should_split": bool, "rationale": string}
    segments: List[Dict]  # normalized segments

    # Per-branch context (changed to Dict, no reducer)
    current_segment: Dict

    # Map branch intermediates per segment id
    segment_pools: Annotated[List[Dict], operator.add]  # [{segment_id, activities: [...] }]
    segment_schedules: Annotated[List[Dict], operator.add]  # [{segment_id, schedule: [...]}]

    # Aggregation
    total_schedule: Dict  # {days: [...], transitions: [...]} 
    report_markdown: str

In [None]:
class State(TypedDict):
    messages: Annotated[List, add_messages]
    trip_start_date: str
    trip_end_date: str
    traveler_preferences: Dict
    start_flight: str
    return_flight: str
    

    flight_info: Dict
    total_schedule: Dict 
    report_markdown: str

### Transportation Research Agent

In [None]:
@tool
def get_flight_data(start_flight: str, return_flight: str) -> str:
    """
    Retrieve flight details for given flight numbers.
    """
    result = search_flight_details(start_flight, return_flight)
    json.dumps(result)
    return json.dumps(result)


In [200]:
#get_flight_data()

In [201]:
flight_data_agent = create_react_agent(
    llm,
    tools=[get_flight_data],
    prompt=get_system_prompt(
        "You are a travel planning assistant specializing in retrieving round-trip flight information. "
        "Use the get_flight_data tool with provided flight numbers and return its output directly as a JSON string. "
        "Do NOT include any explanatory text, comments, or non-JSON content. "
        "Return an empty JSON object {} if no flight data is available or if the tool fails."
    ),
)

### Activities (Weather focused) Agent

### Schedule Generate Agent

### Define Graph

In [202]:
def flight_data_node(state: State) -> State:
    # Extract flight numbers from last message (basic parsing, improve with regex/LLM)
    last_message = state['messages'][-1].content if state['messages'] else ""
    # Example: expect format like "flights: start=ABC123, return=DEF456"
    import re
    start_match = re.search(r'start=(\w+)', last_message)
    return_match = re.search(r'return=(\w+)', last_message)
    start_flight = start_match.group(1) if start_match else "DEFAULT_START"
    return_flight = return_match.group(1) if return_match else "DEFAULT_RETURN"
    
    # Run agent
    agent_input = {"messages": state['messages'] + [HumanMessage(content=f"Get data for {start_flight} and {return_flight}")]}
    try:
        agent_output = flight_data_agent.invoke(agent_input)
    except Exception as e:
        print(f"Agent invocation failed: {e}")
        agent_output = {"messages": [SystemMessage(content="{}")]}  # Fallback

    # Parse safely
    try:
        final_content = agent_output['messages'][-1].content
        print(f"Agent output content: {final_content}")  # Debug
        if not final_content.strip():
            print("Warning: Agent output is empty")
            flight_info = {}
        else:
            flight_info = json.loads(final_content)
    except (KeyError, IndexError) as e:
        print(f"Error accessing agent output: {e}")
        flight_info = {}
    except json.JSONDecodeError as e:
        print(f"JSON parsing failed: {e}\nContent was: {final_content}")
        flight_info = {}
    
    # Update state with extracted dates (adjust keys based on search_flight_details)
    state['flight_info'] = flight_info
    state['trip_start_date'] = flight_info.get('start_flight', {}).get('departure_date', state.get('trip_start_date', ''))
    state['trip_end_date'] = flight_info.get('return_flight', {}).get('arrival_date', state.get('trip_end_date', ''))
    state['messages'].append(SystemMessage(content=f"Flight info updated: {flight_info}"))
    return state

In [203]:
import uuid
from datetime import datetime, timedelta
from typing import Any, Dict


def _invoke_json(system: str, user: str) -> Dict[str, Any]:
    for attempt in range(3):  # Retry up to 3 times
        messages = [SystemMessage(content=system), HumanMessage(content=user)]
        ai = llm.invoke(messages)
        content = ai.content
        try:
            parsed = json.loads(content)
            required_keys = ["segments"]  # Add validation for your nodes
            if all(key in parsed for key in required_keys):
                return parsed
        except json.JSONDecodeError:
            start = content.find("{")
            end = content.rfind("}")
            if start != -1 and end != -1:
                try:
                    return json.loads(content[start:end + 1])
                except:
                    pass
    raise ValueError("Failed to parse JSON from LLM after retries.")


def _date_range_days(start_date: str, end_date: str) -> int:
    try:
        start = datetime.fromisoformat(start_date).date()
        end = datetime.fromisoformat(end_date).date()
        return (end - start).days + 1
    except ValueError:
        raise ValueError(f"Invalid date format: {start_date} or {end_date}. Must be ISO (YYYY-MM-DD).")


def day_to_split_node(state: State) -> State:
    system = (
        "You split trips into logical city segments. Output STRICT JSON only with keys: "
        "split_decision: {should_split: bool, rationale: string}, "
        "segments: [{segment_id, location, start_date, end_date, days}]. "
        "Ensure segments fully cover the trip with no gaps/overlaps and ISO dates."
    )
    user = json.dumps({
        "trip_start_date": state.get("trip_start_date") or state.get("start_date"),
        "trip_end_date": state.get("trip_end_date") or state.get("return_date"),
        "preferences": state.get("traveler_preferences", {}),
        "message": state.get("messages", [])[-1].content if state.get("messages") else "",
    })
    result = _invoke_json(system, user)

    segments = result.get("segments", [])

    # Add fallback if no segments or no split
    if not segments or not result.get("split_decision", {}).get("should_split", True):
        default_start = state.get("trip_start_date") or state.get("start_date") or "2025-01-01"
        default_end = state.get("trip_end_date") or state.get("return_date") or "2025-01-10"
        segments = [{
            "location": "Full Trip / Default Location",
            "start_date": default_start,
            "end_date": default_end,
        }]

    normalized = []
    for seg in segments:
        seg_id = seg.get("segment_id") or str(uuid.uuid4())
        loc = seg.get("location", "Unknown")
        s = seg.get("start_date", "1970-01-01")
        e = seg.get("end_date", "1970-01-02")
        days = seg.get("days") or _date_range_days(s, e)
        normalized.append({
            "segment_id": seg_id,
            "location": loc,
            "start_date": s,
            "end_date": e,
            "days": days,
        })

    state["split_decision"] = result.get("split_decision", {})
    state["segments"] = normalized
    return state


def activities_pool_node(state: State) -> State:
    seg = state["current_segment"]  # Direct dict

    # Add check for safety
    if not seg:
        raise ValueError("No current_segment provided for activities_pool_node")

    system = (
        "Generate a diverse activity pool tailored to the city and days. "
        "Return STRICT JSON: {segment_id, activities: ["
        "{name, category, est_duration_min, opening_hours, area, cost_level, notes}]}]. "
        "Include food, culture, light hiking options based on preferences."
    )
    user = json.dumps({
        "segment": seg,
        "preferences": state.get("traveler_preferences", {}),
        "target_pool_size": 12,
    })
    result = _invoke_json(system, user)
    entry = {
        "segment_id": seg["segment_id"],
        "activities": result.get("activities", [])[:15],
    }
    return {"segment_pools": [entry]}


def daily_schedule_node(state: State) -> State:
    seg = state["current_segment"]  # Direct dict

    # Add check for safety
    if not seg:
        raise ValueError("No current_segment provided for daily_schedule_node")

    pools_by_id = {p["segment_id"]: p for p in state.get("segment_pools", [])}
    pool = pools_by_id.get(seg["segment_id"], {"activities": []})

    system = (
        "Create daily schedules from the candidate pool. Respect transit time, "
        "opening hours, and breaks. Each day ~6-9 hours of activities. "
        "Return STRICT JSON: {segment_id, schedule: ["
        "{date, items: [{start_min, end_min, type, name, notes}]}]}. "
        "Times are minutes from 00:00. Include meals and rest."
    )
    user = json.dumps({
        "segment": seg,
        "activities_pool": pool,
        "assumptions": {
            "default_open": "09:00-20:00",
            "avg_transit_min": 20,
            "lunch_window": "12:00-13:00",
            "dinner_window": "18:00-19:00",
        },
    })
    result = _invoke_json(system, user)
    entry = {
        "segment_id": seg["segment_id"],
        "schedule": result.get("schedule", []),
    }
    return {"segment_schedules": [entry]}



def total_schedule_node(state: State) -> State:
    schedules_by_id = {s["segment_id"]: s["schedule"] for s in state.get("segment_schedules", [])}
    ordered_segments = sorted(state.get("segments", []), key=lambda s: s["start_date"]) 

    days_out = []
    transitions = []
    for i, seg in enumerate(ordered_segments):
        for day in schedules_by_id.get(seg["segment_id"], []):
            days_out.append({
                "segment_id": seg["segment_id"],
                "location": seg["location"],
                "date": day.get("date"),
                "items": day.get("items", []),
            })
        if i > 0:
            prev = ordered_segments[i - 1]
            transitions.append({
                "from": prev["location"],
                "to": seg["location"],
                "date": seg["start_date"],
                "notes": "Inter-city transfer between segments",
            })

    state["total_schedule"] = {"days": days_out, "transitions": transitions}
    return state


In [204]:
def report_node(state: State) -> State:
    total = state.get("total_schedule", {"days": [], "transitions": []})
    lines = []
    lines.append(f"# Trip Plan: {state['trip_start_date']} to {state['trip_end_date']}")
    lines.append("")
    # Overview
    lines.append("## Overview")
    lines.append("- Segments:")
    for seg in sorted(state.get("segments", []), key=lambda s: s["start_date"]):
        lines.append(f"  - {seg['location']}: {seg['start_date']} → {seg['end_date']} ({seg['days']} days)")
    if total.get("transitions"):
        lines.append("- Transitions:")
        for t in total["transitions"]:
            lines.append(f"  - {t['date']}: {t['from']} → {t['to']} — {t['notes']}")
    lines.append("")

    # Daily tables
    lines.append("## Daily Schedule")
    for day in sorted(total.get("days", []), key=lambda d: d.get("date", "")):
        lines.append(f"### {day.get('date', '')} — {day.get('location', '')}")
        lines.append("| Time | Type | Name | Notes |")
        lines.append("|------|------|------|-------|")
        for item in day.get("items", []):
            start = item.get("start_min", 0)
            end = item.get("end_min", 0)
            def _hm(m):
                h = int(m // 60)
                mm = int(m % 60)
                return f"{h:02d}:{mm:02d}"
            time_str = f"{_hm(start)}–{_hm(end)}" if end else _hm(start)
            lines.append(f"| {time_str} | {item.get('type','activity')} | {item.get('name','')} | {item.get('notes','')} |")
        lines.append("")

    state["report_markdown"] = "\n".join(lines)
    return state

In [205]:
def split_router(state: State) -> List[Send]:
    if state.get("split_decision", {}).get("should_split", False) and state.get("segments"):
        return [Send("activities_pool_agent", {"current_segment": seg}) for seg in state["segments"]]  # seg is dict
    else:
        return []  # No fanout, but with fallback in day_to_split_node, this should rarely be empty

In [206]:
def join_router(state: State):
    # Proceed when we have a schedule for every segment
    num_segs = len(state.get("segments", []))
    have = len({s.get("segment_id") for s in state.get("segment_schedules", [])})
    if num_segs > 0 and have >= num_segs:
        return "total_schedule_agent"
    return "__end__"



In [207]:
workflow = StateGraph(State)


# Core nodes
workflow.add_node("flight_data_agent", flight_data_node)
workflow.add_node("day_to_split_agent", day_to_split_node)
workflow.add_node("activities_pool_agent", activities_pool_node)
workflow.add_node("daily_schedule_agent", daily_schedule_node)
workflow.add_node("total_schedule_agent", total_schedule_node)
workflow.add_node("report_agent", report_node)

# Fan-out per segment and gather back
workflow.add_edge("flight_data_agent", "day_to_split_agent")
workflow.add_conditional_edges("day_to_split_agent", split_router)
workflow.add_edge("activities_pool_agent", "daily_schedule_agent")
workflow.add_conditional_edges("daily_schedule_agent", join_router)
workflow.add_edge("total_schedule_agent", "report_agent")
workflow.add_edge("report_agent", END)

# Entry point
workflow.set_entry_point("flight_data_agent")
graph = workflow.compile()

In [208]:
# Try ASCII
try:
    graph.get_graph().print_ascii()
except ImportError as e:
    print("ASCII drawing requires grandalf. Install with: pip install grandalf\n", e)

# Try PNG
# try:
#     output_path = "graph.png"
#     graph.get_graph().draw_png(output_path)
#     print(f"PNG graph written to {output_path}")
# except ImportError as e:
#     print("PNG drawing requires pygraphviz. Install with: pip install pygraphviz\n", e)

    +-----------+      
    | __start__ |      
    +-----------+      
           *           
           *           
           *           
+-------------------+  
| flight_data_agent |  
+-------------------+  
           *           
           *           
           *           
+--------------------+ 
| day_to_split_agent | 
+--------------------+ 
           *           
           *           
           *           
      +---------+      
      | __end__ |      
      +---------+      


### Run Graph

In [209]:
input_message = HumanMessage(
    content=(
        "Plan a 10-day Japan trip arriving Kansai (Osaka). Focus on food, culture, "
        "light hiking, and efficient transit."
    )
)

In [None]:
TRIP_START = "2025-03-10"
TRIP_END = "2025-03-24"
PREFERENCES = {
    "pace": "moderate",
    "interests": ["food", "culture", "light hiking"],
    "budget": "mid",
}

#EMAIL_ADDRESS = "terence2379@gmail.com"

In [211]:
events = graph.stream(
    input={
        "messages": [input_message],
        "trip_start_date": TRIP_START,
        "trip_end_date": TRIP_END,
        "traveler_preferences": PREFERENCES,
    },
    config={"recursion_limit": 50},  # Increased from 30
    stream_mode="values",
)
final_report = None
for event in events:
    if "report_markdown" in event:
        final_report = event["report_markdown"]

if final_report:
    display(Markdown(final_report))
else:
    print("No final report produced. Check inputs or graph configuration.")

Agent output content: I understand you'd like flight data for your Japan trip planning, but I'm currently unable to retrieve the specific flight information you need. The flight data retrieval tool requires specific flight numbers to provide departure and arrival details.

For your 10-day Japan itinerary focusing on food, culture, light hiking, and efficient transit through Kansai (Osaka), you might want to:

1. Check your airline's website or booking confirmation for the specific flight numbers
2. Contact your travel agent or airline directly for the flight details
3. Use an online flight tracking service with your booking reference

Once you have the specific flight numbers, I'd be happy to help retrieve the flight data for your trip planning. In the meantime, I can certainly help with other aspects of your Japan itinerary planning if you'd like!
JSON parsing failed: Expecting value: line 1 column 1 (char 0)
Content was: I understand you'd like flight data for your Japan trip plannin

KeyError: 'current_segment'

### Send Email


In [None]:
#body = message.content

#title = llm.invoke(
#    f"Extract the subject from this report content as the email subject and return the title directly without any introductory text: {body}")

#subject = title.content

#result = send_report_via_email(subject, body, EMAIL_ADDRESS)
#print(result)