In [207]:
import torch
import torch.nn as nn

import sys

In [208]:
sys.path.append('..')
sys.path.append("./associative-recurrent-memory-transformer")

In [209]:
from grouped_batching.llama1b_grouping import (
    wrap_model_with_armt, get_grouped_states, 
    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,
    make_grouped_sliced_layer_from_single_layer
)
from grouped_batching.fast_executor import (
    FastGroupedArmtExecutor, GroupedLayerContext, 
    associate_with_context, update_mem_with_context
)

In [210]:
from grouped_batching.universal_grouping import (
    extract_params_from_module, get_universal_grouped_states,
    get_module_by_path, make_universal_grouped_layer,
    set_module_by_path, make_universal_grouped_model
)




In [211]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
torch.set_grad_enabled(False)
;

''

In [212]:
import torch.nn as nn
import traceback

def add_trace_to_forward(module, msg):
    original_forward = module.forward

    def traced_forward(*args, **kwargs):
        print(f"\n[TRACE] {msg} {module.__class__.__name__}.forward called from:\n" + ''.join(traceback.format_stack(limit=10)))
        return original_forward(*args, **kwargs)

    module.forward = traced_forward
    return module

In [213]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import copy

MODEL_TYPE = "llama" # "gpt2" "llama"


if MODEL_TYPE == "gpt2":
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
    source_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m"
                                                , attn_implementation="flash_attention_2"
                                                ,torch_dtype=dtype)
else:
    tokenizer = AutoTokenizer.from_pretrained("JackFram/llama-160m")
    source_model = AutoModelForCausalLM.from_pretrained("JackFram/llama-160m"
                                                , attn_implementation="flash_attention_2"
                                                ,torch_dtype=dtype)

# Replace all LayerNorm modules in the model with LayerNorm without weights and biases
# for name, module in source_model.named_modules():
#     if isinstance(module, nn.LayerNorm):
#         # Create a new LayerNorm with elementwise_affine=False (no weights and biases)
#         new_layernorm = nn.LayerNorm(
#             normalized_shape=module.normalized_shape,
#             eps=module.eps,
#             elementwise_affine=False
#         )
        
#         # Get the parent module and attribute name to replace the LayerNorm
#         parent_name = '.'.join(name.split('.')[:-1])
#         child_name = name.split('.')[-1]
        
#         if parent_name:
#             parent = source_model
#             for part in parent_name.split('.'):
#                 parent = getattr(parent, part)
#             setattr(parent, child_name, new_layernorm)
#         else:
#             setattr(source_model, child_name, new_layernorm)

# for l in source_model.transformer.h:
    # l.mlp = nn.Identity()
    # l.attn.attention.out_proj = nn.Identity()
    # l.attn.attention = nn.Identity()

# source_model.transformer.h = source_model.transformer.h[:1]

# source_model.transformer.wpe.weight.fill_(0)
# source_model.transformer.wpe.weight.data = torch.tensor([3,3,3], dtype=dtype)
# source_model.transformer.wte.weight.fill_(0)
# source_model.transformer.wte.weight.data = torch.tensor([3,3,3], dtype=dtype)
# add_trace_to_forward(source_model.transformer.wpe, 'wpe')
# add_trace_to_forward(source_model.transformer.wte, 'wte')


# source_model.ln_f = source_model.transformer.ln_f
# source_model.transformer.ln_f = nn.Identity()
# source_model.lm_head = nn.Identity()
reference_model = copy.deepcopy(source_model)

In [214]:
source_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 768)
    (layers): ModuleList(
      (0-11): 12 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (o_proj): Linear(in_features=768, out_features=768, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=768, out_features=3072, bias=False)
          (up_proj): Linear(in_features=768, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=768, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((768,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((768,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((768,), eps=1e-06)
    (rotary_emb): LlamaRotaryEm

In [215]:
armt_config = dict(
    segment_size=1024,
    num_mem_tokens=128,
    d_mem=64,
)

In [216]:
source_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 768)
    (layers): ModuleList(
      (0-11): 12 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (o_proj): Linear(in_features=768, out_features=768, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=768, out_features=3072, bias=False)
          (up_proj): Linear(in_features=768, out_features=3072, bias=False)
          (down_proj): Linear(in_features=3072, out_features=768, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((768,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((768,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((768,), eps=1e-06)
    (rotary_emb): LlamaRotaryEm

In [217]:
if MODEL_TYPE == "gpt2":
    layers_attr = 'transformer.h'
else:
    layers_attr = 'model.layers'

torch.manual_seed(0)
armt_model = wrap_model_with_armt(source_model, **armt_config, layers_attr=layers_attr)
armt_model.to("cuda")

torch.manual_seed(0)
armt_reference_model = wrap_model_with_armt(reference_model, **armt_config, layers_attr=layers_attr)
armt_reference_model.to("cuda")
;

''

In [218]:
if MODEL_TYPE == "gpt2":
    grouped_params = get_universal_grouped_states(armt_model.memory_cell.model.transformer.h)
else:
    grouped_params = get_universal_grouped_states(armt_model.memory_cell.model.model.layers)

In [128]:
list(grouped_params.keys())

['W_mq.weight',
 'W_mk.weight',
 'W_mv.weight',
 'W_mb.weight',
 'W_mb.bias',
 'layer.self_attn.q_proj.weight',
 'layer.self_attn.k_proj.weight',
 'layer.self_attn.v_proj.weight',
 'layer.self_attn.o_proj.weight',
 'layer.mlp.gate_proj.weight',
 'layer.mlp.up_proj.weight',
 'layer.mlp.down_proj.weight',
 'layer.input_layernorm.weight',
 'layer.post_attention_layernorm.weight',
 'W_mem',
 'z']

In [129]:
grouped_context = GroupedLayerContext()


In [130]:
if MODEL_TYPE == "gpt2":
    layer_base = copy.deepcopy(armt_model.memory_cell.model.transformer.h[0])
else:
    layer_base = copy.deepcopy(armt_model.memory_cell.model.model.layers[0])

grouped_layer = make_universal_grouped_layer(
    grouped_context, 
    layer_base,
    grouped_params,
    use_layer_norm=(MODEL_TYPE == "gpt2"), # layer norm for gpt2, rms norm for llama
)

SUBSTITUTE efficient module_path='W_mq': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='W_mk': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='W_mv': len(weight_values)=12 None 
SUBSTITUTE naive module_path='W_mb': len(weight_values)=12 12 
SUBSTITUTE efficient module_path='layer.self_attn.q_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.self_attn.k_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.self_attn.v_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.self_attn.o_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.mlp.gate_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.mlp.up_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.mlp.down_proj': len(weight_values)=12 None 
SUBSTITUTE norm_path='layer.input_layernorm': torch.Size([12, 768]) None 
SUBSTITUTE norm_path='layer.post_attention_layernorm': tor

In [131]:
source_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 768)
    (layers): ModuleList(
      (0-11): 12 x AssociativeLayerWrapper(
        (W_mq): Linear(in_features=768, out_features=64, bias=False)
        (W_mk): Linear(in_features=768, out_features=64, bias=False)
        (W_mv): Linear(in_features=768, out_features=768, bias=False)
        (W_mb): Linear(in_features=768, out_features=1, bias=True)
        (layer): LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=False)
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=768, out_features=3072, bias=False)
            (up_proj): Linear(in_features=768, out_features=3072, bias=Fal

In [132]:
grouped_layer.W_mem.data.shape, grouped_layer.z.data.shape

(torch.Size([12, 384, 768]), torch.Size([12, 384]))

In [135]:
from grouped_batching.universal_executor import UniversalGroupedExecutor

In [160]:
source_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 768)
    (layers): ModuleList(
      (0): AssociativeLayerWrapper(
        (W_mq): Linear(in_features=768, out_features=64, bias=False)
        (W_mk): Linear(in_features=768, out_features=64, bias=False)
        (W_mv): Linear(in_features=768, out_features=768, bias=False)
        (W_mb): Linear(in_features=768, out_features=1, bias=True)
        (layer): LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=False)
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=768, out_features=3072, bias=False)
            (up_proj): Linear(in_features=768, out_features=3072, bias=False)
    

In [None]:
if MODEL_TYPE == "gpt2":
    layers_path = "memory_cell.model.transformer.h"
else:
    layers_path = "memory_cell.model.model.layers"

armt_grouped_model, source_model_layers = make_universal_grouped_model(
    armt_model, 
    grouped_layer,
    layers_path=layers_path
)


In [192]:

def preprocess_segment_gpt2(model, input_segment):
    pos_ids = torch.arange(
        0, 
        input_segment.shape[0], 
        device=input_segment.device
    )
    wpe_emb = model.memory_cell.model.transformer.wpe(pos_ids)
    input_segment.add_(wpe_emb)
    return input_segment

grouped_compute_gpt2 = None # uses default grouped layer

def postprocess_segment_gpt2(model, output_segment):
    out = model.memory_cell.model.transformer.ln_f(output_segment)
    out_tokens = model.memory_cell.model.lm_head(out)
    return out_tokens



preprocess_segment_llama = None

def grouped_compute_llama(model, grouped_layer, grouped_input_tensor):
    position_ids = torch.arange(
        0, 
        grouped_input_tensor.shape[1], 
        device=grouped_input_tensor.device
    ).unsqueeze(0)
    position_embeddings = model.memory_cell.model.model.rotary_emb(grouped_input_tensor, position_ids)
    batch_output = grouped_layer.forward(grouped_input_tensor, position_embeddings=position_embeddings)
    return batch_output
    

def postprocess_segment_llama(model, output_segment):
    out = model.memory_cell.model.model.norm(output_segment)
    out_tokens = model.memory_cell.model.lm_head(out)
    return out_tokens


In [193]:
if MODEL_TYPE == "gpt2":
    executor = UniversalGroupedExecutor(
        model=armt_grouped_model,
        grouped_layer=grouped_layer,
        context=grouped_context,
        n_layers=source_model.config.num_hidden_layers,
        model_path="memory_cell.model.transformer",
        out_norm_attr="memory_cell.model.transformer.ln_f",
        lm_head_path="memory_cell.model.lm_head",
        memory_path="memory_cell",
        preprocess_segment_fn = preprocess_segment_gpt2,
        postprocess_segment_fn = postprocess_segment_gpt2,
        grouped_compute_fn = grouped_compute_gpt2,
    )
else:
    executor = UniversalGroupedExecutor(
        model=armt_grouped_model,
        grouped_layer=grouped_layer,
        context=grouped_context,
        n_layers=source_model.config.num_hidden_layers,
        model_path="memory_cell.model.model",
        out_norm_attr="memory_cell.model.model.norm",
        lm_head_path="memory_cell.model.lm_head",
        memory_path="memory_cell",
        preprocess_segment_fn = preprocess_segment_llama,
        postprocess_segment_fn = postprocess_segment_llama,
        grouped_compute_fn = grouped_compute_llama,
    )

executor.vanilla_model = armt_reference_model

# executor = FastGroupedArmtExecutor(
#     armt_grouped_model, 
#     grouped_layer, 
#     grouped_context, 
#     source_model.config.num_hidden_layers
# )

In [194]:
executor.model.memory_cell.model.model.norm

LlamaRMSNorm((768,), eps=1e-06)

In [195]:
### ONLY FOR FAST LATENCY VERSION

# compile full layers
segments_input = torch.rand((source_model.config.num_hidden_layers, armt_config['segment_size'], source_model.config.hidden_size), device="cuda", dtype=dtype)

i, j = 0, source_model.config.num_hidden_layers
grouped_context.start_idx = i
grouped_context.end_idx = j
grouped_context.is_full = True

ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])
grouped_layer.generate_mode = True
if MODEL_TYPE == "gpt2":
    armt_grouped_model.memory_cell.model.transformer(inputs_embeds=segments_input[i:j], use_cache=False)
else:
    armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)
update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])

# del ao
# del segments_input

In [196]:
# 4096, 8192, 16384, 32768, 65536, 131072
num_segments = 4096//armt_config["segment_size"]
input_ids = torch.randint(
    0, 10000, 
    (1, num_segments*armt_config["segment_size"]), 
    dtype=torch.long, 
    device="cuda"
)

In [197]:
input_ids.shape

torch.Size([1, 4096])

In [198]:
%%time
# %%timeit

torch.manual_seed(0)

with torch.no_grad():
    # armt_reference_model.memory_cell.zero_mem()
    armt_reference_model.memory_cell.generate_mode(False)
    reference_output = armt_reference_model.forward(input_ids)

torch.cuda.synchronize()

CPU times: user 514 ms, sys: 21.9 ms, total: 536 ms
Wall time: 118 ms


In [199]:
%%time
# %%timeit

torch.manual_seed(0)

with torch.no_grad():
    output = executor.forward(input_ids, skip_concat=False)

torch.cuda.synchronize()

CPU times: user 64.1 ms, sys: 0 ns, total: 64.1 ms
Wall time: 63.3 ms


In [200]:
torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)

tensor(0., device='cuda:0')

In [201]:
output.logits.shape, reference_output.logits.shape

(torch.Size([4096, 32000]), torch.Size([1, 4096, 32000]))

In [31]:
missed_tokens = (output.logits.argmax(1) != reference_output.logits.argmax(2)[0]).sum()

In [32]:
missed_tokens, output.logits.shape[0], missed_tokens/output.logits.shape[0]

(tensor(576, device='cuda:0'), 4096, tensor(0.1406, device='cuda:0'))

In [202]:
generate_kwargs = {
    'max_new_tokens': 20,
    'pad_token_id': 0,
    'eos_token_id': 1,
    'attention_mask': None,
    # 'attention_mask': torch.tril(torch.ones((1024, 1024), device='cuda', dtype=bool))
}

In [203]:
gen_out_ref = armt_reference_model.generate(input_ids, **generate_kwargs)

In [204]:
gen_out_ref

tensor([[29908,  1159,   376, 29908, 29908, 29908, 29908, 29908, 29908, 29908,
         29908, 29908, 29908, 29908, 29908, 29908, 29908, 29908, 29908, 29908]],
       device='cuda:0')

In [205]:
gen_out = executor.generate(input_ids, seg_size=1024+128, **generate_kwargs)

torch.Size([1, 3072]) torch.Size([1, 1024])


In [206]:
gen_out

(tensor([[29908,  1159,   376, 29908, 29908, 29908, 29908, 29908, 29908, 29908,
          29908, 29908, 29908, 29908, 29908, 29908, 29908, 29908, 29908, 29908]],
        device='cuda:0'),
 0.0005311965942382812)