In [2]:
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import numpy as np
np.set_printoptions(precision=4, suppress=True, linewidth=200)
import types, torch
from torch.nn import functional as F
from tokenizers import Tokenizer

In [3]:
tokenizer = Tokenizer.from_file("20B_tokenizer.json")

args = types.SimpleNamespace()
args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066'
args.n_layer = 24
args.n_embd = 1024

context = "\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence."
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
TEMPERATURE = 1.0
TOP_P = 0.85
########################################################################################################

### RWKV Time Mixing Implementation

In the RWKV model, time mixing is a key step to handle the changes of input sequences over time. The following is a detailed formula description and code comments for the `time_mixing` function.

#### Formula Description

The core idea of ​​time mixing is to mix the current input with the previous state through the time mixing coefficient to generate new key, value and gating signals. This process involves the following steps:

1. **Mixed input**:
- Take a weighted average of the current input \(x \) and the previous state:
$$ x_k = x \cdot \text{time\_mix\_k} + \text{state}[5i+1] \cdot (1 - \text{time\_mix\_k}) $$
$$ x_v = x \cdot \text{time\_mix\_v} + \text{state}[5i+1] \cdot (1 - \text{time\_mix\_v}) $$
$$ x_r = x \cdot \text{time\_mix\_r} + \text{state}[5i+1] \cdot (1 - \text{time\_mix\_r}) $$

2. **State update**:
- Update state:
$$ \text{state}[5i+1] = x $$

3. **Calculate gate signal**:
- Use sigmoid activation function to calculate gate signal \( r \):
$$ r = \sigma(\text{rw} @ x_r) $$

4. **Calculate key and value**:
- Generate key \( k \) and value \( v \) through linear transformation:
$$ k = \text{kw} @ x_k $$
$$ v = \text{vw} @ x_v $$

5. **Weighted sum calculation**:
- Calculate weighted sum \( wkv \) according to weighted sum formula:
$$ a = e1 \cdot aa + e2 \cdot v $$
$$ b = e1 \cdot bb + e2 $$
$$ \text{wkv} = a / b $$

Code as follows:

```python
@torch.jit.script_method
def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
# Mix the current input with the previous state
xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)

# Update the state
state[5*i+1] = x

# Compute the gate signal
r = torch.sigmoid(rw @ xr)

# Compute the key and value
k = kw @ xk
v = vw @ xv

# Read the previous accumulated value from the state
aa = state[5*i+2]
bb = state[5*i+3]
pp = state[5*i+4]

# Calculate the weighted sumPart
ww = time_first + k
qq = torch.maximum(pp, ww)
e1 = torch.exp(pp - qq)
e2 = torch.exp(ww - qq)
a = e1 * aa + e2 * v
b = e1 * bb + e2
wkv = a / b

# Calculate new weights and states
ww = pp + time_decay
qq = torch.maximum(ww, k)
e1 = torch.exp(ww - qq)
e2 = torch.exp(k - qq)
state[5*i+2] = e1 * aa + e2 * v
state[5*i+3] = e1 * bb + e2
state[5*i+4] = qq

# Calculate the final output
return ow @ (r * wkv)
```

### Detailed explanation

1. **Mixed input**:
- `xk`, `xv`, `xr` are the input `x` and the state `state`, used to calculate the key, value and gating signal respectively.

2. **State Update**:
- Store the current input `x` in the state array for use in the next step.

3. **Calculate the gating signal**:
- Use `torch.sigmoid` to calculate the gating signal `r`, which determines how much information will be passed.

4. **Calculate Keys and Values**:
- Use matrix multiplication to calculate the key `k` and value `v`.

5. **Weighted Sum Calculation**:
- Calculate the weighted sum `wkv` by exponential weighted average, which involves dealing with numerical stability issues (via `torch.maximum` and exponential operations).

6. **Update State**:
- Update the accumulated values ​​in the state array for use in subsequent time steps.

7. **Calculate Final Output**:
- Use the gating signal `r` and the weighted sum `wkv` to calculate the final output.

In this way, by gradually mixing the current input and the previous state, the RWKV model achieves efficient processing of time series data.

### RWKV Channel Mixing Implementation and Code Comments

In the RWKV model, channel mixing is another key step for handling information exchange between different channels. The following is a detailed formula description and code comments for the `channel_mixing` function.

#### Formula Description

The core idea of ​​channel mixing is to mix the current input with the previous state through the channel mixing coefficient to generate new key and gating signals. This process involves the following steps:

1. **Mixed input**:
- Take a weighted average of the current input \(x \) and the previous state:
$$ x_k = x \cdot \text{time\_mix\_k} + \text{state}[5i+0] \cdot (1 - \text{time\_mix\_k}) $$
$$ x_r = x \cdot \text{time\_mix\_r} + \text{state}[5i+0] \cdot (1 - \text{time\_mix\_r}) $$

2. **State update**:
- Update the state:
$$ \text{state}[5i+0] = x $$3. **Calculate the gate signal**:
- Use the sigmoid activation function to calculate the gate signal \(r\):
$$r = \sigma(\text{rw} @ x_r) $$

4. **Calculate the key**:
- Generate the key \(k\) through ReLU and square transformation:
$$k = (\text{ReLU}(\text{kw} @ x_k))^2 $$

5. **Calculate the output**:
- Use the gate signal and the key to calculate the final output:
$$\text{output} = r \cdot (\text{vw} @ k) $$

The code is as follows:

```python
@torch.jit.script_method
def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
# Mix the current input and the previous state
xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)

# Update state
state[5*i+0] = x

# Calculate gating signal
r = torch.sigmoid(rw @ xr)

# Calculate key and pass ReLU and square transformation
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper

# Calculate final output
return r * (vw @ k)
```

1. **Mixed input**:
- `xk`, `xr` are weighted mixtures of input `x` and state `state`, used to calculate key and gating signal respectively.

2. **State update**:
- Store the current input `x` in the state array for the next step of calculation.

3. **Calculate gating signal**:
- Use `torch.sigmoid` to calculate the gating signal `r`, which determines how much information will be passed.

4. **Calculate the key**:
- Use `torch.relu` to calculate the key `k`, and then perform a square transformation to increase nonlinearity.

5. **Calculate the final output**:
- Use the gating signal `r` and the key `k` to calculate the final output.

Through these steps, the RWKV model realizes the effective exchange of information between channels and enhances the model's ability to process input data.

In [4]:
class RWKV_RNN(torch.jit.ScriptModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.eval() # set torch to inference mode
        
        w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
        for k in w.keys():
            if      '.time_' in k: w[k] = w[k].squeeze()
            if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
            else: w[k] = w[k].float() # convert to f32 type
        
        self.w = types.SimpleNamespace() # set self.w from w
        self.w.blocks = {}
        for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
            parts = k.split('.')
            last = parts.pop()
            here = self.w
            for p in parts:
                if p.isdigit():
                    p = int(p)
                    if p not in here: here[p] = types.SimpleNamespace()
                    here = here[p]
                else:
                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                    here = getattr(here, p)
            setattr(here, last, w[k])

    def layer_norm(self, x, w):
        return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)

    @torch.jit.script_method
    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
        state[5*i+0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

    @torch.jit.script_method
    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
        xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
        xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
        xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
        state[5*i+1] = x
        r = torch.sigmoid(rw @ xr)
        k = kw @ xk
        v = vw @ xv
        
        aa = state[5*i+2]
        bb = state[5*i+3]
        pp = state[5*i+4]
        ww = time_first + k
        qq = torch.maximum(pp, ww)
        e1 = torch.exp(pp - qq)
        e2 = torch.exp(ww - qq)
        a = e1 * aa + e2 * v
        b = e1 * bb + e2
        wkv = a / b
        ww = pp + time_decay
        qq = torch.maximum(ww, k)
        e1 = torch.exp(ww - qq)
        e2 = torch.exp(k - qq)
        state[5*i+2] = e1 * aa + e2 * v
        state[5*i+3] = e1 * bb + e2
        state[5*i+4] = qq
        return ow @ (r * wkv)

    def forward(self, token, state):
        with torch.no_grad():
            if state == None:
                state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
                for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity
            
            x = self.w.emb.weight[token]
            x = self.layer_norm(x, self.w.blocks[0].ln0)
            for i in range(self.args.n_layer):
                att = self.w.blocks[i].att
                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, 
                    att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay, 
                    att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
                ffn = self.w.blocks[i].ffn
                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, 
                    ffn.time_mix_k, ffn.time_mix_r, 
                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
            
            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
            return x.float(), state

##########################################################################################################

The sampling method has not changed compared to versions v2 and v3, only some optimization adjustments have been made to the code.

In [6]:
def sample_logits(out, temperature=1.0, top_p=0.8):
    probs = F.softmax(out, dim=-1).numpy()
    sorted_probs = np.sort(probs)[::-1]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
    probs[probs < cutoff] = 0
    if temperature != 1.0:
        probs = probs.pow(1.0 / temperature)
    probs = probs / np.sum(probs)
    out = np.random.choice(a=len(probs), p=probs)
    return out

########################################################################################################

In [7]:
print(f'\nUsing CPU. Loading {args.MODEL_NAME} ...')
model = RWKV_RNN(args)


Using CPU. Loading /data1/ckw/RWKV-4-Pile-430M-20220808-8066 ...


In [8]:
print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None
for token in tokenizer.encode(context).ids:
    init_out, init_state = model.forward(token, init_state)


Preprocessing context (slow version. see v2/rwkv/model.py for fast version)


In [9]:
for TRIAL in range(NUM_TRIALS):
    print(f'\n\n--[ Trial {TRIAL} ]-----------------', context, end="")
    all_tokens = []
    out_last = 0
    out, state = init_out.clone(), init_state.clone()
    for i in range(LENGTH_PER_TRIAL):
        token = sample_logits(out, TEMPERATURE, TOP_P)
        all_tokens += [token]
        tmp = tokenizer.decode(all_tokens[out_last:])
        if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
            print(tmp, end="", flush=True)
            out_last = i + 1
        out, state = model.forward(token, state)       
print('\n')



--[ Trial 0 ]----------------- 
DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The machine learning solutions applied to the class are called Persona, which consist of several categories:

\begin{tabular}{|c|c|c|}
\hline
  Name   & Description  \\ \hline
\hline
  \end{tabular}

DataWhalechina organizes the data in two ways:

\begin{tabular}{|c|c|c|}
\hline
  \multicolumn{2}{|c}{

--[ Trial 1 ]----------------- 
DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The main goal is to allow learners to learn how to use artificial intelligence in an integrated fashion, by using both AI and deep learning techniques. Datawhalechina aims to teach AI algorithms from scratch and teach them from scratch to become competent with many algorithms that humans could not have.

Applications

Projects 
 DeeplearningAI : Encourage AI algorithms to bec

### Note: RWKV's Scaling Law

RWKV's scaling laws describe the mathematical relationship between model performance and various factors. These factors include model size ($N$), dataset size ($D$), or optimal computational budget ($C_{\min}$). Scaling laws are important in two ways:
1. **Prediction and Planning**: They allow us to predict and plan costs and performance by interpolation and extrapolation before training large models.
2. **Feedback and Research**: They provide important feedback on model failures and guide future research directions.

#### Summary of key points:
- **Comparison with previous RNN research**: Previous work has shown that LSTM does not follow exactly the same log-linear scaling law as Transformer. However, the training results of the RWKV model show that RWKV follows the same general scaling law form as Transformer.
- **Experimental verification**: In the [v4 paper](https://arxiv.org/abs/2305.13048), 45 RWKV models were trained to verify the linear relationship between the loss and the amount of computation. The $r^2$ value of the linear fit was 0.994, and even if it was extrapolated by an order of magnitude, the fit was still very good ($r^2$ was 0.875).

These results show the superiority of the RWKV model when scalingand similar performance scaling behavior to Transformer.