In [1]:
import sys

In [2]:
sys.path.append("../..")

In [3]:
import copy
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

from grouped_batching.llama1b_grouping import (
    wrap_model_with_armt, get_grouped_states, 
    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,
    make_grouped_sliced_layer_from_single_layer
)
from grouped_batching.batching import GroupedBatcher
from grouped_batching.executor import ArmtGroupedExecutor
from grouped_batching.fast_executor import FastGroupedArmtExecutor, GroupedLayerContext, associate_with_context, update_mem_with_context



In [4]:
# torch.set_default_device("cuda:1")

In [5]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
torch.set_grad_enabled(False)
;

''

In [6]:
source_model = AutoModelForCausalLM.from_pretrained(
                                                    # "meta-llama/Llama-3.2-1B"
                                                    # "meta-llama/Llama-3.2-3B"
                                                #     "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
                                                    "JackFram/llama-160m"
                                            #  , attn_implementation="sdpa"
                                            , attn_implementation="flash_attention_2"
                                             ,torch_dtype=dtype)
source_model.eval()
source_model.lm_head = torch.nn.Identity()
reference_model = copy.deepcopy(source_model)
# reference_model = source_model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [7]:
model_config = source_model.config

In [8]:
armt_config = dict(
    segment_size=1024,
    num_mem_tokens=128,
    d_mem=64,
)

In [None]:
torch.manual_seed(0)
armt_model = wrap_model_with_armt(source_model, **armt_config)
armt_model.to("cuda")
;

''

In [14]:
from grouped_batching.llama1b_grouping_autograd import make_grouped_training_layer_from_single_layer

In [15]:

### FAST LATENCY VERSION

grouped_context = GroupedLayerContext()

grouped_states = get_grouped_states(armt_model)
grouped_layer = make_grouped_sliced_layer_from_single_layer(
    grouped_context, copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states
)
armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)


executor = FastGroupedArmtExecutor(
    armt_grouped_model, 
    grouped_layer, 
    grouped_context, 
    model_config.num_hidden_layers, 
)


In [16]:
### ONLY FOR FAST LATENCY VERSION

# compile full layers
segments_input = torch.rand((model_config.num_hidden_layers, armt_config['segment_size'], model_config.hidden_size), device="cuda", dtype=dtype)

i, j = 0, model_config.num_hidden_layers
grouped_context.start_idx = i
grouped_context.end_idx = j
grouped_context.is_full = True

ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])
grouped_layer.generate_mode = True
armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)
update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])

# del ao
# del segments_input



jit compile As: [torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768]), torch.Size([512, 768])] Bs: [torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64])]

// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8
using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =
  typename cutlass::gemm::kernel::DefaultGemmGrouped<
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::

In [17]:
segments_input[i:j].shape

torch.Size([12, 512, 768])

In [18]:
# # %%time


# with torch.profiler.profile(
#     activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
#     record_shapes=True,
#     profile_memory=True,
# ) as prof:
#     # for _ in range(1):
#     #     for __ in range(model_config.num_hidden_layers):
#         # if __ != 0:
#         ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])
#         grouped_layer.generate_mode = True
#         _ = armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)
#         # if __ != 0:
#         update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])

# torch.cuda.synchronize()

In [19]:
%%time
num_retries = 5
for _ in range(num_retries):
    for __ in range(model_config.num_hidden_layers):
        # if __ != 0:
        ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])
        grouped_layer.generate_mode = True
        _ = armt_grouped_model.memory_cell.model.model(
            inputs_embeds=segments_input[i:j], use_cache=False
        )
        # if __ != 0:
        update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])

torch.cuda.synchronize()

CPU times: user 281 ms, sys: 0 ns, total: 281 ms
Wall time: 281 ms


In [20]:
model_config.num_hidden_layers

12

In [20]:
# Then we have full load, group size is equal to 
# divide by num_retries and num_layers (equal to full load) to get average time per segment
# this number is used for ideal scaling line in paper's table
281/num_retries/model_config.num_hidden_layers

4.683333333333334