!!!多了`agent->generate`、`agent->gradeDocuments`、`agent->rewrite`三条边，不知道怎么出来的。需要研究一下!!!

In [None]:
import './../../loadenv.mjs'

# Retriever

In [2]:
import { CheerioWebBaseLoader } from '@langchain/community/document_loaders/web/cheerio'
import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters'
import { MemoryVectorStore } from 'langchain/vectorstores/memory'
import { getEmbeddings } from './../../utils.mjs'

const urls = [
    'https://lilianweng.github.io/posts/2023-06-23-agent/',
    'https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/',
    'https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/'
]

const docs = await Promise.all(
    urls.map(url => new CheerioWebBaseLoader(url).load()),
)
const docsList = docs.flat()

const textSplitter = new RecursiveCharacterTextSplitter({
    chunkSize: 500,
    chunkOverlap: 50,
})
const docSplits = await textSplitter.splitDocuments(docsList)

const vectorStore = await MemoryVectorStore.fromDocuments(
    docSplits,
    getEmbeddings(),
)

const retriever = vectorStore.asRetriever()


In [None]:
docSplits

# Agent State

In [4]:
import { Annotation } from '@langchain/langgraph'
import { BaseMessage } from '@langchain/core/messages'

const GraphState = Annotation.Root({
  messages: Annotation<BaseMessage[]>({
    reducer: (x, y) => x.concat(y),
    default: () => [],
  })
})

In [5]:
import { createRetrieverTool } from 'langchain/tools/retriever'
import { ToolNode } from '@langchain/langgraph/prebuilt'

const tool = createRetrieverTool(
  retriever,
  {
    name: 'retrieve_blog_posts',
    description: 'Search and return information about Lilian Weng blog posts on LLM agents, prompt engineering, and adversarial attacks on LLMs.'

  }
)

const tools = [tool]

const toolNode = new ToolNode<typeof GraphState.State>(tools)

# Nodes and Edges

In [6]:
import { END } from '@langchain/langgraph'
import { pull } from 'langchain/hub'
import { z } from 'zod'
import { ChatPromptTemplate } from '@langchain/core/prompts'
import { getModel } from './../../utils.mjs'
import { AIMessage } from '@langchain/core/messages'

function shouldRetrieve(state: typeof GraphState.State): string {
  const { messages } = state
  console.log('---DECIDE TO RETRIEVE---')
  const lastMessage = messages[messages.length - 1]

  if ('tool_calls' in lastMessage && Array.isArray(lastMessage.tool_calls) && lastMessage.tool_calls.length) {
    console.log('---DECISION: RETRIEVE---')
    return 'retrieve'
  }
  return END
}

async function gradeDocuments(state: typeof GraphState.State): Promise<Partial<typeof GraphState.State>> {
  console.log('---GET RELEVANCE---')

  const { messages } = state
  const tool = {
    name: 'give_relevance_score',
    description: 'Give a relevance score to the retrieved documents.',
    schema: z.object({
      binaryScore: z.string().describe("Relevance score 'yes' or 'no'")
    })
  }

  const prompt = ChatPromptTemplate.fromTemplate(
    `You are a grader assessing relevance of retrieved docs to a user question.
Here are the retrieved docs:
\n ------- \n
{context} 
\n ------- \n
Here is the user question: {question}
If the content of the docs are relevant to the users question, score them as relevant.
Give a binary score 'yes' or 'no' score to indicate whether the docs are relevant to the question.
Yes: The docs are relevant to the question.
No: The docs are not relevant to the question.`,
  )

  const model = getModel({
    temperature: 0,
  }).bindTools([tool], {
    tool_choice: tool.name,
  })

  const chain = prompt.pipe(model)

  const lastMessage = messages[messages.length - 1]
  const score = await chain.invoke({
    question: messages[0].content as string,
    context: lastMessage.content as string,
  })

  return {
    messages: [score],
  }
}

function checkRelevance(state: typeof GraphState.State): string {
  console.log('---CHECK RELEVANCE---')
  const { messages } = state
  const lastMessage = messages[messages.length - 1]
  if (!('tool_calls' in lastMessage)) {
    throw new Error("The 'checkRelevance' node requires the most recent message to contain tool calls.")
  }
  const toolCalls = (lastMessage as AIMessage).tool_calls
  if (!toolCalls || !toolCalls.length) {
    throw new Error('Last message was not a function message')
  }
  if (toolCalls[0].args.binaryScore === 'yes') {
    console.log('---DECISION: DOCS RELEVANT---')
    return 'yes'
  }
  console.log('---DECISION: DOCS NOT RELEVANT---')
  return 'no'
}

async function agent(state: typeof GraphState.State): Promise<Partial<typeof GraphState.State>> {
  console.log('---CALL AGENT---')
  const { messages } = state
  const filteredMessages = messages.filter((message) => {
    if ('tool_calls' in message && Array.isArray(message.tool_calls) && message.tool_calls.length > 0) {
      return message.tool_calls[0].name !== 'give_relevance_score'
    }
    return true
  })
  console.log(filteredMessages)
  const model = getModel({
    temperature: 0,
    streaming: true
  }).bindTools(tools)
  const response = await model.invoke(filteredMessages)
  return {
    messages: [response],
  }
}

async function rewrite(state: typeof GraphState.State): Promise<Partial<typeof GraphState.State>> {
  console.log('---TRANSFORM QUERY---')
  const { messages } = state
  const question = messages[0].content as string
  const prompt = ChatPromptTemplate.fromTemplate(
    `Look at the input and try to reason about the underlying semantic intent / meaning. \n 
  Here is the initial question:
  \n ------- \n
  {question} 
  \n ------- \n
  Formulate an improved question:`,
  )
  const model = getModel({
    temperature: 0,
    streaming: true,
  })
  const response = await prompt.pipe(model).invoke({ question })
  return {
    messages: [response],
  }
}

async function generate(state: typeof GraphState.State): Promise<Partial<typeof GraphState.State>> {
  console.log('---GENERATE---')
  const { messages } = state
  const question = messages[0].content as string
  const lastToolMessage = messages.slice().reverse().find(msg => msg._getType() === 'tool')
  if (!lastToolMessage) {
    throw new Error('No tool message found in the conversation history')
  }
  const docs = lastToolMessage.content as string
  const prompt = await pull<ChatPromptTemplate>('rlm/rag-prompt')
  const llm = getModel({
    temperature: 0,
    streaming: true,
  })

  const ragChain = prompt.pipe(llm)

  const response = await ragChain.invoke({
    context: docs,
    question
  })

  return {
    messages: [response],
  }
}

# Graph

In [13]:
import { StateGraph } from '@langchain/langgraph'

const workflow = new StateGraph(GraphState)
  .addNode('agent', agent)
  .addNode('retrieve', toolNode)
  .addNode('gradeDocuments', gradeDocuments)
  .addNode('rewrite', rewrite)
  .addNode('generate', generate)

In [None]:
import { START } from '@langchain/langgraph'

workflow.addEdge(START, 'agent')

workflow.addConditionalEdges(
  'agent',
  shouldRetrieve,
)

workflow.addEdge('retrieve', 'gradeDocuments')

workflow.addConditionalEdges(
  'gradeDocuments',
  checkRelevance,
  {
    yes: 'generate',
    no: 'rewrite',
  }
)

workflow.addEdge('generate', END)
workflow.addEdge('rewrite', 'agent')

const app = workflow.compile()

In [None]:
// const graph = app.getGraph()
// graph.edges

In [None]:
import { printGraph } from './../../utils.mjs'
await printGraph(app.getGraph())

In [None]:
import { HumanMessage } from '@langchain/core/messages'

const inputs = {
    messages: [
        new HumanMessage(
            "What are the types of agent memory based on Lilian Weng's blog post?",
        )
    ]
}
let finalState;
for await (const output of await app.stream(inputs)) {
    for (const [key, value] of Object.entries(output)) {
        const lastMsg = output[key].messages[output[key].messages.length - 1]
        console.log(`Output from node: '${key}'`)
        console.dir({
            type: lastMsg._getType(),
            content: lastMsg.content,
            tool_calls: lastMsg.tool_calls,
        }, { depth: null })
        console.log('---\n')
        finalState = value
    }
}

console.log(JSON.stringify(finalState, null, 2))