## 03_langgraph

LangGraph を用いた Agent 構成

In [None]:
import os
from dotenv import load_dotenv, find_dotenv
import uuid
from typing import Literal, List

import oracledb

from langfuse import Langfuse
from langfuse.callback import CallbackHandler

from langchain_core.documents import Document
from langchain_core.tools import tool
from langchain_core.tools import Tool
from langchain_core.prompt_values import HumanMessage
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_community.vectorstores.oraclevs import OracleVS
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI
from langchain_community.embeddings.oci_generative_ai import OCIGenAIEmbeddings
from langchain_community.tools.tavily_search.tool import TavilySearchResults

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode

環境変数を `.env` から取得します

In [None]:
_ = load_dotenv(find_dotenv())

# Oracle Database
un = os.getenv("ORACLE_USERNAME")
pw = os.getenv("ORACLE_PASSWORD")
dsn = os.getenv("ORACLE_DSN")
config_dir = "/tmp/wallet"
wallet_location = "/tmp/wallet"
wallet_password = os.getenv("WALLET_PASSWORD")
table_name = os.getenv("TABLE_NAME")

# Tavily
tavily_api_key = os.getenv("TAVILY_API_KEY")

# OCI
compartment_id = os.getenv("COMPARTMENT_ID")
service_endpoint = os.getenv("SERVICE_ENDPOINT")

# Langfuse
secret_key = os.getenv("LANGFUSE_SECRET_KEY")
public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
langfuse_host = os.getenv("LANGFUSE_HOST")

# Cohere
cohere_api_key = os.getenv("COHERE_API_KEY")

Langfuse のクライアントを宣言します

In [None]:
langfuse = Langfuse(
    secret_key=secret_key,
    public_key=public_key,
    host=langfuse_host
)
langfuse_handler = CallbackHandler(
    secret_key=secret_key,
    public_key=public_key,
    host=langfuse_host,
    sample_rate=0.5
)

Agent が使うツールを定義します

In [None]:
connection = oracledb.connect(
    user=un,
    password=pw,
    dsn=dsn,
    config_dir=config_dir,
    wallet_location=wallet_location,
    wallet_password=wallet_password
)

embeddings = OCIGenAIEmbeddings(
    auth_type="INSTANCE_PRINCIPAL",
    model_id="cohere.embed-multilingual-v3.0",
    service_endpoint=service_endpoint,
    compartment_id=compartment_id,
)

oracle_vs = OracleVS(
    client=connection,
    embedding_function=embeddings,
    table_name="OCHAT",
    distance_strategy=DistanceStrategy.COSINE,
    query="What is a Oracle Database"
)

使用するツール群の宣言  
アプリケーションとして実装するときは個別のクラス or ツール群をまとめて実装するのが綺麗そう？

In [None]:
@tool
def search(query: str):
    """Call to surf the web."""
    if "東京" or "とうきょう" or "Tokyo" or "tokyo" in query:
        return "東京は今日も最高気温35度越えの猛暑です。"
    return "日本は今日、全国的に晴れです。"

@tool
def vector_search(query: str) -> List[Document]:
    """Using vector search(Oracle Database 23ai)."""
    docs = oracle_vs.similarity_search(
        query=query,
        k=1
    )
    return docs

web_search_tool = TavilySearchResults()

In [None]:
tools = [
    Tool(
        name="WeatherSearch",
        func=search,
        description="天気を検索します",
    ),
    Tool(
        name="VectorSearch",
        func=vector_search,
        description="OCHaCafe固有な話題やKubernetes, IaCなどクラウドネイティブ関連話題の検索に役立ちます"
    ),
    web_search_tool,
]

tool_node = ToolNode(
    tools=tools,
    name="ochat-tools",
    tags=["ochat", "web", "vector"]
)

In [None]:
model_name = "cohere.command-r-plus"
is_stream = True

models_args = {
    "temperature": 0.3,
    "max_tokens": 1024,
    "top_p": 0.75,
    "top_k": 0,
    "frequency_penalty": 0,
    "presence_penalty": 0,
    "preamble_override": langfuse.get_prompt(
        name="ochat-preamble",
        type="text"
    ).compile()
}

定義したツール群を使用するようにモデルを宣言

In [None]:
model = ChatOCIGenAI(
    auth_type="INSTANCE_PRINCIPAL",
    service_endpoint=service_endpoint,
    compartment_id=compartment_id,
    model_id=model_name,
    is_stream=is_stream,
    model_kwargs=models_args
).bind_tools(tools=tools)

In [None]:
def should_continue(state: MessagesState) -> Literal["tools", END]:
    messages = state["messages"]
    last_message = messages[-1]
    if last_message.tool_calls:
        return "tools"
    return END

def call_model(state: MessagesState):
    messages = state["messages"]
    response = model.invoke(messages)
    return {"messages": [response]}


In [None]:
workflow = StateGraph(MessagesState)

workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)

workflow.set_entry_point("agent")

In [None]:
workflow.add_conditional_edges("agent", should_continue)

workflow.add_edge("tools", "agent")

checkpointer = MemorySaver()

In [None]:
app = workflow.compile(checkpointer=checkpointer)

app.get_graph().print_ascii()

In [None]:
session_id = str(uuid.uuid4())

final_state = app.invoke(
    input={
        "messages": [
            HumanMessage(content="OCHaCafeってなんでしょうか？また代表的なテーマには何がありますか？")
        ]
    },
    config={
        "configurable": {
            "thread_id": session_id
        },
        "callbacks": [langfuse_handler],
    },
)
result = final_state["messages"][-1].content

print(result)