In [1]:
import torch

# Check if CUDA (GPU support) is available
if torch.cuda.is_available():
    print("CUDA is available! You can use GPU acceleration.")
else:
    print("CUDA is not available. Using CPU instead.")

CUDA is available! You can use GPU acceleration.


# Load the dump model

In [2]:
import contextlib
import io
import os
import random
import warnings


def ids_tensor(shape, vocab_size, rng=None, name=None):
    #  Creates a random int32 tensor of the shape within the vocab size
    import torch

    if rng is None:
        rng = random.Random()

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


def get_llama_model(
    input_dims=[(2, 1024)],
    hidden_size=1024,  # 4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=4,  # 32,
    _attn_implementation="eager",
    with_mask: bool = True,
):
    import torch
    from transformers import LlamaConfig
    from transformers.models.llama.modeling_llama import LlamaModel

    config = LlamaConfig(
        num_hidden_layers=num_hidden_layers,
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        intermediate_size=intermediate_size,
        max_position_embeddings=max_position_embeddings,
        num_attention_heads=num_attention_heads,
    )
    if _attn_implementation:
        config._attn_implementation = _attn_implementation

    class LlamaModelWrapper(torch.nn.Module):
        def __init__(self, config):
            super().__init__()
            self.model = LlamaModel(config)

        def forward(self, input_ids, attention_mask):
            model_output = self.model(input_ids, attention_mask=attention_mask)
            return model_output.to_tuple()

    def generate_example_inputs(batch: int, seq: int, vocab_size: int):
        input_ids = ids_tensor([batch, seq], vocab_size)
        input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))
        assert input_mask.dtype == torch.float32
        return input_ids, input_mask

    example_args_collection = []
    for b, s in input_dims:
        example_args_collection.append(generate_example_inputs(b, s, vocab_size))

    return LlamaModelWrapper(config), example_args_collection


print("creation of the model.")
model, example_args_collection = get_llama_model()
print("done.")

creation of the model.


  from .autonotebook import tqdm as notebook_tqdm


done.


In [3]:
# Module
print(model)

LlamaModelWrapper(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0): LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=11008, bias=False)
          (up_proj): Linear(in_features=1024, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
)


In [4]:
# Run the model
input_dim = (2, 1024)
example_arg = example_args_collection[0]
model(*example_args_collection[0])

(tensor([[[-1.1170, -0.7974, -1.9838,  ...,  0.6843,  1.6093, -0.8728],
          [-1.1460, -0.9141, -2.0375,  ...,  0.6653,  1.5838, -0.8952],
          [-1.1642, -1.0227, -1.9189,  ...,  0.7274,  1.6082, -0.8445],
          ...,
          [-1.2189, -0.9158, -1.9213,  ...,  0.7216,  1.5668, -0.8586],
          [-1.1780, -0.9684, -1.8898,  ...,  0.7232,  1.4959, -0.9167],
          [-1.1783, -0.8946, -1.9557,  ...,  0.8079,  1.5094, -0.8847]],
 
         [[ 0.0058,  1.1115, -0.4164,  ...,  1.2131,  0.0819, -1.4715],
          [-0.9829,  1.3897,  0.8368,  ..., -0.1238, -0.0239, -1.1874],
          [-2.0979, -0.1517,  0.1348,  ..., -0.5682, -0.5758, -0.7007],
          ...,
          [-0.8932, -0.5153, -0.9346,  ...,  0.8343, -0.2991, -0.7240],
          [-0.2193,  0.0176,  0.0269,  ...,  0.3472, -0.6515, -1.4265],
          [-1.1566,  0.6324,  0.8453,  ..., -0.3906, -0.5370, -1.0049]]],
        grad_fn=<MulBackward0>),
 ((tensor([[[[-3.5316e-01, -2.8751e-01,  2.6478e-01,  ...,  1.0172e+

# About the dump model

The wrapping boils down to:
```
- wrapper
    - model
        - LlamaDecoderLayer
```
The model then is:  
![Llama model](llama.jpg)

Which corresponds on the netron model to:  
![Netron labeled](llama_netron_labeled.jpg)

In [5]:
# Step by step
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import BaseModelOutputWithPast

# args
input_ids, attention_mask = example_arg
_model = model.model

# Forward
output_attentions = _model.config.output_attentions
output_hidden_states = _model.config.output_hidden_states
use_cache = _model.config.use_cache
return_dict = _model.config.use_return_dict
inputs_embeds = _model.embed_tokens(input_ids)

past_seen_tokens = 0
if not isinstance(None, StaticCache):
    past_key_values = DynamicCache.from_legacy_cache(None)
    past_seen_tokens = past_key_values.get_seq_length()

if isinstance(past_key_values, StaticCache):
    raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(
    past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

position_ids = cache_position.unsqueeze(0)
causal_mask = _model._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

# embed positions
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = None
all_self_attns = None
next_decoder_cache = None

decoder_layer = _model.layers[0]

# forward from decoder layer
attention_mask=causal_mask
position_ids=position_ids
past_key_value=past_key_values
output_attentions=output_attentions
use_cache=use_cache
cache_position=cache_position

residual = hidden_states
hidden_states = decoder_layer.input_layernorm(hidden_states) # hidden_states is x

# Self Attention
hidden_states, self_attn_weights, present_key_value = decoder_layer.self_attn(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_value=past_key_value,
    output_attentions=output_attentions,
    use_cache=use_cache,
    cache_position=cache_position
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = decoder_layer.post_attention_layernorm(hidden_states)
mlp_input = hidden_states # for mlp fuzed kernel
hidden_states = decoder_layer.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)

# if output_attentions:
#     outputs += (self_attn_weights,)

outputs += (present_key_value,)

layer_outputs = outputs
# end of forward from decoder layer

hidden_states = layer_outputs[0]

next_decoder_cache = layer_outputs[2 if output_attentions else 1]

hidden_states = _model.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
    all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
    next_cache = (
        next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
    )

_return = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

# mlp fuzed kernel

In [46]:
# Isolate the MLP
mlp = decoder_layer.mlp
torch.onnx.export(
    mlp,
    (mlp_input,),
    "mlp.onnx",
)

The `mlp` module actually correspond to a gated multi-layer perceptron.  
**Compare Netron, the gMLP paper, the default implementation and your implementation.**  
The gated MLP returns: `self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))`, that is to say:
$$(\sigma(X \times G^T) \odot (X \times U^T)) \times D^T$$
With the shapes:
- $X: n,d$
- $U: \tilde{d},d$
- $G: \tilde{d},d$
- $D: d,\tilde{d}$  

$\sigma$ corresponds to the $SiLU$ activation function: $\sigma(x)=x*sigmoid(x)$

The notations from the original gMLP paper are:  
$$Z=\sigma(X U), \quad \tilde{Z}=s(Z), \quad Y=\tilde{Z} V$$

![MLP netron labeled](mlp_netron_labeled.png)

In [45]:
# Parameters
hidden_size = _model.config.hidden_size
intermediate_size = _model.config.intermediate_size
mlp_input_shape = mlp_input.shape
G = mlp.gate_proj.weight
U = mlp.up_proj.weight
D = mlp.down_proj.weight

# Forward
X = mlp_input
output = mlp(mlp_input)
act_fn = torch.nn.SiLU()
_ = (act_fn(X@G.T)*(X@U.T)) @ D.T
assert torch.allclose(output, _)

# Export

In [None]:
# Export
torch.save(mlp,"mlp.pt")
torch.save(mlp_input,"mlp_input.pt")