In [None]:
import sys
from pathlib import Path
import typing as t
from langgraph.graph import StateGraph, START, END
import textwrap
from IPython.display import Image

sys.path.append(str(Path().cwd().parent))

from llm_experiments.chat import instantiate_chat

In [None]:
model = instantiate_chat("4o-mini")

In [None]:
class State(t.TypedDict):
    topic: str
    liberal_opinion: str
    conservative_opinion: str
    compromised_opinion: str

In [None]:
def generate_liberal_opinion(state: State):
    input_text = textwrap.dedent(
        f"""
        you are a partisan of a debate.
        you are given a topic and you need to say a liberal opinion on the topic.
        topic: {state["topic"]}
        """
    )
    msg = model.invoke(input_text)
    return {"liberal_opinion": msg.content}


def generate_conservative_opinion(state: State):
    input_text = textwrap.dedent(
        f"""
        you are a partisan of a debate.
        you are given a topic and you need to say a conservative opinion on the topic.
        topic: {state["topic"]}
        """
    )
    msg = model.invoke(input_text)
    return {"conservative_opinion": msg.content}


def find_a_compromise(state: State):
    input_text = textwrap.dedent(
        f"""
        you are a political analyst.
        you are given two opinions on a topic and you need to find a compromise between the two.
        topic: {state["topic"]}
        liberal opinion: {state["liberal_opinion"]}
        conservative opinion: {state["conservative_opinion"]}
        """
    )
    msg = model.invoke(input_text)
    return {"compromised_opinion": msg.content}


graph = StateGraph(State)
graph.add_node("generate_liberal_opinion", generate_liberal_opinion)
graph.add_node("generate_conservative_opinion", generate_conservative_opinion)
graph.add_node("find_a_compromise", find_a_compromise)

graph.add_edge(START, "generate_liberal_opinion")
graph.add_edge(START, "generate_conservative_opinion")
graph.add_edge("generate_liberal_opinion", "find_a_compromise")
graph.add_edge("generate_conservative_opinion", "find_a_compromise")
graph.add_edge("find_a_compromise", END)

chain = graph.compile()

In [None]:
from IPython.display import display

display(Image(chain.get_graph().draw_mermaid_png()))

In [None]:
for k, v in chain.invoke({"topic": "migrants"}).items():
    print(k)
    print(v)
    print()