In [1]:
import sys


In [2]:
sys.path.append("/home/jovyan/sivtsov/associative-recurrent-memory-transformer")
sys.path.append("/home/jovyan/sivtsov/armt")

In [3]:
import copy
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

from grouped_batching.llama1b_grouping import (
    wrap_model_with_armt, get_grouped_states, 
    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,
    make_grouped_sliced_layer_from_single_layer
)
from grouped_batching.batching import GroupedBatcher
from grouped_batching.executor import ArmtGroupedExecutor
from grouped_batching.fast_executor import FastGroupedArmtExecutor, GroupedLayerContext, associate_with_context, update_mem_with_context

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# torch.set_default_device("cuda:1")

In [5]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
torch.set_grad_enabled(False)
;

''

In [6]:
source_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B"
                                            #  , attn_implementation="sdpa"
                                            , attn_implementation="flash_attention_2"
                                             ,torch_dtype=dtype)
source_model.eval()
source_model.lm_head = torch.nn.Identity()
reference_model = copy.deepcopy(source_model)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [7]:
model_config = source_model.config

In [8]:
armt_config = dict(
    # segment_size=32,
    # num_mem_tokens=16,
    # segment_size=512,
    # num_mem_tokens=128,
    segment_size=1024,
    num_mem_tokens=128,
    d_mem=64,
)

In [9]:
torch.manual_seed(0)
armt_model = wrap_model_with_armt(source_model, **armt_config)
armt_model.to("cuda")

torch.manual_seed(0)
armt_reference_model = wrap_model_with_armt(reference_model, **armt_config)
armt_reference_model.to("cuda")
;

''

In [10]:
model_cpt = "/home/jovyan/.cache/huggingface/hub/models--irodkin--ARMT-llama3.2-1B/snapshots/746e74bba3edc4cb3eaa11e13df5d900495e2300/armt_llama3.2-1B_step19500.bin"
cpt = torch.load(model_cpt, map_location='cuda')

  cpt = torch.load(model_cpt, map_location='cuda')


In [11]:
# armt_model.load_state_dict(cpt, strict=False)
# armt_reference_model.load_state_dict(cpt, strict=False)

In [12]:
# grouped_states = get_grouped_states(armt_model)
# grouped_layer = make_grouped_layer_from_single_layer(
#     copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)
# # grouped_layer._grouped_execution = True
# # grouped_layer._skip_associating = True
# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)


In [13]:
from grouped_batching.llama1b_grouping_autograd import make_grouped_training_layer_from_single_layer

In [14]:
### TRAINABLE VERSION

# grouped_layer = make_grouped_training_layer_from_single_layer(
#     copy.deepcopy(armt_model.memory_cell.model.model.layers[0]),
#     armt_model.memory_cell.model.model.layers
# )
# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)

### AMORTIZIBLE VERSION

# grouped_states = get_grouped_states(armt_model)
# grouped_layer = make_grouped_layer_from_single_layer(
#         copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)
    
# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)

# batcher = GroupedBatcher(
#     armt_grouped_model, 
#     n_layers=model_config.num_hidden_layers, 
#     seg_size=armt_config["segment_size"]+armt_config["num_mem_tokens"], 
#     hid_dim=model_config.hidden_size, 
#     pos_embed_dim=model_config.hidden_size
# )
# executor = ArmtGroupedExecutor(armt_grouped_model, grouped_layer, batcher)

### FAST LATENCY VERSION

grouped_context = GroupedLayerContext()

grouped_states = get_grouped_states(armt_model)
grouped_layer = make_grouped_sliced_layer_from_single_layer(
    grouped_context, copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states
)
armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)


executor = FastGroupedArmtExecutor(
    armt_grouped_model, 
    grouped_layer, 
    grouped_context, 
    model_config.num_hidden_layers, 
)


In [15]:
### ONLY FOR FAST LATENCY VERSION

# compile full layers
segments_input = torch.rand((model_config.num_hidden_layers, 512, 2048), device="cuda", dtype=dtype)

i, j = 0, 16
grouped_context.start_idx = i
grouped_context.end_idx = j
grouped_context.is_full = True

ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])
grouped_layer.generate_mode = True
armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)
update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])



jit compile As: [torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048]), torch.Size([512, 2048])] Bs: [torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64])]

// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8
using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =
  typename cutlass::gemm:

In [16]:
torch.cuda.empty_cache()

In [17]:
num_segments = 150
input_ids = torch.randint(
    0, 10000, 
    (1, num_segments*armt_config["segment_size"]), 
    dtype=torch.long, 
    device="cuda"
)


In [18]:
input_ids.shape

torch.Size([1, 153600])

In [19]:
%%time
with torch.no_grad():
    # armt_reference_model.memory_cell.zero_mem()
    armt_reference_model.memory_cell.generate_mode(False)
    reference_output = armt_reference_model.forward(input_ids)

torch.cuda.synchronize()

CPU times: user 6.46 s, sys: 303 ms, total: 6.76 s
Wall time: 6.28 s


In [20]:
%%time

with torch.no_grad():
    output = executor.forward(input_ids)

torch.cuda.synchronize()

CPU times: user 2.92 s, sys: 1.9 ms, total: 2.92 s
Wall time: 2.92 s


In [21]:
torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)

tensor(0.0104, device='cuda:0')

In [22]:
# for some it is zero during all computations
executor.armt_model.memory_cell.model.model.layers[0].W_mem.abs().sum(), armt_reference_model.memory_cell.model.model.layers[0].W_mem.abs().sum()

(tensor(0., device='cuda:0'), tensor(0., device='cuda:0'))

In [23]:
executor.armt_model.memory_cell.model.model.layers[0].z

tensor([[1.8516e+00, 3.1641e-01, 3.5400e-03,  ..., 8.1177e-03, 1.5545e-04,
         2.9602e-03],
        [1.0986e-02, 2.6001e-02, 3.6523e-01,  ..., 1.3428e-03, 6.4941e-02,
         1.6797e-01],
        [9.3079e-04, 7.9956e-03, 5.3101e-03,  ..., 4.4688e+00, 2.3535e-01,
         5.3101e-03],
        ...,
        [0.0000e+00, 0.0000e+00, 1.0645e-01,  ..., 9.7656e-02, 0.0000e+00,
         1.5312e+00],
        [6.6406e-02, 0.0000e+00, 8.3618e-03,  ..., 9.4238e-02, 5.4688e+00,
         1.1812e+01],
        [5.7188e+00, 2.6367e-01, 6.9336e-02,  ..., 0.0000e+00, 1.8164e-01,
         1.2695e-01]], device='cuda:0')

In [25]:
armt_reference_model.memory_cell.model.model.layers[1].z

tensor([[1.1047e-02, 2.6123e-02, 3.6523e-01, 1.9684e-03, 4.1875e+00, 5.4688e+00,
         2.0599e-04, 1.0312e+00, 4.1246e-05, 3.9577e-05, 3.0518e-02, 1.8311e-03,
         2.7466e-04, 0.0000e+00, 4.4861e-03, 2.4531e+00, 1.0889e-01, 6.4844e-01,
         1.3428e-02, 1.2031e+00, 3.0938e+00, 3.4809e-05, 3.8086e-01, 1.7422e+00,
         1.8516e+00, 1.1562e+00, 1.1230e-02, 6.5002e-03, 1.0791e-01, 9.7656e-04,
         1.1292e-02, 2.0117e-01, 1.1641e+00, 4.9375e+00, 1.0234e+00, 6.0156e-01,
         9.8145e-02, 2.1094e-01, 3.7689e-03, 3.4943e-03, 9.1934e-04, 1.1094e+00,
         1.2500e+00, 2.7812e+00, 5.3750e+00, 2.9449e-03, 7.9727e-04, 1.6406e-01,
         1.0681e-02, 5.4626e-03, 2.7812e+00, 3.1471e-05, 5.9009e-06, 4.5625e+00,
         7.3047e-01, 9.6680e-02, 2.0508e-01, 8.9453e-01, 1.4420e-03, 4.1260e-02,
         1.1292e-02, 1.7090e-02, 4.8828e-02, 1.4258e-01, 4.6094e-01, 1.1658e-02,
         0.0000e+00, 7.7515e-03, 2.0156e+00, 4.1875e+00, 1.9836e-03, 3.5938e+00,
         2.1973e-03, 1.5137e

#### this way you can "batch" several inputs to amortize the cost of the batcher

__BUT:__ this will work only for discriminative tasks for now, because for autoregressive generation memory of all entries in batch should be preserved (currently only last segment memory will be preserved) 

In [26]:
### ONLY FOR AMORTIZABLE VERSION

output_list = executor.forward([input_ids, input_ids])

torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16,

In [39]:
torch.allclose(output_list[0].logits, output_list[1].logits)

True