In [1]:
import sys

In [2]:
sys.path.append("..")

In [3]:
import copy
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

from grouped_batching.llama1b_grouping import (
    wrap_model_with_armt, get_grouped_states, 
    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,
    make_grouped_sliced_layer_from_single_layer
)
from grouped_batching.batching import GroupedBatcher
from grouped_batching.executor import ArmtGroupedExecutor
from grouped_batching.fast_executor import FastGroupedArmtExecutor, GroupedLayerContext, associate_with_context, update_mem_with_context



In [4]:
# torch.set_default_device("cuda:1")

In [5]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
torch.set_grad_enabled(False)
;

''

In [6]:
source_model = AutoModelForCausalLM.from_pretrained(
                                                    "meta-llama/Llama-3.2-1B"
                                                    # "meta-llama/Llama-3.2-3B"
                                                #     "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
                                                    # "JackFram/llama-160m"
                                            #  , attn_implementation="sdpa"
                                            , attn_implementation="flash_attention_2"
                                             ,torch_dtype=dtype)
source_model.eval()
source_model.lm_head = torch.nn.Identity()
reference_model = copy.deepcopy(source_model)
# reference_model = source_model
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [7]:
model_config = source_model.config

In [8]:
armt_config = dict(
    # segment_size=512,
    segment_size=1024,
    # segment_size=2048,
    # segment_size=4096,
    num_mem_tokens=128,
    d_mem=64,
)

In [9]:
torch.manual_seed(0)
armt_model = wrap_model_with_armt(source_model, **armt_config)
armt_model.to("cuda")

torch.manual_seed(0)
armt_reference_model = wrap_model_with_armt(reference_model, **armt_config)
armt_reference_model.to("cuda")
;

''

In [10]:
# grouped_states = get_grouped_states(armt_model)
# grouped_layer = make_grouped_layer_from_single_layer(
#     copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)
# # grouped_layer._grouped_execution = True
# # grouped_layer._skip_associating = True
# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)


In [11]:
from grouped_batching.llama1b_grouping_autograd import make_grouped_training_layer_from_single_layer

In [12]:
### TRAINABLE VERSION

# grouped_layer = make_grouped_training_layer_from_single_layer(
#     copy.deepcopy(armt_model.memory_cell.model.model.layers[0]),
#     armt_model.memory_cell.model.model.layers
# )
# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)

### AMORTIZIBLE VERSION

# grouped_states = get_grouped_states(armt_model)
# grouped_layer = make_grouped_layer_from_single_layer(
#         copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)
    
# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)

# batcher = GroupedBatcher(
#     armt_grouped_model, 
#     n_layers=model_config.num_hidden_layers, 
#     seg_size=armt_config["segment_size"]+armt_config["num_mem_tokens"], 
#     hid_dim=model_config.hidden_size, 
#     pos_embed_dim=model_config.hidden_size
# )
# executor = ArmtGroupedExecutor(armt_grouped_model, grouped_layer, batcher)

### FAST LATENCY VERSION

grouped_context = GroupedLayerContext()

grouped_states = get_grouped_states(armt_model)
grouped_layer = make_grouped_sliced_layer_from_single_layer(
    grouped_context, copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states
)
armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)


executor = FastGroupedArmtExecutor(
    armt_grouped_model, 
    grouped_layer, 
    grouped_context, 
    model_config.num_hidden_layers, 
)


In [13]:
### ONLY FOR FAST LATENCY VERSION

# compile full layers
segments_input = torch.rand((model_config.num_hidden_layers, armt_config['segment_size'], model_config.hidden_size), device="cuda", dtype=dtype)

i, j = 0, model_config.num_hidden_layers
grouped_context.start_idx = i
grouped_context.end_idx = j
grouped_context.is_full = True

ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])
grouped_layer.generate_mode = True
armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)
update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])

# del ao
# del segments_input



jit compile As: [torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048])] Bs: [torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64])]

// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8
using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =
  typenam

In [14]:
segments_input[i:j].shape

torch.Size([16, 1024, 2048])

In [15]:
torch.cuda.empty_cache()

In [16]:
# 4096, 8192, 16384, 32768, 65536, 131072
num_segments = 32768//armt_config["segment_size"]
input_ids = torch.randint(
    0, 10000, 
    (1, num_segments*armt_config["segment_size"]), 
    dtype=torch.long, 
    device="cuda"
)


In [17]:
input_ids.shape

torch.Size([1, 32768])

In [18]:
%%time
# %%timeit

with torch.no_grad():
    # armt_reference_model.memory_cell.zero_mem()
    armt_reference_model.memory_cell.generate_mode(False)
    reference_output = armt_reference_model.forward(input_ids)

torch.cuda.synchronize()

CPU times: user 2.3 s, sys: 55.1 ms, total: 2.36 s
Wall time: 1.83 s


In [19]:
# del reference_output

In [20]:
%%time
# %%timeit

with torch.no_grad():
    output = executor.forward(input_ids, skip_concat=False)

torch.cuda.synchronize()

CPU times: user 669 ms, sys: 3.3 ms, total: 673 ms
Wall time: 672 ms


In [21]:
# del output

In [22]:
torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)

tensor(0.0103, device='cuda:0')

In [23]:
# del output

In [24]:
for i in range(len(armt_reference_model.memory_cell.model.model.layers)):
    armt_ref_mem = armt_reference_model.memory_cell.model.model.layers[i].W_mem
    armt_grouped_mem = executor.armt_model.memory_cell.model.model.layers[0].W_mem[i]
    
    rel_norm_mem = torch.norm(armt_ref_mem-armt_grouped_mem)/(torch.norm(armt_ref_mem)+1e-6)
    
    armt_ref_zvalue = armt_reference_model.memory_cell.model.model.layers[i].z
    armt_grouped_zvalue = executor.armt_model.memory_cell.model.model.layers[0].z[i]
    
    rel_norm_zvalue = torch.norm(armt_ref_zvalue-armt_grouped_zvalue)/(torch.norm(armt_ref_zvalue)+1e-6)
    
    print(f"Layer {i}: rel_norm_mem: {rel_norm_mem}, rel_norm_zvalue: {rel_norm_zvalue}")

Layer 0: rel_norm_mem: 0.0, rel_norm_zvalue: 0.000820159912109375
Layer 1: rel_norm_mem: 0.0, rel_norm_zvalue: 0.005950927734375
Layer 2: rel_norm_mem: 0.0, rel_norm_zvalue: 0.00567626953125
Layer 3: rel_norm_mem: 0.0, rel_norm_zvalue: 0.0086669921875
Layer 4: rel_norm_mem: 0.0, rel_norm_zvalue: 0.01544189453125
Layer 5: rel_norm_mem: 0.0, rel_norm_zvalue: 0.0142822265625
Layer 6: rel_norm_mem: 0.0, rel_norm_zvalue: 0.00787353515625
Layer 7: rel_norm_mem: 0.0, rel_norm_zvalue: 0.0211181640625
Layer 8: rel_norm_mem: 0.0, rel_norm_zvalue: 0.01312255859375
Layer 9: rel_norm_mem: 0.0, rel_norm_zvalue: 0.01544189453125
Layer 10: rel_norm_mem: 0.0, rel_norm_zvalue: 0.00909423828125
Layer 11: rel_norm_mem: 0.0, rel_norm_zvalue: 0.015869140625
Layer 12: rel_norm_mem: 0.0, rel_norm_zvalue: 0.0084228515625
Layer 13: rel_norm_mem: 0.0, rel_norm_zvalue: 0.01214599609375
Layer 14: rel_norm_mem: 0.0, rel_norm_zvalue: 0.00836181640625
Layer 15: rel_norm_mem: 0.0, rel_norm_zvalue: 0.00982666015625


#### this way you can "batch" several inputs to amortize the cost of the batcher

In [26]:
### ONLY FOR AMORTIZABLE VERSION

output_list = executor.forward([input_ids, input_ids])

torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16,

In [39]:
torch.allclose(output_list[0].logits, output_list[1].logits)

True