----
## Transformer Attention

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
n_embd = 512
block_size = 64

In [None]:
class Head(nn.Module):
    '''one head of self-attention'''

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias = False)
        self.query = nn.Linear(n_embd, head_size, bias = False)
        self.value = nn.Linear(n_embd, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        # compute attention scores ("AFFINITIES")
        wei = q @ k.transpose(-2, -1) * c**-0.5
        wei = wei.masker_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim = -1)
        # perform the wighted aggregation of the values
        v = self.value(x)
        out = wei @ v
        return out

----
## RWKV Attention

In [25]:
import numpy as np
np.set_printoptions(precision = 4, suppress = True, linewidth = 200)
import types
import torch
from torch.nn import functional as F
from tokenizers import Tokenizer

In [26]:
# 加载分词器
tokenizer = Tokenizer.from_file('20B_tokenizer.json')

args = types.SimpleNamespace()

args.MODEL_NAME = 'RWKV-4-Pile-430M-20220808-8066'
args.n_layer = 24
args.n_embd = 1024

context = "\n"
NUM_TRIALS = 3
LENGTH_PER_TRIAL = 100
TEMPERATURE = 1.0
TOP_P = 0.85

class RWKV_RNN(torch.jit.ScriptModule):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.eval() # 将模型设置为评估模式，这样在模型中就不会使用dropout和batch normalization


        w = torch.load(args.MODEL_NAME + '.pth', map_location = 'cpu')
        for k in w.keys():
            if '.time_' in k:
                w[k] = w[k].squeeze()
            if '.time_decay' in k:
                w[k] = -torch.exp(w[k].float())
            else:
                w[k] = w[k].float()

        self.w = types.SimpleNamespace()
        self.w.blocks = {}
        for k in w.keys():
            parts = k.split('.')
            last = parts.pop()
            here = self.w
            for p in parts:
                if p.isdigit():
                    p = int(p)
                    if p not in here:
                        here[p] = types.SimpleNamespace()
                    here = here[p]
                else:
                    if not hasattr(here, p):
                        setattr(here, p, types.SimpleNamespace())
                    here = getattr(here, p)
            setattr(here, last, w[k])

    def layer_norm(self,x, w):
        return F.layer_norm(x, (self.args.n_embd,), weight = w.weight, bias = w.bias)
    
    @torch.jit.script_method
    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):
        '''Channel mixing function for RWKV model'''
        xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)
        xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)
        state[5*i+0] = x
        r = torch.sigmoid(rw @ xr)
        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
        return r * (vw @ k)

    @torch.jit.script_method
    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):
        '''Time mixing function for RWKV model'''
        xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)
        xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)
        xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)
        state[5*i+1] = x
        r = torch.sigmoid(rw @ xr)
        k = kw @ xk
        v = vw @ xv
        
        aa = state[5*i+2]
        bb = state[5*i+3]
        pp = state[5*i+4]
        ww = time_first + k
        qq = torch.maximum(pp, ww)
        e1 = torch.exp(pp - qq)
        e2 = torch.exp(ww - qq)
        a = e1 * aa + e2 * v
        b = e1 * bb + e2
        wkv = a / b
        ww = pp + time_decay
        qq = torch.maximum(ww, k)
        e1 = torch.exp(ww - qq)
        e2 = torch.exp(k - qq)
        state[5*i+2] = e1 * aa + e2 * v
        state[5*i+3] = e1 * bb + e2
        state[5*i+4] = qq
        return ow @ (r * wkv)

    def forward(self, token, state):
        '''前向传播函数，输入token和state，输出token的下一个token和更新后的state'''
        with torch.no_grad():
            if state == None:
                state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)
                for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity
            
            x = self.w.emb.weight[token]
            x = self.layer_norm(x, self.w.blocks[0].ln0)
            for i in range(self.args.n_layer):
                att = self.w.blocks[i].att
                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, 
                    att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay, 
                    att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)
                ffn = self.w.blocks[i].ffn
                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, 
                    ffn.time_mix_k, ffn.time_mix_r, 
                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
            
            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
            return x.float(), state


----
采样函数

In [29]:
def sample_logits(out, temperature = 1.0, top_p = 0.8):
    probs = F.softmax(out, dim = -1).numpy()

    sorted_probs = np.sort(probs)[::-1]
    cumulative_probs = np.cumsum(sorted_probs)
    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
    probs[probs < cutoff] = 0
    if temperature != 1.0:
        probs = probs.pow(1.0 / temperature)
    probs = probs / np.sum(probs)

    out = np.random.choice(a = len(probs), p = probs)
    return out

----
文本生成流程

In [30]:
# 打印使用CPU 加载模型的信息， 其中 args.MODEL_NAME 是模型名称
print(f'\nUsing CPU to load {args.MODEL_NAME} ...')
model = RWKV_RNN(args)

print(f'\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')
init_state = None
for token in tokenizer.encode(context).ids:
    init_out, init_state = model.forward(token, init_state)

for TRIAL in range(NUM_TRIALS):
    print(f'\n\n--[ Trial {TRIAL}]-----------------', context, end = "")
    all_tokens = []
    out_last = 0
    out, state = init_out.clone(), init_state.clone()
    for i in range(LENGTH_PER_TRIAL):
        token = sample_logits(out, TEMPERATURE, TOP_P)
        all_tokens += [token]
        tmp = tokenizer.decode(all_tokens[out_last:])
        if '\nfffd' not in tmp:
            print(tmp, end = "", flush = True)
            out_last = i + 1
        out, state = model.forward(token,state)
print('\n')


Using CPU to load RWKV-4-Pile-430M-20220808-8066 ...

Preprocessing context (slow version. see v2/rwkv/model.py for fast version)


--[ Trial 0]----------------- 
In a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese. According to the scientists, the dragons are the oldest living creatures in the entire world.

The strange, mysterious, and unique dragons have been living in this remote and inaccessible valley for thousands of years. This isolated valley was believed to be the world’s oldest active volcano, and the ancient Buddhist sage said that it should be removed from the endangered species list.

The researchers said that many previous scientific studies failed to agree that the dragons live in Tibet. The researchers also suspected that the dragons

--[ Trial 1]----------------- 
In a shocking finding, scientist discovered 