In [1]:
import numpy as np
import types, torch
from torch.nn import functional as F
from tokenizers import Tokenizer
from time import time

In [2]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [3]:
class RWKV_RNN(torch.jit.ScriptModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.eval() # set torch to inference mode
        
        w = torch.load(args.MODEL_NAME + '.pth', map_location=device)
        for k in w.keys():
            if      '.time_' in k: w[k] = w[k].squeeze()
            if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}
            else: w[k] = w[k].float() # convert to f32 type
        
        self.w = types.SimpleNamespace() # set self.w from w
        self.w.blocks = {}
        for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
            parts = k.split('.')
            last = parts.pop()
            here = self.w
            for p in parts:
                if p.isdigit():
                    p = int(p)
                    if p not in here: here[p] = types.SimpleNamespace()
                    here = here[p]
                else:
                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
                    here = getattr(here, p)
            setattr(here, last, w[k])

    def layer_norm(self, x, w):
        return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)

    @torch.jit.script_method
    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
        state[5*i+0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

    @torch.jit.script_method
    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
        xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
        xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
        xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
        state[5*i+1] = x
        r = torch.sigmoid(rw @ xr)
        k = kw @ xk
        v = vw @ xv
        
        aa = state[5*i+2]
        bb = state[5*i+3]
        pp = state[5*i+4]
        ww = time_first + k
        qq = torch.maximum(pp, ww)
        e1 = torch.exp(pp - qq)
        e2 = torch.exp(ww - qq)
        a = e1 * aa + e2 * v
        b = e1 * bb + e2
        wkv = a / b
        ww = pp + time_decay
        qq = torch.maximum(ww, k)
        e1 = torch.exp(ww - qq)
        e2 = torch.exp(k - qq)
        state[5*i+2] = e1 * aa + e2 * v
        state[5*i+3] = e1 * bb + e2
        state[5*i+4] = qq
        return ow @ (r * wkv)

    def forward(self, token, state):
        with torch.no_grad():
            if state == None:
                state = torch.zeros(self.args.n_layer * 5, self.args.n_embd, device=device)
                for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity
            
            x = self.w.emb.weight[token]
            x = self.layer_norm(x, self.w.blocks[0].ln0)
            for i in range(self.args.n_layer):
                att = self.w.blocks[i].att
                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, 
                    att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay, 
                    att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
                ffn = self.w.blocks[i].ffn
                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, 
                    ffn.time_mix_k, ffn.time_mix_r, 
                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
            
            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
            return x.float(), state
        
    def sample_logits(self, logits, temperature=1.0, top_p=0.8, top_k=0):
        top_k = int(top_k)
        
        probs = F.softmax(logits.float(), dim=-1)
        sorted_ids = torch.argsort(probs)
        sorted_probs = probs[sorted_ids]
        sorted_probs = torch.flip(sorted_probs, dims=(0,))
        cumulative_probs = torch.cumsum(sorted_probs, 0)
        cutoff = float(sorted_probs[torch.argmax((cumulative_probs > top_p).to(torch.long))])
        probs[probs < cutoff] = 0
        if top_k < len(probs) and top_k > 0:
                probs[sorted_ids[:-top_k]] = 0
        if temperature != 1.0:
            probs = probs.pow(1.0 / temperature)
        probs = probs / torch.sum(probs)
        out = torch.multinomial(probs, num_samples=1)[0]
        return int(out)
    
    def preprocess(self, token_ids, init_state=None):
        for token_id in token_ids:
            init_out, init_state = self.forward(token_id, init_state)
        return init_out, init_state
        
    def generate_from_initial_state(self, init_out, init_state, temperature=0.5, top_p=0.85, max_num_tokens=100, top_k=0):
        all_tokens = []
        out_last = 0
        out, state = init_out.clone(), init_state.clone()
        for i in range(max_num_tokens):
            token = self.sample_logits(out, temperature, top_p, top_k)
            all_tokens += [token]
            out, state = self.forward(token, state)

        return all_tokens
    
    def generate(self, token_ids, temperature=0.5, top_p=0.85, max_num_tokens=100, top_k=0):
        init_out, init_state = self.preprocess(token_ids)
        return self.generate_with_initial_state(init_out, init_state, temperature, top_p, max_num_tokens, top_k)
        

In [4]:
tokenizer = Tokenizer.from_file("data/rwkv/20B_tokenizer.json")

args = types.SimpleNamespace()
args.MODEL_NAME = 'data/rwkv/RWKV-4-Pile-430M-20220808-8066'
args.n_layer = 24
args.n_embd = 1024

print(f'\nLoading model {args.MODEL_NAME} using {device}...')
start_time = time()
model = RWKV_RNN(args)
elapsed_time = time() - start_time
print(f'Model loaded in {elapsed_time:.2f} seconds')


Loading model data/rwkv/RWKV-4-Pile-430M-20220808-8066 using cuda...
Model loaded in 41.98 seconds


In [5]:
context = "In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese."

token_ids = tokenizer.encode(context).ids
init_out, init_state = model.preprocess(token_ids)
for i in range(5):
    generated = tokenizer.decode(model.generate_from_initial_state(init_out, init_state))
    print(f'\n\n------- Run number {i+1} -------\n',context, generated)





------- Run number 1 -------
 In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese. 

The dragons were found in a valley in the far west of Tibet, which is home to the largest population of Tibetan sheep. The dragons were discovered by a herding team of Chinese scientists who were searching for an area of unpopulated land that had been discovered by a team of Chinese scientists.

The team of Chinese scientists discovered the dragons and discovered that there was a herd of dragons living in the valley. The dragons were discovered by a team of Chinese scientists who were searching for an


------- Run number 2 -------
 In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chines