In [1]:
import numpy as np
import math, json, time, types, copy, sys, os
import torch
from torch.nn import functional as F
import torch.nn as nn

from transformers import PreTrainedTokenizerFast

np.set_printoptions(precision=4, suppress=True, linewidth=200)

In [2]:
RUN_DEVICE = 'cpu' # cpu cuda
ctx_len = 768
n_layer = 12
n_embd = 768
# n_layer = 24
# n_embd = 1024

# ---> download RWKV-3 169M model from https://huggingface.co/BlinkDL/rwkv-3-pile-169m/tree/main

# MODEL_NAME = '/data1/ckw/RWKV-3-Pile-430M-20220817-10602'
MODEL_NAME = '/data1/ckw/RWKV-3-Pile-20220720-10704'
K_EPS = 1e-8

vocab_size = 50277
VOCAB_NAME = '20B_tokenizer.json'

print(f'\n* running on {RUN_DEVICE}')


* running on cpu


In the v3 version, some modifications have been made to the TimeMix and ChannelMix parts of the RWKV model. The main changes are as follows:

1. **ChannelMix**:
- In the v2 version, a time mixing technique was used in the calculation of the ChannelMix module, that is, a time sliding window was used to smooth the input to a certain extent. In the v3 version, this technique was removed and time mixing was no longer used.
- In addition, the `time_mix_k` and `time_mix_r` parameters in the ChannelMix module still exist in the v3 version, but not in the v2 version.

2. **TimeMix**:
- In the v2 version, the time mixing technique was also used in the calculation of the TimeMix module, and the `time_mix` parameter was used to control the degree of time mixing. In the v3 version, this technique was removed and time mixing was no longer used.
- At the same time, the TimeMix module in the v3 version canceled the `time_mix` parameter, and directly used the time sliding window to process the input in the calculation.
- In addition, the `time_mix_v` parameter is removed in the v3 version, and the values ​​are no longer time-mixed in the calculation.

Overall, the v3 version simplifies and adjusts the operations of time mixing and channel mixing, eliminating some of the complexity in previous versions, making itThe model is more concise and efficient.

In [3]:
class RWKV_ChannelMix(nn.Module):
    def __init__(self, layer_id):
        super().__init__()
        self.layer_id = layer_id

        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        self.time_mix_k = nn.Parameter(torch.ones(1, 1, n_embd))
        self.time_mix_r = nn.Parameter(torch.ones(1, 1, n_embd))
        
        hidden_sz = 4 * n_embd
        self.key = nn.Linear(n_embd, hidden_sz, bias=False)
        self.receptance = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(hidden_sz, n_embd, bias=False)

    def forward(self, x):
        xx = self.time_shift(x)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

        k = self.key(xk)
        k = torch.square(torch.relu(k))
        kv = self.value(k)
        
        rkv = torch.sigmoid(self.receptance(xr)) * kv
        return rkv

In [4]:
class RWKV_TimeMix(nn.Module):
    def __init__(self, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.time_decay = nn.Parameter(torch.ones(n_embd, 1))
        self.time_curve = torch.tensor([-(ctx_len - 2 - i) for i in range(ctx_len-1)]).unsqueeze(0)
        self.time_first = nn.Parameter(torch.ones(n_embd, 1) * math.log(0.3))
        
        self.time_shift = nn.ZeroPad2d((0,0,1,-1))
        self.time_mix_k = nn.Parameter(torch.ones(1,1,n_embd))
        self.time_mix_v = nn.Parameter(torch.ones(1,1,n_embd))
        self.time_mix_r = nn.Parameter(torch.ones(1,1,n_embd))

        self.key = nn.Linear(n_embd, n_embd, bias=False)
        self.value = nn.Linear(n_embd, n_embd, bias=False)
        self.receptance = nn.Linear(n_embd, n_embd, bias=False)

        self.output = nn.Linear(n_embd, n_embd, bias=False)

    def forward(self, x):
        B, T, C = x.size()

        xx = self.time_shift(x)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

        k = self.key(xk).transpose(-1, -2)
        v = self.value(xv).transpose(-1, -2)
        r = self.receptance(xr)

        k = torch.clamp(k, max=60)
        k = torch.exp(k)

        kv = k * v

        self.time_w = torch.cat([torch.exp(self.time_decay) * self.time_curve.to(self.time_decay.device), self.time_first], dim=-1)
        w = torch.exp(self.time_w)
        
        w = w[:,-T:].unsqueeze(1)
        wkv = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(kv), w, groups=C)
        wk = F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w, groups=C) + K_EPS

        rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
        
        rwkv = self.output(rwkv)
        return rwkv

In this code, `xk`, `xv`, and `xr` represent the input after time mixing, and are used for the key, value, and receiving vectors for subsequent calculations. In version v3, these three vectors are calculated separately, while in version v2, these three vectors are calculated in the same time mixing process.

Specifically, `xk` is used to calculate the key vector through the input after time mixing, `xv` is used to calculate the value vector, and `xr` is used to calculate the receiving vector. This separate calculation method can make the model more flexible and better adapt to different data features.

The method of calculating the key, value, and receiving vector separately is also mentioned in some related papers or studies. For example, in the self-attention mechanism, the key, value, and query vectors are usually calculated separately. This separate calculation method can improve the flexibility and expressiveness of the model, so it is widely used in practice.

### RWKV Block

RWKV Block is a basic module that combines TimeMix and ChannelMix operations. Each module in Block (TimeMix and ChannelMix) processes input data through normalization and residual connection to enhance the stability and performance of the model.

### Main components and operations

1. **LayerNorm**: used to normalize input and enhance the stability of training.

- `self.ln1` and `self.ln2` normalize the input before TimeMix and ChannelMix, respectively.

2. **TimeMix**: Combine the information of the current time step and the previous time step to capture the time dependency.

- `self.att = RWKV_TimeMix(layer_id)` initializes the TimeMix module.

3. **ChannelMix**: Mix between different channels to enhance the expressiveness of the model.
- `self.ffn = RWKV_ChannelMix(layer_id)` initializes the channel mixing module.

4. **Residual connection**: By adding the output of the mixing operation back to the original input, it maintains the information flow and enhances the gradient propagation ability of the model.

Through this settingRWKV's Block can efficiently process sequence data, combine time and channel information, and improve the performance of the model.

In [5]:
class Block(nn.Module):
    def __init__(self, layer_id):
        super().__init__()
        self.layer_id = layer_id

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        if self.layer_id == 0: #增加了初始的归一化
            self.ln0 = nn.LayerNorm(n_embd)
        
        self.att = RWKV_TimeMix(layer_id)
        self.ffn = RWKV_ChannelMix(layer_id)

    def forward(self, x):
        if self.layer_id == 0:
            x = self.ln0(x)
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

Next, the main parts of the RWKV model are implemented:

1. **Model loading and preprocessing**: The code loads the model weights and preprocesses the time-related weights.
2. **LayerNorm**: Layer normalization is implemented in the `LN` method, about the use of LayerNorm.
3. **Feedforward network (FF) and self-attention (SA)**: The `FF` method implements the calculation of the feedforward network, and the `SA` method implements the calculation of the self-attention mechanism. These two parts correspond to the detailed calculation of TimeMix and ChannelMix.
4. **Running the model**: The `run` method implements the overall operation logic of the model, passes through each layer in turn, and finally outputs the results. That is, the operation and reasoning process of the model.

In [11]:
time_buf = {}

class RWKV_RNN():
    def __init__(self, MODEL_NAME=MODEL_NAME):
        print('\nloading RWKV-RNN', MODEL_NAME)
        self.ctx_len = ctx_len
        self.n_layer = n_layer
        self.n_embd = n_embd
        self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=VOCAB_NAME)

        self.w = types.SimpleNamespace()
        
        w = torch.load(MODEL_NAME + '.pth', map_location=torch.device(RUN_DEVICE))

        for x in w.keys():
            if '.time_' in x:
                w[x] = w[x].squeeze()
            if '.time_decay' in x:
                w[x] = torch.exp(-torch.exp(w[x]))
            if '.time_first' in x:
                w[x] = torch.exp(w[x])
                    
            xx = x.split('.')
            here = self.w
            for i in range(len(xx)):
                if xx[i].isdigit():
                    ii = int(xx[i])
                    if ii not in here:
                        here[ii] = types.SimpleNamespace()
                    here = here[ii]
                else:
                    if i == len(xx) - 1:
                        setattr(here, xx[i], w[x])
                    elif not hasattr(here, xx[i]):
                        if xx[i+1].isdigit():
                            setattr(here, xx[i], {})
                        else:
                            setattr(here, xx[i], types.SimpleNamespace())
                    here = getattr(here, xx[i])

        self.clear()
    
    def clear(self):
        self.xx = {}
        self.aa = {}
        self.bb = {}
    def save(self, target):
        target.xx = copy.deepcopy(self.xx)
        target.aa = copy.deepcopy(self.aa)
        target.bb = copy.deepcopy(self.bb)
    def load(self, target):
        self.xx = copy.deepcopy(target.xx)
        self.aa = copy.deepcopy(target.aa)
        self.bb = copy.deepcopy(target.bb)

    def LN(self, xx, w):
        return F.layer_norm(xx, (n_embd,), weight=w.weight, bias=w.bias)

    def FF(self, xx, w, name):
        if name not in self.xx:
            self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
        xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
        xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)

        self.xx[name] = xx

        r = torch.sigmoid(w.receptance.weight @ xr)
        k = torch.square(torch.relu(w.key.weight @ xk))
        kv = w.value.weight @ k

        return r * kv

    def SA(self, xx, w, name):
        if name not in self.xx:
            self.xx[name] = torch.zeros(n_embd, device=RUN_DEVICE)
            self.aa[name] = torch.zeros(n_embd, device=RUN_DEVICE)
            self.bb[name] = torch.zeros(n_embd, device=RUN_DEVICE)

        xk = xx * w.time_mix_k + self.xx[name] * (1 - w.time_mix_k)
        xv = xx * w.time_mix_v + self.xx[name] * (1 - w.time_mix_v)
        xr = xx * w.time_mix_r + self.xx[name] * (1 - w.time_mix_r)

        self.xx[name] = xx

        r = torch.sigmoid(w.receptance.weight @ xr)

        k = torch.exp(torch.clamp(w.key.weight @ xk, max=60))
        v = w.value.weight @ xv
        kv = k * v

        a = self.aa[name] + w.time_first * kv
        b = self.bb[name] + w.time_first * k
        self.aa[name] = w.time_decay * self.aa[name] + kv
        self.bb[name] = w.time_decay * self.bb[name] + k

        rwkv = r * a / (b + K_EPS)

        return w.output.weight @ rwkv

    def run(self, ctx):
        w = self.w
        x = w.emb.weight[ctx[-1]]

        x = self.LN(x, w.blocks[0].ln0) #相比v2版本，增加了一个初始的归一化
        for i in range(n_layer):
            x = x + self.SA(self.LN(x, w.blocks[i].ln1), w.blocks[i].att, f'att.{i}')
            x = x + self.FF(self.LN(x, w.blocks[i].ln2), w.blocks[i].ffn, f'ffn.{i}')

        x = self.LN(x, w.ln_out)

        x = w.head.weight @ x
        x = x.tolist()

        return x

In [12]:
# Edit model.py to set CPU / CUDA mode. Runs on CPU by default.

TEMPERATURE = 1.0
TOP_P = 0.7

DEBUG_DEBUG = False
LENGTH_OF_EACH = 333
NUM_TRIALS = 3

context = '\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.'

##############################################################################################################

In [13]:
model = RWKV_RNN()


loading RWKV-RNN /data1/ckw/RWKV-3-Pile-20220720-10704


Next, we sample from the given output logits to generate a new token. It implements **temperature adjustment sampling** and **core sampling (Top-p sampling)**. The specific steps are as follows:

1. **Softmax conversion**: Convert the logits output by the model to a probability distribution through the softmax function.
2. **Sorting and cumulative probability calculation**: Sort the probabilities from high to low and calculate the cumulative probability distribution.
3. **Core sampling**:
- Calculate the minimum value of the cumulative probability exceeding `top_p` and determine the cutoff value `cutoff`.
- Set all probabilities below the cutoff value to 0, thereby retaining the most important `top_p` part of the probability.
4. **Temperature adjustment**: If `temperature` is not 1, adjust the probability distribution to make it smoother or sharper.
5. **Sampling**: Sample a value from the adjusted probability distribution and return the corresponding index.

This method is particularly commonly used in text generation tasks. By adjusting the `temperature` and `top_p` parameters, the diversity and quality of the generated text can be controlled.

The sampling method between v2 and v3 has not changed.

In [14]:
def sample_logits(out, temperature=1.0, top_p=None):
# Convert the output into a probability distribution (via the softmax function)
    probs = F.softmax(torch.tensor(out), dim=-1)
    
# Sort by probability from high to low
    sorted_probs, _ = torch.sort(probs, descending=True)

# Calculate the cumulative probability distribution
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy()
    
# Calculate the cutoff value (cutoff) based on the cumulative probability and top_p
    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
    
# Set the probability below the cutoff value to 0
    probs[probs < cutoff] = 0

# If temperature is not equal to 1, the probability is adjusted by temperature
    if temperature != 1.0:
        probs = probs.pow(1.0 / temperature)

# Sample a value from the adjusted probability distribution and return it
    return torch.multinomial(probs, num_samples=1)[0]


In [15]:
for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS):
    ctx = [model.tokenizer.encode(context)][0]
    src_len = len(ctx)
    print(context, end='')

    model.clear()
    if TRIAL == 0: # build the RNN hidden state?
        init_state = types.SimpleNamespace()
        for i in range(src_len if DEBUG_DEBUG else src_len):
            x = ctx[:i+1]
            if i == src_len - 1:
                init_state.out = model.run(x)
            else:
                model.run(x)
        model.save(init_state)
    else:
        model.load(init_state)

    if DEBUG_DEBUG:
        out = init_state.out
        print('\n', np.array(x), '==>', np.array(
            out), np.max(out), np.min(out))

    for i in range(src_len, src_len + (0 if DEBUG_DEBUG else LENGTH_OF_EACH)):
        x = ctx[:i+1]
        x = x[-model.ctx_len:]

        if i == src_len:
            out = copy.deepcopy(init_state.out) # load the RNN hidden state
        else:
            out = model.run(x) # run the RNN

        out[0] = -999999999  # disable <|endoftext|>

        char = sample_logits(out, temperature=TEMPERATURE, top_p=TOP_P)
        char = char.item()
        print(model.tokenizer.decode(char), end='', flush=True)

        ctx += [char]
    print('\n' + '-' * 70, end='')


DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The technology focuses on learning from human behavior and not information.

We are an independent Data Whalechina team. This team of trained Data Whalechina team is available to answer any question and provide guidance.

The information provided on this site is not legal advice, and should not be construed as legal advice. You should consult a lawyer for advice regarding your specific situation.

We take no responsibility for the content, accuracy, or completeness of any information on this site or any information provided from third parties.

Cookies are used to store information about you so that we can remember and provide you with products and services you may have used. You can opt-out of our use of cookies at any time. To learn more, please read our cookie policy.

Cookies are tiny files stored on your computer that allow you to access information such a