In [1]:
import torch
import torch.nn as nn

import sys

In [2]:
sys.path.append('..')
sys.path.append("./associative-recurrent-memory-transformer")

In [3]:
from grouped_batching.llama1b_grouping import (
    wrap_model_with_armt, get_grouped_states, 
    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,
    make_grouped_sliced_layer_from_single_layer
)
from grouped_batching.fast_executor import (
    FastGroupedArmtExecutor, GroupedLayerContext, 
    associate_with_context, update_mem_with_context
)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from grouped_batching.universal_grouping import (
    extract_params_from_module, get_universal_grouped_states,
    get_module_by_path, make_universal_grouped_layer,
    set_module_by_path, make_universal_grouped_model
)




In [5]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
torch.set_grad_enabled(False)
;

''

In [6]:
import torch.nn as nn
import traceback

def add_trace_to_forward(module, msg):
    original_forward = module.forward

    def traced_forward(*args, **kwargs):
        print(f"\n[TRACE] {msg} {module.__class__.__name__}.forward called from:\n" + ''.join(traceback.format_stack(limit=10)))
        return original_forward(*args, **kwargs)

    module.forward = traced_forward
    return module

In [37]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import copy

MODEL_TYPE = "gpt2" # "gpt2" "llama"


if MODEL_TYPE == "gpt2":
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
    source_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m"
                                                , attn_implementation="flash_attention_2"
                                                ,torch_dtype=dtype)
else:
    tokenizer = AutoTokenizer.from_pretrained("JackFram/llama-160m")
    source_model = AutoModelForCausalLM.from_pretrained("JackFram/llama-160m"
                                                , attn_implementation="flash_attention_2"
                                                ,torch_dtype=dtype)

# Replace all LayerNorm modules in the model with LayerNorm without weights and biases
# for name, module in source_model.named_modules():
#     if isinstance(module, nn.LayerNorm):
#         # Create a new LayerNorm with elementwise_affine=False (no weights and biases)
#         new_layernorm = nn.LayerNorm(
#             normalized_shape=module.normalized_shape,
#             eps=module.eps,
#             elementwise_affine=False
#         )
        
#         # Get the parent module and attribute name to replace the LayerNorm
#         parent_name = '.'.join(name.split('.')[:-1])
#         child_name = name.split('.')[-1]
        
#         if parent_name:
#             parent = source_model
#             for part in parent_name.split('.'):
#                 parent = getattr(parent, part)
#             setattr(parent, child_name, new_layernorm)
#         else:
#             setattr(source_model, child_name, new_layernorm)

# for l in source_model.transformer.h:
    # l.mlp = nn.Identity()
    # l.attn.attention.out_proj = nn.Identity()
    # l.attn.attention = nn.Identity()

# source_model.transformer.h = source_model.transformer.h[:1]

# source_model.transformer.wpe.weight.fill_(0)
# source_model.transformer.wpe.weight.data = torch.tensor([3,3,3], dtype=dtype)
# source_model.transformer.wte.weight.fill_(0)
# source_model.transformer.wte.weight.data = torch.tensor([3,3,3], dtype=dtype)
# add_trace_to_forward(source_model.transformer.wpe, 'wpe')
# add_trace_to_forward(source_model.transformer.wte, 'wte')


# source_model.ln_f = source_model.transformer.ln_f
# source_model.transformer.ln_f = nn.Identity()
# source_model.lm_head = nn.Identity()
reference_model = copy.deepcopy(source_model)

In [38]:
source_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoFlashAttention2(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_

In [39]:
armt_config = dict(
    segment_size=1024,
    num_mem_tokens=128,
    d_mem=64,
)

In [40]:
source_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoFlashAttention2(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_

In [41]:
if MODEL_TYPE == "gpt2":
    layers_attr = 'transformer.h'
else:
    layers_attr = 'model.layers'

torch.manual_seed(0)
armt_model = wrap_model_with_armt(source_model, **armt_config, layers_attr=layers_attr)
armt_model.to("cuda")

torch.manual_seed(0)
armt_reference_model = wrap_model_with_armt(reference_model, **armt_config, layers_attr=layers_attr)
armt_reference_model.to("cuda")
;

''

In [42]:
if MODEL_TYPE == "gpt2":
    grouped_params = get_universal_grouped_states(armt_model.memory_cell.model.transformer.h)
else:
    grouped_params = get_universal_grouped_states(armt_model.memory_cell.model.model.layers)

In [43]:
list(grouped_params.keys())

['W_mq.weight',
 'W_mk.weight',
 'W_mv.weight',
 'W_mb.weight',
 'W_mb.bias',
 'layer.ln_1.weight',
 'layer.ln_1.bias',
 'layer.attn.attention.k_proj.weight',
 'layer.attn.attention.v_proj.weight',
 'layer.attn.attention.q_proj.weight',
 'layer.attn.attention.out_proj.weight',
 'layer.attn.attention.out_proj.bias',
 'layer.ln_2.weight',
 'layer.ln_2.bias',
 'layer.mlp.c_fc.weight',
 'layer.mlp.c_fc.bias',
 'layer.mlp.c_proj.weight',
 'layer.mlp.c_proj.bias',
 'W_mem',
 'z']

In [44]:
grouped_context = GroupedLayerContext()


In [45]:
if MODEL_TYPE == "gpt2":
    layer_base = copy.deepcopy(armt_model.memory_cell.model.transformer.h[0])
else:
    layer_base = copy.deepcopy(armt_model.memory_cell.model.model.layers[0])

grouped_layer = make_universal_grouped_layer(
    grouped_context, 
    layer_base,
    grouped_params,
    use_layer_norm=(MODEL_TYPE == "gpt2"), # layer norm for gpt2, rms norm for llama
)

SUBSTITUTE efficient module_path='W_mq': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='W_mk': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='W_mv': len(weight_values)=12 None 
SUBSTITUTE naive module_path='W_mb': len(weight_values)=12 12 
SUBSTITUTE efficient module_path='layer.attn.attention.k_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.attn.attention.v_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.attn.attention.q_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.attn.attention.out_proj': len(weight_values)=12 12 
SUBSTITUTE efficient module_path='layer.mlp.c_fc': len(weight_values)=12 12 
SUBSTITUTE efficient module_path='layer.mlp.c_proj': len(weight_values)=12 12 
SUBSTITUTE norm_path='layer.ln_1': torch.Size([12, 768]) torch.Size([12, 768]) 
SUBSTITUTE norm_path='layer.ln_2': torch.Size([12, 768]) torch.Size([12, 768]) 


In [46]:
source_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x AssociativeLayerWrapper(
        (W_mq): Linear(in_features=768, out_features=64, bias=False)
        (W_mk): Linear(in_features=768, out_features=64, bias=False)
        (W_mv): Linear(in_features=768, out_features=768, bias=False)
        (W_mb): Linear(in_features=768, out_features=1, bias=True)
        (layer): GPTNeoBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPTNeoAttention(
            (attention): GPTNeoFlashAttention2(
              (attn_dropout): Dropout(p=0.0, inplace=False)
              (resid_dropout): Dropout(p=0.0, inplace=False)
              (k_proj): Linear(in_features=768, out_features=768, bias=False)
              (v_proj): Linear(in_features=768, out_features=768, bias=False)
              (q_proj): Linear(in_features

In [47]:
grouped_layer.W_mem.data.shape, grouped_layer.z.data.shape

(torch.Size([12, 384, 768]), torch.Size([12, 384]))

In [48]:
from grouped_batching.universal_executor import UniversalGroupedExecutor

In [49]:
source_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x AssociativeLayerWrapper(
        (W_mq): Linear(in_features=768, out_features=64, bias=False)
        (W_mk): Linear(in_features=768, out_features=64, bias=False)
        (W_mv): Linear(in_features=768, out_features=768, bias=False)
        (W_mb): Linear(in_features=768, out_features=1, bias=True)
        (layer): GPTNeoBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPTNeoAttention(
            (attention): GPTNeoFlashAttention2(
              (attn_dropout): Dropout(p=0.0, inplace=False)
              (resid_dropout): Dropout(p=0.0, inplace=False)
              (k_proj): Linear(in_features=768, out_features=768, bias=False)
              (v_proj): Linear(in_features=768, out_features=768, bias=False)
              (q_proj): Linear(in_features

In [50]:
if MODEL_TYPE == "gpt2":
    layers_path = "memory_cell.model.transformer.h"
else:
    layers_path = "memory_cell.model.model.layers"

armt_grouped_model, source_model_layers = make_universal_grouped_model(
    armt_model, 
    grouped_layer,
    layers_path=layers_path
)


In [51]:

def preprocess_segment_gpt2(model, input_segment):
    pos_ids = torch.arange(
        0, 
        input_segment.shape[0], 
        device=input_segment.device
    )
    wpe_emb = model.memory_cell.model.transformer.wpe(pos_ids)
    input_segment.add_(wpe_emb)
    return input_segment

grouped_compute_gpt2 = None # uses default grouped layer

def postprocess_segment_gpt2(model, output_segment):
    out = model.memory_cell.model.transformer.ln_f(output_segment)
    out_tokens = model.memory_cell.model.lm_head(out)
    return out_tokens



preprocess_segment_llama = None

def grouped_compute_llama(model, grouped_layer, grouped_input_tensor):
    position_ids = torch.arange(
        0, 
        grouped_input_tensor.shape[1], 
        device=grouped_input_tensor.device
    ).unsqueeze(0)
    position_embeddings = model.memory_cell.model.model.rotary_emb(grouped_input_tensor, position_ids)
    batch_output = grouped_layer.forward(grouped_input_tensor, position_embeddings=position_embeddings)
    return batch_output
    

def postprocess_segment_llama(model, output_segment):
    out = model.memory_cell.model.model.norm(output_segment)
    out_tokens = model.memory_cell.model.lm_head(out)
    return out_tokens


In [52]:
if MODEL_TYPE == "gpt2":
    executor = UniversalGroupedExecutor(
        model=armt_grouped_model,
        grouped_layer=grouped_layer,
        context=grouped_context,
        n_layers=source_model.config.num_hidden_layers,
        model_path="memory_cell.model.transformer",
        out_norm_attr="memory_cell.model.transformer.ln_f",
        lm_head_path="memory_cell.model.lm_head",
        memory_path="memory_cell",
        preprocess_segment_fn = preprocess_segment_gpt2,
        postprocess_segment_fn = postprocess_segment_gpt2,
        grouped_compute_fn = grouped_compute_gpt2,
    )
else:
    executor = UniversalGroupedExecutor(
        model=armt_grouped_model,
        grouped_layer=grouped_layer,
        context=grouped_context,
        n_layers=source_model.config.num_hidden_layers,
        model_path="memory_cell.model.model",
        out_norm_attr="memory_cell.model.model.norm",
        lm_head_path="memory_cell.model.lm_head",
        memory_path="memory_cell",
        preprocess_segment_fn = preprocess_segment_llama,
        postprocess_segment_fn = postprocess_segment_llama,
        grouped_compute_fn = grouped_compute_llama,
    )

executor.vanilla_model = armt_reference_model

# executor = FastGroupedArmtExecutor(
#     armt_grouped_model, 
#     grouped_layer, 
#     grouped_context, 
#     source_model.config.num_hidden_layers
# )

In [53]:
executor.model.memory_cell.model.model.norm

AttributeError: 'GPTNeoForCausalLM' object has no attribute 'model'

In [54]:
### ONLY FOR FAST LATENCY VERSION

# compile full layers
segments_input = torch.rand((source_model.config.num_hidden_layers, armt_config['segment_size'], source_model.config.hidden_size), device="cuda", dtype=dtype)

i, j = 0, source_model.config.num_hidden_layers
grouped_context.start_idx = i
grouped_context.end_idx = j
grouped_context.is_full = True

ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])
grouped_layer.generate_mode = True
if MODEL_TYPE == "gpt2":
    armt_grouped_model.memory_cell.model.transformer(inputs_embeds=segments_input[i:j], use_cache=False)
else:
    armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)
update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])

# del ao
# del segments_input

In [55]:
# 4096, 8192, 16384, 32768, 65536, 131072
num_segments = 4096//armt_config["segment_size"]
input_ids = torch.randint(
    0, 10000, 
    (1, num_segments*armt_config["segment_size"]), 
    dtype=torch.long, 
    device="cuda"
)

In [56]:
input_ids.shape

torch.Size([1, 4096])

#### Forward (prefill)

In [57]:
%%time
# %%timeit

torch.manual_seed(0)

with torch.no_grad():
    # armt_reference_model.memory_cell.zero_mem()
    armt_reference_model.memory_cell.generate_mode(False)
    reference_output = armt_reference_model.forward(input_ids)

torch.cuda.synchronize()

CPU times: user 581 ms, sys: 6.47 ms, total: 588 ms
Wall time: 99.3 ms


In [58]:
%%time
# %%timeit

torch.manual_seed(0)

with torch.no_grad():
    output = executor.forward(input_ids, skip_concat=False)

torch.cuda.synchronize()

CPU times: user 60.4 ms, sys: 3.83 ms, total: 64.3 ms
Wall time: 67 ms


In [59]:
torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)

tensor(0.0332, device='cuda:0')

#### Generate

In [60]:
generate_kwargs = {
    'max_new_tokens': 20,
    'pad_token_id': 0,
    'eos_token_id': 1,
    'attention_mask': None,
    # 'attention_mask': torch.tril(torch.ones((1024, 1024), device='cuda', dtype=bool))
}

In [61]:
gen_out_ref = armt_reference_model.generate(input_ids, **generate_kwargs)

In [62]:
gen_out_ref

tensor([[  13,  405, 1065,   13,  198,  198,    7,   64,    8,   77,  375,  198,
            8,   68,   13,   68,   13,   68,   13,   68]], device='cuda:0')

In [63]:
gen_out = executor.generate(input_ids, seg_size=1024+128, **generate_kwargs)

torch.Size([1, 3072]) torch.Size([1, 1024])


In [64]:
gen_out

(tensor([[  13,  405, 1065,   13,  198,  198,    7,   64,    8,   77,  375,  198,
             8,   68,   13,   68,   13,   68,   13,   68]], device='cuda:0'),
 0.00043964385986328125)

In [65]:
(gen_out_ref == gen_out[0]).all()

tensor(True, device='cuda:0')