---
sidebar_position: 4
---

# 如何创建自定义聊天模型类

```{=mdx}
:::info 前置条件

本指南假定您熟悉以下概念：

- [聊天模型](/docs/concepts/chat_models)

:::
```

本笔记本介绍了如何创建自定义聊天模型包装器，以防您希望使用自己的聊天模型或与LangChain直接支持的不同的包装器。

在扩展 [`SimpleChatModel` 类](https://api.js.langchain.com/classes/langchain_core.language_models_chat_models.SimpleChatModel.html) 后，聊天模型需要实现以下几项必要内容：

- 一个 `_call` 方法，该方法接收消息列表和调用选项（包括 `stop` 序列等内容），并返回一个字符串。

- 一个 `_llmType` 方法，该方法返回一个字符串。仅用于日志记录。

您还可以实现以下可选方法：

- 一个 `_streamResponseChunks` 方法，该方法返回一个 `AsyncGenerator` 并逐个生成 [`ChatGenerationChunks`](https://api.js.langchain.com/classes/langchain_core.outputs.ChatGenerationChunk.html)。这允许LLM支持流式输出。

我们来实现一个非常简单的自定义聊天模型，它只是回显输入的前 `n` 个字符。

In [2]:
import {
  SimpleChatModel,
  type BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { AIMessageChunk, type BaseMessage } from "@langchain/core/messages";
import { ChatGenerationChunk } from "@langchain/core/outputs";

interface CustomChatModelInput extends BaseChatModelParams {
  n: number;
}

class CustomChatModel extends SimpleChatModel {
  n: number;

  constructor(fields: CustomChatModelInput) {
    super(fields);
    this.n = fields.n;
  }

  _llmType() {
    return "custom";
  }

  async _call(
    messages: BaseMessage[],
    options: this["ParsedCallOptions"],
    runManager?: CallbackManagerForLLMRun
  ): Promise<string> {
    if (!messages.length) {
      throw new Error("No messages provided.");
    }
    // Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
    // await subRunnable.invoke(params, runManager?.getChild());
    if (typeof messages[0].content !== "string") {
      throw new Error("Multimodal messages are not supported.");
    }
    return messages[0].content.slice(0, this.n);
  }

  async *_streamResponseChunks(
    messages: BaseMessage[],
    options: this["ParsedCallOptions"],
    runManager?: CallbackManagerForLLMRun
  ): AsyncGenerator<ChatGenerationChunk> {
    if (!messages.length) {
      throw new Error("No messages provided.");
    }
    if (typeof messages[0].content !== "string") {
      throw new Error("Multimodal messages are not supported.");
    }
    // Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
    // await subRunnable.invoke(params, runManager?.getChild());
    for (const letter of messages[0].content.slice(0, this.n)) {
      yield new ChatGenerationChunk({
        message: new AIMessageChunk({
          content: letter,
        }),
        text: letter,
      });
      // Trigger the appropriate callback for new chunks
      await runManager?.handleLLMNewToken(letter);
    }
  }
}

我们现在可以像使用其他聊天模型一样使用它：

In [3]:
const chatModel = new CustomChatModel({ n: 4 });

await chatModel.invoke([["human", "I am an LLM"]]);

AIMessage {
  lc_serializable: true,
  lc_kwargs: {
    content: 'I am',
    tool_calls: [],
    invalid_tool_calls: [],
    additional_kwargs: {},
    response_metadata: {}
  },
  lc_namespace: [ 'langchain_core', 'messages' ],
  content: 'I am',
  name: undefined,
  additional_kwargs: {},
  response_metadata: {},
  id: undefined,
  tool_calls: [],
  invalid_tool_calls: [],
  usage_metadata: undefined
}


支持流式传输：

In [4]:
const stream = await chatModel.stream([["human", "I am an LLM"]]);

for await (const chunk of stream) {
  console.log(chunk);
}

AIMessageChunk {
  lc_serializable: true,
  lc_kwargs: {
    content: 'I',
    tool_calls: [],
    invalid_tool_calls: [],
    tool_call_chunks: [],
    additional_kwargs: {},
    response_metadata: {}
  },
  lc_namespace: [ 'langchain_core', 'messages' ],
  content: 'I',
  name: undefined,
  additional_kwargs: {},
  response_metadata: {},
  id: undefined,
  tool_calls: [],
  invalid_tool_calls: [],
  tool_call_chunks: [],
  usage_metadata: undefined
}
AIMessageChunk {
  lc_serializable: true,
  lc_kwargs: {
    content: ' ',
    tool_calls: [],
    invalid_tool_calls: [],
    tool_call_chunks: [],
    additional_kwargs: {},
    response_metadata: {}
  },
  lc_namespace: [ 'langchain_core', 'messages' ],
  content: ' ',
  name: undefined,
  additional_kwargs: {},
  response_metadata: {},
  id: undefined,
  tool_calls: [],
  invalid_tool_calls: [],
  tool_call_chunks: [],
  usage_metadata: undefined
}
AIMessageChunk {
  lc_serializable: true,
  lc_kwargs: {
    content: 'a',
    tool_ca

## 更丰富的输出

如果你想利用 LangChain 的回调系统来实现诸如令牌追踪之类的功能，你可以扩展 [`BaseChatModel`](https://api.js.langchain.com/classes/langchain_core.language_models_chat_models.BaseChatModel.html) 类并实现更低层级的
`_generate` 方法。它同样以一个 `BaseMessage` 列表作为输入，但需要你构造并返回一个允许添加额外元数据的 `ChatGeneration` 对象。
下面是一个示例：

In [5]:
import { AIMessage, BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";
import {
  BaseChatModel,
  BaseChatModelCallOptions,
  BaseChatModelParams,
} from "@langchain/core/language_models/chat_models";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";

interface AdvancedCustomChatModelOptions
  extends BaseChatModelCallOptions {}

interface AdvancedCustomChatModelParams extends BaseChatModelParams {
  n: number;
}

class AdvancedCustomChatModel extends BaseChatModel<AdvancedCustomChatModelOptions> {
  n: number;

  static lc_name(): string {
    return "AdvancedCustomChatModel";
  }

  constructor(fields: AdvancedCustomChatModelParams) {
    super(fields);
    this.n = fields.n;
  }

  async _generate(
    messages: BaseMessage[],
    options: this["ParsedCallOptions"],
    runManager?: CallbackManagerForLLMRun
  ): Promise<ChatResult> {
    if (!messages.length) {
      throw new Error("No messages provided.");
    }
    if (typeof messages[0].content !== "string") {
      throw new Error("Multimodal messages are not supported.");
    }
    // Pass `runManager?.getChild()` when invoking internal runnables to enable tracing
    // await subRunnable.invoke(params, runManager?.getChild());
    const content = messages[0].content.slice(0, this.n);
    const tokenUsage = {
      usedTokens: this.n,
    };
    return {
      generations: [{ message: new AIMessage({ content }), text: content }],
      llmOutput: { tokenUsage },
    };
  }

  _llmType(): string {
    return "advanced_custom_chat_model";
  }
}

这会将回调事件和 `streamEvents` 方法中的额外返回信息传递：

In [13]:
const chatModel = new AdvancedCustomChatModel({ n: 4 });

const eventStream = await chatModel.streamEvents([["human", "I am an LLM"]], {
  version: "v2",
});

for await (const event of eventStream) {
  if (event.event === "on_chat_model_end") {
    console.log(JSON.stringify(event, null, 2));
  }
}

{
  "event": "on_chat_model_end",
  "data": {
    "output": {
      "lc": 1,
      "type": "constructor",
      "id": [
        "langchain_core",
        "messages",
        "AIMessage"
      ],
      "kwargs": {
        "content": "I am",
        "tool_calls": [],
        "invalid_tool_calls": [],
        "additional_kwargs": {},
        "response_metadata": {
          "tokenUsage": {
            "usedTokens": 4
          }
        }
      }
    }
  },
  "run_id": "11dbdef6-1b91-407e-a497-1a1ce2974788",
  "name": "AdvancedCustomChatModel",
  "tags": [],
  "metadata": {
    "ls_model_type": "chat"
  }
}


## 追踪（高级）

如果您正在实现一个自定义的聊天模型，并希望将其与 [LangSmith](https://smith.langchain.com/) 等追踪服务一起使用，
您可以通过在模型上实现 `invocationParams()` 方法，自动记录某次调用所使用的参数。

此方法完全是可选的，但它返回的任何内容都将作为元数据记录到追踪中。

以下是一种可能的使用模式：

In [10]:
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { BaseChatModel, type BaseChatModelCallOptions, type BaseChatModelParams } from "@langchain/core/language_models/chat_models";
import { BaseMessage } from "@langchain/core/messages";
import { ChatResult } from "@langchain/core/outputs";

interface CustomChatModelOptions extends BaseChatModelCallOptions {
  // Some required or optional inner args
  tools: Record<string, any>[];
}

interface CustomChatModelParams extends BaseChatModelParams {
  temperature: number;
  n: number;
}

class CustomChatModel extends BaseChatModel<CustomChatModelOptions> {
  temperature: number;

  n: number;

  static lc_name(): string {
    return "CustomChatModel";
  }

  constructor(fields: CustomChatModelParams) {
    super(fields);
    this.temperature = fields.temperature;
    this.n = fields.n;
  }

  // Anything returned in this method will be logged as metadata in the trace.
  // It is common to pass it any options used to invoke the function.
  invocationParams(options?: this["ParsedCallOptions"]) {
    return {
      tools: options?.tools,
      n: this.n,
    };
  }

  async _generate(
    messages: BaseMessage[],
    options: this["ParsedCallOptions"],
    runManager?: CallbackManagerForLLMRun
  ): Promise<ChatResult> {
    if (!messages.length) {
      throw new Error("No messages provided.");
    }
    if (typeof messages[0].content !== "string") {
      throw new Error("Multimodal messages are not supported.");
    }
    const additionalParams = this.invocationParams(options);
    const content = await someAPIRequest(messages, additionalParams);
    return {
      generations: [{ message: new AIMessage({ content }), text: content }],
      llmOutput: {},
    };
  }

  _llmType(): string {
    return "advanced_custom_chat_model";
  }
}