In [5]:
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
import torch.nn as nn
from torch.nn import functional as F

MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method

![](./img/01.png)

图1：RWKV架构概述。左侧：时间混合和通道混合块；右上角：作为RNN单元的RWKV时间混合块；中下部：前馈模块中的令牌移位模块和Eagle时间混合；右下角：Finch时间混合中的令牌移位模块。所有形状注释为简洁起见假设为单头。虚线箭头（左侧，右上角）表示在Finch中有连接，但在Eagle中没有。

首先RWKV 6相比于RWKV 5在Token Shift上进行了改进，具体看下面的中间底部和右下角的图，分别是RWKV 4/5的Token Shift方式和RWKV 6的Token Shift方式。

具体内容如下：

### 公式部分

Finch Token Shift中使用的数据依赖线性插值（ddlerp）定义如下：

\begin{align*}
\text{lora}_{\Box}(x) &= \lambda_{\Box} + \tanh(x A_{\Box}) B_{\Box} \tag{14} \\
\text{ddlerp}_{\Box}(a, b) &= a + (b - a) \odot \text{lora}_{\Box}(a + (b - a) \odot \mu_{x}) \tag{15}
\end{align*}

### 解释部分

- **可学习向量和矩阵**：
  - $\mu_{x}$ 和每个 $\lambda_{\Box}$ 引入了维度为 $D$ 的可训练向量。
  - $A_{\Box} \in \mathbb{R}^{D \times 32}$ 和 $B_{\Box} \in \mathbb{R}^{32 \times D}$ 引入了新的可训练权重矩阵。
  - 对于公式中提到的LoRA$_{\omega}$的特殊情况，引入了双倍大小的可训练权重矩阵：$A_{\omega} \in \mathbb{R}^{D \times 64}$ 和 $B_{\omega} \in \mathbb{R}^{64 \times D}$。

- **未来模型扩展**：
  - 图1中右下角显示了一个示意图。
  - 未来7B及更大规模的Finch模型预计将进一步增加这些权重矩阵的大小（可能翻倍或更多）。

### 功能与作用

这种带有数据依赖性的Token Shift新形式旨在扩展模型超越RWKV-4/Eagle风格的Token Shift的能力，使得每个通道分配的新旧数据量现在依赖于当前和前一个时间步的输入。

### 详细解释

- **数据依赖线性插值（ddlerp）**：
  - ddlerp通过公式14和公式15实现，它结合了当前时间步和前一个时间步的信息来计算插值。
  - $\text{lora}_{\Box}(x)$利用了一个$\lambda_{\Box}$向量和通过$\tanh$函数处理的$x A_{\Box}$与$B_{\Box}$的乘积来生成。

- **模型能力扩展**：
  - 通过这种数据依赖的Token Shift，Finch模型能够更灵活地处理时间步之间的信息传递，使得模型在处理复杂序列数据时更加精确和高效。

总结来说，Finch在Token Shift上引入了数据依赖的线性插值，利用可训练的向量和矩阵来增强模型的灵活性和能力，使其能够更好地处理时间步之间的信息，从而提高了模型的整体性能。

In [6]:
class RWKV_TOKENIZER():
    table: list[list[list[bytes]]]
    good: list[set[int]]
    wlen: list[int]
    def __init__(self, file_name):
        self.idx2token = {}
        sorted = [] # must be already sorted
        lines = open(file_name, "r", encoding="utf-8").readlines()
        for l in lines:
            idx = int(l[:l.index(' ')])
            x = eval(l[l.index(' '):l.rindex(' ')])
            x = x.encode("utf-8") if isinstance(x, str) else x
            assert isinstance(x, bytes)
            assert len(x) == int(l[l.rindex(' '):])
            sorted += [x]
            self.idx2token[idx] = x

        self.token2idx = {}
        for k, v in self.idx2token.items():
            self.token2idx[v] = int(k)

        # precompute some tables for fast matching
        self.table = [[[] for j in range(256)] for i in range(256)]
        self.good = [set() for i in range(256)]
        self.wlen = [0 for i in range(256)]

        for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
            s = sorted[i]
            if len(s) >= 2:
                s0 = int(s[0])
                s1 = int(s[1])
                self.table[s0][s1] += [s]
                self.wlen[s0] = max(self.wlen[s0], len(s))
                self.good[s0].add(s1)

    def encodeBytes(self, src: bytes) -> list[int]:
        src_len: int = len(src)
        tokens: list[int] = []
        i: int = 0
        while i < src_len:
            s: bytes = src[i : i + 1]

            if i < src_len - 1:
                s1: int = int(src[i + 1])
                s0: int = int(src[i])
                if s1 in self.good[s0]:
                    sss: bytes = src[i : i + self.wlen[s0]]
                    try:
                        s = next(filter(sss.startswith, self.table[s0][s1]))
                    except:
                        pass
            tokens.append(self.token2idx[s])
            i += len(s)

        return tokens

    def decodeBytes(self, tokens):
        return b''.join(map(lambda i: self.idx2token[i], tokens))

    def encode(self, src: str):
        return self.encodeBytes(src.encode("utf-8"))

    def decode(self, tokens):
        return self.decodeBytes(tokens).decode('utf-8')

    def printTokens(self, tokens):
        for i in tokens:
            s = self.idx2token[i]
            try:
                s = s.decode('utf-8')
            except:
                pass
            print(f'{repr(s)}{i}', end=' ')
            # print(repr(s), i)
        print()

########################################################################################################

In [7]:
#采样方式没有变化
def sample_logits(out, temperature=1.0, top_p=0.8):
    probs = F.softmax(out, dim=-1).numpy()
    sorted_probs = np.sort(probs)[::-1]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
    probs[probs < cutoff] = 0
    if temperature != 1.0:
        probs = probs.pow(1.0 / temperature)
    probs = probs / np.sum(probs)
    out = np.random.choice(a=len(probs), p=probs)
    return out

########################################################################################################

In [8]:
tokenizer = RWKV_TOKENIZER("./rwkv_vocab_v20230424.txt")

args = types.SimpleNamespace()
args.MODEL_NAME = '/data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096'
args.n_layer = 24
args.n_embd = 2048
args.vocab_size = 65536

context = "\nDatawhale is "
# context = "\n我们发现"
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
TEMPERATURE = 1.0
TOP_P = 0.7

相比于RWKV 5的Channel Mixing（见下面代码）来说，RWKV6的Channel Mixing没有变化，这里的`time_maa_k`和RWKV 5中的`time_mix_k`是相同形状的可学习参数，都是一个维度为D（模型的隐藏层维度）的张量。

Finch在时间混合（Time Mixing）上做了以下改进，具体内容如下：

### 公式部分

Finch时间混合的公式如下：

\begin{align*}
\Box_t &= \text{lerp}_{\Box}(x_t, x_{t-1}) W_{\Box}, \quad \Box \in \{r, k, v, g\} \tag{16} \\
d_t &= \text{lora}_d(\text{ddlerp}_d(x_t, x_{t-1})) \tag{17} \\
w_t &= \exp(-\exp(d_t)) \tag{18} \\
\text{wk} \mathbf{v}_t &= \text{diag}(u) \cdot k_t^\top \cdot v_t + \sum_{i=1}^{t-1} \left( \prod_{j=1}^{i-1} w_j \right) \cdot k_i^\top \cdot v_i \in \mathbb{R}^{(D/h) \times (D/h)} \tag{19} \\
o_t &= \text{concat} \left( \text{SiLU}(g_t) \odot \text{LayerNorm}(r_t \cdot \text{wk} \mathbf{v}_t) \right) W_o \in \mathbb{R}^D \tag{20}
\end{align*}

### 解释部分

- **可学习向量和矩阵**：
  - $\Box_t$ 是通过线性插值（lerp）计算得到的，适用于接受度（receptance）、键（key）、值（value）和门控向量（gate vectors）。
  - $d_t$ 是通过 $\text{lora}_d$ 函数对 $\text{ddlerp}_d(x_t, x_{t-1})$ 进行处理得到的。
  - $w_t$ 是由 $d_t$ 计算得到的，用于控制衰减的动态变化。

- **时间混合计算**：
  - $\text{wk} \mathbf{v}_t$ 是通过当前键值对 $k_t^\top \cdot v_t$ 和所有之前时间步的键值对 $k_i^\top \cdot v_i$ 的加权和计算得到的，权重由 $w_t$ 控制。
  - 输出 $o_t$ 是通过连接（concat） $\text{SiLU}(g_t)$ 和 $\text{LayerNorm}(r_t \cdot \text{wk} \mathbf{v}_t)$ 的结果得到的。

- **递归形式**：
  \begin{align*}
  \text{wk} \mathbf{v}' &= s + \text{diag}(u) \cdot k^\top \cdot v \tag{21} \\
  s' &= \text{diag}(w) \cdot s + k^\top \cdot v \tag{22}
  \end{align*}

### 功能与作用

与Eagle不同，Finch中的 $w_t$ 不是在整个序列中固定的。每个通道的 $w_t$ 可以随时间动态变化，具体取决于数据输入，这也是Finch中衰减的核心变化。

### 详细解释

- **动态衰减**：
  - Finch引入的数据依赖衰减使得每个通道的 $w_t$ 可以根据当前和之前的输入动态变化，而不是固定的学习向量。
  - 这种动态衰减机制通过新的LoRA机制应用到学习向量上，增加了模型的灵活性。

- **高级Token-Shift**：
  - 新的时间衰减 $w_t$ 进一步应用了LoRA机制，允许每个通道的 $w_t$ 基于当前和之前的令牌混合来变化。

总结来说，Finch在时间混合上通过引入数据依赖的动态衰减机制和高级Token-Shift，实现了更高的灵活性和精确度，使模型能够更好地处理和融合时间步之间的信息，从而提高了整体性能。

In [9]:
class RWKV_RNN(MyModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.eval() # set torch to inference mode
        
        w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')

        for k in w.keys():
            w[k] = w[k].float() # convert to f32 type
            if      '.time_' in k: w[k] = w[k].squeeze()
            if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)

        self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
        self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
        
        self.w = types.SimpleNamespace() # set self.w from w
        self.w.blocks = {}
        for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
            parts = k.split('.')
            last = parts.pop()
            here = self.w
            for p in parts:
                if p.isdigit():
                    p = int(p)
                    if p not in here: here[p] = types.SimpleNamespace()
                    here = here[p]
                else:
                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                    here = getattr(here, p)
            setattr(here, last, w[k])

    def layer_norm(self, x, w):
        return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)

    @MyFunction
    def channel_mixing(self, x, state, i:int, time_maa_k, time_maa_r, kw, vw, rw):
        i0 = (2+self.head_size)*i+0
        sx = state[i0] - x
        xk = x + sx * time_maa_k
        xr = x + sx * time_maa_r
        state[i0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

    @MyFunction
    def time_mixing(self, x, state, i:int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2, time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
        H = self.n_head
        S = self.head_size

        i1 = (2+S)*i+1
        sx = state[i1] - x
        state[i1] = x
        xxx = x + sx * x_maa
        xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
        xxx = torch.bmm(xxx, tm_w2).view(5, -1)
        mw, mk, mv, mr, mg = xxx.unbind(dim=0)

        xw = x + sx * (w_maa + mw)
        xk = x + sx * (k_maa + mk)
        xv = x + sx * (v_maa + mv)
        xr = x + sx * (r_maa + mr)
        xg = x + sx * (g_maa + mg)

        w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)
        w = torch.exp(-torch.exp(w.float()))

        r = (rw @ xr).view(H, 1, S)
        k = (kw @ xk).view(H, S, 1)
        v = (vw @ xv).view(H, 1, S)
        g = F.silu(gw @ xg)

        s = state[(2+S)*i+2:(2+S)*(i+1), :].reshape(H, S, S)

        x = torch.zeros(H, S)
        a = k @ v
        x = r @ (time_first * a + s)
        s = a + w * s
    
        state[(2+S)*i+2:(2+S)*(i+1), :] = s.reshape(S, -1)
        x = x.flatten()

        x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps = 64e-5).squeeze(0) * g # same as gn(x/8, eps=1e-5)
        return ow @ x

    def forward(self, token, state):
        with torch.no_grad():
            if state == None:
                state = torch.zeros(self.args.n_layer * (2+self.head_size), self.args.n_embd)
            
            x = self.w.emb.weight[token]
            x = self.layer_norm(x, self.w.blocks[0].ln0)
            for i in range(self.args.n_layer):
                att = self.w.blocks[i].att
                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
                    att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r, att.time_maa_g, att.time_maa_w1, att.time_maa_w2,
                    att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,
                    att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight, att.output.weight,
                    att.ln_x.weight, att.ln_x.bias)
                ffn = self.w.blocks[i].ffn
                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, 
                    ffn.time_maa_k, ffn.time_maa_r, 
                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
            
            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
            return x.float(), state

In [10]:
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)

print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None


Using CPU. Loading /data1/ckw/RWKV-x060-World-1B6-v2.1-20240328-ctx4096 ...

Preprocessing context (slow version. see v2/rwkv/model.py for fast version)


In [11]:
for token in tokenizer.encode(context):
    init_out, init_state = model.forward(token, init_state)

In [12]:
for TRIAL in range(NUM_TRIALS):
    print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
    all_tokens = []
    out_last = 0
    out, state = init_out.clone(), init_state.clone()
    for i in range(LENGTH_PER_TRIAL):
        token = sample_logits(out, TEMPERATURE, TOP_P)
        all_tokens += [token]
        try:
            tmp = tokenizer.decode(all_tokens[out_last:])
            if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
                print(tmp, end="", flush=True)
                out_last = i + 1
        except:
            pass
        out, state = model.forward(token, state)       
print('\n')



--[ Trial 0 ]----------------- 
Datawhale is ➡️‼️
https://twitter.com/datawhale_cn/status/1463997087819689985
#Data #AI #DataAnalytics #AIOps #DataOps #MachineLearning #DataScience #DataLakeAnalytics #Hadoop #Amazon #Google #AWS #Azure #Dataprep #DevOps #OSS #Linux #Unix #BigData #BigDataOps #DataArchitecture #DataScienceOps #MachineLearningOps

--[ Trial 1 ]----------------- 
Datawhale is 🤓

--[ Trial 2 ]----------------- 
Datawhale is 🤯. They have a solid team, a really good SaaS product and the tools to support their users. That said, I have to take a serious look at the privacy and security of their platform before I buy into their story. I think this is a case of big companies buying into the hype, and they're not taking into account all the realities that go into building a privacy-focused product.
P.S. You can still apply to Datawhale's Program.



v6和v5相比，感觉更喜欢使用emoj了哈哈