In [1]:
import torch
import json
import math
import hiq
import torch.nn as nn
import torch.nn.functional as F

from llama import ModelArgs, Transformer, Tokenizer, LLaMA
from llama.generation import sample_top_p

In [2]:
checkpoint = torch.load('../7B/consolidated.00.pth', map_location="cpu")

with open('../7B/params.json', "r") as f:
    params = json.loads(f.read())

In [3]:
model_args: ModelArgs = ModelArgs(
    max_seq_len=512, max_batch_size=1, **params
)
tokenizer = Tokenizer('../tokenizer.model')
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint,strict= False)

generator = LLaMA(model, tokenizer)

In [4]:
temperature: float = 0.8
top_p: float = 0.95
max_seq_len=512
max_batch_size=1

In [5]:
prompts = ["I believe the meaning of life is"]

max_gen_len = 256

bsz = 1 
params = params
prompt_tokens = [generator.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])

total_len = min(max_seq_len, max_gen_len + max_prompt_size)

tokens = torch.full((bsz, total_len),generator.tokenizer.pad_id).cuda().long()

for k, t in enumerate(prompt_tokens):
    tokens[k, : len(t)] = torch.tensor(t).long()
input_text_mask = tokens != generator.tokenizer.pad_id
start_pos = min_prompt_size

In [6]:
prev_pos = 0

with torch.no_grad():
    model.eval()
    for cur_pos in range(start_pos, total_len):
        if cur_pos == total_len-1:
            break
        input_tensor = torch.cat((tokens[:, prev_pos:cur_pos],torch.tensor([[prev_pos]]).cuda()), 1)
        logits = model(input_tensor)
        if temperature > 0:
            probs = torch.softmax(logits / temperature, dim=-1)
            next_token = sample_top_p(probs, top_p)
        else:
            next_token = torch.argmax(logits, dim=-1)
        next_token = next_token.reshape(-1)
        # only replace token if prompt has already been generated
        next_token = torch.where(
            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
        )
        tokens[:, cur_pos] = next_token
        prev_pos = cur_pos

In [7]:
input_tensor = torch.cat((tokens[:, prev_pos:cur_pos],torch.tensor([[prev_pos]]).cuda()), 1)

token_len = input_tensor.shape[1] 
tokens = input_tensor[:, 0: token_len-1]
start_pos = input_tensor[:, -1].item()
_bsz, seqlen = tokens.shape
h = model.tok_embeddings(tokens)
model.freqs_cis = model.freqs_cis.to(h.device)
freqs_cis = model.freqs_cis[start_pos : start_pos + seqlen]

mask = None
if seqlen > 1:
    mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
    mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

### Initialize new modules

In [8]:
from llama.model import FeedForward, apply_rotary_emb

In [9]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)

        return output * self.weight

In [10]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads // 1
        
        self.head_dim = args.dim // args.n_heads
    
        self.wq = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wk = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wv = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wo = nn.Linear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
        )
        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        )
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        )
        if hiq.get_env_bool("KV_CAHCHE_IN_GPU", True):
            self.cache_k = self.cache_k.cuda()
            self.cache_v = self.cache_v.cuda()

    def forward(self, x: torch.Tensor, freqs_cis:torch.Tensor):
        start_pos = 262  
        bsz, seqlen = 1, 1 
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        #if mask is not None:
            #scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
        output = output.transpose(
            1, 2
        ).contiguous().view(bsz, seqlen, -1)

        return self.wo(output)

In [11]:
attention = Attention(model_args)

In [12]:
#### load specific layer weight from checkpoint
layer_id = 0

name_list = []
for name in ['wq', 'wk', 'wv', 'wo']:
    full_layer_name = 'layers.' + str(layer_id) + '.attention.' + name + '.weight'
    name_list.append(full_layer_name)
    
attention.wq.weight.data = checkpoint[name_list[0]]
attention.wk.weight.data = checkpoint[name_list[1]]
attention.wv.weight.data = checkpoint[name_list[2]]
attention.wo.weight.data = checkpoint[name_list[3]]

In [13]:
ffn = FeedForward(dim=model_args.dim, hidden_dim=4 * model_args.dim, multiple_of=model_args.multiple_of).half()

In [14]:
#### load specific layer weight from checkpoint
name_list_ffn = []
for name in ['w1', 'w2', 'w3']:
    full_layer_name = 'layers.' + str(layer_id) + '.feed_forward.' + name + '.weight'
    name_list.append(full_layer_name)

ffn.w1.weight.data = checkpoint[name_list[0]]
ffn.w2.weight.data = checkpoint[name_list[1]]
ffn.w3.weight.data = checkpoint[name_list[2]]

In [15]:
attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps).half()
ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps).half()

attention_norm.weight.data = checkpoint['layers.0.attention_norm.weight']
ffn_norm.weight.data = checkpoint['layers.0.ffn_norm.weight']

In [16]:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = attention
        self.feed_forward = ffn
        self.layer_id = layer_id
        self.attention_norm = attention_norm
        self.ffn_norm = ffn_norm

    def forward(self, x: torch.Tensor, freqs_cis:torch.Tensor):
        #print(freqs_cis.shape)
        h = x + self.attention.forward(self.attention_norm(x), freqs_cis)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

In [17]:
tf_block = TransformerBlock(0, model_args).cuda()

In [18]:
save_file = 'tf_block_1input_weight_sq1_last_iter_v3.onnx'

In [20]:
with torch.no_grad():
    torch.onnx.export(tf_block, (h,freqs_cis), save_file)

  xq_shape[-1] = int(xq_shape[-1]/2)
  xk_shape[-1] = int(xk_shape[-1]/2)
  assert freqs_cis.shape == (x.shape[1], x.shape[-2], 2)


verbose: False, log level: Level.ERROR

