In [None]:
import './../../loadenv.mjs'

# helper utilities

In [2]:
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
import { StructuredTool } from '@langchain/core/tools'
import { convertToOpenAITool } from '@langchain/core/utils/function_calling'
import { Runnable } from '@langchain/core/runnables'
import { getModel } from './../../utils.mjs'
import { ChatOpenAI } from '@langchain/openai'

async function createAgent({
    llm,
    tools,
    systemMessage
}: {
    llm: ChatOpenAI;
    tools: StructuredTool[];
    systemMessage: string;
}): Promise<Runnable> {
    const toolNames = tools.map(tool => tool.name).join(', ')
    const formattedTools = tools.map(t => convertToOpenAITool(t))

    let prompt = ChatPromptTemplate.fromMessages([
        [
            "system",
            "You are a helpful AI assistant, collaborating with other assistants." +
            " Use the provided tools to progress towards answering the question." +
            " If you are unable to fully answer, that's OK, another assistant with different tools " +
            " will help where you left off. Execute what you can to make progress." +
            // " If you or any of the other assistants have the final answer or deliverable," +
            // " prefix your response with FINAL ANSWER so the team knows to stop." +
            /// 根据豆包的功能，对提示词进行修改
            " Only if you have all ready displayed a bar chart to user, " +
            " can you prefix your response with FINAL ANSWER so the team knows to stop." +

            " You have access to the following tools: {tool_names}.\n{system_message}",
        ],
        new MessagesPlaceholder("messages"),
    ])
    prompt = await prompt.partial({
        system_message: systemMessage,
        tool_names: toolNames,
    })

    return prompt.pipe(llm.bind({ tools: formattedTools }))
}

# define state

In [3]:
import { BaseMessage } from '@langchain/core/messages'
import { Annotation } from '@langchain/langgraph'

const AgentState = Annotation.Root({
    messages: Annotation<BaseMessage[]>({
        reducer: (x, y) => x.concat(y),
    }),
    sender: Annotation<string>({
        reducer: (x, y) => y ?? x ?? 'user',
        default: () => 'user',
    }),
})

# define tools

In [4]:
import { TavilySearchResults } from '@langchain/community/tools/tavily_search'
import { chartTool } from './../../utils.mjs'

const tavilyTool = new TavilySearchResults()

# Create Graph

## define agent nodes

In [5]:
import { HumanMessage } from '@langchain/core/messages'
import type { RunnableConfig } from '@langchain/core/runnables'

async function runAgentNode(props: {
    state: typeof AgentState.State;
    agent: Runnable;
    name: string;
    config?: RunnableConfig;
}) {
    const { state, agent, name, config } = props
    let result = await agent.invoke(state, config)
    if (!result?.tool_calls || result.tool_calls.length === 0) {
        result = new HumanMessage({ ...result, name, })
    }
    return {
        messages: [result],
        sender: name,
    }
}

const llm = getModel()

const researchAgent = await createAgent({
    llm,
    tools: [tavilyTool],
    systemMessage: 'You should provide accurate data for the chart generator to use.\nLet\'s think step by step.'
})

async function researchNode(
    state: typeof AgentState.State,
    config?: RunnableConfig,
) {
    return runAgentNode({
        state,
        agent: researchAgent,
        name: 'Researcher',
        config,
    })
}

const chartAgent = await createAgent({
    llm,
    tools: [chartTool],
    systemMessage: 'Any charts you display will be visible by the user.'
})

async function chartNode(state: typeof AgentState.State) {
    return runAgentNode({
        state,
        agent: chartAgent,
        name: 'ChartGenerator',
    })
}

In [6]:
// const researchResults = await researchNode({
//     messages: [new HumanMessage('Research the US primaries in 2024')],
//     sender: 'User',
// })

// researchResults

## define tool node

In [7]:
import { ToolNode } from '@langchain/langgraph/prebuilt'

const tools = [tavilyTool, chartTool]
const toolNode = new ToolNode<typeof AgentState.State>(tools)

In [8]:
// await toolNode.invoke(researchResults)

## define edge logic

In [9]:
import { AIMessage } from '@langchain/core/messages'

function router(state: typeof AgentState.State) {
    const messages = state.messages
    const lastMessage = messages[messages.length - 1] as AIMessage
    if (lastMessage?.tool_calls && lastMessage.tool_calls.length > 0) {
        return 'call_tool'
    }
    if (typeof lastMessage.content === 'string' && lastMessage.content.startsWith('FINAL ANSWER')) {
        return 'end'
    }
    return 'continue'
}

## define the graph

In [None]:
import { END, START, StateGraph } from '@langchain/langgraph'

const workflow = new StateGraph(AgentState)
    .addNode('Researcher', researchNode)
    .addNode('ChartGenerator', chartNode)
    .addNode('call_tool', toolNode)

workflow.addConditionalEdges(
    'Researcher',
    router,
    {
        continue: 'ChartGenerator',
        call_tool: 'call_tool',
        end: END,
    },
)

workflow.addConditionalEdges(
    'ChartGenerator',
    router,
    {
        continue: 'Researcher',
        call_tool: 'call_tool',
        end: END
    }
)

workflow.addConditionalEdges(
    'call_tool',
    x => x.sender,
    {
        Researcher: 'Researcher',
        ChartGenerator: 'ChartGenerator',
    }
)

workflow.addEdge(START, 'Researcher')
const graph = workflow.compile()

In [None]:
import { printGraph } from './../../utils.mjs'
await printGraph(graph.getGraph())

# Invoke

In [None]:
const streamResults = await graph.stream(
    {
        messages: [
            new HumanMessage({
                content: 'Generate a bar chart of the US gdp over the past 3 years.'
            }),
        ],
    },
    {
        recursionLimit: 150,
    }
)

const prettifyOutput = (output: Record<string, any>) => {
    const keys = Object.keys(output)
    const firstItem = output[keys[0]]

    if ('messages' in firstItem && Array.isArray(firstItem.messages)) {
        const lastMessage = firstItem.messages[firstItem.messages.length - 1]
        console.dir({
            type: lastMessage._getType(),
            content: lastMessage.content,
            tool_calls: lastMessage.tool_calls,
        })
    }
    if ('sender' in firstItem) {
        console.log({
            sender: firstItem.sender,
        })
    }
}

for await (const output of await streamResults) {
    if (!output?.__end__) {
        prettifyOutput(output)
        console.log('----')
    }
}