In [1]:
import torch
import json
import math
import hiq
import torch.nn as nn
import torch.nn.functional as F

from llama import ModelArgs, Transformer, Tokenizer, LLaMA
from llama.generation import sample_top_p

In [2]:
with open('../7B/params.json', "r") as f:
    params = json.loads(f.read())

In [3]:
model_args: ModelArgs = ModelArgs(
    max_seq_len=512, max_batch_size=128, **params
)
tokenizer = Tokenizer('../tokenizer.model')
model_args.vocab_size = tokenizer.n_words

In [4]:
temperature: float = 0.8
top_p: float = 0.95
max_seq_len=512

### Initialize new modules

In [5]:
from llama.model import FeedForward, apply_rotary_emb

In [6]:
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)

        return output * self.weight

In [7]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads // 1
        self.max_batch_size = args.max_batch_size
        self.head_dim = args.dim // args.n_heads
    
        self.wq = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wk = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wv = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wo = nn.Linear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
        )
        self.cache_k = torch.zeros(
            (self.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        )
        self.cache_v = torch.zeros(
            (self.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        )
        if hiq.get_env_bool("KV_CAHCHE_IN_GPU", True):
            self.cache_k = self.cache_k.cuda()
            self.cache_v = self.cache_v.cuda()

    def forward(self, x: torch.Tensor, freqs_cis:torch.Tensor):
        start_pos = 262  
        bsz, seqlen = self.max_batch_size, 1 
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        #if mask is not None:
            #scores = scores + mask  # (bs, n_local_heads, slen, cache_len + slen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
        output = output.transpose(
            1, 2
        ).contiguous().view(bsz, seqlen, -1)

        return self.wo(output)

In [8]:
attention = Attention(model_args)

In [9]:
ffn = FeedForward(dim=model_args.dim, hidden_dim=4 * model_args.dim, multiple_of=model_args.multiple_of).half()

In [10]:
attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps).half()
ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps).half()

In [11]:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = attention
        self.feed_forward = ffn
        self.layer_id = layer_id
        self.attention_norm = attention_norm
        self.ffn_norm = ffn_norm

    def forward(self, x: torch.Tensor, freqs_cis:torch.Tensor):
        #print(freqs_cis.shape)
        h = x + self.attention.forward(self.attention_norm(x), freqs_cis)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

In [12]:
tf_block = TransformerBlock(0, model_args)

In [13]:
save_file = 'tf_block_static_bsz128_sq1_last_iter_v1.onnx'

In [14]:
h = torch.randn((model_args.max_batch_size, 1 ,model_args.dim), dtype = torch.float16)

In [15]:
freq_cis = torch.randn((1,64,2))

In [16]:
with torch.no_grad():
    torch.onnx.export(tf_block.half().cuda(), (h.cuda(), freq_cis.cuda()), save_file)

  xq_shape[-1] = int(xq_shape[-1]/2)
  xk_shape[-1] = int(xk_shape[-1]/2)
  assert freqs_cis.shape == (x.shape[1], x.shape[-2], 2)


verbose: False, log level: Level.ERROR

