In [1]:
import sys


In [2]:
sys.path.append("/home/jovyan/sivtsov/associative-recurrent-memory-transformer")
sys.path.append("/home/jovyan/sivtsov/armt")

In [3]:
import copy
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

from grouped_batching.llama1b_grouping import wrap_model_with_armt, get_grouped_states, make_grouped_layer_from_single_layer, make_grouped_model_from_naive
from grouped_batching.batching import GroupedBatcher
from grouped_batching.executor import ArmtGroupedExecutor

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# torch.set_default_device("cuda:1")

In [5]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
torch.set_grad_enabled(False)
;

''

In [57]:
source_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B"
                                            #  , attn_implementation="sdpa"
                                            , attn_implementation="flash_attention_2"
                                             ,torch_dtype=dtype)
source_model.eval()
source_model.lm_head = torch.nn.Identity()
reference_model = copy.deepcopy(source_model)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [58]:
model_config = source_model.config

In [59]:
armt_config = dict(
    # segment_size=32,
    # num_mem_tokens=16,
    # segment_size=512,
    # num_mem_tokens=128,
    segment_size=1024,
    num_mem_tokens=128,
    d_mem=64,
)

In [60]:
torch.manual_seed(0)
armt_model = wrap_model_with_armt(source_model, **armt_config)
armt_model.to("cuda")

torch.manual_seed(0)
armt_reference_model = wrap_model_with_armt(reference_model, **armt_config)
armt_reference_model.to("cuda")
;

''

In [61]:
model_cpt = "/home/jovyan/.cache/huggingface/hub/models--irodkin--ARMT-llama3.2-1B/snapshots/746e74bba3edc4cb3eaa11e13df5d900495e2300/armt_llama3.2-1B_step19500.bin"
cpt = torch.load(model_cpt, map_location='cuda')

  cpt = torch.load(model_cpt, map_location='cuda')


In [62]:
# armt_model.load_state_dict(cpt, strict=False)
# armt_reference_model.load_state_dict(cpt, strict=False)

In [63]:
# grouped_states = get_grouped_states(armt_model)
# grouped_layer = make_grouped_layer_from_single_layer(
#     copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)
# # grouped_layer._grouped_execution = True
# # grouped_layer._skip_associating = True
# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)


In [64]:
from grouped_batching.llama1b_grouping_autograd import make_grouped_training_layer_from_single_layer

In [65]:
# grouped_layer = make_grouped_training_layer_from_single_layer(
#     copy.deepcopy(armt_model.memory_cell.model.model.layers[0]),
#     armt_model.memory_cell.model.model.layers
# )
# armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)

grouped_states = get_grouped_states(armt_model)
grouped_layer = make_grouped_layer_from_single_layer(
        copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)
    
armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)
    

In [66]:
batcher = GroupedBatcher(
    armt_grouped_model, 
    n_layers=model_config.num_hidden_layers, 
    seg_size=armt_config["segment_size"]+armt_config["num_mem_tokens"], 
    hid_dim=model_config.hidden_size, 
    pos_embed_dim=model_config.hidden_size
)
executor = ArmtGroupedExecutor(armt_grouped_model, grouped_layer, batcher)


In [67]:
torch.cuda.empty_cache()

In [68]:
num_segments = 150
input_ids = torch.randint(
    0, 10000, 
    (1, num_segments*armt_config["segment_size"]), 
    dtype=torch.long, 
    device="cuda"
)


In [69]:
input_ids.shape

torch.Size([1, 153600])

In [85]:
%%time
with torch.no_grad():
    armt_reference_model.memory_cell.zero_mem()
    reference_output = armt_reference_model.forward(input_ids)

torch.cuda.synchronize()

CPU times: user 6.89 s, sys: 43.1 ms, total: 6.93 s
Wall time: 5.59 s


In [86]:
%%time

with torch.no_grad():
    output = executor.forward(input_ids)

torch.cuda.synchronize()

CPU times: user 3.53 s, sys: 3.75 ms, total: 3.54 s
Wall time: 3.54 s


In [87]:
torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)

tensor(0.0104, device='cuda:0')

In [82]:
# for some it is zero during all computations
executor.armt_model.memory_cell.model.model.layers[0].W_mem.abs().sum(), armt_reference_model.memory_cell.model.model.layers[0].W_mem.abs().sum()

(tensor(0., device='cuda:0'), tensor(0.))

In [26]:
executor.armt_model.memory_cell.model.model.layers[0].z

tensor([[2.7895e-05, 3.6478e-05, 7.2098e-04,  ..., 2.5269e-02, 3.2471e-02,
         3.5048e-05],
        [1.6504e-01, 8.3594e-01, 0.0000e+00,  ..., 1.2741e-03, 2.8125e+00,
         6.1719e-01],
        [1.1875e+00, 1.0498e-01, 3.5400e-02,  ..., 8.8281e-01, 1.1826e-03,
         9.6094e-01],
        ...,
        [5.8594e-02, 9.1250e+00, 2.6250e+00,  ..., 3.4531e+00, 3.6133e-01,
         9.8877e-03],
        [5.9766e-01, 5.2002e-02, 7.8125e-03,  ..., 5.6763e-03, 0.0000e+00,
         2.5156e+00],
        [6.2988e-02, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 2.1625e+01,
         3.3691e-02]], device='cuda:0')

#### this way you can "batch" several inputs to amortize the cost of the batcher

__BUT:__ this will work only for discriminative tasks for now, because for autoregressive generation memory of all entries in batch should be preserved (currently only last segment memory will be preserved) 

In [26]:
output_list = executor.forward([input_ids, input_ids])

torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16, 1, 1])
torch.Size([16, 128, 1]) torch.Size([16,

In [39]:
torch.allclose(output_list[0].logits, output_list[1].logits)

True