In [None]:
import sys
from pathlib import Path
import operator
from pydantic import BaseModel, Field
import typing as t
from langgraph.graph import StateGraph, START, END
from langgraph.constants import Send
from langchain_core.messages import HumanMessage, SystemMessage

sys.path.append(str(Path().cwd().parent))

from llm_experiments.chat import instantiate_chat

In [None]:
class Section(BaseModel):
    name: str = Field(description="the name for this section of the report")
    description: str = Field(description="brief overview of the main topics and concepts to be covered in this section")

In [None]:
model = instantiate_chat("4o-mini")

In [None]:
class Sections(BaseModel):
    sections: list[Section] = Field(description="sections of the report")


planner = model.with_structured_output(Sections)

In [None]:
class State(t.TypedDict):
    topic: str
    sections: list[Section]
    completed_sections: t.Annotated[list, operator.add]
    final_report: str


def orchestrator(state: State):
    report_sections = planner.invoke(
        [
            SystemMessage(content="generate a plan for the report"),
            HumanMessage(content=f"here is the report topic: {state['topic']}"),
        ]
    )
    return {"sections": report_sections.sections}

In [None]:
class WorkerState(t.TypedDict):
    section: Section
    completed_sections: t.Annotated[list, operator.add]


def llm_call(state: WorkerState):
    section = model.invoke(
        [
            SystemMessage(
                content="write a report section following the provided name and description. include no preamble for each section. use markdown formatting."
            ),
            HumanMessage(
                content=f"here is the section name: {state['section'].name} and description: {state['section'].description}"
            ),
        ]
    )
    return {"completed_sections": [section.content]}

In [None]:
def synthesizer(state: State):
    completed_sections = state["completed_sections"]
    completed_report_sections = "\n\n--\n\n".join(completed_sections)
    return {"final_report": completed_report_sections}

In [None]:
def assign_workers(state: State):
    return [Send("llm_call", {"section": s}) for s in state["sections"]]

In [None]:
builder = StateGraph(State)

builder.add_node("orchestrator", orchestrator)
builder.add_node("llm_call", llm_call)
builder.add_node("synthesizer", synthesizer)

builder.add_edge(START, "orchestrator")
builder.add_conditional_edges("orchestrator", assign_workers, ["llm_call"])
builder.add_edge("llm_call", "synthesizer")
builder.add_edge("synthesizer", END)

chain = builder.compile()

In [None]:
from IPython.display import display, Image

display(Image(chain.get_graph().draw_mermaid_png()))

In [None]:
state = {"topic": "the impact of ai on the future of work"}

chain.invoke(state)