In [1]:
import sys


In [2]:
sys.path.append("/home/jovyan/sivtsov/associative-recurrent-memory-transformer")
sys.path.append("/home/jovyan/sivtsov/armt")

In [3]:
import copy
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

from grouped_batching.llama1b_grouping import wrap_model_with_armt, get_grouped_states, make_grouped_layer_from_single_layer, make_grouped_model_from_naive
from grouped_batching.batching import GroupedBatcher
from grouped_batching.executor import ArmtGroupedExecutor

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
torch.set_default_device("cuda:1")

In [5]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
# torch.set_grad_enabled(False)
;

''

In [6]:
source_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B"
                                             , attn_implementation="sdpa"
                                            # , attn_implementation="flash_attention_2"
                                             ,torch_dtype=dtype)
source_model.eval()
source_model.lm_head = torch.nn.Identity()
reference_model = copy.deepcopy(source_model)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

In [7]:
model_config = source_model.config

In [8]:
armt_config = dict(
    segment_size=32,
    num_mem_tokens=16,
    d_mem=64,
)

In [9]:
armt_model = wrap_model_with_armt(source_model, **armt_config)
armt_model.to("cuda")

armt_reference_model = wrap_model_with_armt(reference_model, **armt_config)
armt_reference_model.to("cuda")
;

''

In [10]:
model_cpt = "/home/jovyan/.cache/huggingface/hub/models--irodkin--ARMT-llama3.2-1B/snapshots/746e74bba3edc4cb3eaa11e13df5d900495e2300/armt_llama3.2-1B_step19500.bin"
cpt = torch.load(model_cpt, map_location='cuda')

  cpt = torch.load(model_cpt, map_location='cuda')


In [11]:
armt_model.load_state_dict(cpt, strict=False)
armt_reference_model.load_state_dict(cpt, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['memory_cell.model.lm_head.weight'])

In [12]:
from grouped_batching.llama1b_grouping_autograd import make_grouped_training_layer_from_single_layer

In [13]:
grouped_layer =make_grouped_training_layer_from_single_layer(
    copy.deepcopy(armt_model.memory_cell.model.model.layers[0]),
    armt_model.memory_cell.model.model.layers
)
armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)

In [14]:
batcher = GroupedBatcher(
    armt_grouped_model, 
    n_layers=model_config.num_hidden_layers, 
    seg_size=armt_config["segment_size"]+armt_config["num_mem_tokens"], 
    hid_dim=model_config.hidden_size, 
    pos_embed_dim=model_config.hidden_size
)
executor = ArmtGroupedExecutor(armt_grouped_model, grouped_layer, batcher)


In [15]:
torch.cuda.empty_cache()

In [16]:
num_segments = 4
input_ids = torch.randint(
    0, 5000, 
    (1, num_segments*armt_config["segment_size"]), 
    dtype=torch.long, 
    device="cuda"
)


In [17]:
armt_reference_model.memory_cell.zero_mem()
armt_reference_model.zero_grad()
reference_output = armt_reference_model.forward(input_ids)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


In [18]:
executor.armt_model.zero_grad()
output = executor.forward(input_ids)

GROUPED GEMM dtype: torch.bfloat16

// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8
using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =
  typename cutlass::gemm::kernel::DefaultGemmGrouped<
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::bfloat16_t, cutlass::layout::RowMajor,
    float,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<256, 128, 64>,
    cutlass::gemm::GemmShape<64, 64, 64>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    cutlass::epilogue::thread::LinearCombination<cutlass::bfloat16_t, 8, float, float>,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
    3,
    cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
    cutlass::arch::OpMultiplyAdd
>::GemmKernel;

// Define named type
struct cutlass_tensoro

In [19]:
output.logits

tensor([[[-0.2656,  5.3750, -0.7305,  ..., -0.8438, -2.3281,  1.3984],
         [-0.2236,  0.8672,  1.2109,  ...,  0.2539, -1.6250, -0.2637],
         [ 1.1328,  3.4844,  3.9219,  ..., -0.2695, -2.6719, -0.1152],
         ...,
         [-1.0938,  0.8555, -0.7461,  ..., -1.6797,  1.2266,  0.5273],
         [-0.5703,  4.5625,  0.5469,  ..., -0.0330, -3.0312,  0.9961],
         [-0.3828,  3.5000, -1.2109,  ...,  3.4375, -0.8320, -3.6094]]],
       device='cuda:0', grad_fn=<CatBackward0>)

In [20]:
reference_output.logits

tensor([[[-0.2812,  5.4688, -0.5781,  ..., -0.8594, -2.2812,  1.5234],
         [-0.2236,  0.9023,  1.2031,  ...,  0.2188, -1.7031, -0.2832],
         [ 1.1875,  3.4219,  3.9062,  ..., -0.3008, -2.7031, -0.1069],
         ...,
         [-1.1875,  0.8711, -0.8047,  ..., -1.2891,  1.2891,  0.5508],
         [-0.4570,  4.2812,  0.5547,  ...,  0.1279, -2.8594,  1.2109],
         [-0.6992,  3.9062, -1.2891,  ...,  2.9062, -0.8242, -3.4844]]],
       device='cuda:0', dtype=torch.float32, grad_fn=<CatBackward0>)

In [21]:
torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)

tensor(0.0970, device='cuda:0', dtype=torch.float32, grad_fn=<DivBackward0>)

In [22]:
output.logits.sum().backward(retain_graph=True)

GROUPED GEMM dtype: torch.bfloat16

// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8
using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =
  typename cutlass::gemm::kernel::DefaultGemmGrouped<
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::bfloat16_t, cutlass::layout::RowMajor,
    float,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<256, 128, 64>,
    cutlass::gemm::GemmShape<64, 64, 64>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    cutlass::epilogue::thread::LinearCombination<cutlass::bfloat16_t, 8, float, float>,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
    3,
    cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
    cutlass::arch::OpMultiplyAdd
>::GemmKernel;

// Define named type
struct cutlass_tensoro

In [23]:
reference_output.logits.sum().backward(retain_graph=True)

In [40]:
reference_model.model.layers[-1].layer.mlp.down_proj.weight.grad

tensor([[-1.2500e+01, -6.9500e+01, -1.7188e+00,  ...,  3.0938e+00,
         -8.9453e-01,  2.0938e+00],
        [-1.1250e+01, -6.8000e+01, -2.4844e+00,  ...,  3.2969e+00,
         -1.1094e+00,  1.8125e+00],
        [-7.5312e+00, -4.2500e+01, -1.3750e+00,  ...,  1.8594e+00,
         -4.7070e-01,  1.1875e+00],
        ...,
        [-1.3500e+01, -8.6000e+01, -3.4062e+00,  ...,  4.6562e+00,
         -1.6016e+00,  2.1875e+00],
        [-1.1750e+01, -5.7750e+01, -6.6406e-02,  ...,  2.0312e+00,
         -4.5312e-01,  2.0312e+00],
        [-1.2000e+01, -6.3250e+01, -8.1641e-01,  ...,  2.6406e+00,
         -7.8516e-01,  2.0312e+00]], device='cuda:0')

In [39]:
grouped_layer.layer.mlp.down_proj.wg[-1].grad

tensor([[-12.8750, -11.6250,  -7.8125,  ..., -13.8750, -12.1250, -12.5000],
        [-73.5000, -72.0000, -46.2500,  ..., -88.5000, -60.5000, -67.0000],
        [ -1.2578,  -1.9688,  -1.0234,  ...,  -2.8125,   0.4414,  -0.2188],
        ...,
        [  4.0625,   4.0625,   2.5000,  ...,   5.3438,   2.9375,   3.4531],
        [ -1.4844,  -1.6094,  -0.8672,  ...,  -2.0781,  -1.0000,  -1.3281],
        [  1.5938,   1.4062,   0.9336,  ...,   1.7188,   1.5078,   1.5625]],
       device='cuda:0')