In [1]:
import json
from dotenv import load_dotenv
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain_community.tools.tavily_search import TavilySearchResults

from langgraph.graph import StateGraph, MessagesState, START, END, MessageGraph
from IPython.display import display, Image
from typing import TypedDict, Annotated, Literal
from pydantic import BaseModel, Field
import operator
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode

load_dotenv("../.env")
llm = ChatGroq(
    # model="llama3-groq-70b-8192-tool-use-preview",
    model="llama-3.1-70b-versatile",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    # other params...
)
#embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

memory = MemorySaver()
config = {"configurable": {"thread_id": 1}}

In [2]:
from langchain_openai import ChatOpenAI
chatModel = ChatOpenAI(base_url="http://ai.mtcl.lan:11436/v1", api_key="fake_api_key", model="llama3.1")

In [16]:
joke_call_count = 0

def agent(input: list[HumanMessage]):
    return input

def joke_finder(input: list[HumanMessage]):
    global joke_call_count
    joke_call_count += 1
    print("joke_call_count: ", str(joke_call_count))
    print(llm.invoke(input).content)
    #print(chatModel.invoke(input).content)
    return input


def router_node1_to_node2(input: list[HumanMessage]):
    if joke_call_count < 5:
        return "tell_joke"
    else:
        return "end_joke"

In [19]:

graph = MessageGraph()

node1_id = "agent"
node2_id = "joke_finder"

graph.add_node(node1_id, agent)
graph.add_node(node2_id, joke_finder)

graph.add_conditional_edges(
    node1_id,
    router_node1_to_node2,
    {"tell_joke": node2_id, "end_joke": END}
)
graph.add_edge(node2_id, node1_id)
graph.set_entry_point(node1_id)

app = graph.compile()

In [18]:
def save_graph(runnable_graph, file_path):
    png_out = runnable_graph.get_graph().draw_mermaid_png()
    with open(file_path, "wb") as f:
        f.write(png_out)

In [20]:
save_graph(app,"third4_run.png")