In [2]:
import numpy as np
from torch import load as torch_load  # Only for loading the model weights
from tokenizers import Tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
layer_norm = lambda x, w, b : (x - np.mean(x)) / np.std(x) * w + b
exp = np.exp
sigmoid = lambda x : 1/(1 + exp(-x))

In [4]:
def time_mixing(x, last_x, last_num, last_den, decay, bonus, mix_k, mix_v, mix_r, Wk, Wv, Wr, Wout):

    #   *state[i][:3]: last_x, last_num, last_den
    #   *params(f'blocks.{i}.att'): 
    #   time_decay, time_first, time_mix_k, time_mix_v, time_mix_r, key.weight, value.weight, receptanec.weight, output.weight
    #   decay,      bonus,      mix_k,      mix_v,      mix_r,      Wk,         Wv,           Wr,                Wout

    k = Wk @ ( x * mix_k + last_x * (1 - mix_k) )
    v = Wv @ ( x * mix_v + last_x * (1 - mix_v) )
    r = Wr @ ( x * mix_r + last_x * (1 - mix_r) )

    wkv = (last_num + exp(bonus + k) * v) / (last_den + exp(bonus + k))
    rwkv = sigmoid(r) * wkv

    num = exp(-exp(decay)) * last_num + exp(k) * v
    den = exp(-exp(decay)) * last_den + exp(k)

    return Wout @ rwkv, (x,num,den)

In [5]:
def channel_mixing(x, last_x, mix_k, mix_r, Wk, Wr, Wv):
    k = Wk @ ( x * mix_k + last_x * (1 - mix_k) )
    r = Wr @ ( x * mix_r + last_x * (1 - mix_r) )
    vk = Wv @ np.maximum(k, 0)**2
    return sigmoid(r) * vk, x

In [6]:
def RWKV(model, token, state):
    # get embedding
    params = lambda prefix : [model[key] for key in model.keys() if key.startswith(prefix)]

    x = params('emb')[0][token]
    x = layer_norm(x, *params('blocks.0.ln0'))

    for i in range(N_LAYER):
        x_ = layer_norm(x, *params(f'blocks.{i}.ln1'))
        dx, state[i][:3] = time_mixing(x_, *state[i][:3], *params(f'blocks.{i}.att'))
        x = x + dx

        x_ = layer_norm(x, *params(f'blocks.{i}.ln2'))
        dx, state[i][3] = channel_mixing(x_, state[i][3], *params(f'blocks.{i}.ffn'))
        x = x + dx

    x = layer_norm(x, *params('ln_out'))
    x = params('head')[0] @ x

    e_x = exp(x-np.max(x))
    probs = e_x / e_x.sum() # Softmax of x

    return probs, state

In [7]:
def sample_probs(probs, temperature=1.0, top_p=0.85):
    sorted_probs = np.sort(probs)[::-1]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff = sorted_probs[np.argmax(cumulative_probs > top_p)]
    probs[probs < cutoff] = 0
    probs = probs**(1/temperature)
    # will generate a random number
    return np.random.choice(a=len(probs), p=probs/np.sum(probs))

===================================================================================

In [8]:
# Available at https://huggingface.co/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth
MODEL_FILE = 'checkpoints/rwkv_file/RWKV-4-Pile-430M-20220808-8066.pth'
N_LAYER = 24
N_EMBD = 1024

print(f'\nLoading {MODEL_FILE}')
weights = torch_load(MODEL_FILE, map_location='cpu')
for k in weights.keys():
    if '.time_' in k: weights[k] = weights[k].squeeze()
    weights[k] = weights[k].float().numpy() # convert to f32 type


Loading checkpoints/rwkv_file/RWKV-4-Pile-430M-20220808-8066.pth


In [9]:
print(weights.keys())

odict_keys(['emb.weight', 'blocks.0.ln1.weight', 'blocks.0.ln1.bias', 'blocks.0.ln2.weight', 'blocks.0.ln2.bias', 'blocks.0.att.time_decay', 'blocks.0.att.time_first', 'blocks.0.att.time_mix_k', 'blocks.0.att.time_mix_v', 'blocks.0.att.time_mix_r', 'blocks.0.att.key.weight', 'blocks.0.att.value.weight', 'blocks.0.att.receptance.weight', 'blocks.0.att.output.weight', 'blocks.0.ffn.time_mix_k', 'blocks.0.ffn.time_mix_r', 'blocks.0.ffn.key.weight', 'blocks.0.ffn.receptance.weight', 'blocks.0.ffn.value.weight', 'blocks.0.ln0.weight', 'blocks.0.ln0.bias', 'blocks.1.ln1.weight', 'blocks.1.ln1.bias', 'blocks.1.ln2.weight', 'blocks.1.ln2.bias', 'blocks.1.att.time_decay', 'blocks.1.att.time_first', 'blocks.1.att.time_mix_k', 'blocks.1.att.time_mix_v', 'blocks.1.att.time_mix_r', 'blocks.1.att.key.weight', 'blocks.1.att.value.weight', 'blocks.1.att.receptance.weight', 'blocks.1.att.output.weight', 'blocks.1.ffn.time_mix_k', 'blocks.1.ffn.time_mix_r', 'blocks.1.ffn.key.weight', 'blocks.1.ffn.recep

In [10]:
tokenizer = Tokenizer.from_file("checkpoints/rwkv_file/20B_tokenizer.json")

In [11]:
print(f'\nPreprocessing context')

context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

state = np.zeros((N_LAYER, 4, N_EMBD), dtype=np.float32)
print(state.shape)



Preprocessing context
(24, 4, 1024)


In [None]:
for token in tokenizer.encode(context).ids:
    probs, state = RWKV(weights, token, state)

In [14]:
print(context, end="")
for i in range(100):
    token = sample_probs(probs)
    print(tokenizer.decode([token]), end="", flush=True)
    probs, state = RWKV(weights, token, state)


In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.

The international team of scientists say they found evidence of three dragons, sharing one tongue, with “chosen” and “white” dragons.

The work was conducted by Dr Haijing Chen, of the Chinese Academy of Sciences, and her team from the Academy of Chinese Academy of Sciences, China Academy of Sciences.

“We detected these three dragons and their descendants, along with their grandchildren, in the high mountainous region of Tibet,” says Chen. “This is a rare

playground

In [12]:
token = tokenizer.encode(context).ids[0]

In [95]:

params = lambda prefix : [weights[key] for key in weights.keys() if key.startswith(prefix)]
x = params('emb')[0][token]
x = layer_norm(x, *params('blocks.0.ln0'))
i = 0
x_ = layer_norm(x, *params(f'blocks.{i}.ln1'))
# dx, state[i][:3] = time_mixing(x_, *state[i][:3], *params(f'blocks.{i}.att'))
_x, last_x, last_num, last_den, decay, bonus, mix_k, mix_v, mix_r, Wk, Wv, Wr, Wout = \
    x_, *state[i][:3], *params(f'blocks.{i}.att')
k = Wk @ ( _x * mix_k + last_x * (1 - mix_k) )
v = Wv @ ( _x * mix_v + last_x * (1 - mix_v) )
r = Wr @ ( _x * mix_r + last_x * (1 - mix_r) )

wkv = (last_num + exp(bonus + k) * v) / (last_den + exp(bonus + k))
rwkv = sigmoid(r) * wkv
num = exp(-exp(decay)) * last_num + exp(k) * v
den = exp(-exp(decay)) * last_den + exp(k)
dx = Wout @ rwkv
state[i][:3] = (_x,num,den)

x = x + dx

x_ = layer_norm(x, *params(f'blocks.{i}.ln2'))

# dx, state[i][3] = channel_mixing(x_, state[i][3], *params(f'blocks.{i}.ffn'))
_x, last_x, mix_k, mix_r, Wk, Wr, Wv = x_, state[i][3], *params(f'blocks.{i}.ffn')
k = Wk @ ( _x * mix_k + last_x * (1 - mix_k) )
r = Wr @ ( _x * mix_r + last_x * (1 - mix_r) )
vk = Wv @ np.maximum(k, 0)**2
dx, state[i][3] = sigmoid(r) * vk, _x

x = x + dx



In [96]:
Wk.shape

(4096, 1024)