In [1]:
import torch
import json
import math
import hiq

import torch.nn as nn
import torch.nn.functional as F

from llama.model import precompute_freqs_cis
from llama.model import apply_rotary_emb
from llama.model import ModelArgs, Attention, RMSNorm, FeedForward
from llama import Tokenizer

In [7]:
with open('./65B/params.json', "r") as f:
    params = json.loads(f.read())

params

{'dim': 8192,
 'multiple_of': 256,
 'n_heads': 64,
 'n_layers': 80,
 'norm_eps': 1e-05,
 'vocab_size': -1}

In [8]:
model_args: ModelArgs = ModelArgs(
    max_seq_len=512, max_batch_size=1, **params
)

In [9]:
tokenizer = Tokenizer('../tokenizer.model')
model_args.vocab_size = tokenizer.n_words

In [10]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.n_local_heads = args.n_heads // 1
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wk = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wv = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
        )
        self.wo = nn.Linear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
        )
        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        )
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
        )
        if hiq.get_env_bool("KV_CAHCHE_IN_GPU", True):
            self.cache_k = self.cache_k.cuda()
            self.cache_v = self.cache_v.cuda()

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
        bsz, seqlen = 1, 1
        start_pos = 262
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        xq = xq.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
        output = output.transpose(
            1, 2
        ).contiguous().view(bsz, seqlen, -1)

        return self.wo(output)

In [11]:
attention = Attention(model_args)

In [12]:
class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = attention
        self.feed_forward = FeedForward(
            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
        h = x + self.attention.forward(self.attention_norm(x), freqs_cis)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

In [13]:
tfb = TransformerBlock(0, model_args)

In [14]:
params = model_args

In [15]:
h = torch.randn((1,1,8192), dtype = torch.float16).cuda()

In [16]:
freqs_cis = precompute_freqs_cis(params.dim // params.n_heads, params.max_seq_len * 2)

In [17]:
freqs_cis = freqs_cis[262:263]

### Export part

In [18]:
with torch.no_grad():
    torch.onnx.export(tfb.half().cuda(), (h,freqs_cis.to(h.device)), 'transformerblock_65B_sq1_last_iter_noweight_v1.onnx', opset_version=14)

  xq_shape[-1] = int(xq_shape[-1]/2)
  xk_shape[-1] = int(xk_shape[-1]/2)
  assert freqs_cis.shape == (x.shape[1], x.shape[-2], 2)


verbose: False, log level: Level.ERROR



### Export Embedding 

In [20]:
token_embedding = nn.Embedding(model_args.vocab_size, model_args.dim)

In [21]:
tokens = torch.randint(0,10, (1,1))

In [22]:
with torch.no_grad():
    torch.onnx.export(token_embedding.half().cuda(), tokens.cuda(), '65B_token_embedding.onnx', opset_version=14)

verbose: False, log level: Level.ERROR



### Export Last FC 

In [25]:
last_fc_layer = nn.Linear(model_args.dim, model_args.vocab_size, bias = False)

In [26]:
dummy_input = torch.randn((1,1,8192), dtype = torch.float16)

In [27]:
with torch.no_grad():
    torch.onnx.export(last_fc_layer.half().cuda(), dummy_input.cuda(), '65B_last_fc_layer.onnx', opset_version=14)

verbose: False, log level: Level.ERROR

