In [1]:
import sys
import copy
sys.path.append("./associative-recurrent-memory-transformer")
sys.path.append("./")
sys.path.append("../")
from tqdm import tqdm

import logging
import math
import os
from pathlib import Path
from itertools import chain
from datetime import timedelta
import torch
import datasets
import numpy as np
from torch.utils.data import DataLoader
from transformers import set_seed

from torch.nn.utils.rnn import pad_sequence
from peft import get_peft_model, LoraConfig, TaskType
from babilong.babilong_utils import TaskDataset, SentenceSampler, NoiseInjectionDataset

from grouped_batching.batching import GroupedBatcher
from grouped_batching.executor import ArmtGroupedExecutor
from grouped_batching.fast_executor import FastGroupedArmtExecutor, GroupedLayerContext, associate_with_context, update_mem_with_context
from grouped_batching.llama1b_grouping import (
    wrap_model_with_armt, get_grouped_states, 
    make_grouped_layer_from_single_layer, make_grouped_model_from_naive,
    make_grouped_sliced_layer_from_single_layer
)

torch.autograd.set_detect_anomaly(True)
from grouped_batching.llama1b_grouping_autograd import make_grouped_training_layer_from_single_layer, make_grouped_sliced_training_layer_from_single_layer
from modeling_amt.language_modeling import AssociativeRecurrentWrapper, AssociativeMemoryCell

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
def prepare_opt_armt_train(original_model, segment_size):
    # merge lora back to model
    merge_and_save = True
    unmerge_and_save = False
    if merge_and_save or unmerge_and_save:
        if merge_and_save:
            original_model.memory_cell.model.merge_and_unload()
        if unmerge_and_save:
            original_model.memory_cell.model.unload()
        original_model.memory_cell.model = original_model.memory_cell.model.base_model.model
    # transform model into grouped version
    dtype = torch.bfloat16
    device = "cuda"
    original_model = original_model.to(dtype)
    grouped_context = GroupedLayerContext()
    grouped_context.is_training = True
    # grouped_states = get_grouped_states(armt_model)
    grouped_layer = make_grouped_sliced_training_layer_from_single_layer(
        grouped_context, copy.deepcopy(original_model.memory_cell.model.model.layers[0]), original_model.memory_cell.model.model.layers
    )
    grouped_layer = grouped_layer.to(dtype)
    grouped_layer = grouped_layer.to(device)
    armt_grouped_model, source_model_layers = make_grouped_model_from_naive(original_model, grouped_layer)
    armt_grouped_model.to(device)
    executor = FastGroupedArmtExecutor(
        armt_grouped_model, 
        grouped_layer, 
        grouped_context, 
        16,#model_config.num_hidden_layers, 
        vanilla_armt_model=original_model,
    )
    ### ONLY FOR FAST LATENCY VERSION
    # compile full layers
    segments_input = torch.rand((16, segment_size, 2048), device="cuda", dtype=dtype)
    i, j = 0, 16
    grouped_context.start_idx = i
    grouped_context.end_idx = j
    grouped_context.is_full = True
    
    ao = associate_with_context(grouped_layer, grouped_context, segments_input[i:j])
    grouped_layer.generate_mode = True
    armt_grouped_model.memory_cell.model.model(inputs_embeds=segments_input[i:j], use_cache=False)
    update_mem_with_context(grouped_layer, grouped_context, segments_input[i:j])
    return executor

In [4]:
armt_cpt_path = "../../data/pretrained_models/RMT-Llama-3.2-1B-Instruct-8x1024-mem16-lora-babilong-qa1-5_ct-v3.1/model.safetensors"

In [5]:
torch.set_default_device("cuda:0")

In [6]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)

In [7]:
# load base model
source_model = AutoModelForCausalLM.from_pretrained("unsloth/Llama-3.2-1B-Instruct",
                                                    attn_implementation="flash_attention_2",
                                                    torch_dtype=dtype,
                                                    device_map="cpu")
source_model.eval()
#source_model.lm_head = torch.nn.Identity()
#reference_model = copy.deepcopy(source_model)

tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct")

In [8]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    inference_mode=True, 
    r=8, 
    lora_alpha=32, 
    lora_dropout=0.1,
    )
source_model = get_peft_model(source_model, peft_config)

In [9]:
# after wrap base model in original ARMT and ARMT with grouped batching, and load pretrained weigths
# the actual segment_size for this model is segment_size - mem_size, so we will use it later
segment_size = 1024
mem_size = 16
segment_alignment = "left"
attend_to_previous_input = False
device = "cpu"
max_n_segments = 32
mem_cell_args = dict(
    base_model=source_model,
    num_mem_tokens=mem_size,
    d_mem=64,
    layers_attr="model.model.layers",
    wrap_pos=False,
    correction=True,
)

cell = AssociativeMemoryCell(**mem_cell_args)
original_model = AssociativeRecurrentWrapper(cell,
                                            segment_size=segment_size,
                                            max_n_segments=max_n_segments,
                                            vary_n_segments=True,
                                            k2=-1,
                                            return_all_logits=False,
).to(device)

if "safetensors" in armt_cpt_path:
    from safetensors.torch import load_model
    load_model(original_model, armt_cpt_path, device="cuda:0")
else:
    cpt = torch.load(armt_cpt_path, map_location=device)
    original_model.load_state_dict(cpt, strict=True)
original_model = original_model.to("cuda")

In [10]:
# define params
# load and prepare data
output_dir = "./runs/test/babilong_multitask/train_test"
# set current working dir
working_dir = "./associative-recurrent-memory-transformer"
working_dir = str(Path(working_dir).expanduser().absolute())
os.chdir(working_dir)
seed = 1
set_seed(seed)
# set bfloat16 for compatibility with grouped version
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
model_name = "unsloth/Llama-3.2-1B-Instruct"
noise_dataset = "wikitext"
noise_dataset_split = "wikitext-103-raw-v1"
babi_path = "../babilong/data/tasks_1-20_v1-2/en-10k"
max_n_facts = 800
segment_size = 1024
max_n_segments = 8
sample_size = segment_size * max_n_segments
learning_rate = 1e-05
task_start_pct = None
task_end_pct = None
mixed_length_ratio = 0.0

In [11]:
model = prepare_opt_armt_train(original_model, segment_size-mem_size)


// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8
using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =
  typename cutlass::gemm::kernel::DefaultGemmGrouped<
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::bfloat16_t, cutlass::layout::RowMajor,
    float,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<256, 128, 64>,
    cutlass::gemm::GemmShape<64, 64, 64>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    cutlass::epilogue::thread::LinearCombination<cutlass::bfloat16_t, 8, float, float>,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
    3,
    cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
    cutlass::arch::OpMultiplyAdd
>::GemmKernel;

// Define named type
struct cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x

In [12]:
noise_dataset = datasets.load_dataset(noise_dataset, noise_dataset_split)
noise_dataset_train = noise_dataset['train']
noise_dataset_test = noise_dataset['test']

# task dataset 
task_datasets = ["qa1_single-supporting-fact"]
train_paths = [os.path.join(babi_path, f"{td}_train.txt") for td in task_datasets]
test_paths = [os.path.join(babi_path, f"{td}_test.txt") for td in task_datasets]

task_dataset_train = TaskDataset(train_paths[0], max_n_facts=max_n_facts)
task_dataset_test = TaskDataset(test_paths[0], max_n_facts=max_n_facts)

# background text
qa_margin = 70          # leave space for questions and answers
train_sample_size = [int(segment_size * i) for i in range(1, max_n_segments)] + [sample_size]
train_sample_size = [s - qa_margin for s in train_sample_size]

test_sample_size = sample_size - qa_margin
max_sentence_len = None
if (task_start_pct is not None) and (task_end_pct is not None):
    # do not sample sentences longer than task position range * 0.5
    max_sentence_len = int((task_end_pct - task_start_pct) * 0.5 * sample_size)
    
noise_sampler_train = SentenceSampler(noise_dataset_train, tokenizer=tokenizer, max_sentence_len=max_sentence_len, shuffle=True, random_seed=None)
noise_sampler_test = SentenceSampler(noise_dataset_test, tokenizer=tokenizer, max_sentence_len=max_sentence_len, shuffle=True, random_seed=42)

train_dataset = NoiseInjectionDataset(task_dataset=task_dataset_train,
                                        noise_sampler=noise_sampler_train,
                                        tokenizer=tokenizer,
                                        sample_size=train_sample_size,
                                        mixed_length_ratio=mixed_length_ratio,
                                        task_start_pct=task_start_pct,
                                        task_end_pct=task_end_pct
                                        )

test_dataset = NoiseInjectionDataset(task_dataset=task_dataset_test,
                                        noise_sampler=noise_sampler_test,
                                        tokenizer=tokenizer,
                                        sample_size=test_sample_size,
                                        mixed_length_ratio=mixed_length_ratio,
                                        task_start_pct=task_start_pct,
                                        task_end_pct=task_end_pct
                                        )

id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
gen_token = tokenizer.encode('GEN')[0]
eos_token = tokenizer.eos_token_id

In [13]:
def get_input_ids(sample):
    template = "{} {}Answer with a single word."
    context = tokenizer.decode(sample['input_tokens'])
    messages = [
        {"role": "user", "content": template.format(context, sample['question'])},
        {"role": "assistant", "content": sample['answer']}
    ]
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=False
    )
    return input_ids

In [14]:
def collate_fn(batch):
    inputs = [get_input_ids(sample) for sample in batch]
    input_ids = [torch.tensor(i) for i in inputs]
    attention_mask = [torch.ones_like(b, dtype=bool) for b in input_ids]

    input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)
    attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)

    collated = {}
    collated['input_ids'] = collated['labels'] = input_ids
    collated['attention_mask'] = attention_mask.bool()

    return collated

In [15]:
# run train
def train_one_epoch(training_loader, segment_size, max_steps=-1, device="cuda"):
    running_loss = 0.
    last_loss = 0.
    for i, data in tqdm(enumerate(training_loader)):
        # Every data instance is an input + label pair
        #inputs, labels = data["input_ids"][...,:-1].to(device), data["labels"][...,1:].to(device)
        inputs, labels = data["input_ids"].to(device), data["labels"].to(device)
        # pad each sample to the segm_num*segm_size, cause model cannot handle other samples
        
        pad_shape = (segment_size - inputs.shape[-1] % (segment_size), 0)
        # print(inputs.shape, pad_shape)
        inputs = torch.nn.functional.pad(inputs.squeeze(), pad_shape).unsqueeze(0)
        labels = torch.nn.functional.pad(labels.squeeze(), pad_shape, value=-100).unsqueeze(0)
        # print(inputs, labels)
        # print(inputs.shape, labels.shape)
        
        # Zero your gradients for every batch
        optimizer.zero_grad()
        
        # Make predictions for this batch
        outputs = model(inputs)

        outputs = outputs.logits
        # shift predicted and labels
        labels = labels[..., 1:]
        outputs = outputs[..., :-1, :]
        # flatten and calc loss
        labels = labels.view(-1)
        outputs = outputs.view(-1, outputs.size(-1))
        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()
        
        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 10 == 0:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = i + 1
            print('Loss/train', last_loss, tb_x)
            running_loss = 0.
        if i > max_steps:
            break

    return last_loss

In [16]:
set_seed(421)
device = "cuda"
train_dl = DataLoader(train_dataset, batch_size=1,
                      shuffle=False, collate_fn=collate_fn)

loss_fn = torch.nn.CrossEntropyLoss().to(device)
#optimizer = torch.optim.AdamW(list(model.armt_model.parameters()) + list(model.grouped_layer.parameters()), lr=args.learning_rate)
optimizer = torch.optim.AdamW(model.armt_model.parameters(), lr=learning_rate)
# Quick fix, only for BS=1, w/o any mask and normal padding, etc.
for i in range(1):
    train_one_epoch(train_dl, segment_size, max_steps=100, device=device)
    # save model
    os.makedirs(output_dir, exist_ok=True)
    torch.save(model.grouped_layer.state_dict(), os.path.join(output_dir, "grouped_layer.pth"))
    torch.save(model.armt_model.state_dict(), os.path.join(output_dir, "armt_model.pth"))

# save model
os.makedirs(output_dir, exist_ok=True)
torch.save(model.grouped_layer.state_dict(), os.path.join(output_dir, "grouped_layer.pth"))
torch.save(model.armt_model.state_dict(), os.path.join(output_dir, "armt_model.pth"))

0it [00:00, ?it/s]


// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8
using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =
  typename cutlass::gemm::kernel::DefaultGemmGrouped<
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, 8,
    cutlass::bfloat16_t, cutlass::layout::RowMajor,
    float,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<256, 128, 64>,
    cutlass::gemm::GemmShape<64, 64, 64>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    cutlass::epilogue::thread::LinearCombination<cutlass::bfloat16_t, 8, float, float>,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>,
    3,
    cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly,
    cutlass::arch::OpMultiplyAdd
>::GemmKernel;

// Define named type
struct cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x

1it [00:02,  2.42s/it]

  batch 1 loss: 0.0119375
Loss/train 0.0119375 1


11it [00:18,  1.56s/it]

  batch 11 loss: 0.06053125
Loss/train 0.06053125 11


21it [00:35,  1.53s/it]

  batch 21 loss: 0.05021875
Loss/train 0.05021875 21


31it [00:50,  1.39s/it]

  batch 31 loss: 0.046125
Loss/train 0.046125 31


41it [01:06,  1.60s/it]

  batch 41 loss: 0.04334375
Loss/train 0.04334375 41


51it [01:23,  1.55s/it]

  batch 51 loss: 0.042078125
Loss/train 0.042078125 51


61it [01:38,  1.44s/it]

  batch 61 loss: 0.040796875
Loss/train 0.040796875 61


71it [01:54,  1.65s/it]

  batch 71 loss: 0.040578125
Loss/train 0.040578125 71


81it [02:12,  1.73s/it]

  batch 81 loss: 0.0395
Loss/train 0.0395 81


91it [02:26,  1.57s/it]

  batch 91 loss: 0.038859375
Loss/train 0.038859375 91


101it [02:41,  1.45s/it]

  batch 101 loss: 0.039015625
Loss/train 0.039015625 101


101it [02:43,  1.62s/it]
