In [None]:
import sys
import typing as t
import textwrap
from pathlib import Path
from pydantic import BaseModel, Field
from langgraph.graph import StateGraph, START, END

sys.path.append(str(Path().cwd().parent))

from llm_experiments.models import instantiate_chat

In [None]:
class State(t.TypedDict):
    purpose: str
    code: str
    review: str
    approved: t.Literal["true", "false"]

In [None]:
class Code(BaseModel):
    code: str = Field(description="the code to be written")

In [None]:
class Review(BaseModel):
    approved: t.Literal["true", "false"] = Field(description="true if the code is approved, false otherwise")
    review: str = Field(description="feedback on the code")

In [None]:
model = instantiate_chat("4o-mini")
coder = model.with_structured_output(Code)
reviewer = model.with_structured_output(Review)

In [None]:
def write_code(state: State):
    for k, v in state.items():
        print(f" {k} ".center(88, "="))
        print(v)
        print()

    role = textwrap.dedent(
        f"""
        You are a senior software engineer writing code.
        The purpose of the code is to {state["purpose"]}.
        {f"but take into account the following review: {state['review']}" if state.get("approved") else ""}
        """
    )
    messages = [
        {"role": "system", "content": role},
    ]
    res = coder.invoke(messages)
    return {"code": res.code}

In [None]:
def review_code(state: State):
    role = textwrap.dedent(
        f"""
        You are a senior software engineer reviewing a piece of code.
        The purpose of the code is to {state["purpose"]}.
        if the code is not good enough, return "false" in the approved field. otherwise, return "true".
        """
    )
    messages = [
        {"role": "system", "content": role},
        {"role": "user", "content": state["code"]},
    ]
    res = reviewer.invoke(messages)
    return {"approved": res.approved, "review": res.review}

In [None]:
def router(state: State):
    match state["approved"]:
        case "true":
            return "approved"
        case "false":
            return "needs_improvement"
        case _:
            raise ValueError(f"Invalid approval status: {state['approved']}")

In [None]:
builder = StateGraph(State)

builder.add_node("write_code", write_code)
builder.add_node("review_code", review_code)

builder.add_edge(START, "write_code")
builder.add_edge("write_code", "review_code")
builder.add_conditional_edges("review_code", router, {"approved": END, "needs_improvement": "write_code"})

graph = builder.compile()

In [None]:
from IPython.display import display, Image

display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
res = graph.invoke({"purpose": "save all pages of a kindle book to a pdf file"})

In [None]:
print(res["code"])