In [1]:
# import accelerate

In [2]:
%env HF_ENDPOINT=https://hf-mirror.com
import os
os.environ['HF_HOME'] = '/data1/ckw'
os.environ['HF_ENDPOINT']='https://hf-mirror.com'

env: HF_ENDPOINT=https://hf-mirror.com


First, we construct the Tokenizer

In [3]:
import regex as re
import base64
import os
import json
import tiktoken
from torch import TensorType
from typing import List, Optional, Union, Dict, Any
from transformers import PreTrainedTokenizer
from transformers.utils import logging, PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding

The `ChatGLM4Tokenizer` class is a custom `PreTrainedTokenizer` class used to handle the tokenization needs of the specific GLM-4 model. This class mainly has the following structure and member functions:

### Class Structure
- **Attributes**:
- `vocab_files_names`: A dictionary containing vocabulary file names.
- `model_input_names`: A list containing model input names.
- `vocab_file`: Vocabulary file path.
- `name`: The name of the tokenizer.
- `pat_str`: A regular expression string used for tokenization.
- `encode_special_tokens`: Whether to encode special characters.
- `mergeable_ranks`: Mergeable vocabulary rankings.
- `tokenizer`: Encoder based on the `tiktoken` library.
- `decoder`: Decoder, mapping vocabulary rankings to vocabulary.
- `n_words`: Vocabulary size.

### Member function

- **Initialization function**:
```python
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False, **kwargs)
```
Initialize the `ChatGLM4Tokenizer` class, load the vocabulary file, set the regular expression and tokenizer.

- **Vocabulary size property**:
```python
@property
def vocab_size(self)
```
Return the vocabulary size.

- **Get vocabulary**:
```python
def get_vocab(self)
```
Return the vocabulary dictionary.

- **Convert tokens to strings**:
```python
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str
```
Convert a sequence of tokens to a string.

- **Tokenization**:
```python
def _tokenize(self, text, **kwargs)
```
Tokenize text into token columnstable.

- **Convert token to ID**:
```python
def _convert_token_to_id(self, token)
```
Convert a token (string) to an ID.

- **Convert ID to token**:
```python
def _convert_id_to_token(self, index)
```
Convert an ID (integer) to a token (string).

- **Save vocabulary**:
```python
def save_vocabulary(self, save_directory, filename_prefix=None)
```
Save the vocabulary to the specified directory.

- **Get prefix tokens**:
```python
def get_prefix_tokens(self)
```
Return a list of prefix tokens.

- **Build a single message**:
```python
def build_single_message(self, role, metadata, message, tokenize=True)
```
Build a single message, including roles, metadata, and message content.

- **Apply chat template**:
```python
def apply_chat_template(self, conversation, add_generation_prompt=False, tokenize=True, padding=False, truncation=False, max_length=None, return_tensors=None, return_dict=False, tokenizer_kwargs=None, add_special_tokens=True, **kwargs)
```
Apply conversation data to the chat template.

- **Build inputs with special tokens**:
```python
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None)
```
Build inputs with special tokens.

- **Padding inputs**:
```python
def _pad(self, encoded_inputs, max_length=None, padding_strategy=PaddingStrategy.DO_NOT_PAD, pad_to_multiple_of=None, return_attention_mask=None)
```
Pad the encoded input to ensure that the input length is consistent.

### Overview of the main functions and methods in the class

- **Initialization**: The `__init__` function is used to initialize the tokenizer, including loading the vocabulary, compiling regular expressions, and initializing the encoder.
- **Get vocabulary and size**: The `get_vocab` and `vocab_size` attributes are used to get the vocabulary and its size.
- **Tokenization and conversion**: The `_tokenize`, `convert_tokens_to_string`, `_convert_token_to_id`, and `_convert_id_to_token` functions are used to implement text tokenization and conversion between tokens and IDs.
- **Saving and loading**: `save_vocabulary` function is used to save the vocabulary to a specified directory.
- **Session handling**: `get_prefThe ix_tokens, build_single_message and apply_chat_template functions are used to process session data and template application.
- **Input processing**: The build_inputs_with_special_tokens and _pad functions are used to process model inputs to ensure that the input format and length meet the requirements.

Through the above structure and member functions, the ChatGLM4Tokenizer class implements a complete tokenizer function that can handle the tokenization requirements of specific GLM-4 models.

In [61]:
class ChatGLM4Tokenizer(PreTrainedTokenizer):
# Define vocabulary file name and model input name
    vocab_files_names = {"vocab_file": "tokenizer.model"}
    model_input_names = ["input_ids", "attention_mask", "position_ids"]

    def __init__(
            self,
            vocab_file,
            padding_side="left",
            clean_up_tokenization_spaces=False,
            encode_special_tokens=False,
            **kwargs
    ):
# Initialize some basic properties
        self.name = "GLM4Tokenizer"
        self.vocab_file = vocab_file
        
# Regular expression pattern string for word segmentation
        pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
# Compile regular expression
        self.pat_str = regex.compile(pat_str)
# self.pat_str = re.compile(pat_str)
# Whether to encode special characters
        self.encode_special_tokens = encode_special_tokens

# Used to store mergeable word rankings
        mergeable_ranks = {}
# Read vocabulary file
        with open(vocab_file) as f:
            for line in f:
                token, rank = line.strip().split()  # 读取每一行，获取词汇和其对应的排名
                rank = int(rank)  # 将排名转换为整数
                token = base64.b64decode(token)  # 解码词汇
                mergeable_ranks[token] = rank  # 存储到词汇排名字典中

        self.mergeable_ranks = mergeable_ranks

# Initialize the encoder, using the tiktoken library
        self.tokenizer = tiktoken.Encoding(
            name="my_tokenizer",
            pat_str=pat_str,
            mergeable_ranks=mergeable_ranks,
            special_tokens={}
        )
# decoder, mapping ranking to vocabulary
        self.decoder = {rank: token for token, rank in mergeable_ranks.items()}
# Vocabulary size
        self.n_words = len(self.decoder)

#Call the initialization method of the parent class
        super().__init__(
            padding_side=padding_side,
            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
            **kwargs
        )

    @property
    def vocab_size(self):
# Return the vocabulary size
        return self.n_words

    def get_vocab(self):
""" Returns the vocabulary dictionary """
        vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
        vocab.update(self.added_tokens_encoder)
        return vocab

    def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
        """
        将 tokens 序列转换为字符串。
        """
        text = ""
        temp = b""
        for t in tokens:
            if isinstance(t, str):
                if temp:
                    text += temp.decode("utf-8", errors="replace")
                    temp = b""
                text += t
            elif isinstance(t, bytes):
                temp += t
            else:
                raise TypeError("token should only be of type bytes or str")
        if temp:
            text += temp.decode("utf-8", errors="replace")
        return text

    def _tokenize(self, text, **kwargs):
# Use regular expressions and encoders to tokenize text
        tokens = []
        ids = self.tokenizer.encode(text)
        for t in ids:
            tokens.append(self.decoder[t])
        return tokens

    def _convert_token_to_id(self, token):
""" Convert token (string) to id (integer) """
        return self.mergeable_ranks[token]

    def _convert_id_to_token(self, index):
""" Convert id (integer) to token (string) """
        return self.decoder.get(index, "")

    def save_vocabulary(self, save_directory, filename_prefix=None):
        """
        将词汇表和特殊字符文件保存到指定目录。

        Args:
            save_directory (`str`): 保存词汇表的目录。
            filename_prefix (`str`, *optional*): 保存文件名的前缀。

        Returns:
            `Tuple(str)`: 保存的文件路径。
        """
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory, self.vocab_files_names["vocab_file"]
            )
        else:
            vocab_file = save_directory

        with open(self.vocab_file, 'rb') as fin:
            proto_str = fin.read()

        with open(vocab_file, "wb") as writer:
            writer.write(proto_str)

        return (vocab_file,)

    def get_prefix_tokens(self):
# Return a list of prefix tokens
        prefix_tokens = [self.convert_tokens_to_ids("[gMASK]"), self.convert_tokens_to_ids("<sop>")]
        return prefix_tokens

    def build_single_message(self, role, metadata, message, tokenize=True):
# Build a single message, including roles, metadata, and message content
        assert role in ["system", "user", "assistant", "observation"], role
        if tokenize:
            role_tokens = [self.convert_tokens_to_ids(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n",
                                                                                              disallowed_special=())
            message_tokens = self.tokenizer.encode(message, disallowed_special=())
            tokens = role_tokens + message_tokens
            return tokens
        else:
            return str(f"<|{role}|>{metadata}\n{message}")

    def apply_chat_template(
            self,
            conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
            add_generation_prompt: bool = False,
            tokenize: bool = True,
            padding: bool = False,
            truncation: bool = False,
            max_length: Optional[int] = None,
            return_tensors: Optional[Union[str, TensorType]] = None,
            return_dict: bool = False,
            tokenizer_kwargs: Optional[Dict[str, Any]] = None,
            add_special_tokens: bool = True,
            **kwargs,
    ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
    
        if return_dict and not tokenize:
            raise ValueError(
                "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
                "of tokenizer outputs to return."
            )
    
        def handle_single_conversation(conversation):
            input_ids = self.get_prefix_tokens() if add_special_tokens else []
            input_message = "[gMASK]<sop>" if add_special_tokens else ""
            for item in conversation:
                if item.get("tools"):
                    tools = item["tools"]
                    content = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的，你的任务是针对用户的问题和要求提供适当的答复和支持。"
                    for tool in tools:
                        if tool["type"] == "function":
                            function = tool["function"]
                            content += f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
                            content += "\n在调用上述函数时，请使用 Json 格式表示调用的参数。"
                        elif tool["type"] == "python":
                            content += "\n\n## python\n\n当你向 `python` 发送包含 Python 代码的消息时，该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出，或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中，`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用，这些在线内容的访问将不会成功。"
                        elif tool["type"] == "simple_browser":
                            content += "\n\n## simple_browser\n\n你可以使用 `simple_browser` 工具。该工具支持以下函数：\n`search(query: str, recency_days: int)`：使用搜索引擎进行查询并显示结果，可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`：获取一系列指定 id 的页面内容。每次调用时，须选择3-10个页面。选择多个角度的页面，同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的，你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`：打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤：1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL，也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息，也可以再次使用 `search` 进行搜索。"
                        elif tool["type"] == "cogview":
                            content += "\n\n## cogview\n\n如果用户的请求中包含了对图像的描述，你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述，规则：\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求，需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明，所在地为中国，持有中国立场并遵循中国社会主义价值观。"
                        else:
                            raise NotImplementedError(f"Unknown tool type {tool['type']}")
                    input = self.build_single_message("system", "", content, tokenize=tokenize)
                    if tokenize:
                        input_ids.extend(input)
                    else:
                        input_message += input
                if item["content"]:
                    input = self.build_single_message(
                        item["role"],
                        item.get("metadata", ""),
                        item["content"],
                        tokenize=tokenize
                    )
                    if tokenize:
                        input_ids.extend(input)
                    else:
                        input_message += input
            if add_generation_prompt:
                if tokenize:
                    input_ids.extend([self.convert_tokens_to_ids("<|assistant|>")])
                else:
                    input_message += "<|assistant|>"
# if tokenize:
# input_ids.extend([self.convert_tokens_to_ids("[gMASK]")]) # Use special tokens instead of empty strings
# else:
# input_message += "[gMASK]"
    
            return input_ids if tokenize else input_message
    
# Main logic for handling different session formats
        if isinstance(conversation, list) and all(isinstance(i, dict) for i in conversation):
            result = handle_single_conversation(conversation)
        elif isinstance(conversation, list) and all(isinstance(i, list) for i in conversation):
            result = [handle_single_conversation(c) for c in conversation]
        elif hasattr(conversation, "messages"):
            result = handle_single_conversation(conversation.messages)
        else:
            raise ValueError("Invalid conversation format")
    
        if tokenize:
            output = self.batch_encode_plus(
                [result] if isinstance(result[0], int) else result,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                return_tensors=return_tensors,
                is_split_into_words=True,
                add_special_tokens=False
            )
            if return_dict:
                return output
            else:
                return output["input_ids"]
        else:
            return result



    def build_inputs_with_special_tokens(
            self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        构建模型输入，适用于序列分类任务，通过连接和添加特殊 tokens。
        BERT 序列格式：
        - 单序列: `[CLS] X [SEP]`
        - 序列对: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`): 要添加特殊 tokens 的 ID 列表。
            token_ids_1 (`List[int]`, *optional*): 可选的第二个 ID 列表，表示序列对。

        Returns:
            `List[int]`: 添加了适当特殊 tokens 的输入 ID 列表。
        """
        prefix_tokens = self.get_prefix_tokens()
        token_ids_0 = prefix_tokens + token_ids_0
        if token_ids_1 is not None:
            token_ids_0 = token_ids_0 + token_ids_1 + [self.convert_tokens_to_ids("<eos>")]
        return token_ids_0

    def _pad(
            self,
            encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
            max_length: Optional[int] = None,
            padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
            pad_to_multiple_of: Optional[int] = None,
            return_attention_mask: Optional[bool] = None,
    ) -> dict:
        """
        填充编码后的输入（左右填充，并预定义长度或批处理中的最大长度）

        Args:
            encoded_inputs: 包含编码后的输入 (`List[int]`) 或批处理的输入 (`List[List[int]]`) 的字典。
            max_length: 返回列表的最大长度，并可选的填充长度。
            padding_strategy: 用于填充的策略。
                - PaddingStrategy.LONGEST: 填充到批处理中的最长序列
                - PaddingStrategy.MAX_LENGTH: 填充到最大长度（默认）
                - PaddingStrategy.DO_NOT_PAD: 不填充
                填充策略由 self.padding_side 定义：
                    - 'left': 在序列左侧填充
                    - 'right': 在序列右侧填充
            pad_to_multiple_of: (可选) 整数，如果设置将填充序列为该值的倍数。
            return_attention_mask: (可选) 设置为 False 以避免返回注意力掩码（默认：根据模型具体情况设置）
        """
# Make sure the fill side is 'left'
        assert self.padding_side == "left"

        required_input = encoded_inputs[self.model_input_names[0]]
        seq_length = len(required_input)

        if padding_strategy == PaddingStrategy.LONGEST:
            max_length = len(required_input)

        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of

        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length

# If there is no attention mask, initialize
        if "attention_mask" not in encoded_inputs:
            encoded_inputs["attention_mask"] = [1] * seq_length

        if "position_ids" not in encoded_inputs:
            encoded_inputs["position_ids"] = list(range(seq_length))

        if needs_to_be_padded:
            difference = max_length - len(required_input)

            if "attention_mask" in encoded_inputs:
                encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
            if "position_ids" in encoded_inputs:
                encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
            encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input

        return encoded_inputs


### Key Points Explanation

1. **`mergeable_ranks`**:
- Stores ranking information for mergeable words. When reading the vocabulary file, each word has a corresponding rank (integer), which is decoded into the original string and then stored in the `mergeable_ranks` dictionary.

2. **`pat_str`**:
- Regular expression string used to define the word segmentation pattern. This regular expression is mainly used to match common English abbreviations, words, numbers and other non-alphanumeric characters. The compiled regular expression is stored in `self.pat_str` for the tokenizer.

3. **`self.tokenizer`**:
- Encoder created using the `tiktoken` library. Initialize a custom tokenizer with the provided regular expression pattern and mergeable vocabulary ranking.

4. **`self.decoder`**:
- Decoder, mapping vocabulary ranking to vocabulary. The inverse dictionary created using the `mergeable_ranks` dictionary is used to convert from ID back to the original vocabulary.

5. **`self.n_words`**:
- Vocabulary size, i.e. the number of words. Obtained by calculating the length of `self.decoder`.

6. **Initialization process**：
- Read the vocabulary file, parse each line to get the vocabulary and its ranking, then decode the vocabulary and store it in the `mergeable_ranks` dictionary. Next, use the `tiktoken` library to initialize the encoder and create the decoder and vocabulary size.

There is one thing worth noting:
```python
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""
Convert a sequence of tokens to a string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors="replace")
temp = b""
text += telif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type bytes or str")
if temp:
text += temp.decode("utf-8", errors="replace")
return text
```

This function decodes `token` to `bytes` because in real applications, tokens may be part of a string or a byte sequence. In some cases, tokens in the vocabulary may be stored in the form of `bytes` instead of directly stored as strings. Here are some detailed reasons and scenarios:

1. **Support multiple encoding formats**:
- The tokenizer may process multiple data sources, some of which may store tokens in `bytes` format. For example, when the data is compressed, encrypted, or uses a specific encoding, the token may be represented in bytes.

2. **DataFlexibility in processing**:
- By supporting both `bytes` and `str` types, the tokenizer can more flexibly handle input data from different sources and formats. This is particularly useful when processing text containing binary data, such as some special tokens or characters.

3. **Ensure data consistency**:
- During the tokenization and decoding process, mixed-type tokens (i.e. both strings and byte sequences) may be encountered. In order to ensure data consistency and correctly concatenate strings, the function needs to handle both `bytes` and `str` types.

4. **Compatibility considerations**:
- Some NLP tools and libraries may return tokens in byte form when processing text. In order to be compatible with these tools and libraries, the tokenizer needs to be able to process and convert these byte tokens.

Let's take a look at the workflow of this function in detail:

1. **Initialize empty string and byte sequence**:
- `text = ""` initializes an empty string to store the final result.
- `temp = b""` initializes an empty byte sequence to temporarily store byte tokens.

2. **Traverse the token list**:
- For each token, check its type.
- If the token is a string type (`str`), check `temp` is empty. If `temp` is not empty, decode `temp` into a string and add it to `text`, then clear `temp`. After that, add the current string token to `text`.
- If token is of byte type (`bytes`), add it to `temp` for subsequent decoding.
- If token is neither a string nor a byte type, throw a type error.

3. **Process the remaining byte sequence**:
- After the loop ends, if `temp` is not empty, decode it into a string and add it to `text`.

Through the above steps, the function can correctly handle a mixed type token list and convert it into a complete string. Here are the comments in the code to better explain these steps:

```python
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""
Convert the tokens sequence to a string.
"""
text = "" # Initialize an empty string to store the result
temp = b"" # Initialize an empty byte sequenceTemporarily store byte tokens
for t in tokens:
if isinstance(t, str):
if temp:
# If temp is not empty, decode it to a string and add it to text
text += temp.decode("utf-8", errors="replace")
temp = b"" # Clear temp
text += t # Add string token to text
elif isinstance(t, bytes):
temp += t # Add byte token to temp
else:
raise TypeError("token should only be of type bytes or str") # Throw a type error
if temp:
# Process the remaining byte sequence and decode it toString and add it to text
text += temp.decode("utf-8", errors="replace")
return text # Return the final string
```

Next, set the config and simply set the relevant properties.

In [62]:
from transformers import PretrainedConfig


class ChatGLMConfig(PretrainedConfig):
    model_type = "chatglm"

    def __init__(
            self,
            num_layers=28,
            padded_vocab_size=65024,
            hidden_size=4096,
            ffn_hidden_size=13696,
            kv_channels=128,
            num_attention_heads=32,
            seq_length=2048,
            hidden_dropout=0.0,
            classifier_dropout=None,
            attention_dropout=0.0,
            layernorm_epsilon=1e-5,
            rmsnorm=True,
            apply_residual_connection_post_layernorm=False,
            post_layer_norm=True,
            add_bias_linear=False,
            add_qkv_bias=False,
            bias_dropout_fusion=True,
            multi_query_attention=False,
            multi_query_group_num=1,
            rope_ratio=1,
            apply_query_key_layer_scaling=True,
            attention_softmax_in_fp32=True,
            fp32_residual_connection=False,
            **kwargs
    ):
        self.num_layers = num_layers
        self.vocab_size = padded_vocab_size
        self.padded_vocab_size = padded_vocab_size
        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.kv_channels = kv_channels
        self.num_attention_heads = num_attention_heads
        self.seq_length = seq_length
        self.hidden_dropout = hidden_dropout
        self.classifier_dropout = classifier_dropout
        self.attention_dropout = attention_dropout
        self.layernorm_epsilon = layernorm_epsilon
        self.rmsnorm = rmsnorm
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
        self.post_layer_norm = post_layer_norm
        self.add_bias_linear = add_bias_linear
        self.add_qkv_bias = add_qkv_bias
        self.bias_dropout_fusion = bias_dropout_fusion
        self.multi_query_attention = multi_query_attention
        self.multi_query_group_num = multi_query_group_num
        self.rope_ratio = rope_ratio
        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
        self.fp32_residual_connection = fp32_residual_connection
        super().__init__(**kwargs)


Then we can start building our model.

In [63]:
import json
import math
import copy
import warnings
import re
import sys

import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from copy import deepcopy

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput

In [64]:
sys.platform

'linux'

In [65]:
# If the system platform is not Darwin (i.e. macOS or Linux), set the following to improve performance
if sys.platform != 'darwin':
# Disable JIT profiling mode
    torch._C._jit_set_profiling_mode(False)
# Disable JIT profiling executor
    torch._C._jit_set_profiling_executor(False)
# Enable tensor fusion on CPU
    torch._C._jit_override_can_fuse_on_cpu(True)
# Enable tensor fusion on GPU
    torch._C._jit_override_can_fuse_on_gpu(True)

# Get the logger
logger = logging.get_logger(__name__)

# Checkpoints for documentation
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
# Configuration for documentation
_CONFIG_FOR_DOC = "ChatGLMConfig"

# Default initialization function
def default_init(cls, *args, **kwargs):
    return cls(*args, **kwargs)

# Invalid Score Logits Processor Class
class InvalidScoreLogitsProcessor(LogitsProcessor):
# Process input logits
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# If there are NaN or infinity in logits, reset them
        if torch.isnan(scores).any() or torch.isinf(scores).any():
            scores.zero_()  # 将所有分数重置为 0
            scores[..., 198] = 5e4  # 将特定位置的分数设置为一个大值
        return scores

# Function to split a tensor along the last dimension
def split_tensor_along_last_dim(
        tensor: torch.Tensor,
        num_partitions: int,
        contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
    """
    沿着最后一个维度分割张量。

    参数:
        tensor: 输入张量。
        num_partitions: 要分割的部分数量。
        contiguous_split_chunks: 如果为 True，则使每个分块在内存中是连续的。

    返回:
        张量的列表
    """
# Get the size of the last dimension
    last_dim = tensor.dim() - 1
    last_dim_size = tensor.size()[last_dim] // num_partitions
# Split tensor
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# NOTE: torch.split does not create contiguous tensors by default
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)
    return tensor_list


### Key points explanation

1. **System platform check**:
- If the system platform is not `Darwin` (macOS), turn off the profiling mode and executor of the JIT, and enable tensor fusion on the CPU and GPU. These settings can improve the performance of the model.

2. **Logger**:
- Get a logger for logging information and debugging during model building.

3. **Document string**:
- `_CHECKPOINT_FOR_DOC` and `_CONFIG_FOR_DOC` are string constants for document generation, indicating the checkpoint and configuration files of the model and configuration.

4. **Default initialization function**:
- The `default_init` function is a general initialization function used to instantiate the class.

5. **`InvalidScoreLogitsProcessor` class**:
- This is a logits processor class used to handle invalid logits scores. If there are `NaN` or infinite values ​​in the scores, reset them to 0, and set the scores of specific positions (such as 198) to a large value (`5e4`) to ensure that the model outputs valid results.

6. **`split_tensor_`along_last_dim` function:
- This function is used to split a tensor along its last dimension. The parameters include the input tensor, the number of parts to split, and a boolean value (whether to make each chunk contiguous in memory). If `contiguous_split_chunks` is true, each chunk is contiguous in memory; otherwise, a list of split tensors is returned directly.

In [66]:
# Rotation position embedding class
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None):
        super().__init__()
# Calculate the reciprocal frequency
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
# Register the countdown frequency as buffer
        self.register_buffer("inv_freq", inv_freq)
        self.dim = dim
        self.original_impl = original_impl
        self.rope_ratio = rope_ratio

# Specific method to implement forward propagation
    def forward_impl(
            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
    ):
        """
        增强的 Transformer 使用旋转位置嵌入。

        参考自:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT 许可证:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """
# Calculate $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        base = base * self.rope_ratio
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem))

# Create position index `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)

# Calculate the product of the position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()

# Cache calculated cosine and sine values
        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

# Emulate the behavior of complex32, otherwise you will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        return cache

# Forward propagation method
    def forward(self, max_seq_len, offset=0):
        return self.forward_impl(
            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
        )

# Using the TorchScript JIT compiler to optimize the rotation position embedding application function
@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
# x: [b, np, sq, hn]
    b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    rot_dim = rope_cache.shape[-2] * 2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
# Truncating to support variable sizes
    rope_cache = rope_cache[:, :sq]
    xshaped = x.reshape(b, np, sq, rot_dim // 2, 2)
    rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)

# RMS normalization layer class
class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
# Initialize weight parameters
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
# Get the data type of the input tensor
        input_dtype = hidden_states.dtype
# Calculate variance
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
# Perform RMS normalization
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
# Apply weights and return the same data type as input
        return (self.weight * hidden_states).to(input_dtype)

### Key Points Explanation

1. **`RotaryEmbedding` class**:
- A module for implementing rotational position embedding.
- `inv_freq` is a tensor of inverse frequency used for position embedding calculation.
- `forward_impl` method calculates rotational position embedding based on sequence length and embedding dimension.
- `forward` method calls `forward_impl` for forward propagation.

2. **`apply_rotary_pos_emb` function**:
- Update the input tensor with rotational position embedding.
- `x` is the input tensor, `rope_cache` is the precomputed rotational position embedding.
- The function first splits and reshapes the input tensor, then applies the rotational position embedding to each block, and finally splices the result back to the original tensor.

3. **`RMSNorm` class**:
- A module for implementing RMS normalization.
- `weight` is the normalized weight parameter.
- `forward` method calculates the variance of the input tensor, normalizes it, and then applies the weights.

These comments and explanations can help understand the functions and implementation details of each part, which is very useful for model building and debugging.

In [67]:
class CoreAttention(torch.nn.Module):
    def __init__(self, config: ChatGLMConfig, layer_number):
        super(CoreAttention, self).__init__()

# Configuration parameters
        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)

# Calculate the projection size
        projection_size = config.kv_channels * config.num_attention_heads

# The value of each attention head and each partition
        self.hidden_size_per_partition = projection_size
        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff
        self.coeff = coeff

# Attention dropout
        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)

    def forward(self, query_layer, key_layer, value_layer, attention_mask):
        pytorch_major_version = int(torch.__version__.split('.')[0])
        if pytorch_major_version >= 2:
# PyTorch 2.0 and above
            if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                                 is_causal=True)
            else:
                if attention_mask is not None:
                    attention_mask = ~attention_mask
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                                 attention_mask)
            context_layer = context_layer.transpose(1, 2).contiguous()
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.reshape(*new_context_layer_shape)
        else:
# Handling PyTorch versions below 2.0

# Raw attention score
# [b, np, sq, sk]
            output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2))

# Resize the view [b, np, sq, hn] -> [b * np, sq, hn]
            query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
# Resize the view [b, np, sk, hn] -> [b * np, sk, hn]
            key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)

# Pre-allocate input tensor: [b * np, sq, sk]
            matmul_input_buffer = torch.empty(
                output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
                device=query_layer.device
            )

# Calculate raw attention scores. [b * np, sq, sk]
            matmul_result = torch.baddbmm(
                matmul_input_buffer,
                query_layer,  # [b * np, sq, hn]
                key_layer.transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / self.norm_factor),
            )

# Change the view to [b, np, sq, sk]
            attention_scores = matmul_result.view(*output_size)

# Handling attention scores and dropout
            if self.attention_softmax_in_fp32:
                attention_scores = attention_scores.float()
            if self.coeff is not None:
                attention_scores = attention_scores * self.coeff
            if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
                attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
                                            device=attention_scores.device, dtype=torch.bool)
                attention_mask.tril_()
                attention_mask = ~attention_mask
            if attention_mask is not None:
                attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
            attention_probs = F.softmax(attention_scores, dim=-1)
            attention_probs = attention_probs.type_as(value_layer)

# Drop the attention of the entire token, this comes from the original Transformer paper
            attention_probs = self.attention_dropout(attention_probs)

# Resize the view [b * np, sq, hn]
            value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1)
# Resize the view [b * np, sq, sk]
            attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# Calculate the context layer [b * np, sq, hn]
            context_layer = torch.bmm(attention_probs, value_layer)
# Resize the view [b, np, sq, hn]
            context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [b, sq, np, hn]
            context_layer = context_layer.transpose(1, 2).contiguous()
# [b, sq, np, hn] --> [b, sq, hp]
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.reshape(*new_context_layer_shape)

        return context_layer

### CoreAttention Detailed Explanation and Formula

In neural network models, especially Transformer architecture, attention mechanism plays a vital role. CoreAttention implements the core part of the self-attention mechanism, which includes the following steps: calculating attention scores, applying attention masks, calculating attention weights, and calculating context vectors.

#### 1. Calculating Attention Scores

The calculation of attention scores can be expressed as matrix multiplication. For query vector $Q$ and key vector $K$, the formula for calculating the attention score matrix $A$ is:
\begin{align*} A = \frac{QK^T}{\sqrt{d_k}} \end{align*}
Where $d_k$ is the dimension of the key vector, $\sqrt{d_k}$ is a scaling factor to prevent the score value from being too large.

In the code, it is implemented by matrix multiplication:
```python
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer, # [b * np, sq, hn]
key_layer.transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
```
Where `self.norm_factor` is $\sqrt{d_k}$.

#### 2. Apply Attention Mask

In order to prevent the model from focusing on unnecessary parts, use attention mask to mask certain positions. The attention mask is implemented by setting the attention score of the corresponding position to negative infinity:
```python
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
```

#### 3. Calculate Attention Weights

Apply the softmax function to convert the attention score into a probability distribution, indicating the degree of attention of each query vector to the key vector:
\begin{align*} \text{AttentionWeights} = \text{softmax}(A) \end{align*}
Implemented in code as:
```python
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.type_as(value_layer)
```

#### 4. Calculate context vectors

The context vector is obtained by multiplying the attention weights with the value vector and summing them:
\begin{align*} \text{Context} = \text{AttentionWeights} \cdot V \end{align*}
Where $V$ is the value vector.

Implemented in code as:
```python
context_layer = torch.bmm(attention_probs, value_layer)
```

#### Detailed principle explanation

1. **Initialization and setting**:
- When initializing the class, the attention scale, number of layers, and other information are set according to the configuration parameters.
- `projection_size` defines the projection size, `hidden_size_per_attention_he`ad` and `num_attention_heads_per_partition` define the hidden layer size of each attention head and each partition.

2. **Forward propagation steps**:
- **Query, key and value calculation**: Calculate the query, key and value vectors.
- **Calculate attention scores**: Calculate the attention score matrix through matrix multiplication.
- **Apply attention mask**: Set the positions to be masked to negative infinity to avoid affecting subsequent calculations.
- **Calculate attention weights**: Apply the softmax function to get the attention weights.
- **Calculate context vector**: Multiply the attention weights with the value vector to get the context vector.

3. **Specific implementation details**:
- Choose the appropriate attention calculation method according to the PyTorch version.
- Transform tensors of different dimensions to ensure shape matching.
- Use dropout to prevent overfitting.

### Correspondence between formula and code

- **Attention score**:
\begin{align*}
A = \frac{QK^T}{\sqrt{d_k}}
\end{align*}
Corresponding code:
```python
matmul_result= torch.baddbmm(
matmul_input_buffer,
query_layer,
key_layer.transpose(1, 2),
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
```

- **Apply attention mask**:
\begin{align*}
A'_{ij} = \begin{cases} 
A_{ij} & \text{if } \text{mask}_{ij} = 1 \\
-\infty & \text{if } \text{mask}_{ij} = 0 
\end{cases}
\end{align*}
Corresponding code:
```python
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
```

- **Attention weight**:\begin{align*}
\text{AttentionWeights}_{ij} = \frac{\exp(A'_{ij})}{\sum_k \exp(A'_{ik})}
\end{align*}
Corresponding code:
```python
attention_probs = F.softmax(attention_scores, dim=-1)
```

- **Context vector**:
\begin{align*}
\text{Context}_{i} = \sum_j \text{AttentionWeights}_{ij} V_j
\end{align*}
Corresponding code:
```python
context_layer = torch.bmm(attention_probs, value_layer)
```

Next, let's review the attention mechanism

### Principles and formula explanation of the attention mechanism

#### 1. What is the attention mechanism?
The attention mechanism is a technology in deep learning, especially in the field of natural language processing, which enables the model to dynamically focus on different parts when processing an input sequence. Simply put, the attention mechanism allows the model to selectively focus on other parts of the input sequence when processing a certain element, rather than treating all of them equally.

#### 2. The core formula of the attention mechanism
Let's look at the formula for the attention score:
\begin{align*} A = \frac{QK^T}{\sqrt{d_k}} \end{align*}

The symbols here are explained as follows:
- \( Q \) is the query vector.
- \( K \) is the key vector.
- \( d_k \) is the dimension of the key vector.
- \( A \) is the attention score matrix.

#### 3. Why use this formula to calculate the attention score?

The core idea of ​​the attention score formula is to measure the relevance between the query and each key by calculating the dot product of the query vector and the key vector. The larger the dot product result, the stronger the relevance between the query and the key, and the model should pay more attention to the value vector corresponding to the key.

Scaling factor in the formula\(\sqrt{d_k}\) is to avoid the dot product result being too large, because if it is not scaled, the larger vector dimension will cause the dot product result to be very large, which will cause the softmax function to output a gradient close to zero, affecting the model training effect.

#### 4. Calculate attention weights
After obtaining the attention score matrix \( A \), convert it to the attention weight matrix through the softmax function:
\begin{align*} \text{AttentionWeights}_{ij} = \frac{\exp(A'_{ij})}{\sum_k \exp(A'_{ik})} \end{align*}
Where \( A' \) is the attention score matrix after applying the mask.

#### 5. Calculate the context vector
Finally, use the attention weight matrix to perform weighted summation on the value vector to obtain the context vector:
\begin{align*} \text{Context}_{i} = \sum_j \text{AttentionWeights}_{ij} V_j \end{align*}
Here \( V \) is the value vector.

### Relationship between attention mechanism and human attention

The attention mechanism has certain similarities with human attention, but there are also significant differences.

- **Similarities**:
- **Selective attention**: Just like humans selectively focus on certain important paragraphs when reading an article and ignore other parts, the attention mechanism also allows the model to selectively focus on different parts when processing a sequence.
- **Dynamic adjustment**: Human attention is dynamic and adjusts its focus according to the context. The attention mechanism is also dynamic and can adjust the attention weight according to changes in the input.

- **Difference**:
- **Different mechanisms**: Human attention is achieved through the complex neural network of the brain, including the comprehensive processing of multiple sensory information such as vision and hearing. The attention mechanism is a mathematical method that is implemented through operations such as dot product and softmax.
- **Different purposes**: Human attention is used for understanding and interaction, while the attention mechanism is mainly used to improve the performance and efficiency of the model when processing long sequence data.

### Details in the specific code implementation

In the `CoreAttention` class, the calculation and application of the attention score is implemented by the following code snippet:
```python
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer, # [b * np, sq, hn]
key_layer.transpose(1, 2), #[b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
```
Here, the `torch.baddbmm` function performs batch matrix multiplication, calculates the \(QK^T\) part of the formula, and scales it by dividing it by \(\sqrt{d_k}\).

Apply softmax to get the attention weights:
```python
attention_probs = F.softmax(attention_scores, dim=-1)
```

Finally, calculate the context vector:
```python
context_layer = torch.bmm(attention_probs, value_layer)
```

This step multiplies the attention weights with the value vector and sums them to get the final context representation.

Through the above explanation, we can better understand the principle and implementation of the attention mechanism, as well as its important role in the model.

After completing the core implementation of the attention mechanism, we build the self-attention

### Principle and formula explanation of SelfAttention class

#### Function of SelfAttention class

The SelfAttention class implements the Self-Attention Mechanism, which is the core part of the Transformer model. The goal of the self-attention mechanism is to allow the representation of each position to dynamically pay attention to other positions of the input sequence, thereby capturing global information.

#### Steps and formulas of the self-attention mechanism

The self-attention mechanism includes the following key steps:

1. **Linear transformation**:
The input hidden state \(X\) is mapped to the query \(Q\), key \(K\) and value \(V\) spaces through linear layers:
\begin{align*}
Q = XW_Q, \quad K = XW_K, \quad V = XW_V
\end{align*}
Among them, \(W_Q\), \(W_K\) and \(W_V\) are the learned weight matrices.

2. **Calculate attention score**:
The attention score matrix \(A\) is obtained by calculating the dot product of the query \(Q\) and the key \(K\) and dividing by the scaling factor \(\sqrt{d_k}\):
\begin{align*}
A = \frac{QK^T}{\sqrt{d_k}}
\end{align*}

3. **Apply attention mask**:
For self-attention mechanism, if mask is used (for example, in decoding stage), set the positions that do not need to be paid attention to to negative infinity to mask:
\begin{align*}
A'_{ij} = \begin{cases}
A_{ij} & \text{if } \text{mask}_{ij} = 1 \\
-\infty & \text{if } \text{mask}_{ij} = 0
\end{cases}
\end{align*}

4. **Calculate attention weight**:
Apply softmax function to the attention score matrix \(A'\) to get the attention weight matrix:
\begin{align*}
\text{AttentionWeights}_{ij} = \frac{\exp(A'_{ij})}{\sum_k \exp(A'_{ik})}
\end{align*}

5. **Calculate the context vector**:
Use the attention weight to calculate the value \(V\)Weighted summation to get the context vector:
\begin{align*}
\text{Context} = \text{AttentionWeights} \cdot V
\end{align*}

### Specific implementation of the SelfAttention class

The following are the implementation details of the core steps and formulas in the SelfAttention class:

In [68]:
class SelfAttention(torch.nn.Module):
"""Abstract class for parallel self-attention layers.

    自注意力层接受形状为 [s, b, h] 的输入，并返回相同形状的输出。
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(SelfAttention, self).__init__()
        self.layer_number = max(1, layer_number)

        self.projection_size = config.kv_channels * config.num_attention_heads

# The value of each attention head and each partition
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.multi_query_attention = config.multi_query_attention
        self.qkv_hidden_size = 3 * self.projection_size
        if self.multi_query_attention:
            self.num_multi_query_groups_per_partition = config.multi_query_group_num
            self.qkv_hidden_size = (
                    self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
            )
        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
                                         bias=config.add_bias_linear or config.add_qkv_bias,
                                         device=device, **_config_to_kwargs(config)
                                         )

        self.core_attention = CoreAttention(config, self.layer_number)

# Output layer
        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
                               device=device, **_config_to_kwargs(config)
                               )

    def forward(self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True):
# hidden_states: [b, sq, h]

        # =====================
# Query, Key and Value
        # =====================

# Attention head [b, sq, h] --> [b, sq, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        if self.multi_query_attention:
            (query_layer, key_layer, value_layer) = mixed_x_layer.split(
                [
                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                ],
                dim=-1,
            )
            query_layer = query_layer.view(
                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            key_layer = key_layer.view(
                key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
            value_layer = value_layer.view(
                value_layer.size()[:-1]
                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
        else:
            new_tensor_shape = mixed_x_layer.size()[:-1] + \
                               (self.num_attention_heads_per_partition,
                                3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn]
            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

# [b, sq, np, hn] -> [b, np, sq, hn]
        query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]

# Apply relative position encoding (rotation embedding)
        if rotary_pos_emb is not None:
            query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
            key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

# Adjust key and value for inference
        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            key_layer = torch.cat((cache_k, key_layer), dim=2)
            value_layer = torch.cat((cache_v, value_layer), dim=2)
        if use_cache:
            if kv_cache is None:
                kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
            else:
                kv_cache = (key_layer, value_layer)
        else:
            kv_cache = None

        if self.multi_query_attention:
            key_layer = key_layer.unsqueeze(2)
            key_layer = key_layer.expand(
                -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
            )
            key_layer = key_layer.contiguous().view(
                key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:]
            )
            value_layer = value_layer.unsqueeze(2)
            value_layer = value_layer.expand(
                -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1
            )
            value_layer = value_layer.contiguous().view(
                value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:]
            )

# Core attention calculation
        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)

# Output. [sq, b, h]
        output = self.dense(context_layer)

        return output, kv_cache

### Detailed explanation of the principle

1. **Linear transformation**:
```python
mixed_x_layer = self.query_key_value(hidden_states)
```
Input hidden state `hidden_states` is mapped to the query, key and value space through the linear layer.

2. **Split query, key and value**:
```python
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
```
Split the result after linear transformation into query, key and value.

3. **Shape transformation**:
```python
query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]
```
Transform the shape of query, key, and value from `[b, sq, np, hn]` to `[b, np, sq, hn]`.

4. **Should**Use rotary position encoding**:
```python
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
```
If rotary position encoding is present, apply position encoding to query and key.

5. **Adjust keys and values ​​for inference**:
```python
if kv_cache is not None:
cache_k, cache_v = kv_cache

key_layer = torch.cat((cache_k, key_layer), dim=2)
value_layer = torch.cat((cache_v, value_layer), dim=2)
if use_cache:
if kv_cache is None:kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
else:
kv_cache = (key_layer, value_layer)
else:
kv_cache = None
```

6. **Core attention calculation**:
```python
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
```

7. **Output layer**:
```python
output = self.dense(context_layer)
```

Through the above detailed principle explanation and formula, we can better understand the implementation of the SelfAttention class and its role in the Transformer model. The self-attention mechanism dynamically adjusts the weights between different positions to make the modelThe model can capture global information more effectively, thereby improving the performance and generalization ability of the model.

In [69]:
def _config_to_kwargs(args):
    common_kwargs = {
        "dtype": args.torch_dtype,
    }
    return common_kwargs


class MLP(torch.nn.Module):
"""Multilayer Perceptron (MLP).

    MLP 将接受隐藏状态为 h 的输入，将其投影到 4*h 的隐藏维度，执行非线性变换，然后将状态投影回 h 的隐藏维度。
    """

    def __init__(self, config: ChatGLMConfig, device=None):
        super(MLP, self).__init__()

        self.add_bias = config.add_bias_linear

# Project to 4h. If using swiglu, double the output width, see https://arxiv.org/pdf/2002.05202.pdf for details
        self.dense_h_to_4h = nn.Linear(
            config.hidden_size,
            config.ffn_hidden_size * 2,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )

        def swiglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.silu(x[0]) * x[1]

        self.activation_func = swiglu

# Project back to h.
        self.dense_4h_to_h = nn.Linear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )

    def forward(self, hidden_states):
# [s, b, 4hp]
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
        output = self.dense_4h_to_h(intermediate_parallel)
        return output


class GLMBlock(torch.nn.Module):
"""A transformer layer.

    Transformer 层接受形状为 [s, b, h] 的输入，并返回相同形状的输出。
    """

    def __init__(self, config: ChatGLMConfig, layer_number, device=None):
        super(GLMBlock, self).__init__()
        self.layer_number = layer_number

        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
        self.fp32_residual_connection = config.fp32_residual_connection

        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Layer normalization on input data
        self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                             dtype=config.torch_dtype)

# Self-attention layer
        self.self_attention = SelfAttention(config, layer_number, device=device)
        self.hidden_dropout = config.hidden_dropout

# Layer normalization on attention output
        self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                      dtype=config.torch_dtype)

# Multilayer Perceptron (MLP)
        self.mlp = MLP(config, device=device)

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
    ):
# hidden_states: [s, b, h]

# Layer normalization at the beginning of transformer layer
        layernorm_output = self.input_layernorm(hidden_states)
# Self-attention layer
        attention_output, kv_cache = self.self_attention(
            layernorm_output,
            attention_mask,
            rotary_pos_emb,
            kv_cache=kv_cache,
            use_cache=use_cache
        )

# Residual connection
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
        layernorm_input = residual + layernorm_input

# Layer normalization after self-attention
        layernorm_output = self.post_attention_layernorm(layernorm_input)

# Multilayer Perceptron (MLP)
        mlp_output = self.mlp(layernorm_output)

# The second residual connection
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
        output = residual + output

        return output, kv_cache


class GLMTransformer(torch.nn.Module):
"""Transformer class."""

    def __init__(self, config: ChatGLMConfig, device=None):
        super(GLMTransformer, self).__init__()

        self.fp32_residual_connection = config.fp32_residual_connection
        self.post_layer_norm = config.post_layer_norm

# Number of layers
        self.num_layers = config.num_layers

Transformer Layer
        def build_layer(layer_number):
            return GLMBlock(config, layer_number, device=device)

        self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])

        if self.post_layer_norm:
            LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Final layer normalization before output
            self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                 dtype=config.torch_dtype)

        self.gradient_checkpointing = False

    def _get_layer(self, layer_number):
        return self.layers[layer_number]

    def forward(
            self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None,
            use_cache: Optional[bool] = True,
            output_hidden_states: Optional[bool] = False,
    ):
        if not kv_caches:
            kv_caches = [None for _ in range(self.num_layers)]
        presents = () if use_cache else None
        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        all_self_attentions = None
        all_hidden_states = () if output_hidden_states else None
        for index in range(self.num_layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer = self._get_layer(index)
            if self.gradient_checkpointing and self.training:
                layer_ret = torch.utils.checkpoint.checkpoint(
                    layer,
                    hidden_states,
                    attention_mask,
                    rotary_pos_emb,
                    kv_caches[index],
                    use_cache,
                    use_reentrant=False
                )
            else:
                layer_ret = layer(
                    hidden_states,
                    attention_mask,
                    rotary_pos_emb,
                    kv_cache=kv_caches[index],
                    use_cache=use_cache
                )
            hidden_states, kv_cache = layer_ret
            if use_cache:
# Token by token decoding, using tuple format
                if kv_caches[0] is not None:
                    presents = presents + (kv_cache,)
# Prefill decode, use tensor format to save CUDA memory
                else:
                    if len(presents) == 0:
                        presents = kv_cache
                    else:
                        presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

# Final layer normalization
        if self.post_layer_norm:
            hidden_states = self.final_layernorm(hidden_states)

        return hidden_states, presents, all_hidden_states, all_self_attentions


## Changes in tensor shape and their meaning.

### Tensor shapes in `CoreAttention` class

1. **`query_layer`, `key_layer`, `value_layer` shape changes**:
```python
# Input shape [b, np, sq, hn] and [b, np, sk, hn]
# [b, np, sq, hn] -> [b * np, sq, hn]
query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1)
# [b, np, sk, hn] -> [b * np, sk, hn]
key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1)
```

2. **Calculation of raw attention score**:
```python
# Input shape [b * np, sq, hn] and [b * np, hn, sk]
#Result shape [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_input_buffer, query_layer, key_layer.transpose(1, 2), beta=0.0, alpha=(1.0 / self.norm_factor))
```

3. **Adjust the view to the original shape**:
```python
# [b * np, sq, sk] -> [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
```

4. **Attention probability shape change**:
```python
# Input shape [b, np, sq, sk]
# Adjusted shape [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
```

5. **Calculation contextLayer**:
```python
# Input shapes [b * np, sq, sk] and [b * np, sk, hn]
# Result shape [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)
# Adjust view [b * np, sq, hn] -> [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] -> [b, sq, np, hn]
context_layer = context_layer.transpose(1, 2).contiguous()
# [b, sq, np, hn] -> [b, sq, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
```

### Tensor shape in `SelfAttention` class

1. **`mixed_x_layer` shape change**:
```python
# Input shape [b, sq, h]
# Result shape [b, sq, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)
```

2. **Multi-query attention shape change**:
```python
# [b, sq, (np * 3 * hn)] -> [b, sq, np, hn] and [b, sq, np, hn]
(query_layer, key_layer, value_layer) = mixed_x_layer.split([...], dim=-1)
# Adjust view
query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))
value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))
```

3. **Shape change of normal attention**:
```python
# [b, sq, (np * 3 * hn)] -> [b, sq, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [b, sq, np, 3 * hn] -> 3 [b, sq, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
```

4. **Transpose operation**:
```python
# [b, sq, np, hn] -> [b, np, sq, hn]
query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]
```

### Tensor shape in `MLP` class

1. **Forward pass**:
```python
# InputShape [s, b, h]
# Shape after dense_h_to_4h linear layer [s, b, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# Shape after dense_4h_to_h linear layer [s, b, h]
output = self.dense_4h_to_h(intermediate_parallel)
```

### Tensor shape in `GLMBlock` class

1. **Input and output shapes of self-attention layer**:
```python
# Input shape [s, b, h]
# Output shape of self-attention layer [s, b, h]
attention_output, kv_cache = self.self_attention(layernorm_output, attention_mask, rotary_pos_emb, kv_cache=kv_cache, use_cache=use_cache)
```

2. **Residual connection**:
```python
# Residual connection, shape remains unchanged [s, b, h]
layernorm_input = residual + layernorm_input
```

3. **Input and output shape of multi-layer perceptron (MLP)**:
```python
# MLP output shape [s, b, h]
mlp_output = self.mlp(layernorm_output)
```

### Tensor shape in `GLMTransformer` class

1. **Layer-by-layer processing**:
```python
# Input and output shape of each layer [s, b, h]
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states, kv_cache = layer(hidden_states, attention_mask, rotary_pos_emb, kv_cache=kv_caches[index], use_cache=use_cache)
```

2. **Final layer normalization**:
```python
# Final layer normalization, shape [s, b, h]
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
```

Summarizing these shape changes helps understand how the input and output of each layer are passed and processed, ensuring that the model maintains the correct tensor shape at each step.

### Key Points Explanation

1. **`CoreAttention` class**:
- Implements the core attention mechanism.
- Supports `scaled_dot_product_attention` of PyTorch 2.0 and above, as well as custom attention calculations of older versions.

2. **`SelfAttention` class**:
- Implements the self-attention layer.
- Includes the calculation of query, key, value, and the application of the core attention mechanism.
- Supports multi-query attention mechanism.

3. **`MLP` class**:
- Multilayer perceptron (MLP), including two linear layers and a nonlinear activation function.

4. **`GLMBlock` class**:
- Implements a Transformer layer, including a self-attention layer and an MLP layer.
- Includes layer normalization and residual connections.

5. **`GLMTransformer` class**:
- Implements the Transformer model, including multiple Transformer layers.
- Support for gradient checkpointing and final layer normalization.

### Steps for extracting input hidden states from the self-attention mechanism

In the SelfAttention class, the self-attention mechanism extracts and processes the input hidden states through a series of steps to finally generate a context vector. These steps include linear transformation, calculating attention scores, applying attention masks, calculating attention weights, and generating context vectors. The following are the detailed steps and explanations:

#### 1. Input hidden state

The shape of the input hidden state \( \text{hidden_states} \) is usually \([b, sq, h]\), where:
- \( b \) is the batch size.
- \( sq \) is the sequence length.
- \( h \) is the hidden layer dimension (hidden size).

#### 2. Linear transformation

The input hidden state \( \text{hidden_states} \) is projected to the query \( Q \), key \( K \), and value \( V \) space through a linear layer. The purpose of this step is to map the input to different subspaces for attention calculation. The formula is as follows: \begin{align*} Q = \text{hidden_states} \cdot W_Q \end{align*}
\begin{align*} K = \text{hidden_states} \cdot W_K \end{align*}
\begin{align*} V = \text{hidden_states} \cdot W_V \end{align*}

In the code, this is done as follows:
```python
mixed_x_layer = self.query_key_value(hidden_states)
```
Here `self.query_key_value` is a linear layer that projects the input hidden state into a larger space, and the resulting shape is \([b, sq, (3 \times \text{hidden_size})]\).

#### 3. Split query, key, and value

Split the above result into query, key, and value vectors. After splitting, the shape of each vector is \([b, sq, \text{hidden_size}]\).
```python
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer,3)
```
`split_tensor_along_last_dim` splits the tensor into three parts along the last dimension.

#### 4. Shape transformation

In order to adapt to the subsequent matrix multiplication operation, the query, key and value need to be reshaped. Convert them from \([b, sq, np, hn]\) to \([b, np, sq, hn]\), where \(np\) is the number of attention heads and \(hn\) is the dimension of each attention head.
```python
query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]]
```

#### 5. Apply rotational position encoding

If there is a rotational position embedding, it is applied to the query and key vectors. Rotational position encoding can help the model better capture position information.
```python
if rotary_pos_emb is not None:
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
```

#### 6. Adjust keys and values ​​for inference

During the inference phase, you may need to cache keys and values ​​so that they can be reused in the next time step to improve efficiency. Concatenate the cached keys and values ​​with the keys and values ​​of the current time step.
```python
if kv_cache is not None:
cache_k, cache_v = kv_cache
key_layer = torch.cat((cache_k, key_layer), dim=2)
value_layer = torch.cat((cache_v, value_layer), dim=2)
if use_cache:
if kv_cache is None:
kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), dim=1)
else:
kv_cache = (key_layer, value_layer)
else:
kv_cache = None
```

#### 7. Core Attention Calculation

Use the CoreAttention class to calculate the attention score, attention weight, and context vector.
```python
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask)
```

#### 8. Output Layer

Finally, the context vector is projected back to the original hidden layer dimension through a linear layer. This step is to map the calculated result back to the shape of the input for subsequent processing.
```python
output = self.dense(context_layer)
```

### Summary

Through the above steps, the self-attention mechanism realizes the process of extracting and processing information from the input hidden state. The detailed explanation of each step is as follows:

1. **Input hidden state**: hidden representation of the input sequence.
2. **Linear transformation**: Project the hidden representation into the space of query, key and value.
3. **Split query, key and value**: Split the projected result into query, key and value vectors.
4. **Shape Transformation**: Adjust the shape of query, key and value to adapt to subsequent operations.
5. **Apply Rotation Position Encoding**: Enhance the position information of query and key vectors.
6. **Adjust keys and values ​​for inference**: Cache keys and values ​​during the inference phase to improve efficiency.
7. **Core Attention Computation**: Complete the core part of the attention mechanism by calculating attention scores, weights and context vectors.
8. **Output Layer**: Project the context vector back to the original hidden layer dimension.

Together, these steps form the self-attention mechanism, which enables the model to dynamically focus on different parts of the input sequence, capture global information and generate richer representations.

### ChatGLMPreTrainedModel class

In [70]:
class ChatGLMPreTrainedModel(PreTrainedModel):
    """
    处理权重初始化和下载及加载预训练模型的简单接口的抽象类。
    """
    is_parallelizable = False  # 是否可并行化
    supports_gradient_checkpointing = True  # 是否支持梯度检查点
    config_class = ChatGLMConfig  # 配置类
    base_model_prefix = "transformer"  # 基础模型前缀
    _no_split_modules = ["GLMBlock"]  # 不拆分的模块列表

    def _init_weights(self, module: nn.Module):
"""Initialize weights"""
        return

    def get_masks(self, input_ids, past_key_values, padding_mask=None):
        """
        获取注意力掩码
        """
        batch_size, seq_length = input_ids.shape
# Create a lower triangular matrix as the full attention mask
        full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
        full_attention_mask.tril_()
        past_length = 0
        if past_key_values:
            past_length = past_key_values[0][0].shape[2]
        if past_length:
            full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, device=input_ids.device), full_attention_mask), dim=-1)
        if padding_mask is not None:
            full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
        if not past_length and padding_mask is not None:
            full_attention_mask -= padding_mask.unsqueeze(-1) - 1
        full_attention_mask = (full_attention_mask < 0.5).bool()
        full_attention_mask.unsqueeze_(1)
        return full_attention_mask

    def get_position_ids(self, input_ids, device):
        """
        获取位置ID
        """
        batch_size, seq_length = input_ids.shape
        position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
        return position_ids

    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        if not self.supports_gradient_checkpointing:
            raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

### Embedding Class

In [71]:
class Embedding(torch.nn.Module):
"""Language model embedding layer."""

    def __init__(self, config: ChatGLMConfig, device=None):
        super(Embedding, self).__init__()

        self.hidden_size = config.hidden_size
# Word Embedding Layer (Parallel)
        self.word_embeddings = nn.Embedding(
            config.padded_vocab_size,
            self.hidden_size,
            dtype=config.torch_dtype,
            device=device
        )
        self.fp32_residual_connection = config.fp32_residual_connection

    def forward(self, input_ids):
# Get word embeddings
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings
# If fp32 residual connection is set, convert to floating point
        if self.fp32_residual_connection:
            embeddings = embeddings.float()
        return embeddings

### ChatGLMModel class

In [72]:
class ChatGLMModel(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
        super().__init__(config)
        if empty_init:
            init_method = skip_init  # 跳过初始化
        else:
            init_method = default_init  # 使用默认初始化方法
        init_kwargs = {}
        if device is not None:
            init_kwargs["device"] = device
        self.embedding = init_method(Embedding, config, **init_kwargs)  # 使用 Embedding 类
        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels

# Rotation position embedding
        self.seq_length = config.seq_length
        rotary_dim = (
            config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
        )

        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, original_impl=config.original_rope, 
                                              device=device, dtype=config.torch_dtype)  # 使用 RotaryEmbedding 类
        self.encoder = init_method(GLMTransformer, config, **init_kwargs)  # 使用 GLMTransformer 类
        self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
                                        dtype=config.torch_dtype, **init_kwargs)  # 使用 nn.Linear 类

    def get_input_embeddings(self):
        return self.embedding.word_embeddings

    def set_input_embeddings(self, value):
        self.embedding.word_embeddings = value

    def forward(
            self,
            input_ids,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.BoolTensor] = None,
            full_attention_mask: Optional[torch.BoolTensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ):
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, seq_length = input_ids.shape

        if inputs_embeds is None:
            inputs_embeds = self.embedding(input_ids)  # 使用 Embedding 类

        if full_attention_mask is None:
            if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
                full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

# Rotation position embedding
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)  # 使用 RotaryEmbedding 类
        if position_ids is not None:
            rotary_pos_emb = rotary_pos_emb[position_ids]
        else:
            rotary_pos_emb = rotary_pos_emb[None, :seq_length]

# Run the encoder
        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
            inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
            kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
        )  # 使用 GLMTransformer 类
        if presents is not None and type(presents) is torch.Tensor:
            presents = presents.split(1, dim=0)
            presents = list(presents)
            presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents]
            presents = [tuple([x.squeeze(0) for x in y]) for y in presents]
            presents = tuple(presents)

        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )

### Summary

1. **ChatGLMPreTrainedModel**:
- **get_masks**: Generates attention masks.
- **get_position_ids**: Generates position IDs.

2. **Embedding**:
- **word_embeddings**: Implements word embeddings.

3. **ChatGLMModel**:
- **RotaryEmbedding**: For position encoding.
- **GLMTransformer**: Implements Transformer encoder.
- **forward**: Performs forward propagation, integrating all modules.

4. **Previously built modules used**:
- **RotaryEmbedding**: For generating rotation position embeddings.
- **GLMTransformer**: For the encoder part.
- **Embedding**: For generating word embeddings.
- **CoreAttention**, **SelfAttention**: Indirectly used through GLMTransformer.

### ChatGLMForConditionalGeneration Class

In [73]:
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
        super().__init__(config)

        self.max_sequence_length = config.max_length  # 最大序列长度
        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)  # 使用 ChatGLMModel 类
        self.config = config

    def _update_model_kwargs_for_generation(
            self,
            outputs: ModelOutput,
            model_kwargs: Dict[str, Any],
            is_encoder_decoder: bool = False,
            standardize_cache_format: bool = False,
    ) -> Dict[str, Any]:
# Update past_key_values
        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
            outputs, standardize_cache_format=standardize_cache_format
        )

# Update the attention mask
        if "attention_mask" in model_kwargs:
            attention_mask = model_kwargs["attention_mask"]
            model_kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )

# Update location ids
        if "position_ids" in model_kwargs:
            position_ids = model_kwargs["position_ids"]
            new_position_id = position_ids[..., -1:].clone()
            new_position_id += 1
            model_kwargs["position_ids"] = torch.cat(
                [position_ids, new_position_id], dim=-1
            )

        model_kwargs["is_first_forward"] = False
        return model_kwargs

    def prepare_inputs_for_generation(
            self,
            input_ids: torch.LongTensor,
            past_key_values: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            is_first_forward: bool = True,
            **kwargs
    ) -> dict:
# If past_key_values ​​is not empty, only take the last token of input_ids
        if position_ids is None:
            position_ids = self.get_position_ids(input_ids, device=input_ids.device)
        if not is_first_forward:
            if past_key_values is not None:
                position_ids = position_ids[..., -1:]
                input_ids = input_ids[:, -1:]
        return {
            "input_ids": input_ids,
            "past_key_values": past_key_values,
            "position_ids": position_ids,
            "attention_mask": attention_mask,
            "return_last_logit": True,
            "use_cache": use_cache
        }

    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = False,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )  # 使用 ChatGLMModel 类

        hidden_states = transformer_outputs[0]
        if return_last_logit:
            hidden_states = hidden_states[:, -1:]
        lm_logits = self.transformer.output_layer(hidden_states)

        loss = None
        if labels is not None:
            lm_logits = lm_logits.to(torch.float32)

# Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

    @staticmethod
    def _reorder_cache(
            past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
        """
        重新排序 `past_key_values` 缓存以匹配每个生成步骤中的 `beam_idx`。
        """
        return tuple(
            (
                layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)),
                layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)),
            )
            for layer_past in past
        )

    def process_response(self, output, history):
        content = ""
        history = deepcopy(history)
        for response in output.split(""):
            if "\n" in response:
                metadata, content = response.split("\n", maxsplit=1)
            else:
                metadata, content = "", response
            if not metadata.strip():
                content = content.strip()
                history.append({"role": "assistant", "metadata": metadata, "content": content})
                content = content.replace("[[训练时间]]", "2023年")
            else:
                history.append({"role": "assistant", "metadata": metadata, "content": content})
                if history[0]["role"] == "system" and "tools" in history[0]:
                    parameters = json.loads(content)
                    content = {"name": metadata.strip(), "parameters": parameters}
                else:
                    content = {"name": metadata.strip(), "content": content}
        return content, history

    @torch.inference_mode()
    def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
             max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
             **kwargs):
        if history is None:
            history = []
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        history.append({"role": role, "content": query})
        inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=True,
                                               return_tensors="pt", return_dict=True)
        inputs = inputs.to(self.device)
        eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(""),
                        tokenizer.convert_tokens_to_ids("")]
        outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
        response = tokenizer.decode(outputs)
        response, history = self.process_response(response, history)
        return response, history

    @torch.inference_mode()
    def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
                    past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
                    logits_processor=None, return_past_key_values=False, **kwargs):
        if history is None:
            history = []
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(""),
                        tokenizer.convert_tokens_to_ids("")]
        gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        if past_key_values is None:
            inputs = tokenizer.apply_chat_template(history + [{"role": role, "content": query}],
                                                   add_generation_prompt=True, tokenize=True, return_tensors="pt",
                                                   return_dict=True)
        else:
            inputs = tokenizer.apply_chat_template([{"role": role, "content": query}], add_special_tokens=False,
                                                   add_generation_prompt=True, tokenize=True, return_tensors="pt",
                                                   return_dict=True)
        inputs = inputs.to(self.device)
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]
            inputs.position_ids += past_length


            attention_mask = inputs.attention_mask
            attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
            inputs['attention_mask'] = attention_mask
        history.append({"role": role, "content": query})
        for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
                                            eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
                                            **gen_kwargs):
            if return_past_key_values:
                outputs, past_key_values = outputs
            outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
            response = tokenizer.decode(outputs)
            if response and response[-1] != "�":
                response, new_history = self.process_response(response, history)
                if return_past_key_values:
                    yield response, new_history, past_key_values
                else:
                    yield response, new_history

    @torch.inference_mode()
    def stream_generate(
            self,
            input_ids,
            generation_config: Optional[GenerationConfig] = None,
            logits_processor: Optional[LogitsProcessorList] = None,
            stopping_criteria: Optional[StoppingCriteriaList] = None,
            prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
            return_past_key_values=False,
            **kwargs,
    ):
        batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]

        if generation_config is None:
            generation_config = self.generation_config
        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)
        model_kwargs["use_cache"] = generation_config.use_cache
        bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id

        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None

        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        if has_default_max_length and generation_config.max_new_tokens is None:
            warnings.warn(
                f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
                "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
                " recommend using `max_new_tokens` to control the maximum length of the generation.",
                UserWarning,
            )
        elif generation_config.max_new_tokens is not None:
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
            if not has_default_max_length:
                logger.warn(
                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                    "Please refer to the documentation for more information. "
                    "(https://hf-mirror.com/docs/transformers/main/en/main_classes/text_generation)",
                    UserWarning,
                )

        if input_ids_seq_length >= generation_config.max_length:
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
                " increasing `max_new_tokens`."
            )

# 2. Set generation parameters if not already defined
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            encoder_input_ids=input_ids,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
        )

        stopping_criteria = self._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria
        )
        logits_warper = self._get_logits_warper(generation_config)

        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        scores = None
        while True:
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# Forward pass to get the next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=False,
            )

            next_token_logits = outputs.logits[:, -1, :]

# Preprocessing distribution
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

# Sampling
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            if generation_config.do_sample:
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(probs, dim=-1)
# Update generated ids, model inputs and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )
            if return_past_key_values:
                yield input_ids, outputs.past_key_values
            else:
                yield input_ids
# Stop when each sentence is complete or exceeds the maximum length
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                break

### ChatGLMForSequenceClassification class

In [74]:
class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
        super().__init__(config)

        self.num_labels = config.num_labels  # 标签数量
        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)  # 使用 ChatGLMModel 类

        self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half)
        if config.classifier_dropout is not None:
            self.dropout = nn.Dropout(config.classifier_dropout)
        else:
            self.dropout = None
        self.config = config

    def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            full_attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            inputs_embeds: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            full_attention_mask=full_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )  # 使用 ChatGLMModel 类

        hidden_states = transformer_outputs[0]
        pooled_hidden_states = hidden_states[:, -1]
        if self.dropout is not None:
            pooled_hidden_states = self.dropout(pooled_hidden_states)
        logits = self.classifier_head(pooled_hidden_states)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze().float(), labels.squeeze())
                else:
                    loss = loss_fct(logits.float(), labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))

        if not return_dict:
            output = (logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithPast(
            loss

=loss,
            logits=logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )


### Previously built modules used

In this code, multiple modules and methods are based on previously built classes and functions:

1. **ChatGLMModel**:
- **Used in ChatGLMForConditionalGeneration and ChatGLMForSequenceClassification**, as the core part of the Transformer model.

2. **ChatGLMPreTrainedModel**:
- **As the base class of ChatGLMForConditionalGeneration and ChatGLMForSequenceClassification**, it provides interfaces for weight initialization and loading pre-trained models.

3. **RotaryEmbedding**:
- **Used for position encoding in ChatGLMModel**.

4. **CoreAttention and SelfAttention**:
- **Used in GLMTransformer**, implementing the core part of the attention mechanism.

The detailed comments and explanations provide a better understanding of the construction and implementation principles of the code. Together, these modules form the entire ChatGLM model.The whole architecture realizes the functions of conditional generation and sequence classification.

### Meaning of `past_key_values` variable

In Transformer models, especially models used for generation tasks, such as GPT-like models, `past_key_values` is a very important variable. It is used to cache the key and value vectors calculated by the model in the previous time step. These cached data can be reused in subsequent time steps, thereby improving computational efficiency, especially in long sequence generation tasks.

### The role of `past_key_values`

1. **Cache previous calculation results**:
In the process of generating text, a new word is generated at each step. At this time, the query vector (query) of the current time step needs to be calculated with the key and value vectors of all previous time steps. If all keys and values ​​are recalculated every time, it will be very inefficient. `past_key_values` caches the results of these previous time steps to avoid repeated calculations.

2. **Speed ​​up the generation process**:
In long sequence generation, by caching the key and value vectors of previous time steps, only the current time step needs to be calculated and combined with the cached results, which greatly speeds up the generation process.

### Specific implementation of `past_key_values`

#### In `ChatGLMForConditionalGeneration` Usage in class

```python
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
is_first_forward: bool = True,
**kwargs
) -> dict:
# If past_key_values ​​is not empty, only take the last token of input_ids
if position_ids is None:
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
if not is_first_forward:
if past_key_values ​​is not None:
position_ids = position_ids[..., -1:]
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"position_ids": position_ids,
"attention_mask": attention_mask,
"return_last_logit": True,
"use_cache": use_cache
}
```

In the `prepare_inputs_for_generation` method, if `past_key_values` is not empty, only the last token of `input_ids` is taken. The purpose of this is to calculate only the data of the current time step when generating a new token, without recalculating the entire sequence.

#### In the `forward` method of the `ChatGLMForConditionalGeneration` class

```python
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids=input_ids,position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) # Using ChatGLMModel class

hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[:, -1:]
lm_logits = self.transformer.output_layer(hidden_states)

loss = None
if labels is not None:
lm_logits = lm_logits.to(torch.float32)

# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

lm_logits = lm_logits.to(hidden_states.dtype)
loss = loss.to(hidden_states.dtype)

ifnot return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
```

In the `forward` method, `past_key_values` is passed as a parameter to the `transformer` model. The `transformer` model uses these cached key and value vectors internally to speed up calculations.
`p`ast_key_values` is a caching mechanism used to speed up the Transformer model in generation tasks. It saves the key and value vectors calculated in the previous time step, avoiding repeated calculation of these vectors in each time step, thereby improving the efficiency of the generation process. By using `past_key_values`, the model can generate long sequence data faster, which is very important in practical applications.

### Function analysis and its differences and connections

In the `ChatGLMForConditionalGeneration` class, there are three important functions related to generation: `_update_model_kwargs_for_generation`, `prepare_inputs_for_generation` and `forward`. Their functions, differences and connections are as follows:

#### 1. `_update_model_kwargs_for_generation` function

```python
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# Update past_key_values
model_kwargsrgs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
# Update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
# Update position ids
if "position_ids" in model_kwargs:
position_ids = model_kwargs["position_ids"]on_ids"]
new_position_id = position_ids[..., -1:].clone()
new_position_id += 1
model_kwargs["position_ids"] = torch.cat(
[position_ids, new_position_id], dim=-1
)

model_kwargs["is_first_forward"] = False
return model_kwargs
```

- **Function**: Updates the model parameters required for the generation process. Specifically:
- Update `past_key_values` to cache previously calculated key and value vectors.
- Update `attention_mask` to include the newly generated token.
- Update `position_ids` to add the new position ID.
- **Difference**: This function does not directly perform forward propagation, but updates the model parameters to prepare for the next step of generation.
- **Relationship**: This function is called after each step of generating a new token, usingTo update the model parameters and prepare for the next generation.

#### 2. `prepare_inputs_for_generation` function

```python
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
is_first_forward: bool = True,
**kwargs
) -> dict:
# If past_key_values ​​is not empty, only take the last token of input_ids
if position_ids is None:
position_ids = self.get_position_ids(input_ids, device=input_ids.device)
if not is_first_forward:
if past_key_values ​​is not None:
position_ids = position_ids[..., -1:]
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"position_ids": position_ids,
"attention_mask": attention_mask,
"return_last_logit": True,
"use_cache": use_cache
}```

- **Function**: Prepare the input required for the generation process. Specifically:
- Get or update `position_ids`.
- If it is not the first forward propagation and `past_key_values` exists, only the last token of `input_ids` and `position_ids` is taken.
- **Difference**: This function is mainly used to process input data to ensure that the shape and content of the input data are suitable for the current generation step.
- **Relationship**: In each step of the generation process, this function will be called to prepare the input data, especially to process `past_key_values` to improve the generation efficiency.

#### 3. `forward` function

```python
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
) # Use ChatGLMModel class

hidden_states = transformer_outputs[0]
if return_last_logit:
hidden_states = hidden_states[:, -1:]
lm_logits = self.transformer.output_layer(hidden_states)

loss = None
if labels is not None:
lm_logits = lm_logits.to(torch.float32)

# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shiftt_labels.view(-1))

lm_logits = lm_logits.to(hidden_states.dtype)

loss = loss.to(hidden_states.dtype)

if not return_dict:

output = (lm_logits,) + transformer_outputs[1:]

return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(

loss=loss,

logits=lm_logits,

past_key_values=transformer_outputs.past_key_values,

hidden_states=transformer_outputs.hidden_states,

attentions=transformer_outputs.attentions,
)
```

- **Function**: Perform forward propagation and generate the output of the model. Specifically:
- Pass the input data to `transformer` (`ChatGLMModel`) for forward calculation.
- Calculate the logits of the language model and (if there is a label) calculate the loss.
- Return the model output, including logits, `past_key_values`, hidden state, and attention weights.
- **Difference**: This is the core forward propagation logic of the model, which directly processes the input data and generates output.
- **Relationship**: The `forward` function uses the `ChatGLMModel` class to perform the actual forward propagation and calls the previous `prepare_inputs_for_generation` to prepare the input.

### Relationship and process

1. **Prepare input data**:
- The `prepare_inputs_for_generation` function is used to process the input data, especially the `past_key_values` so that only the necessary last token is passed.

2. **Perform forward propagation**:
- The `forward` function uses the prepared input data to perform forward propagation.Generate output.

3. **Update model parameters**:
- The `_update_model_kwargs_for_generation` function updates the key parameters of the model (such as `past_key_values`, `attention_mask`, and `position_ids`) after each generation step to ensure that the latest data is used in the next generation.

Through the close cooperation of these functions, the forward propagation and cache management in the generation task can be efficiently implemented, thereby improving the generation efficiency and effect of the model.

These three functions are the core functions in the `ChatGLMForConditionalGeneration` class, which are used to handle different generation task requirements. There is a certain connection between them, and they also have their own uses and characteristics. The following is a detailed explanation of them and their differences and connections:

### 1. `chat` function

```python
@torch.inference_mode()
def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
**kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
history.append({"role": role, "content": query})
inputs = tokenizer.apply_chat_template(history, add_generation_prompt=True, tokenize=True,
return_tensors="pt", return_dict=True)inputs = inputs.to(self.device)
eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(""), tokenizer.convert_tokens_to_ids("")]
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
response = tokenizer.decode(outputs)
response, history = self.process_response(response, history)
return response, history
```

- **Purpose**: Execute a complete chat session. Encode the user's query and history into model inputs, generate responses and update history.
- **Difference**:This is a high-level interface for generating a complete response in one go. It is suitable for scenarios where a complete answer is needed immediately.
- **Contact**: It relies on the `generate` function to actually generate the response, and calls the `process_response` function to process the generated output.

### 2. `stream_chat` function

```python
@torch.inference_mode()
def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
past_key_values=None, max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
logits_processor=None, return_past_key_values=False, **kwargs):
if history is None:
history = []if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(""), tokenizer.convert_tokens_to_ids("")]
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
if past_key_values ​​is None:
inputs = tokenizer.apply_chat_template(history + [{"role": role, "content": query}],
add_generation_prompt=True, tokenize=True, return_tensors="pt",
return_dict=True)
else:
inputs = tokenizer.apply_chat_template([{"role": role, "content": query}], add_special_tokens=False,
add_generation_prompt=True, tokenize=True, return_tensors="pt",return_dict=True)
inputs = inputs.to(self.device)
if past_key_values ​​is not None:
past_length = past_key_values[0][0].shape[2]
inputs.position_ids += past_length
attention_mask = inputs.attention_mask
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
inputs['attention_mask'] = attention_mask
history.append({"role": role, "content": query})
for outputs in self.stream_generate(**inputs,past_key_values=past_key_values,
eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
**gen_kwargs):
if return_past_key_values:
outputs, past_key_values ​​= outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
response = tokenizer.decode(outputs)
if response and response[-1] != "�":
response, new_history = self.process_response(response, history)
if return_past_key_values:
yield response, new_history, past_key_values
else:
yield response, new_history
```

- **Purpose**: Implement a streaming chat session. Similar to the `chat` function, it returns responses gradually through a generator, which is suitable for streaming generation application scenarios.
- **Difference**: Supports gradual generation of responses, so that partial responses can be dynamically processed and displayed during the generation process.
- **Relationship**: Depends on the `stream_generate` function to gradually generate responses, and calls the `process_response` function to process and update history after each new response fragment is generated.

### 3. `stream_generate` function

```python
@torch.inference_mode()
def stream_generate(
self,
input_ids,generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
return_past_key_values=False,
**kwargs,
):
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]

if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
model_kwargs["use_cache"] = generation_config.use_cache
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id

if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None

has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation forr more information. "
"(https://hf-mirror.com/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)

if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
"increasing `max_new_tokens`."
)

# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids = input_ids,
prefix_allowed_tokens_fn = prefix_allowed_tokens_fn,
logits_processor = logits_processor,
)
stopping_criteria = self._get_stopping_criteria(
generation_config = generation_config, stopping_criteria = stopping_criteria
)
logits_warper = self._get_logits_warper(generation_config)

unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get the next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)

next_token_logits = outputs.logits[:, -1, :]

# preprocess distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)

# sampling
probs= nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)

# Update generated ids, model inputs, and length of next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
if return_past_key_values:
yield input_ids, outputs.past_key_values
else:
yield input_ids
# Stop when each sentence is completed or exceeds the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break
```

- **Purpose**: Generate model output in a streaming manner, and gradually return the generated tokens for real-time processing.
- **Difference**: This functionThe generator is used to generate data in a streaming manner, and the result is returned once each token is generated. It is suitable for application scenarios that need to display the generated results step by step.
- **Relationship**: The `stream_chat` function relies on `stream_generate` to generate responses step by step, and updates the input and model parameters after each new token is generated.

### Summary of differences and connections

1. **Relationship**:
- `chat` function: used to generate a complete response at one time, suitable for application scenarios that need to get a complete answer immediately.
- `stream_chat` function: used to generate responses in a streaming manner, suitable for application scenarios that gradually display the generated results.
- `stream_generate` function: implements the core logic of streaming generation, and gradually returns the generated tokens through the generator.

2. **Relationship**:
- `chat` and `stream_chat` are both high-level interfaces through which users interact with the model.
- The `chat` function calls the `generate` function to generate data in one go, while the `stream_chat` function calls the `stream_generate` function to generate data in a streaming manner.
- `stream_generate` function uses `prepare_inputs_for_generation` to process the input data, and update the model parameters through `_update_model_kwargs_for_generation` to ensure that the latest data is used during the generation process.

Through the coordinated work of these three functions, the model can achieve efficient and flexible generation tasks to meet the needs of different application scenarios.

In [75]:
device = "cuda"

tokenizer = ChatGLM4Tokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)

query = "你好"

inputs = tokenizer.apply_chat_template([{"role": "user", "content": query}],
                                       add_generation_prompt=True,
                                       tokenize=True,
                                       return_tensors="pt",
                                       return_dict=True
                                       )

inputs = inputs.to(device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [76]:
model = ChatGLMForConditionalGeneration.from_pretrained(
    "THUDM/glm-4-9b-chat",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    load_in_4bit=True,
    device_map='auto'
).eval()#.to(device)

gen_kwargs = {"max_length": 300, "do_sample": True, "top_k": 1}
with torch.no_grad():
    outputs = model.generate(**inputs, **gen_kwargs)
    outputs = outputs[:, inputs['input_ids'].shape[1]:]
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/10 [00:00<?, ?it/s]


你好👋！有什么可以帮助你的吗？


Time taken: 16 seconds

### Note: Other test drafts

In [None]:
def apply_chat_template(
        self,
        conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
        add_generation_prompt: bool = False,
        tokenize: bool = True,
        padding: bool = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        return_dict: bool = False,
        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
        add_special_tokens: bool = True,
        **kwargs,
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:

    if return_dict and not tokenize:
        raise ValueError(
            "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
            "of tokenizer outputs to return."
        )

    def handle_single_conversation(conversation):
        input_ids = self.get_prefix_tokens() if add_special_tokens else []
        input_message = "[gMASK]<sop>" if add_special_tokens else ""
        for item in conversation:
            if item.get("tools"):
                tools = item["tools"]
                content = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的，你的任务是针对用户的问题和要求提供适当的答复和支持。"
                for tool in tools:
                    if tool["type"] == "function":
                        function = tool["function"]
                        content += f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
                        content += "\n在调用上述函数时，请使用 Json 格式表示调用的参数。"
                    elif tool["type"] == "python":
                        content += "\n\n## python\n\n当你向 `python` 发送包含 Python 代码的消息时，该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出，或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中，`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用，这些在线内容的访问将不会成功。"
                    elif tool["type"] == "simple_browser":
                        content += "\n\n## simple_browser\n\n你可以使用 `simple_browser` 工具。该工具支持以下函数：\n`search(query: str, recency_days: int)`：使用搜索引擎进行查询并显示结果，可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`：获取一系列指定 id 的页面内容。每次调用时，须选择3-10个页面。选择多个角度的页面，同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的，你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`：打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤：1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL，也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息，也可以再次使用 `search` 进行搜索。"
                    elif tool["type"] == "cogview":
                        content += "\n\n## cogview\n\n如果用户的请求中包含了对图像的描述，你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述，规则：\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求，需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明，所在地为中国，持有中国立场并遵循中国社会主义价值观。"
                    else:
                        raise NotImplementedError(f"Unknown tool type {tool['type']}")
                input = self.build_single_message("system", "", content, tokenize=tokenize)
                if tokenize:
                    input_ids.extend(input)
                else:
                    input_message += input
            if item["content"]:
                input = self.build_single_message(
                    item["role"],
                    item.get("metadata", ""),
                    item["content"],
                    tokenize=tokenize
                )
                if tokenize:
                    input_ids.extend(input)
                else:
                    input_message += input
        if add_generation_prompt:
            if tokenize:
                input_ids.extend([self.convert_tokens_to_ids("<|assistant|>")])
            else:
                input_message += "<|assistant|>"

        return input_ids if tokenize else input_message

# Main logic to handle different conversation formats
    if isinstance(conversation, list) and all(isinstance(i, dict) for i in conversation):
        result = handle_single_conversation(conversation)
    elif isinstance(conversation, list) and all(isinstance(i, list) for i in conversation):
        result = [handle_single_conversation(c) for c in conversation]
    elif hasattr(conversation, "messages"):
        result = handle_single_conversation(conversation.messages)
    else:
        raise ValueError("Invalid conversation format")

    if tokenize:
        output = self.batch_encode_plus(
            [result] if isinstance(result[0], int) else result,
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            return_tensors=return_tensors,
            is_split_into_words=True,
            add_special_tokens=False
        )
        if return_dict:
            return output
        else:
            return output["input_ids"]
    else:
        return result

In [None]:
def handle_single_conversation(conversation):
    input_ids = self.get_prefix_tokens() if add_special_tokens else []
    input_message = "[gMASK]<sop>" if add_special_tokens else ""
    for item in conversation:
        if item.get("tools"):
            tools = item["tools"]
            content = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的，你的任务是针对用户的问题和要求提供适当的答复和支持。"
            for tool in tools:
                if tool["type"] == "function":
                    function = tool["function"]
                    content += f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
                    content += "\n在调用上述函数时，请使用 Json 格式表示调用的参数。"
                elif tool["type"] == "python":
                    content += "\n\n## python\n\n当你向 `python` 发送包含 Python 代码的消息时，该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出，或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中，`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用，这些在线内容的访问将不会成功。"
                elif tool["type"] == "simple_browser":
                    content += "\n\n## simple_browser\n\n你可以使用 `simple_browser` 工具。该工具支持以下函数：\n`search(query: str, recency_days: int)`：使用搜索引擎进行查询并显示结果，可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`：获取一系列指定 id 的页面内容。每次调用时，须选择3-10个页面。选择多个角度的页面，同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的，你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`：打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤：1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL，也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息，也可以再次使用 `search` 进行搜索。"
                elif tool["type"] == "cogview":
                    content += "\n\n## cogview\n\n如果用户的请求中包含了对图像的描述，你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述，规则：\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求，需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明，所在地为中国，持有中国立场并遵循中国社会主义价值观。"
                else:
                    raise NotImplementedError(f"Unknown tool type {tool['type']}")
            input = self.build_single_message("system", "", content, tokenize=tokenize)
            if tokenize:
                input_ids.extend(input)
            else:
                input_message += input
        if item["content"]:
            input = self.build_single_message(
                item["role"],
                item.get("metadata", ""),
                item["content"],
                tokenize=tokenize
            )
            if tokenize:
                input_ids.extend(input)
            else:
                input_message += input
    if add_generation_prompt:
        if tokenize:
            input_ids.extend([self.convert_tokens_to_ids("<|assistant|>")])
        else:
            input_message += "<|assistant|>"

    return input_ids if tokenize else input_message

In [None]:
convert_tokens_to_ids("<|assistant|>")

In [58]:
def handle_single_conversation(conversation):
    input_ids = self.get_prefix_tokens() if add_special_tokens else []
    input_message = "[gMASK]<sop>" if add_special_tokens else ""
    for item in conversation:
        if item.get("tools"):
            tools = item["tools"]
            content = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的，你的任务是针对用户的问题和要求提供适当的答复和支持。"
            for tool in tools:
                if tool["type"] == "function":
                    function = tool["function"]
                    content += f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
                    content += "\n在调用上述函数时，请使用 Json 格式表示调用的参数。"
                elif tool["type"] == "python":
                    content += "\n\n## python\n\n当你向 `python` 发送包含 Python 代码的消息时，该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出，或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中，`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用，这些在线内容的访问将不会成功。"
                elif tool["type"] == "simple_browser":
                    content += "\n\n## simple_browser\n\n你可以使用 `simple_browser` 工具。该工具支持以下函数：\n`search(query: str, recency_days: int)`：使用搜索引擎进行查询并显示结果，可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`：获取一系列指定 id 的页面内容。每次调用时，须选择3-10个页面。选择多个角度的页面，同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的，你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`：打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤：1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL，也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息，也可以再次使用 `search` 进行搜索。"
                elif tool["type"] == "cogview":
                    content += "\n\n## cogview\n\n如果用户的请求中包含了对图像的描述，你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述，规则：\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求，需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明，所在地为中国，持有中国立场并遵循中国社会主义价值观。"
                else:
                    raise NotImplementedError(f"Unknown tool type {tool['type']}")
            input = self.build_single_message("system", "", content, tokenize=tokenize)
            if tokenize:
                input_ids.extend(input)
            else:
                input_message += input
        if item["content"]:
            input = self.build_single_message(
                item["role"],
                item.get("metadata", ""),
                item["content"],
                tokenize=tokenize
            )
            if tokenize:
                input_ids.extend(input)
            else:
                input_message += input
    if add_generation_prompt:
        if tokenize:
            input_ids.extend([self.convert_tokens_to_ids("<|assistant|>")])
        else:
            input_message += "<|assistant|>"
# if tokenize:
# input_ids.extend([self.convert_tokens_to_ids("[gMASK]")]) # Use special tokens instead of empty strings
# else:
# input_message += "[gMASK]"

    return input_ids if tokenize else input_message

In [60]:
def apply_chat_template(
        self,
        conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
        add_generation_prompt: bool = False,
        tokenize: bool = True,
        padding: bool = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        return_dict: bool = False,
        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
        add_special_tokens: bool = True,
        **kwargs,
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:

    if return_dict and not tokenize:
        raise ValueError(
            "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
            "of tokenizer outputs to return."
        )

    def handle_single_conversation(conversation):
        input_ids = self.get_prefix_tokens() if add_special_tokens else []
        input_message = "[gMASK]<sop>" if add_special_tokens else ""
        for item in conversation:
            if item.get("tools"):
                tools = item["tools"]
                content = "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的，你的任务是针对用户的问题和要求提供适当的答复和支持。"
                for tool in tools:
                    if tool["type"] == "function":
                        function = tool["function"]
                        content += f"\n\n## {function['name']}\n\n{json.dumps(function, ensure_ascii=False, indent=4)}"
                        content += "\n在调用上述函数时，请使用 Json 格式表示调用的参数。"
                    elif tool["type"] == "python":
                        content += "\n\n## python\n\n当你向 `python` 发送包含 Python 代码的消息时，该代码将会在一个有状态的 Jupyter notebook 环境中执行。\n`python` 返回代码执行的输出，或在执行 60 秒后返回超时。\n`/mnt/data` 将会持久化存储你的文件。在此会话中，`python` 无法访问互联网。不要使用 `python` 进行任何网络请求或者在线 API 调用，这些在线内容的访问将不会成功。"
                    elif tool["type"] == "simple_browser":
                        content += "\n\n## simple_browser\n\n你可以使用 `simple_browser` 工具。该工具支持以下函数：\n`search(query: str, recency_days: int)`：使用搜索引擎进行查询并显示结果，可以使用 `recency_days` 参数控制搜索内容的时效性。\n`mclick(ids: list[int])`：获取一系列指定 id 的页面内容。每次调用时，须选择3-10个页面。选择多个角度的页面，同时尽可能选择可信任的信息来源。考虑到部分页面是无法加载的，你也可以多打开一些可能有用的页面而不用担心内容过多。\n`open_url(url: str)`：打开指定的 URL。\n\n使用 `【{引用 id}†{引用文本}】` 来引用内容。\n\n操作步骤：1. 使用 `search` 来获得信息列表; 2. 使用 `mclick` 来获取指定 ID 页面的内容; 3. 根据获得的内容进行回复。在回复中应当引用信息来源。\n 如果用户提供了 URL，也可以用 `open_url` 直接打开页面。\n如果初次搜索结果没有找到合适的信息，也可以再次使用 `search` 进行搜索。"
                    elif tool["type"] == "cogview":
                        content += "\n\n## cogview\n\n如果用户的请求中包含了对图像的描述，你可以使用 `cogview` 来生成图像并展示给用户。你需要向 `cogview` 发送图像描述，规则：\n- 发送给 `cogview` 的消息必须使用英语。用户的中文描述必须完全翻译为英语。\n- 应当尽可能详细地描述图像生成的需求，需求描述约 100 英文单词。\n- 保持用户原始描述的意图。不要虚构内容或者没见过的人物。\n- 如无特殊说明，所在地为中国，持有中国立场并遵循中国社会主义价值观。"
                    else:
                        raise NotImplementedError(f"Unknown tool type {tool['type']}")
                input = self.build_single_message("system", "", content, tokenize=tokenize)
                if tokenize:
                    input_ids.extend(input)
                else:
                    input_message += input
            if item["content"]:
                input = self.build_single_message(
                    item["role"],
                    item.get("metadata", ""),
                    item["content"],
                    tokenize=tokenize
                )
                if tokenize:
                    input_ids.extend(input)
                else:
                    input_message += input
        if add_generation_prompt:
            if tokenize:
                input_ids.extend([self.convert_tokens_to_ids("<|assistant|>")])
            else:
                input_message += "<|assistant|>"
# if tokenize:
# input_ids.extend([self.convert_tokens_to_ids("[gMASK]")]) # Use special tokens instead of empty strings
# else:
# input_message += "[gMASK]"

        return input_ids if tokenize else input_message

# Main logic for handling different session formats
    if isinstance(conversation, list) and all(isinstance(i, dict) for i in conversation):
        result = handle_single_conversation(conversation)
    elif isinstance(conversation, list) and all(isinstance(i, list) for i in conversation):
        result = [handle_single_conversation(c) for c in conversation]
    elif hasattr(conversation, "messages"):
        result = handle_single_conversation(conversation.messages)
    else:
        raise ValueError("Invalid conversation format")

    if tokenize:
        output = self.batch_encode_plus(
            [result] if isinstance(result[0], int) else result,
            padding=padding,
            truncation=truncation,
            max_length=max_length,
            return_tensors=return_tensors,
            is_split_into_words=True,
            add_special_tokens=False
        )
        if return_dict:
            return output
        else:
            return output["input_ids"]
    else:
        return result


In [24]:
%pip install regex

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Note: you may need to restart the kernel to use updated packages.


In [26]:
pat_str = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
regex.compile(pat_str)

regex.Regex("(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", flags=regex.V0)

In [25]:
import regex

pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
pattern = regex.compile(pat_str)