In [1]:
import torch
import torch.nn as nn

import sys

In [2]:
sys.path.append('..')
sys.path.append("./associative-recurrent-memory-transformer")

In [3]:
from grouped_batching.llama1b_grouping import (
    wrap_model_with_armt, get_grouped_states, 
    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,
    make_grouped_sliced_layer_from_single_layer
)
from grouped_batching.fast_executor import (
    FastGroupedArmtExecutor, GroupedLayerContext, 
    associate_with_context, update_mem_with_context
)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from grouped_batching.universal_grouping import (
    extract_params_from_module, get_universal_grouped_states,
    get_module_by_path, make_universal_grouped_layer,
    set_module_by_path, make_universal_grouped_model
)




In [5]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
torch.set_grad_enabled(False)
;

''

In [6]:
import torch.nn as nn
import traceback

def add_trace_to_forward(module, msg):
    original_forward = module.forward

    def traced_forward(*args, **kwargs):
        print(f"\n[TRACE] {msg} {module.__class__.__name__}.forward called from:\n" + ''.join(traceback.format_stack(limit=10)))
        return original_forward(*args, **kwargs)

    module.forward = traced_forward
    return module

In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import copy

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
source_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m"
                                             , attn_implementation="flash_attention_2"
                                             ,torch_dtype=dtype)

# Replace all LayerNorm modules in the model with LayerNorm without weights and biases
# for name, module in source_model.named_modules():
#     if isinstance(module, nn.LayerNorm):
#         # Create a new LayerNorm with elementwise_affine=False (no weights and biases)
#         new_layernorm = nn.LayerNorm(
#             normalized_shape=module.normalized_shape,
#             eps=module.eps,
#             elementwise_affine=False
#         )
        
#         # Get the parent module and attribute name to replace the LayerNorm
#         parent_name = '.'.join(name.split('.')[:-1])
#         child_name = name.split('.')[-1]
        
#         if parent_name:
#             parent = source_model
#             for part in parent_name.split('.'):
#                 parent = getattr(parent, part)
#             setattr(parent, child_name, new_layernorm)
#         else:
#             setattr(source_model, child_name, new_layernorm)

# for l in source_model.transformer.h:
    # l.mlp = nn.Identity()
    # l.attn.attention.out_proj = nn.Identity()
    # l.attn.attention = nn.Identity()

# source_model.transformer.h = source_model.transformer.h[:1]

# source_model.transformer.wpe.weight.fill_(0)
# source_model.transformer.wpe.weight.data = torch.tensor([3,3,3], dtype=dtype)
# source_model.transformer.wte.weight.fill_(0)
# source_model.transformer.wte.weight.data = torch.tensor([3,3,3], dtype=dtype)
# add_trace_to_forward(source_model.transformer.wpe, 'wpe')
# add_trace_to_forward(source_model.transformer.wte, 'wte')


# source_model.ln_f = source_model.transformer.ln_f
# source_model.transformer.ln_f = nn.Identity()
# source_model.lm_head = nn.Identity()
reference_model = copy.deepcopy(source_model)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [8]:
source_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoFlashAttention2(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_

In [9]:
armt_config = dict(
    segment_size=1024,
    num_mem_tokens=128,
    d_mem=64,
)

In [10]:
source_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoFlashAttention2(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_

In [11]:
torch.manual_seed(0)
armt_model = wrap_model_with_armt(source_model, **armt_config, layers_attr='transformer.h')
armt_model.to("cuda")

torch.manual_seed(0)
armt_reference_model = wrap_model_with_armt(reference_model, **armt_config, layers_attr='transformer.h')
armt_reference_model.to("cuda")
;

''

In [12]:
grouped_params = get_universal_grouped_states(armt_model.memory_cell.model.transformer.h)

In [13]:
list(grouped_params.keys())

['W_mq.weight',
 'W_mk.weight',
 'W_mv.weight',
 'W_mb.weight',
 'W_mb.bias',
 'layer.ln_1.weight',
 'layer.ln_1.bias',
 'layer.attn.attention.k_proj.weight',
 'layer.attn.attention.v_proj.weight',
 'layer.attn.attention.q_proj.weight',
 'layer.attn.attention.out_proj.weight',
 'layer.attn.attention.out_proj.bias',
 'layer.ln_2.weight',
 'layer.ln_2.bias',
 'layer.mlp.c_fc.weight',
 'layer.mlp.c_fc.bias',
 'layer.mlp.c_proj.weight',
 'layer.mlp.c_proj.bias',
 'W_mem',
 'z']

In [14]:
grouped_context = GroupedLayerContext()


In [15]:
grouped_layer = make_universal_grouped_layer(
    grouped_context, 
    copy.deepcopy(armt_model.memory_cell.model.transformer.h[0]),
    grouped_params
)

SUBSTITUTE efficient module_path='W_mq': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='W_mk': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='W_mv': len(weight_values)=12 None 
SUBSTITUTE naive module_path='W_mb': len(weight_values)=12 12 
SUBSTITUTE efficient module_path='layer.attn.attention.k_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.attn.attention.v_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.attn.attention.q_proj': len(weight_values)=12 None 
SUBSTITUTE efficient module_path='layer.attn.attention.out_proj': len(weight_values)=12 12 
SUBSTITUTE efficient module_path='layer.mlp.c_fc': len(weight_values)=12 12 
SUBSTITUTE efficient module_path='layer.mlp.c_proj': len(weight_values)=12 12 
SUBSTITUTE norm_path='layer.ln_1': torch.Size([12, 768]) torch.Size([12, 768]) 
SUBSTITUTE norm_path='layer.ln_2': torch.Size([12, 768]) torch.Size([12, 768]) 


In [16]:
source_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x AssociativeLayerWrapper(
        (W_mq): Linear(in_features=768, out_features=64, bias=False)
        (W_mk): Linear(in_features=768, out_features=64, bias=False)
        (W_mv): Linear(in_features=768, out_features=768, bias=False)
        (W_mb): Linear(in_features=768, out_features=1, bias=True)
        (layer): GPTNeoBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPTNeoAttention(
            (attention): GPTNeoFlashAttention2(
              (attn_dropout): Dropout(p=0.0, inplace=False)
              (resid_dropout): Dropout(p=0.0, inplace=False)
              (k_proj): Linear(in_features=768, out_features=768, bias=False)
              (v_proj): Linear(in_features=768, out_features=768, bias=False)
              (q_proj): Linear(in_features

In [17]:
grouped_layer.W_mem.data.shape, grouped_layer.z.data.shape

(torch.Size([12, 384, 768]), torch.Size([12, 384]))

In [18]:
grouped_layer.layer.ln_1.weight.shape

torch.Size([12, 1, 768])

In [19]:
grouped_layer.layer.ln_1.bias.shape

torch.Size([12, 1, 768])

In [20]:
import transformers

def zero_grouped_memory(self):
    """Zero out the memory of a grouped ARMT model."""
    self.W_mem.detach_()
    self.W_mem.fill_(0)
    self.z.detach_()
    self.z.fill_(0)

class UniversalGroupedExecutor(torch.nn.Module):
    """
    Universal grouped executor for ARMT models.
    
    This class provides a flexible executor that can work with any model structure
    by using configurable paths to model components.
    """
    
    def __init__(
        self, 
        model, 
        grouped_layer, 
        context, 
        n_layers, 
        model_path="memory_cell.model.model",
        out_norm_attr="out_norm",
        lm_head_path="memory_cell.model.lm_head",
        memory_path="memory_cell",
        segment_fn=None,
        process_input_fn=None,
        vanilla_model=None,
        preprocess_segment_fn = None,
        postprocess_segment_fn = None,
    ):
        """
        Initialize the universal grouped executor.
        
        Args:
            model: The model to execute
            grouped_layer: The grouped layer
            context: The context for grouped execution
            n_layers: Number of layers in the model
            model_path: Path to the main model module
            out_norm_attr: Attribute name where the output norm is stored
            lm_head_path: Path to the language model head
            memory_path: Path to the memory cell module
            segment_fn: Function to segment inputs (if None, uses model.segment)
            process_input_fn: Function to process inputs (if None, uses memory_cell.process_input)
            vanilla_model: Original model for generation (optional)
        """
        super().__init__()
        self.model = model
        self.grouped_layer = grouped_layer
        self.context = context
        self.n_layers = n_layers
        self.vanilla_model = vanilla_model
        
        self.preprocess_segment_fn = preprocess_segment_fn
        self.postprocess_segment_fn = postprocess_segment_fn
        
        # Store paths
        self.model_path = model_path
        self.out_norm_attr = out_norm_attr
        self.lm_head_path = lm_head_path
        self.memory_path = memory_path
        
        # Get components by path
        self.base_model = get_module_by_path(model, model_path)
        if self.base_model is None:
            raise ValueError(f"Could not find base model at path: {model_path}")

        self.out_norm = get_module_by_path(model, out_norm_attr)
        if self.out_norm is None:
            print(f"Warning: Could not find out norm at path: {out_norm_attr}")
            
        self.lm_head = get_module_by_path(model, lm_head_path)
        if self.lm_head is None:
            print(f"Warning: Could not find LM head at path: {lm_head_path}")
            
        self.memory_cell = get_module_by_path(model, memory_path)
        if self.memory_cell is None:
            raise ValueError(f"Could not find memory cell at path: {memory_path}")
            
        # Segmentation and input processing functions
        self.segment_fn = segment_fn if segment_fn is not None else model.segment
        
        if process_input_fn is not None:
            self.process_input_fn = process_input_fn
        elif hasattr(self.memory_cell, 'process_input'):
            self.process_input_fn = self.memory_cell.process_input
        else:
            raise ValueError("No process_input function provided or found")
            
        # Set generation mode
        self.grouped_layer.generate_mode = True
    
    def forward(self, input_ids, skip_concat=False):
        """
        Forward pass for the grouped model.
        
        Args:
            input_ids: Input token IDs
            skip_concat: Whether to skip concatenating outputs
            
        Returns:
            Model outputs
        """
        self.context.is_full = False
        self.context.start_idx = 0
        self.context.end_idx = 0
        
        # Zero out memory
        zero_grouped_memory(self.grouped_layer)
        
        # Segment inputs
        input_segments = [iseg for iseg in self.segment_fn(input_ids=input_ids)]
        segments = [self.process_input_fn(**iseg)['inputs_embeds'][0] for iseg in input_segments]
        
        segment_outputs = []
        grouped_input = []
        
        for i in range(self.n_layers + len(segments) - 1):
            if i < len(segments):
                # Add new segment until have one 
                new_segment = segments[i]
                if self.preprocess_segment_fn is not None:
                    new_segment = self.preprocess_segment_fn(self.model, new_segment)
                grouped_input.insert(0, new_segment)
                
            grouped_input_tensor = torch.stack(grouped_input).contiguous()
            if i < self.n_layers:
                # Compute before end_idx+=1 to skip first segment association
                if i > 0 and grouped_input_tensor.shape[0] > 1:
                    grouped_input_tensor[:-1, ...] += associate_with_context(self.grouped_layer, self.context, grouped_input_tensor[:-1, ...])
                
                # Allow more weights to be computed
                self.context.end_idx += 1
                if self.context.end_idx == self.n_layers and self.context.start_idx == 0:
                    self.context.is_full = True
            else:
                grouped_input_tensor += associate_with_context(self.grouped_layer, self.context, grouped_input_tensor)
            
            # Process through the grouped layer
            grouped_output = self.grouped_layer.forward(grouped_input_tensor)
            grouped_output = grouped_output[0]
            
            # Update memory with context
            update_mem_with_context(self.grouped_layer, self.context, grouped_output[:, -self.grouped_layer.num_mem_tokens:])
            
            grouped_input = list(grouped_output.unbind(0))
            if i >= self.n_layers - 1:
                segment_out_logits = grouped_input.pop(-1)
                
                
                processed_segment = segment_out_logits[:-self.grouped_layer.num_mem_tokens]
                if self.postprocess_segment_fn is not None:
                    processed_segment = self.postprocess_segment_fn(self.model, processed_segment)
                
                segment_outputs.append(processed_segment)
                
            if i >= len(segments) - 1:
                # Reduce number of weights to be computed
                self.context.start_idx += 1
                self.context.is_full = False
              
        if skip_concat:
            return segment_outputs
        
        # Concatenate outputs
        output = torch.cat(segment_outputs, dim=0)
        
        # Return as a CausalLMOutput
        return transformers.modeling_outputs.CausalLMOutputWithPast(
            logits=output,
        )
    
    def generate(self, input_ids, attention_mask, seg_size, **generate_kwargs):
        """
        Generate text using the model.
        
        Args:
            input_ids: Input token IDs
            attention_mask: Attention mask
            seg_size: Segment size
            **generate_kwargs: Additional keyword arguments for generation
            
        Returns:
            Generated output and copy time
        """
        import time
        
        # Process the vanilla model if available
        if self.vanilla_model is not None:
            vanilla_memory_cell = get_module_by_path(self.vanilla_model, self.memory_path)
            if vanilla_memory_cell is not None:
                vanilla_memory_cell.zero_mem()
        elif hasattr(self.memory_cell, 'zero_mem'):
            self.memory_cell.zero_mem()
            
        # Handle large inputs
        if input_ids.shape[-1] > seg_size - self.grouped_layer.num_mem_tokens:
            # Cut last part of the segment
            last_segm = input_ids.shape[-1] // (seg_size - self.grouped_layer.num_mem_tokens) * (seg_size - self.grouped_layer.num_mem_tokens)
            prev_ids = input_ids[..., :last_segm]
            last_ids = input_ids[..., last_segm:]
            last_attn_mask = attention_mask[..., last_segm:]
            
            # Process previous segments
            _ = self.forward(prev_ids)
            
            # Process last segment
            segmented = self.segment_fn(input_ids=last_ids, attention_mask=last_attn_mask)
            final_segment = segmented[-1]
            
            # Use vanilla model for generation if available
            if self.vanilla_model is not None:
                vanilla_memory_cell = get_module_by_path(self.vanilla_model, self.memory_path)
                if vanilla_memory_cell is not None:
                    # Patch memory
                    time_start = time.time()
                    vanilla_memory_cell.memory = self.memory_cell.memory
                    
                    # Copy weights
                    for idx in range(len(vanilla_memory_cell.layers)):
                        if hasattr(self.grouped_layer, 'W_mem'):
                            vanilla_memory_cell.layers[idx].W_mem = self.grouped_layer.W_mem[idx].unsqueeze(0)
                        if hasattr(self.grouped_layer, 'z'):
                            vanilla_memory_cell.layers[idx].z = self.grouped_layer.z[idx].unsqueeze(0)
                        vanilla_memory_cell.layers[idx].first_seg = False
                    
                    time_end = time.time()
                    out = vanilla_memory_cell.generate(**final_segment, zero_mem=False, **generate_kwargs)
                    vanilla_memory_cell.zero_mem()
                    copy_time = time_end - time_start
                    return out, copy_time
            
            # Use memory cell for generation
            if hasattr(self.memory_cell, 'generate'):
                out = self.memory_cell.generate(**final_segment, zero_mem=False, **generate_kwargs)
                if hasattr(self.memory_cell, 'zero_mem'):
                    self.memory_cell.zero_mem()
                return out, 0
        else:
            # Process inputs directly
            segmented = self.segment_fn(input_ids=input_ids, attention_mask=attention_mask)
            final_segment = segmented[-1]
            
            # Use vanilla model for generation
            if self.vanilla_model is not None:
                vanilla_memory_cell = get_module_by_path(self.vanilla_model, self.memory_path)
                if vanilla_memory_cell is not None:
                    out = vanilla_memory_cell.generate(**final_segment, zero_mem=False, **generate_kwargs)
                    vanilla_memory_cell.zero_mem()
                    return out, 0
            
            # Use memory cell for generation
            if hasattr(self.memory_cell, 'generate'):
                out = self.memory_cell.generate(**final_segment, zero_mem=False, **generate_kwargs)
                if hasattr(self.memory_cell, 'zero_mem'):
                    self.memory_cell.zero_mem()
                return out, 0
        
        # Fallback
        raise ValueError("Could not generate output - no suitable generation method found")
    
    def to(self, device):
        """Move model to device."""
        self.model.to(device)
        self.grouped_layer.to(device)
        if self.vanilla_model is not None:
            self.vanilla_model.to(device)
        return self
    
    def eval(self):
        """Set model to evaluation mode."""
        self.model.eval()
        self.grouped_layer.eval()
        if self.vanilla_model is not None:
            self.vanilla_model.eval()
        return self
    
    def train(self):
        """Set model to training mode."""
        self.model.train()
        self.grouped_layer.train()
        if self.vanilla_model is not None:
            self.vanilla_model.train()
        return self 

In [21]:
source_model

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-11): 12 x AssociativeLayerWrapper(
        (W_mq): Linear(in_features=768, out_features=64, bias=False)
        (W_mk): Linear(in_features=768, out_features=64, bias=False)
        (W_mv): Linear(in_features=768, out_features=768, bias=False)
        (W_mb): Linear(in_features=768, out_features=1, bias=True)
        (layer): GPTNeoBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPTNeoAttention(
            (attention): GPTNeoFlashAttention2(
              (attn_dropout): Dropout(p=0.0, inplace=False)
              (resid_dropout): Dropout(p=0.0, inplace=False)
              (k_proj): Linear(in_features=768, out_features=768, bias=False)
              (v_proj): Linear(in_features=768, out_features=768, bias=False)
              (q_proj): Linear(in_features

In [22]:
armt_grouped_model, source_model_layers = make_universal_grouped_model(
    armt_model, 
    grouped_layer,
    layers_path="memory_cell.model.transformer.h"
)

def preprocess_segment(model, input_segment):
    pos_ids = torch.arange(
        0, 
        input_segment.shape[0], 
        device=input_segment.device
    )
    wpe_emb = model.memory_cell.model.transformer.wpe(pos_ids)
    input_segment.add_(wpe_emb)
    return input_segment

def postprocess_segment(model, output_segment):
    out = model.memory_cell.model.transformer.ln_f(output_segment)
    out_tokens = model.memory_cell.model.lm_head(out)
    return out_tokens

executor = UniversalGroupedExecutor(
    model=armt_grouped_model,
    grouped_layer=grouped_layer,
    context=grouped_context,
    n_layers=source_model.config.num_hidden_layers,
    model_path="memory_cell.model.transformer",
    out_norm_attr="memory_cell.model.transformer.ln_f",
    lm_head_path="memory_cell.model.lm_head",
    memory_path="memory_cell",
    preprocess_segment_fn = preprocess_segment,
    postprocess_segment_fn = postprocess_segment,
)


# executor = FastGroupedArmtExecutor(
#     armt_grouped_model, 
#     grouped_layer, 
#     grouped_context, 
#     source_model.config.num_hidden_layers
# )

In [23]:
executor.model.memory_cell.model.lm_head

Linear(in_features=768, out_features=50257, bias=False)

In [24]:
### ONLY FOR FAST LATENCY VERSION

# compile full layers
segments_input = torch.rand((source_model.config.num_hidden_layers, armt_config['segment_size'], source_model.config.hidden_size), device="cuda", dtype=dtype)

i, j = 0, source_model.config.num_hidden_layers
grouped_context.start_idx = i
grouped_context.end_idx = j
grouped_context.is_full = True

ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])
grouped_layer.generate_mode = True
armt_grouped_model.memory_cell.model.transformer(inputs_embeds=segments_input[i:j], use_cache=False)
update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])

# del ao
# del segments_input

jit compile As: [torch.Size([1024, 768]), torch.Size([1024, 768]), torch.Size([1024, 768]), torch.Size([1024, 768]), torch.Size([1024, 768]), torch.Size([1024, 768]), torch.Size([1024, 768]), torch.Size([1024, 768]), torch.Size([1024, 768]), torch.Size([1024, 768]), torch.Size([1024, 768]), torch.Size([1024, 768])] Bs: [torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64]), torch.Size([768, 64])]

// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8
using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =
  typename cutlass::gemm::kernel::DefaultGemmGrouped<
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
 

In [25]:
# 4096, 8192, 16384, 32768, 65536, 131072
num_segments = 16384//armt_config["segment_size"]
input_ids = torch.randint(
    0, 10000, 
    (1, num_segments*armt_config["segment_size"]), 
    dtype=torch.long, 
    device="cuda"
)

In [26]:
input_ids.shape

torch.Size([1, 16384])

In [27]:
%%time
# %%timeit

torch.manual_seed(0)

with torch.no_grad():
    # armt_reference_model.memory_cell.zero_mem()
    armt_reference_model.memory_cell.generate_mode(False)
    reference_output = armt_reference_model.forward(input_ids)

torch.cuda.synchronize()

CPU times: user 846 ms, sys: 40.6 ms, total: 886 ms
Wall time: 380 ms


In [28]:
%%time
# %%timeit

torch.manual_seed(0)

with torch.no_grad():
    output = executor.forward(input_ids, skip_concat=False)

torch.cuda.synchronize()

CPU times: user 152 ms, sys: 16 ms, total: 168 ms
Wall time: 167 ms


In [29]:
torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)

tensor(0.0325, device='cuda:0')

In [30]:
output.logits.shape, reference_output.logits.shape

(torch.Size([16384, 50257]), torch.Size([1, 16384, 50257]))

In [31]:
missed_tokens = (output.logits.argmax(1) != reference_output.logits.argmax(2)[0]).sum()

In [32]:
missed_tokens, output.logits.shape[0], missed_tokens/output.logits.shape[0]

(tensor(2242, device='cuda:0'), 16384, tensor(0.1367, device='cuda:0'))