In [1]:
import torch
from torch import nn
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# a = torch.rand((batch_size, M, K), device=device, dtype=dtype)
# b_single = torch.rand((K, N), device=device, dtype=dtype)
# b_batch = torch.rand((batch_size, K, N), device=device, dtype=dtype)

In [3]:
hidden_size = 4096

In [4]:
def _pick_num_heads(hidden_size: int, preferred_head_dim: int = 64) -> int:
    # choose a divisor of hidden_size close to hidden_size / preferred_head_dim
    candidates = [d for d in range(1, hidden_size + 1) if hidden_size % d == 0]
    target = max(1, hidden_size // preferred_head_dim)
    return min(candidates, key=lambda d: abs(d - target))

In [5]:
n_heads = _pick_num_heads(hidden_size)
cfg = LlamaConfig(
    hidden_size=hidden_size,
    intermediate_size=4 * hidden_size,
    num_hidden_layers=1,
    num_attention_heads=n_heads,
    max_position_embeddings=4096,
    attention_bias=False,
    # attn_implementation="flash_attention_2",
    #  attn_implementation="eager",
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
)

layer = LlamaDecoderLayer(cfg, layer_idx=0)
layer.to("cuda", dtype=torch.bfloat16)

LlamaDecoderLayer(
  (self_attn): LlamaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  )
  (mlp): LlamaMLP(
    (gate_proj): Linear(in_features=4096, out_features=16384, bias=False)
    (up_proj): Linear(in_features=4096, out_features=16384, bias=False)
    (down_proj): Linear(in_features=16384, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
  (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
)

In [6]:
layer

LlamaDecoderLayer(
  (self_attn): LlamaAttention(
    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  )
  (mlp): LlamaMLP(
    (gate_proj): Linear(in_features=4096, out_features=16384, bias=False)
    (up_proj): Linear(in_features=4096, out_features=16384, bias=False)
    (down_proj): Linear(in_features=16384, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
  (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
)

In [7]:
seq_len = 2048
test_inp = torch.rand((16, seq_len, hidden_size), device="cuda", dtype=torch.bfloat16)

In [8]:
%%time
_ = layer.mlp(test_inp)

torch.cuda.synchronize()

CPU times: user 126 ms, sys: 81.4 ms, total: 208 ms
Wall time: 207 ms


In [9]:
%%time
# Create dummy attention mask and position embeddings (cos, sin tuple for RoPE)
attention_mask = torch.ones((1, 1, seq_len, seq_len), device="cuda", dtype=torch.bool)
head_dim = hidden_size // n_heads
cos = torch.rand((1, seq_len, head_dim), device="cuda", dtype=torch.bfloat16)
sin = torch.rand((1, seq_len, head_dim), device="cuda", dtype=torch.bfloat16)
position_embeddings = (cos, sin)

_ = layer.self_attn(test_inp, attention_mask=attention_mask, position_embeddings=position_embeddings)

torch.cuda.synchronize()

CPU times: user 79.4 ms, sys: 35.7 ms, total: 115 ms
Wall time: 114 ms


In [10]:
del _

In [11]:
n_iters = 12
# n_groups = 32
n_groups = 16

In [22]:
test_inp = torch.rand((n_groups, seq_len, hidden_size), device="cuda", dtype=torch.bfloat16)

In [23]:
inp_concrete = test_inp.clone()

In [24]:
attention_mask = torch.ones((1, 1, seq_len, seq_len), device="cuda", dtype=torch.bool)
head_dim = hidden_size // n_heads
cos = torch.rand((1, seq_len, head_dim), device="cuda", dtype=torch.bfloat16)
sin = torch.rand((1, seq_len, head_dim), device="cuda", dtype=torch.bfloat16)
position_embeddings = (cos, sin)

In [25]:
%%time 

for i in range(n_iters):
    inp_concrete = layer.mlp(inp_concrete)
    inp_concrete, _ = layer.self_attn(
        inp_concrete, 
        attention_mask=attention_mask, 
        position_embeddings=position_embeddings,
        use_cache=False
    )
    # del _


CPU times: user 10.2 ms, sys: 373 μs, total: 10.6 ms
Wall time: 9.6 ms


In [26]:
del inp_concrete
del _

In [27]:
layers = [LlamaDecoderLayer(cfg, layer_idx=0) for _ in range(n_groups)]
layers = [layer.to("cuda", dtype=torch.bfloat16) for layer in layers]    

In [28]:
additional_inp = torch.rand((n_iters, seq_len, hidden_size), device="cuda", dtype=torch.bfloat16)
additional_inp = list(additional_inp.split(1, dim=0))

In [29]:
inp_sol_concrete = list(test_inp.clone().split(1, dim=0))

In [30]:
streams = [torch.cuda.Stream() for _ in range(n_groups)]
ready_from_prev = [torch.cuda.Event() for _ in range(n_groups-1)]

In [33]:
attention_mask = torch.ones((1, 1, seq_len, seq_len), device="cuda", dtype=torch.bool)
head_dim = hidden_size // n_heads
cos = torch.rand((1, seq_len, head_dim), device="cuda", dtype=torch.bfloat16)
sin = torch.rand((1, seq_len, head_dim), device="cuda", dtype=torch.bfloat16)
position_embeddings = (cos, sin)

In [34]:
%%time
for i in range(n_iters):
    for j in range(n_groups):
        with torch.cuda.stream(streams[j]):
            if j > 0:
                streams[j].wait_event(ready_from_prev[j-1])
                # pass
            else:
                inp_sol_concrete[0] = additional_inp[i]
            cur_inp = inp_sol_concrete[j]
             
            cur_inp = layers[j].mlp(cur_inp)

            cur_inp, _ = layer.self_attn(
                cur_inp, 
                attention_mask=attention_mask, 
                position_embeddings=position_embeddings,
                use_cache=False
                )
            
            if j < n_groups - 1:
                inp_sol_concrete[j+1] = cur_inp
                ready_from_prev[j].record()
                
torch.cuda.synchronize()

CPU times: user 1.23 s, sys: 2.48 ms, total: 1.23 s
Wall time: 1.23 s


In [62]:
del inp_sol_concrete
del ready_from_prev

In [71]:
cur_inp.shape

torch.Size([256, 512])