In [1]:
import numpy as np
import math, json, time, types, copy, sys, os
import torch
from torch.nn import functional as F
import torch.nn as nn

from transformers import PreTrainedTokenizerFast

np.set_printoptions(precision=4, suppress=True, linewidth=200)

In [2]:
RUN_DEVICE = 'cpu'
ctx_len = 768
n_layer = 24
n_embd = 1024

MODEL_NAME = '/data1/ckw/20220615-10803' #修改为自己的模型路径

vocab_size = 50277
VOCAB_NAME = '20B_tokenizer.json'

print(f'\n* running on {RUN_DEVICE}')


* running on cpu


### What is RWKV?

RWKV, short for Receptance Weighted Key Value, is a new neural network architecture that combines the advantages of RNN (recurrent neural network) and Transformer. It is designed to solve the memory and computational complexity problems of Transformer when processing long sequences, while retaining the computational efficiency of RNN in the inference phase. RWKV utilizes a linear attention mechanism, which can be formalized as a Transformer or RNN, thereby achieving parallel computation during training and maintaining constant computational and memory complexity during inference.

RWKV's ChannelMix implementation combines time mixing and channel mixing operations. Here is a detailed explanation of the code and its corresponding formula:

1. **Time Mixing**:
Time mixing is implemented through the `time_mix` parameter and the `time_shift` operation. The purpose of this step is to combine the input of the current time step with the input of the previous time step.

Formula:

\begin{align*}
x' = x \cdot \text{time\_mix} + \text{time\_shift}(x) \cdot (1 - \text{time\_mix})
\end{align*}

Where `time_shift` operation is a time step shift operation, and `time_mix` is a trainable parameter.

3. **Key Generation**:
Use a linear layer `self.key` to transform the input `x'` into a key `k`, and then apply the ReLU activation function and the square operation.

Formula:
\begin{align*}
k = \text{ReLU}(\text{key}(x'))^2
\end{align*}

4. **Value Generation**:
Input the key `k` to the value linear layer `self.value` to generate the value `kv`.

Formula:
\begin{align*}
kv = \text{value}(k)
\end{align*}

5. **Receptance Function**:
Use a linear layer `self.receptance` to calculate the reception function `r`, and then apply the Sigmoid activation function.

Formula:
\begin{align*}
r = \sigma(\text{receptance}(x'))
\end{align*}

6. **Final Output**:
Multiply the reception function `r` with the value `kv` to generate the final output `rkv`.

Formula:
\begin{align*}
rkv = r \cdot kv
\end{align*}

Combining these steps, the entire ChannelMix calculation process can be expressed by the following formula:

\begin{align*}
x' & = x \cdot \text{time\_mix} + \text{time\_shift}(x) \cdot (1 - \text{time\_mix}) \\
k & = \text{ReLU}(\text{key}(x'))^2 \\
kv & = \text{value}(k) \\
r & = \sigma(\text{receptance}(x')) \\
\text{output} & = r \cdot kv
\end{align*}

The above formula explains the details of ChannelMix implementation of RWKV.

In [3]:
class RWKV_ChannelMix(nn.Module):
    def __init__(self, layer_id):
        super().__init__()
        self.layer_id = layer_id

        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        self.time_mix = nn.Parameter(torch.ones(1, 1, n_embd))

        hidden_sz = 4 * n_embd
        self.key = nn.Linear(n_embd, hidden_sz, bias=False)
        self.receptance = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(hidden_sz, n_embd, bias=False)

    def forward(self, x):
        x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)

        k = self.key(x)
        k = torch.square(torch.relu(k))
        kv = self.value(k)
        
        rkv = torch.sigmoid(self.receptance(x)) * kv
        return rkv

In the implementation of RWKV, `RWKV_TimeMix` processes the input data through time mixing. The following is the specific implementation and the corresponding formula description:

1. **Time Mixing**:
Time mixing is achieved through the `time_mix` parameter and the `time_shift` operation. The purpose of this step is to combine the input of the current time step with the input of the previous time step.

Formula:
\begin{align*}
x' = x \cdot \text{time\_mix} + \text{time\_shift}(x) \cdot (1 - \text{time\_mix})
\end{align*}

2. **Key Generation**:
Use a linear layer `self.key` to convert the input `x'` into a key `k`, and then transpose it.

Formula:
\begin{align*}
k = \text{key}(x')^T
\end{align*}

3. **Value generation**:
Use a linear layer `self.value` to convert the input `x'` into a value `v`, and then transpose it.Formula:
\begin{align*}
v = \text{value}(x')^T
\end{align*}

4. **Receptance Function**:
Use a linear layer `self.receptance` to calculate the reception function `r`.

Formula:
\begin{align*}
r = \text{receptance}(x')
\end{align*}

5. **Key-value multiplication**:
Multiply the key `k` and the value `v` to get `kv`.

Formula:
\begin{align*}
kv = k \cdot v
\end{align*}

6. **Time weight calculation**:
Calculate the time weight `w`, where `self.time_w` is calculated by `time_decay` and `time_curve`.

Formula:
\begin{align*}
\text{self.time\_w} &= \exp(\text{time\_decay}) \cdot \text{time\_curve} \\
w &= \exp(\text{self.time\_w})
\end{align*}

7. **Convolution operation**:
Use one-dimensional convolution to calculate weighted key and weighted key.

Formula:
\begin{align*}
wkv &= \text{conv1d}(\text{ZeroPad2d}(kv), w, \text{groups}=C) \\
wk &= \text{conv1d}(\text{ZeroPad2d}(k), w, \text{groups}=C) + 1e-9
\end{align*}

8. **Final output**:
Multiply the receiving function `r` with the weighted key ratio `wkv / wk`, and pass it through the output linear layer to get the final output `rwkv`.

Formula:
\begin{align*}
rwkv &= \sigma(r) \cdot \left( \frac{wkv}{wk} \right)^T \\
rwkv &= \text{output}(rwkv)
\end{align*}

Combining these steps, `RThe entire calculation process of WKV_TimeMix` can be expressed as:

\begin{align*}
x' &= x \cdot \text{time\_mix} + \text{time\_shift}(x) \cdot (1 - \text{time\_mix}) \\
k &= \text{key}(x')^T \\
v &= \text{value}(x')^T \\
r &= \text{receptance}(x') \\
kv &= k \cdot v \\
\text{self.time\_w} &= \exp(\text{time\_decay}) \cdot \text{time\_curve} \\
w &= \exp(\text{self.time\_w}) \\
wkv &= \text{conv1d}(\text{ZeroPad2d}(kv), w, \text{groups}=C) \\
wk &= \text{conv1d}(\text{ZeroPad2d}(k), w, \text{groups}=C) + 1e-9 \\
rwkv &= \sigma(r) \cdot \left( \frac{wkv}{wk} \right)^T \\
rwkv &= \text{output}(rwkv)
\end{align*}

In [None]:
class RWKV_TimeMix(nn.Module):
    def __init__(self, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.time_decay = nn.Parameter(torch.ones(n_embd, 1))
        self.time_curve = torch.tensor([-(ctx_len - 2 - i) for i in range(ctx_len-1)]).unsqueeze(0)
        self.time_first = nn.Parameter(torch.ones(n_embd, 1) * math.log(0.3))
        
        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        self.time_mix = nn.Parameter(torch.ones(1,1,n_embd))

        self.key = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(n_embd, n_embd, bias=False)
        self.receptance = nn.Linear(n_embd, n_embd, bias=False)

        self.output = nn.Linear(n_embd, n_embd, bias=False)

    def forward(self, x):
        B, T, C = x.size()

        x = x * self.time_mix + self.time_shift(x) * (1 - self.time_mix)

        k = self.key(x).transpose(-1, -2)
        v = self.value(x).transpose(-1, -2)
        r = self.receptance(x)

        k = torch.clamp(k, max=60)
        k = torch.exp(k)

        kv = k * v

        self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
        w = torch.exp(self.time_w)
        
        w = w[:,-T:].unsqueeze(1)
        wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
        wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + 1e-9

        rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
        
        rwkv = self.output(rwkv)
        return rwkv

### RWKV Block

RWKV Block is a basic module that combines TimeMix and ChannelMix operations. Each module in Block (TimeMix and ChannelMix) processes input data through normalization and residual connection to enhance the stability and performance of the model.

### Main components and operations

1. **LayerNorm**: used to normalize input and enhance the stability of training.

- `self.ln1` and `self.ln2` normalize the input before TimeMix and ChannelMix, respectively.

2. **TimeMix**: Combine the information of the current time step and the previous time step to capture the time dependency.

- `self.att = RWKV_TimeMix(layer_id)` initializes the TimeMix module.

3. **ChannelMix**: Mix between different channels to enhance the expressiveness of the model.
- `self.ffn = RWKV_ChannelMix(layer_id)` initializes the channel mixing module.

4. **Residual connection**: By adding the output of the mixing operation back to the original input, it maintains the information flow and enhances the gradient propagation ability of the model.

Through this settingRWKV's Block can efficiently process sequence data, combine time and channel information, and improve the performance of the model.

In [4]:
class Block(nn.Module):
    def __init__(self, layer_id):
        super().__init__()
        self.layer_id = layer_id  # 存储当前层的ID

# Define two LayerNorm layers to normalize the input
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        
# Define time mixing and channel mixing modules
        self.att = RWKV_TimeMix(layer_id)
        self.ffn = RWKV_ChannelMix(layer_id)

    def forward(self, x):
# First, perform LayerNorm normalization on the input
        x = self.ln1(x)
        
# Perform the time blending operation and add the result back to the input via the residual connection
        x = x + self.att(x)
        
# Normalize the input again using LayerNorm
        x = self.ln2(x)
        
# Perform channel mixing and add the result back to the input through the residual connection
        x = x + self.ffn(x)
        
# Return the final output
        return x


Next, the main parts of the RWKV model are implemented:

1. **Model loading and preprocessing**: The code loads the model weights and preprocesses the time-related weights.
2. **LayerNorm**: Layer normalization is implemented in the `LN` method, about the use of LayerNorm.
3. **Feedforward network (FF) and self-attention (SA)**: The `FF` method implements the calculation of the feedforward network, and the `SA` method implements the calculation of the self-attention mechanism. These two parts correspond to the detailed calculation of TimeMix and ChannelMix.
4. **Running the model**: The `run` method implements the overall operation logic of the model, passes through each layer in turn, and finally outputs the results. That is, the operation and reasoning process of the model.

In [5]:
time_buf = {}  # 用于缓存时间相关信息的全局字典

class RWKV_RNN():
    def __init__(self, MODEL_NAME=MODEL_NAME):
        print('\nloading RWKV-RNN', MODEL_NAME)
        self.ctx_len = ctx_len  # 上下文长度
        self.n_layer = n_layer  # 网络层数
        self.n_embd = n_embd    # 嵌入维度
        self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)  # 初始化分词器

        self.w = types.SimpleNamespace()  # 用于存储模型权重的命名空间
        
# Load model weight file
        w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))

# Handling time-dependent weights
        for x in w.keys():
            if '.time_' in x:
                w[x] = w[x].squeeze()  # 压缩维度
            if '.time_decay' in x:
                w[x] = torch.exp(-torch.exp(w[x]))  # 对时间衰减进行双重指数运算
            if '.time_first' in x:
                w[x] = torch.exp(w[x])  # 对时间初始值进行指数运算
                    
# Store weights in namespace
            xx = x.split('.')
            here = self.w
            for i in range(len(xx)):
                if xx[i].isdigit():
                    ii = int(xx[i])
                    if ii not in here:
                        here[ii] = types.SimpleNamespace()  # 初始化命名空间
                    here = here[ii]
                else:
                    if i == len(xx) - 1:
                        setattr(here, xx[i], w[x])
                    elif not hasattr(here, xx[i]):
                        if xx[i+1].isdigit():
                            setattr(here, xx[i], {})
                        else:
                            setattr(here, xx[i], types.SimpleNamespace())
                    here = getattr(here, xx[i])
    
        self.clear()  # 初始化缓存
    
    def clear(self):
        self.xx = {}  # 清空缓存
        self.aa = {}
        self.bb = {}
    
    def save(self, target):
# Deep copy the current state to the target
        target.xx = copy.deepcopy(self.xx)
        target.aa = copy.deepcopy(self.aa)
        target.bb = copy.deepcopy(self.bb)
    
    def load(self, target):
# Deep copy state from target to current instance
        self.xx = copy.deepcopy(target.xx)
        self.aa = copy.deepcopy(target.aa)
        self.bb = copy.deepcopy(target.bb)

    def LN(self, xx, w):
# Perform LayerNorm normalization
        return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)

    def FF(self, xx, w, name):
# Feedforward network calculation
        if name not in self.xx:
            self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
        x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)  # 混合当前输入和缓存

        self.xx[name] = xx  # 更新缓存

        r = torch.sigmoid(w.receptance.weight @ x)  # 计算接收向量
        k = torch.square(torch.relu(w.key.weight @ x))  # 计算键向量
        kv = w.value.weight @ k  # 计算值向量

        return r * kv  # 返回前馈网络输出

    def SA(self, xx, w, name):
# Self-attention calculation
        if name not in self.xx:
            self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
            self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
            self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)
        x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix)  # 混合当前输入和缓存
        self.xx[name] = xx  # 更新缓存

        r = torch.sigmoid(w.receptance.weight @ x)  # 计算接收向量

        k = torch.exp(torch.clamp(w.key.weight @ x, max=60))  # 计算键向量
        v = w.value.weight @ x  # 计算值向量
        kv = k * v  # 计算键值对

        a = self.aa[name] + w.time_first * kv  # 计算新的a值
        b = self.bb[name] + w.time_first * k  # 计算新的b值
        self.aa[name] = w.time_decay * self.aa[name] + kv  # 更新缓存中的a值
        self.bb[name] = w.time_decay * self.bb[name] + k  # 更新缓存中的b值

        rwkv = r * a / (b + 1e-9)  # 计算自注意力输出

        return w.output.weight @ rwkv  # 返回自注意力输出

    def run(self, ctx):
# Run the model
        w = self.w
        x = w.emb.weight[ctx[-1]]  # 获取当前token的嵌入

# Go through each layer in turn
        for i in range(n_layer):
            x = self.LN(x, w.blocks[i].ln1)  # 归一化
            x = x + self.SA(x, w.blocks[i].att, f'att.{i}')  # 自注意力计算并残差连接
            x = self.LN(x, w.blocks[i].ln2)  # 归一化
            x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}')  # 前馈网络计算并残差连接

        x = self.LN(x, w.ln_out)  # 最后一层归一化

        x = w.head.weight @ x  # 计算输出
        x = x.tolist()  # 转换为列表

        return x  # 返回最终结果

In [6]:
print('''
******************************************************************************
* This is a preview of RWKV-v2-RNN trained on the Pile for only 50B tokens.
* It is NOT indicative of the final performance (which requires 300B tokens).
******************************************************************************''')


******************************************************************************
* This is a preview of RWKV-v2-RNN trained on the Pile for only 50B tokens.
* It is NOT indicative of the final performance (which requires 300B tokens).
******************************************************************************


In [7]:
# Edit model.py to set CPU / CUDA mode. Runs on CPU by default.

TEMPERATURE = 1.0
TOP_P = 0.7

DEBUG_DEBUG = False
LENGTH_OF_EACH = 333
NUM_TRIALS = 3

context = '\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'

##############################################################################################################

In [8]:
model = RWKV_RNN()


loading RWKV-RNN /data1/ckw/20220615-10803


Next, we sample from the given output logits to generate a new token. It implements **temperature adjustment sampling** and **core sampling (Top-p sampling)**. The specific steps are as follows:

1. **Softmax conversion**: Convert the logits output by the model to a probability distribution through the softmax function.
2. **Sorting and cumulative probability calculation**: Sort the probabilities from high to low and calculate the cumulative probability distribution.
3. **Core sampling**:
- Calculate the minimum value of the cumulative probability exceeding `top_p` and determine the cutoff value `cutoff`.
- Set all probabilities below the cutoff value to 0, thereby retaining the most important `top_p` part of the probability.
4. **Temperature adjustment**: If `temperature` is not 1, adjust the probability distribution to make it smoother or sharper.
5. **Sampling**: Sample a value from the adjusted probability distribution and return the corresponding index.

This method is particularly commonly used in text generation tasks. By adjusting the `temperature` and `top_p` parameters, the diversity and quality of the generated text can be controlled.

In [9]:
def sample_logits(out, temperature=1.0, top_p=None):
# Convert the output into a probability distribution (via the softmax function)
    probs = F.softmax(torch.tensor(out), dim=-1)
    
# Sort by probability from high to low
    sorted_probs, _ = torch.sort(probs, descending=True)

# Calculate the cumulative probability distribution
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
    
# Calculate the cutoff value (cutoff) based on the cumulative probability and top_p
    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
    
# Set the probability below the cutoff value to 0
    probs[probs < cutoff] = 0

# If temperature is not equal to 1, the probability is adjusted by temperature
    if temperature != 1.0:
        probs = probs.pow(1.0 / temperature)

# Sample a value from the adjusted probability distribution and return it
    return torch.multinomial(probs, num_samples=1)[0]


In [10]:
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
    ctx = [model.tokenizer.encode(context)][0]
    src_len = len(ctx)
    print(context, end='')

    model.clear()
    if TRIAL == 0:
        init_state = types.SimpleNamespace()
        for i in range(src_len if DEBUG_DEBUG else src_len):
            x = ctx[:i+1]
            if i == src_len - 1:
                init_state.out = model.run(x)
            else:
                model.run(x)
        model.save(init_state)
    else:
        model.load(init_state)

    if DEBUG_DEBUG:
        out = init_state.out
        print('\n', np.array(x), '==>', np.array(
            out), np.max(out), np.min(out))

    for i in range(src_len, src_len + (0 if DEBUG_DEBUG else LENGTH_OF_EACH)):
        x = ctx[:i+1]
        x = x[-model.ctx_len:]

        if i == src_len:
            out = copy.deepcopy(init_state.out)
        else:
            out = model.run(x)

        out[0] = -999999999  # disable <|endoftext|>

        char = sample_logits(out, temperature=TEMPERATURE, top_p=TOP_P)
        char = char.item()
        print(model.tokenizer.decode(char), end='', flush=True)

        ctx += [char]
    print('\n' + '-' * 70, end='')


DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. We want to change the way students learn artificial intelligence. We are very committed to the idea of artificial intelligence. We have already taken part in a joint research project with the European Research Council. This will bring us closer to our goal. We hope that this research project will give us the opportunity to develop a framework for a data-driven approach in artificial intelligence.

I.C. Pfeifer

Research Fellow

Our work in data science and machine learning has a strong connection with the Human Brain Project, an initiative of the University of California at Berkeley. The focus of our work is on the research of language and language learning, with an emphasis on language acquisition. Our research is directed at the development of technology to improve the language acquisition skills of students and their parents.

The Joint Data-Science-Technolo