In [1]:
import sys

In [2]:
sys.path.append("/home/jovyan/sivtsov/associative-recurrent-memory-transformer")
sys.path.append("/home/jovyan/sivtsov/armt")

In [3]:
import copy
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

from grouped_batching.llama1b_grouping import wrap_model_with_armt, get_grouped_states, make_grouped_layer_from_single_layer, make_grouped_model_from_naive
from grouped_batching.batching import GroupedBatcher
from grouped_batching.executor import ArmtGroupedExecutor

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
torch.set_grad_enabled(False)
;

''

In [76]:
source_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B"
                                            #  , attn_implementation="sdpa"
                                            , attn_implementation="flash_attention_2"
                                             ,torch_dtype=dtype)
reference_model = copy.deepcopy(source_model)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

In [77]:
model_config = source_model.config

In [78]:
armt_config = dict(
    # segment_size=32,
    # num_mem_tokens=16,
    segment_size=512,
    num_mem_tokens=128,
    d_mem=64,
)

In [79]:
armt_model = wrap_model_with_armt(source_model, **armt_config)
armt_model.to("cuda")

armt_reference_model = wrap_model_with_armt(reference_model, **armt_config)
armt_reference_model.to("cuda")
;

''

In [80]:
model_cpt = "/home/jovyan/.cache/huggingface/hub/models--irodkin--ARMT-llama3.2-1B/snapshots/746e74bba3edc4cb3eaa11e13df5d900495e2300/armt_llama3.2-1B_step19500.bin"
cpt = torch.load(model_cpt, map_location='cuda')

  cpt = torch.load(model_cpt, map_location='cuda')


In [81]:
# armt_model.load_state_dict(cpt, strict=False)
# armt_reference_model.load_state_dict(cpt, strict=False)

In [82]:
grouped_states = get_grouped_states(armt_model)
grouped_layer = make_grouped_layer_from_single_layer(
    copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)
grouped_layer._grouped_execution = True
armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)


In [83]:
batcher = GroupedBatcher(
    armt_grouped_model, 
    n_layers=model_config.num_hidden_layers, 
    seg_size=armt_config["segment_size"]+armt_config["num_mem_tokens"], 
    hid_dim=model_config.hidden_size, 
    pos_embed_dim=model_config.hidden_size
)
executor = ArmtGroupedExecutor(armt_grouped_model, grouped_layer, batcher)


In [84]:
num_segments = 5
input_ids = torch.randint(
    0, 5000, 
    (1, num_segments*armt_config["segment_size"]), 
    dtype=torch.long, 
    device="cuda"
)


In [85]:
reference_output = reference_model.forward(input_ids)

In [86]:
output = executor.forward(input_ids)

In [87]:
output.logits

tensor([[[ 5.6250,  5.4688,  7.1250,  ..., -5.0000, -5.0000, -5.0000],
         [ 2.6719,  2.5000,  2.4062,  ...,  2.4062,  2.4062,  2.4062],
         [15.1250, 12.8125, 10.3750,  ..., -0.6445, -0.6445, -0.6445],
         ...,
         [11.5000, 12.2500, 11.5625,  ...,  1.7969,  1.7969,  1.7969],
         [11.5000, 12.2500, 11.3750,  ...,  1.9688,  1.9688,  1.9688],
         [10.8750, 11.2500, 11.0625,  ...,  2.0469,  2.0469,  2.0469]]],
       device='cuda:0')

In [88]:
reference_output.logits

tensor([[[ 5.6250,  5.4688,  7.1250,  ..., -5.0000, -5.0000, -5.0000],
         [ 2.6875,  2.5781,  2.4688,  ...,  2.4375,  2.4375,  2.4375],
         [15.0625, 12.8125, 10.3750,  ..., -0.6406, -0.6406, -0.6406],
         ...,
         [10.1875, 12.7500, 10.6875,  ...,  0.8047,  0.8047,  0.8047],
         [10.7500, 12.7500, 11.1250,  ...,  1.3516,  1.3516,  1.3516],
         [ 9.8125, 12.3750, 10.8750,  ...,  1.3438,  1.3438,  1.3438]]],
       device='cuda:0', dtype=torch.float32)

In [89]:
torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)

tensor(0.1459, device='cuda:0', dtype=torch.float32)