In [4]:
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
import torch.nn as nn
from torch.nn import functional as F

MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method

rwkv5 is also called eagal

In [5]:
import torch

In [6]:
class RWKV_TOKENIZER():
    table: list[list[list[bytes]]]
    good: list[set[int]]
    wlen: list[int]
    def __init__(self, file_name):
        self.idx2token = {}
        sorted = [] # must be already sorted
        lines = open(file_name, "r", encoding="utf-8").readlines()
        for l in lines:
            idx = int(l[:l.index(' ')])
            x = eval(l[l.index(' '):l.rindex(' ')])
            x = x.encode("utf-8") if isinstance(x, str) else x
            assert isinstance(x, bytes)
            assert len(x) == int(l[l.rindex(' '):])
            sorted += [x]
            self.idx2token[idx] = x

        self.token2idx = {}
        for k, v in self.idx2token.items():
            self.token2idx[v] = int(k)

# precompute some tables for fast matching
        self.table = [[[] for j in range(256)] for i in range(256)]
        self.good = [set() for i in range(256)]
        self.wlen = [0 for i in range(256)]

        for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
            s = sorted[i]
            if len(s) >= 2:
                s0 = int(s[0])
                s1 = int(s[1])
                self.table[s0][s1] += [s]
                self.wlen[s0] = max(self.wlen[s0], len(s))
                self.good[s0].add(s1)

    def encodeBytes(self, src: bytes) -> list[int]:
        src_len: int = len(src)
        tokens: list[int] = []
        i: int = 0
        while i < src_len:
            s: bytes = src[i : i + 1]

            if i < src_len - 1:
                s1: int = int(src[i + 1])
                s0: int = int(src[i])
                if s1 in self.good[s0]:
                    sss: bytes = src[i : i + self.wlen[s0]]
                    try:
                        s = next(filter(sss.startswith, self.table[s0][s1]))
                    except:
                        pass
            tokens.append(self.token2idx[s])
            i += len(s)

        return tokens

    def decodeBytes(self, tokens):
        return b''.join(map(lambda i: self.idx2token[i], tokens))

    def encode(self, src: str):
        return self.encodeBytes(src.encode("utf-8"))

    def decode(self, tokens):
        return self.decodeBytes(tokens).decode('utf-8')

    def printTokens(self, tokens):
        for i in tokens:
            s = self.idx2token[i]
            try:
                s = s.decode('utf-8')
            except:
                pass
            print(f'{repr(s)}{i}', end=' ')
# print(repr(s), i)
        print()

########################################################################################################

In [7]:
def sample_logits(out, temperature=1.0, top_p=0.8):
    probs = F.softmax(out, dim=-1).numpy()
    sorted_probs = np.sort(probs)[::-1]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
    probs[probs < cutoff] = 0
    if temperature != 1.0:
        probs = probs.pow(1.0 / temperature)
    probs = probs / np.sum(probs)
    out = np.random.choice(a=len(probs), p=probs)
    return out

########################################################################################################

In [8]:
tokenizer = RWKV_TOKENIZER("./rwkv_vocab_v20230424.txt")

# THIS IS NOW UPDATED TO SUPPORT LATEST RWKV-5 WORLD v2 MODELS

args = types.SimpleNamespace()
args.MODEL_NAME = '/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096' #这里不用有后缀.pth
args.n_layer = 24
args.n_embd = 1024
args.vocab_size = 65536

In [9]:
# N_LAYER="12"
# N_EMBD="768"
N_LAYER="24"
N_EMBD="1024"

In [10]:
# context = "\nElon Musk has"
# context = "\nWe found"
context = "Q:Do you know datawhalechina?\nA:"
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
LENGTH_PER_TRIAL = 4096
TEMPERATURE = 1.0
TOP_P = 0.7

Modeling improvements of Eagle (RWKV-5) and Finch (RWKV-6) compared to the basic RWKV-4 architecture:

1. **Improvement steps**:
- **Eagle improvements**: The Eagle model has made several improvements based on RWKV-4, including the introduction of matrix-valued attention states, the application of LayerNorm on attention heads, the use of SiLU (Sigmoid-Weighted Linear Unit) for attention gating, and improved initialization methods. In addition, Eagle removes the sigmoid activation function in the acceptance function.
- **Finch improvements**: The Finch model further introduces data-dependence on decay schedules and token-shifts, making the model more flexible and accurate in handling time and token data.

2. **Core Architecture**:
- The core architecture of these models is still similar to RWKV-4, consisting of a series of stacked residual blocks, similar in shape to the traditional Transformer architecture.
- Each block contains a pre-LayerNorm temporal mixerA Pre-LayerNorm Time-Mixing sub-layer and a Pre-LayerNorm Channel-Mixing sub-layer, corresponding to the attention sub-layer and feed-forward network sub-layer in Transformer.

This is the code implementation of Channel Mixing of RWKV 5, which can be compared with the implementation of RWKV 4.

```python
@MyFunction
def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
i0 = (2+self.head_size)*i+0
xk = x * time_mix_k + state[i0] * (1 - time_mix_k)
xr = x * time_mix_r + state[i0] * (1 - time_mix_r)
state[i0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)
```

The code implementation of Channel Mixing of RWKV 4 is:```python
@torch.jit.script_method
def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
state[5*i+0] = x
r = torch.sigmoid(rw @ xr)
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
return r * (vw @ k)
```

Here, `i` indicates how many layers RWKV has. In each layer of RWKV4, Channel Mixing records one state, and each Time Mixing records 4 states, so there are 5 states in total. And RWKEach layer in V 5 now records `2+self.head_size` states. The states recorded by Channel Mixing and the calculation process are exactly the same as RWKV 4.

![](./img/01.png)

Figure 1: Overview of the RWKV architecture. Left: Temporal Mixing and Channel Mixing blocks; Top right: RWKV Temporal Mixing block as an RNN unit; Bottom middle: Token shift module and Eagle temporal mixing in the feedforward module; Bottom right: Token shift module in Finch temporal mixing. All shape annotations assume single head for simplicity. Dashed arrows (left, top right) indicate connections in Finch but not in Eagle.

Token Shift technology used in Eagle model:

1. **Token Shift**:
- Eagle model adopts Token Shift technology from the previous RWKV model, which is similar to 1D causal convolution of size 2.
- A schematic diagram of this technology can be seen at the center bottom of Figure 1.

2. **Linear interpolation definition**:
- In order to better introduce the Token Shift technology, some symbols are defined.
- Linear interpolation (lerp) is used between time steps $t$ and $t-1$ for RWKV-4 and Eagle Token Shift, and is defined as follows:
\begin{align*}
\text{lerp}_{\Box}(a, b) = a + (b - a) \odot \mu_{\Box}
\end{align*}
- Where each $\mu_{\Box} \in \mathbb{R}^D$ is a learnable vector.

3. **Token Shift Function**:
- Token Shift allows the model to learn the proportion of new information and old information in each time step, which is suitable for acceptance, key, value, and gate vectors for each channel ($r, k, v, g$), and each head applies these vectors independently and uniquely.
- This allows a single head to directly accumulate past and current token data into different subspaces of these vectors, even within a single layer, forming induction heads.

The settings of the Channel Mixing module in Eagle and Finch models and their similarities and differences with the RWKV-4 architecture are as follows:

1. **Module consistency**:
- The channel mixing module in Eagle and Finch models is basically the same as the previous RWKV-4 architecture.
- The only difference is that in Eagle model, the hidden dimension of the channel mixing module is reduced from 4D to 3.5D.

2. **Reason for reducing dimension**:
- This reduction in hidden dimension is to introduce new gating weights in Eagle Time Mixing and ensure the same number of parameters as the previous model (with the same number of layers and embedding dimensions).

3. **Processing in Finch model**:
- Although some new LoRA weight parameters are added in Finch model, there is no further reduction in hidden dimension.

4. **Formula consistency**:
- The formulas for channel mixing are the same as those for the RWKV-4 model. For notational consistency, these formulas are listed again:

\begin{align*}
r'_t &= \text{lerp}_{r'}(x'_t, x'_{t-1}) W_{r'} \in \mathbb{R}^D \quad \text{(Formula 10)} \\
k'_t &= \text{lerp}_{k'}(x'_t, x'_{t-1}) W_{k'} \in \mathbb{R}^{3.5D} \quad \text{(Formula 11)} \\
v'_t &= \text{ReLU}(k'_t)^2 W_{v'} \in \mathbb{R}^D \quad \text{(Formula 12)} \\
o'_t &= \sigma(r'_t) \odot v'_t \in \mathbb{R}^D \quad \text{(Formula 13)}
\end{align*}

These formulas describe the channel mixing operation at time step \( t \):
- Calculate \(r'_t\) and \(k'_t\) using linear interpolation (lerp).
- \(v'_t\) is obtained by multiplying the squared ReLU value of \(k'_t\) by the weight matrix \(W_{v'}\).
- \(o'_t\) is the output of the activation function \(\sigma\) of \(r'_t\) and \(v'_t\) element-wise product.

Among them, 3.5D refers to a way of representing dimensions. In deep learning models, D usually represents the hidden dimension of the model (i.e., the embedding dimension or the dimension of the feature space). For example, if the hidden dimension of the model is 256, then 4D means that this dimension is expanded by 4 times, that is, 1024.

However, 3.5D is an uncommon representation method. Usually, we see integer multiple representations (such as 2D, 4D, etc.). Here, 3.5D represents 3.5 times the hidden dimension.

Specifically, if the base dimension of the model is D, then 3.5D means: \begin{align*} 3.5D = 3.5 \times D \end{align*}

Assuming D is 256, then 3.5D is: \begin{align*} 3.5 \times 256 = 896 \end{align*}

So, 3.5D means that the feature dimension used by the model in a specific layer is 3.5 times the base dimension. In this document, the author mentioned that the reduction from 4D to 3.5D means that they reduced the feature dimension of a certain layer or module in order to introduce new gating weights and keep the number of parameters consistent.

The formula and operation method of Eagle Time Mixing are as follows:

### Formula part

The formula of Eagle Time Mixing is as follows:

\begin{align*}
\Box_t &= \text{lerp}_{\Box}(x_t, x_{t-1}) W_{\Box}, \quad \Box \in \{r, k, v, g\} \tag{4} \\
w &= \exp(-\exp(\omega)) \tag{5} \\
\text{wk} \mathbf{v}_t &= \text{diag}(u) \cdot k_t^\top \cdot v_t + \sum_{i=1}^{t-1} \text{diag}(w)^{t-1-i} \cdot k_i^\top \cdot v_i \in \mathbb{R}^{(D/h) \times (D/h)} \tag{6} \\
o_t &= \text{concat} \left( \text{SiLU}(g_t) \odot \text{LayerNorm}(r_t \cdot \text{wk} \mathbf{v}_t) \right) W_o \in \mathbb{R}^D \tag{7}
\end{align*}

### Explanation

- **LayerNorm operation**: LayerNorm operates independently on each head, which is equivalent to performing GroupNorm on h groups (Wu & He, 2018). It is worth noting that $w$ is calculated by $\omega \in \mathbb{R}^{D/h}$ through the formula $w = \exp(-\exp(\omega))$, where $\omega$ is the actual head trainable parameter. This ensures that $w$ is in the interval (0,1), thus ensuring that $\text{diag}(w)$ is a contraction matrix.

- **wkv_t calculation**: The attention calculation of wkv_t can be written in recursive form:
\begin{align*}
\text{wk} \mathbf{v}' &= s + \text{diag}(u) \cdot k^\top \cdot v \tag{8} \\
s' &= \text{diag}(w) \cdot s + k^\top \cdot v \tag{9}
\end{align*}

- **Explain the wkv_t term of RWKV**: The wk\mathbf{v}_t term of RWKV can be thought of as the decay-based equivalent of the normalized $k^\top v$ term. Notably, for a given head $j$, the recurrent state $s$ is the sum of $k^\top v$, where each channel of $s$ is decayed individually by the corresponding $w$ channels at each time step. Before applying the acceptance vector, gating, and output weights, the $k^\top v$ of the current token is multiplied by a per-channel learned boost $u$ and added to the state, see the top right corner of Figure 1. This gives the current token a special treatment relative to the sum of past tokens contained in the history of decayed states. The acceptance is multiplied by this sum, similar to the query term in linear attention.

The biggest improvement here is that the calculation is now divided into `H = self.n_head` heads, and the calculation results of each head are stored in the state. Compared with RWKV-4, this improvement can be compared to the change from the single-head self-attention mechanism of Transformer to the multi-head attention mechanism.
```python
@MyFunction
def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
H = self.n_head
S = self.head_size

i1 = (2+S)*i+1
xk = x * time_mix_k + state[i1] * (1 - time_mix_k)
xv = x * time_mix_v + state[i1] * (1 - time_mix_v)
xr = x * time_mix_r + state[i1] * (1 - time_mix_r)
xg = x * time_mix_g + state[i1] * (1 - time_mix_g)
state[i1] = x

r = (rw @ xr).view(H, 1, S)
k = (kw @ xk).view(H, S, 1)
v = (vw @ xv).view(H, 1, S)
g = F.silu(gw @ xg)

s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)

x = torch.zeros(H, S)
a = k @ v
x = r @ (time_first * a + s)
s = a + time_decay * s

state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)x = x.flatten()

x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
return ow @ x
```

In [11]:
class RWKV_RNN(MyModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.eval() # set torch to inference mode
        
        w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
        for k in w.keys():
            w[k] = w[k].float() # convert to f32 type
            if      '.time_' in k: w[k] = w[k].squeeze()
            if '.time_decay' in k: w[k] = torch.exp(-torch.exp(w[k])).unsqueeze(-1)
            if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)

        self.n_head = w['blocks.0.att.time_decay'].shape[0]
        self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
        
        self.w = types.SimpleNamespace() # set self.w from w
        self.w.blocks = {}
        for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
            parts = k.split('.')
            last = parts.pop()
            here = self.w
            for p in parts:
                if p.isdigit():
                    p = int(p)
                    if p not in here: here[p] = types.SimpleNamespace()
                    here = here[p]
                else:
                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                    here = getattr(here, p)
            setattr(here, last, w[k])

    def layer_norm(self, x, w):
        return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)

    @MyFunction
    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        i0 = (2+self.head_size)*i+0
        xk = x * time_mix_k + state[i0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[i0] * (1 - time_mix_r)
        state[i0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

    @MyFunction
    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_mix_g, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
        H = self.n_head
        S = self.head_size

        i1 = (2+S)*i+1
        xk = x * time_mix_k + state[i1] * (1 - time_mix_k)
        xv = x * time_mix_v + state[i1] * (1 - time_mix_v)
        xr = x * time_mix_r + state[i1] * (1 - time_mix_r)
        xg = x * time_mix_g + state[i1] * (1 - time_mix_g)
        state[i1] = x

        r = (rw @ xr).view(H, 1, S)
        k = (kw @ xk).view(H, S, 1)
        v = (vw @ xv).view(H, 1, S)
        g = F.silu(gw @ xg)

        s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)

        x = torch.zeros(H, S)
        a = k @ v
        x = r @ (time_first * a + s)
        s = a + time_decay * s
    
        state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
        x = x.flatten()

        x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
        return ow @ x

    def forward(self, token, state):
        with torch.no_grad():
            if state == None:
                state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)
            
            x = self.w.emb.weight[token]
            x = self.layer_norm(x, self.w.blocks[0].ln0)
            for i in range(self.args.n_layer):
# print(i)
                att = self.w.blocks[i].att
                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, 
                    att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_mix_g, att.time_faaaa, att.time_decay, 
                    att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
                    att.ln_x.weight, att.ln_x.bias)
                ffn = self.w.blocks[i].ffn
                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, 
                    ffn.time_mix_k, ffn.time_mix_r, 
                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
            
            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
            return x.float(), state

In [32]:
# context = "Q:Do you know datawhalechina?\nA:"
context = '\nQ:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?'

In [33]:
args.MODEL_NAME

'/data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096'

In [34]:
args.n_layer,args.n_embd

(24, 1024)

In [35]:
# args.n_layer = 24
# args.n_embd = 1024

In [36]:
# args.n_layer = 12
# args.n_embd = 768

In [37]:
# args.MODEL_NAME='../models/rwkv-5-world-1b5'

In [38]:
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)

print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None


Using CPU. Loading /data1/ckw/RWKV-5-World-0.4B-v2-20231113-ctx4096 ...

Preprocessing context (slow version. see v2/rwkv/model.py for fast version)


In [39]:
init_state = None

In [40]:
LENGTH_PER_TRIAL=1024

In [41]:
for token in tokenizer.encode(context):
    init_out, init_state = model.forward(token, init_state)

for TRIAL in range(NUM_TRIALS):
    print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
    all_tokens = []
    out_last = 0
    out, state = init_out.clone(), init_state.clone()
    for i in range(LENGTH_PER_TRIAL):
        token = sample_logits(out, TEMPERATURE, TOP_P)
        all_tokens += [token]
        try:
            tmp = tokenizer.decode(all_tokens[out_last:])
            if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
                print(tmp, end="", flush=True)
                out_last = i + 1
        except:
            pass
        out, state = model.forward(token, state)       
print('\n')



--[ Trial 0 ]----------------- 
Q:DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. How do you think of it?
QI: I think that the group of students is actually the whole AI community.
Q: In the first episode, how do you think you, a student, can use AI to solve a problem?
QI: It's a great opportunity to help develop and build knowledge, so that if we see AI problems, we can help solve them.
Q: How do you think that students can also participate in the teaching of AI?
QI: It is very important to let the students to think that there is an AI problem, and we can solve it by teaching AI.
Q: How do you think the research that we did on AI can be used to develop AI technologies?
QI: The research is interesting and it can be used to develop AI technologies.
Q: Do you think that students can learn from your research?
QI: I think so.
Q: You also talk about the use of AI in real-life applications. What do you think of t