In [5]:
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
import torch.nn as nn
from torch.nn import functional as F

MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method

![](./img/01.png)

Figure 1: Overview of the RWKV architecture. Left: Temporal Mixing and Channel Mixing blocks; Top right: RWKV Temporal Mixing block as an RNN unit; Bottom middle: Token shift module and Eagle temporal mixing in the feedforward module; Bottom right: Token shift module in Finch temporal mixing. All shape annotations assume single head for simplicity. Dashed arrows (left, top right) indicate connections in Finch but not in Eagle.

First of all, RWKV 6 has made improvements in Token Shift compared to RWKV 5. Specifically, see the pictures at the bottom middle and lower right corner below, which are the Token Shift methods of RWKV 4/5 and RWKV 6 respectively.

The details are as follows:

### Formula part

The data-dependent linear interpolation (ddlerp) used in Finch Token Shift is defined as follows:

\begin{align*}
\text{lora}_{\Box}(x) &= \lambda_{\Box} + \tanh(x A_{\Box}) B_{\Box} \tag{14} \\
\text{ddlerp}_{\Box}(a, b) &= a + (b - a) \odot \text{lora}_{\Box}(a + (b - a) \odot \mu_{x}) \tag{15}
\end{align*}

### Explanation part

- **Learnable vectors and matrices**:
- $\mu_{x}$ and each $\lambda_{\Box}$ introduce a trainable vector of dimension $D$.
- $A_{\Box} \in \mathbb{R}^{D \times 32}$ and $B_{\Box} \in \mathbb{R}^{32 \times D}$ introduce new trainable weight matrices.
- For the special case of LoRA$_{\omega}$ mentioned in the formula, a double-sized trainableWeight matrices: $A_{\omega} \in \mathbb{R}^{D \times 64}$ and $B_{\omega} \in \mathbb{R}^{64 \times D}$.

- **Future model extensions**:
- A schematic is shown in the lower right corner of Figure 1.
- Future 7B and larger Finch models are expected to further increase the size of these weight matrices (possibly doubling or more).

### Function and effect

This new form of Token Shift with data dependency is designed to extend the model's capabilities beyond the RWKV-4/Eagle style Token Shift, so that the amount of new and old data assigned to each channel now depends on the input of the current and previous time steps.

### Detailed explanation

- **Data-dependent linear interpolation (ddlerp)**:
- ddlerp is implemented by Equation 14 and Equation 15, which combines information from the current time step and the previous time step to calculate the interpolation.
- $\text{lora}_{\Box}(x)$ is generated by using a $\lambda_{\Box}$ vector and the product of $x A_{\Box}$ and $B_{\Box}$ processed by the $\tanh$ function.

- **Model capability expansion**:
- Through this data-dependent Token ShIft, the Finch model can handle information transfer between time steps more flexibly, making the model more accurate and efficient when processing complex sequence data.

In summary, Finch introduces data-dependent linear interpolation on Token Shift, and uses trainable vectors and matrices to enhance the flexibility and capabilities of the model, enabling it to better handle information between time steps, thereby improving the overall performance of the model.

In [6]:
class RWKV_TOKENIZER():
    table: list[list[list[bytes]]]
    good: list[set[int]]
    wlen: list[int]
    def __init__(self, file_name):
        self.idx2token = {}
        sorted = [] # must be already sorted
        lines = open(file_name, "r", encoding="utf-8").readlines()
        for l in lines:
            idx = int(l[:l.index(' ')])
            x = eval(l[l.index(' '):l.rindex(' ')])
            x = x.encode("utf-8") if isinstance(x, str) else x
            assert isinstance(x, bytes)
            assert len(x) == int(l[l.rindex(' '):])
            sorted += [x]
            self.idx2token[idx] = x

        self.token2idx = {}
        for k, v in self.idx2token.items():
            self.token2idx[v] = int(k)

# precompute some tables for fast matching
        self.table = [[[] for j in range(256)] for i in range(256)]
        self.good = [set() for i in range(256)]
        self.wlen = [0 for i in range(256)]

        for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
            s = sorted[i]
            if len(s) >= 2:
                s0 = int(s[0])
                s1 = int(s[1])
                self.table[s0][s1] += [s]
                self.wlen[s0] = max(self.wlen[s0], len(s))
                self.good[s0].add(s1)

    def encodeBytes(self, src: bytes) -> list[int]:
        src_len: int = len(src)
        tokens: list[int] = []
        i: int = 0
        while i < src_len:
            s: bytes = src[i : i + 1]

            if i < src_len - 1:
                s1: int = int(src[i + 1])
                s0: int = int(src[i])
                if s1 in self.good[s0]:
                    sss: bytes = src[i : i + self.wlen[s0]]
                    try:
                        s = next(filter(sss.startswith, self.table[s0][s1]))
                    except:
                        pass
            tokens.append(self.token2idx[s])
            i += len(s)

        return tokens

    def decodeBytes(self, tokens):
        return b''.join(map(lambda i: self.idx2token[i], tokens))

    def encode(self, src: str):
        return self.encodeBytes(src.encode("utf-8"))

    def decode(self, tokens):
        return self.decodeBytes(tokens).decode('utf-8')

    def printTokens(self, tokens):
        for i in tokens:
            s = self.idx2token[i]
            try:
                s = s.decode('utf-8')
            except:
                pass
            print(f'{repr(s)}{i}', end=' ')
# print(repr(s), i)
        print()

########################################################################################################

In [7]:
#The sampling method has not changed
def sample_logits(out, temperature=1.0, top_p=0.8):
    probs = F.softmax(out, dim=-1).numpy()
    sorted_probs = np.sort(probs)[::-1]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
    probs[probs < cutoff] = 0
    if temperature != 1.0:
        probs = probs.pow(1.0 / temperature)
    probs = probs / np.sum(probs)
    out = np.random.choice(a=len(probs), p=probs)
    return out

########################################################################################################

In [8]:
tokenizer = RWKV_TOKENIZER("./rwkv_vocab_v20230424.txt")

args = types.SimpleNamespace()
args.MODEL_NAME = '/data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096'
args.n_layer = 24
args.n_embd = 2048
args.vocab_size = 65536

context = "\nDatawhale is "
# context = "\nWe found"
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
TEMPERATURE = 1.0
TOP_P = 0.7

Compared to the Channel Mixing of RWKV 5 (see the code below), the Channel Mixing of RWKV6 has not changed. The `time_maa_k` here and the `time_mix_k` in RWKV 5 are learnable parameters of the same shape, both of which are tensors of dimension D (the hidden layer dimension of the model).

Finch made the following improvements on Time Mixing, the details are as follows:

### Formula

The formula of Finch Time Mixing is as follows:

\begin{align*}
\Box_t &= \text{lerp}_{\Box}(x_t, x_{t-1}) W_{\Box}, \quad \Box \in \{r, k, v, g\} \tag{16} \\
d_t &= \text{lora}_d(\text{ddlerp}_d(x_t, x_{t-1})) \tag{17} \\
w_t &= \exp(-\exp(d_t)) \tag{18} \\
\text{wk} \mathbf{v}_t &= \text{diag}(u) \cdot k_t^\top \cdot v_t + \sum_{i=1}^{t-1} \left( \prod_{j=1}^{i-1} w_j \right) \cdot k_i^\top \cdot v_i \in \mathbb{R}^{(D/h) \times (D/h)} \tag{19} \\
o_t &= \text{concat} \left(\text{SiLU}(g_t) \odot \text{LayerNorm}(r_t \cdot \text{wk} \mathbf{v}_t) \right) W_o \in \mathbb{R}^D \tag{20}
\end{align*}

### Explanation

- **Learnable vectors and matrices**:
- $\Box_t$ is calculated by linear interpolation (lerp) and is applicable to acceptance, key, value and gate vectors.
- $d_t$ is obtained by processing $\text{ddlerp}_d(x_t, x_{t-1})$ with the $\text{lora}_d$ function.
- $w_t$ is calculated by $d_t$ and is used to control the dynamic changes of attenuation.

- **Time-mixing calculation**:
- $\text{wk} \mathbf{v}_t$ is calculated by the weighted sum of the current key-value pair $k_t^\top \cdot v_t$ and the key-value pairs $k_i^\top \cdot v_i$ of all previous time steps, with the weights controlled by $w_t$.
- The output $o_t$ is calculated byConcatenate (concat) $\text{SiLU}(g_t)$ and $\text{LayerNorm}(r_t \cdot \text{wk} \mathbf{v}_t)$.

- **Recursive form**:
\begin{align*}
\text{wk} \mathbf{v}' &= s + \text{diag}(u) \cdot k^\top \cdot v \tag{21} \\
s' &= \text{diag}(w) \cdot s + k^\top \cdot v \tag{22}
\end{align*}

### Function and role

Unlike Eagle, $w_t$ in Finch is not fixed throughout the sequence. $w_t$ of each channel can change dynamically over time, depending on the data input, which is also the core change of attenuation in Finch.

### Detailed explanation

- **Dynamic decay**:
- The data-dependent decay introduced by Finch allows $w_t$ of each channel to change dynamically based on the current and previous inputs, rather than a fixed learning vector.
- This dynamic decay mechanism is applied to the learning vector through the new LoRA mechanism, increasing the flexibility of the model.

-**Advanced Token-Shift**:
- The new time decay $w_t$ further applies the LoRA mechanism, allowing $w_t$ of each channel to change based on the current and previous token mixing.

In summary, Finch achieves higher flexibility and accuracy in time mixing by introducing data-dependent dynamic decay mechanism and advanced Token-Shift, enabling the model to better process and fuse information between time steps, thereby improving overall performance.

In [9]:
class RWKV_RNN(MyModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.eval() # set torch to inference mode
        
        w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')

        for k in w.keys():
            w[k] = w[k].float() # convert to f32 type
            if      '.time_' in k: w[k] = w[k].squeeze()
            if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)

        self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
        self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
        
        self.w = types.SimpleNamespace() # set self.w from w
        self.w.blocks = {}
        for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
            parts = k.split('.')
            last = parts.pop()
            here = self.w
            for p in parts:
                if p.isdigit():
                    p = int(p)
                    if p not in here: here[p] = types.SimpleNamespace()
                    here = here[p]
                else:
                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                    here = getattr(here, p)
            setattr(here, last, w[k])

    def layer_norm(self, x, w):
        return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)

    @MyFunction
    def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw):
        i0 = (2+self.head_size)*i+0
        sx = state[i0] - x
        xk = x + sx * time_maa_k
        xr = x + sx * time_maa_r
        state[i0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

    @MyFunction
    def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
        H = self.n_head
        S = self.head_size

        i1 = (2+S)*i+1
        sx = state[i1] - x
        state[i1] = x
        xxx = x + sx * x_maa
        xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
        xxx = torch.bmm(xxx, tm_w2).view(5, -1)
        mw, mk, mv, mr, mg = xxx.unbind(dim=0)

        xw = x + sx * (w_maa + mw)
        xk = x + sx * (k_maa + mk)
        xv = x + sx * (v_maa + mv)
        xr = x + sx * (r_maa + mr)
        xg = x + sx * (g_maa + mg)

        w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)
        w = torch.exp(-torch.exp(w.float()))

        r = (rw @ xr).view(H, 1, S)
        k = (kw @ xk).view(H, S, 1)
        v = (vw @ xv).view(H, 1, S)
        g = F.silu(gw @ xg)

        s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)

        x = torch.zeros(H, S)
        a = k @ v
        x = r @ (time_first * a + s)
        s = a + w * s
    
        state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
        x = x.flatten()

        x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
        return ow @ x

    def forward(self, token, state):
        with torch.no_grad():
            if state == None:
                state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)
            
            x = self.w.emb.weight[token]
            x = self.layer_norm(x, self.w.blocks[0].ln0)
            for i in range(self.args.n_layer):
                att = self.w.blocks[i].att
                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
                    att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,
                    att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,
                    att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
                    att.ln_x.weight, att.ln_x.bias)
                ffn = self.w.blocks[i].ffn
                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, 
                    ffn.time_maa_k, ffn.time_maa_r, 
                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
            
            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
            return x.float(), state

In [10]:
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)

print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None


Using CPU. Loading /data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096 ...

Preprocessing context (slow version. see v2/rwkv/model.py for fast version)


In [11]:
for token in tokenizer.encode(context):
    init_out, init_state = model.forward(token, init_state)

In [12]:
for TRIAL in range(NUM_TRIALS):
    print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
    all_tokens = []
    out_last = 0
    out, state = init_out.clone(), init_state.clone()
    for i in range(LENGTH_PER_TRIAL):
        token = sample_logits(out, TEMPERATURE, TOP_P)
        all_tokens += [token]
        try:
            tmp = tokenizer.decode(all_tokens[out_last:])
            if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
                print(tmp, end="", flush=True)
                out_last = i + 1
        except:
            pass
        out, state = model.forward(token, state)       
print('\n')



--[ Trial 0 ]----------------- 
Datawhale is ➡️‼️
https://twitter.com/datawhale_cn/status/1463997087819689985
#Data #AI #DataAnalytics #AIOps #DataOps #MachineLearning #DataScience #DataLakeAnalytics #Hadoop #Amazon #Google #AWS #Azure #Dataprep #DevOps #OSS #Linux #Unix #BigData #BigDataOps #DataArchitecture #DataScienceOps #MachineLearningOps

--[ Trial 1 ]----------------- 
Datawhale is 🤓

--[ Trial 2 ]----------------- 
Datawhale is 🤯. They have a solid team, a really good SaaS product and the tools to support their users. That said, I have to take a serious look at the privacy and security of their platform before I buy into their story. I think this is a case of big companies buying into the hype, and they're not taking into account all the realities that go into building a privacy-focused product.
P.S. You can still apply to Datawhale's Program.



Compared with v6 and v5, I feel like using emoj more haha