In [1]:
import sys


In [2]:
sys.path.append("/home/jovyan/sivtsov/associative-recurrent-memory-transformer")
sys.path.append("/home/jovyan/sivtsov/armt")

In [3]:
import copy
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

from grouped_batching.llama1b_grouping import wrap_model_with_armt, get_grouped_states, make_grouped_layer_from_single_layer, make_grouped_model_from_naive
from grouped_batching.batching import GroupedBatcher
from grouped_batching.executor import ArmtGroupedExecutor

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
torch.set_default_device("cuda:1")

In [5]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
torch.set_grad_enabled(False)
;

''

In [6]:
source_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B"
                                             , attn_implementation="sdpa"
                                            # , attn_implementation="flash_attention_2"
                                             ,torch_dtype=dtype)
source_model.eval()
source_model.lm_head = torch.nn.Identity()
reference_model = copy.deepcopy(source_model)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

In [7]:
model_config = source_model.config

In [8]:
armt_config = dict(
    # segment_size=32,
    # num_mem_tokens=16,
    segment_size=512,
    num_mem_tokens=128,
    d_mem=64,
)

In [9]:
armt_model = wrap_model_with_armt(source_model, **armt_config)
armt_model.to("cuda")

armt_reference_model = wrap_model_with_armt(reference_model, **armt_config)
armt_reference_model.to("cuda")
;

''

In [10]:
model_cpt = "/home/jovyan/.cache/huggingface/hub/models--irodkin--ARMT-llama3.2-1B/snapshots/746e74bba3edc4cb3eaa11e13df5d900495e2300/armt_llama3.2-1B_step19500.bin"
cpt = torch.load(model_cpt, map_location='cuda')

  cpt = torch.load(model_cpt, map_location='cuda')


In [11]:
# armt_model.load_state_dict(cpt, strict=False)
# armt_reference_model.load_state_dict(cpt, strict=False)

In [12]:
grouped_states = get_grouped_states(armt_model)
grouped_layer = make_grouped_layer_from_single_layer(
    copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)
# grouped_layer._grouped_execution = True
# grouped_layer._skip_associating = True
armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)


In [13]:
batcher = GroupedBatcher(
    armt_grouped_model, 
    n_layers=model_config.num_hidden_layers, 
    seg_size=armt_config["segment_size"]+armt_config["num_mem_tokens"], 
    hid_dim=model_config.hidden_size, 
    pos_embed_dim=model_config.hidden_size
)
executor = ArmtGroupedExecutor(armt_grouped_model, grouped_layer, batcher)


In [14]:
torch.cuda.empty_cache()

In [31]:
num_segments = 10
input_ids = torch.randint(
    0, 5000, 
    (1, num_segments*armt_config["segment_size"]), 
    dtype=torch.long, 
    device="cuda"
)


In [32]:
armt_reference_model.memory_cell.zero_mem()
reference_output = reference_model.forward(input_ids)

In [33]:
output = executor.forward(input_ids)

In [34]:
output.logits

tensor([[[-0.0247,  0.0835,  1.6250,  ...,  0.5977, -0.5859,  0.3242],
         [ 0.4023,  2.0938,  4.0938,  ..., -2.2969, -0.8828,  2.1875],
         [-0.2637,  4.8438,  3.8125,  ..., -2.0781, -6.3125,  1.6406],
         ...,
         [-1.6328,  4.2812,  0.6875,  ..., -0.7773, -2.3281, -0.4902],
         [-1.1953,  4.3750, -0.3418,  ..., -2.4688, -2.9375,  1.4922],
         [-1.7031,  4.1250,  2.4219,  ..., -0.4258, -2.7656,  0.0928]]],
       device='cuda:0')

In [35]:
reference_output.logits

tensor([[[-0.0247,  0.0835,  1.6250,  ...,  0.5977, -0.5859,  0.3242],
         [ 0.4023,  2.0938,  4.0938,  ..., -2.2969, -0.8828,  2.1875],
         [-0.2637,  4.8438,  3.8125,  ..., -2.0781, -6.3125,  1.6406],
         ...,
         [-1.1562,  4.0000,  0.7461,  ..., -1.9062, -2.4375,  0.4082],
         [-0.9609,  3.7188, -0.2090,  ..., -2.4219, -2.2031,  1.3438],
         [-1.4062,  3.9062,  2.7969,  ..., -2.2188, -2.3281,  0.7578]]],
       device='cuda:0', dtype=torch.float32)

In [36]:
torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)

tensor(0.3627, device='cuda:0', dtype=torch.float32)

#### this way you can "batch" several inputs to amortize the cost of the batcher

In [38]:
output_list = executor.forward([input_ids, input_ids])

In [39]:
torch.allclose(output_list[0].logits, output_list[1].logits)

True