In [1]:
import sys


In [2]:
sys.path.append("/home/jovyan/gkuzmin/rmt_it/optimized_armt/grouped_batching/associative-recurrent-memory-transformer")
sys.path.append("/home/jovyan/gkuzmin/rmt_it/optimized_armt/grouped_batching")
sys.path.append("/home/jovyan/gkuzmin/rmt_it/optimized_armt")

In [3]:
import cutlass

In [4]:
import copy
import torch

from transformers import AutoTokenizer, AutoModelForCausalLM

from grouped_batching.llama1b_grouping import wrap_model_with_armt, get_grouped_states, make_grouped_layer_from_single_layer, make_grouped_model_from_naive
from grouped_batching.batching import GroupedBatcher
from grouped_batching.executor import ArmtGroupedExecutor
# switch to fast version for generation
#from grouped_batching.fast_executor import FastGroupedArmtExecutor, GroupedLayerContext, associate_with_context, update_mem_with_context
#from grouped_batching.llama1b_grouping_autograd import make_grouped_training_layer_from_single_layer

In [5]:
from modeling_amt.language_modeling_old import AssociativeRecurrentWrapper, AssociativeMemoryCell

In [6]:
#!wget "https://huggingface.co/AIRI-NLP/ARMT-Llama3-1b-Instruct-2x1024-v2/resolve/main/pytorch_model.bin"

In [7]:
armt_cpt_path = "../../data/pretrained_models/RMT-Llama-3.2-1B-Instruct-8x1024-mem16-lora-babilong-qa1-5_ct-v3.1/model.safetensors"

In [8]:
torch.set_default_device("cuda:0")

In [9]:
dtype = torch.bfloat16
torch.set_default_dtype(dtype)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f1de7fea620>

In [10]:
# load base model
source_model = AutoModelForCausalLM.from_pretrained("unsloth/Llama-3.2-1B-Instruct",
                                                    attn_implementation="flash_attention_2",
                                                    torch_dtype=dtype,
                                                    device_map="cpu")
source_model.eval()
#source_model.lm_head = torch.nn.Identity()
#reference_model = copy.deepcopy(source_model)

tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct")

In [11]:
from peft import get_peft_model, LoraConfig, TaskType

In [12]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    inference_mode=True, 
    r=8, 
    lora_alpha=32, 
    lora_dropout=0.1,
    )
source_model = get_peft_model(source_model, peft_config)

In [13]:
# after wrap base model in original ARMT and ARMT with grouped batching, and load pretrained weigths
# the actual segment_size for this model is segment_size - mem_size, so we will use it later
segment_size = 1024
mem_size = 16
segment_alignment = "left"
attend_to_previous_input = False
device = "cpu"
max_n_segments = 32
mem_cell_args = dict(
    base_model=source_model,
    num_mem_tokens=mem_size,
    d_mem=64,
    layers_attr="model.model.layers",
    wrap_pos=False,
    correction=True,
)

cell = AssociativeMemoryCell(**mem_cell_args)
original_model = AssociativeRecurrentWrapper(cell,
                                            segment_size=segment_size-mem_size,
                                            max_n_segments=max_n_segments,
                                            segment_alignment=segment_alignment,
                                            attend_to_previous_input=attend_to_previous_input,
).to(device)

if "safetensors" in armt_cpt_path:
    from safetensors.torch import load_model
    load_model(original_model, armt_cpt_path, device="cuda:0")
else:
    cpt = torch.load(armt_cpt_path, map_location=device)
    original_model.load_state_dict(cpt, strict=True)
original_model.to("cuda")

AssociativeRecurrentWrapper(
  (memory_cell): AssociativeMemoryCell(
    (model): PeftModelForCausalLM(
      (base_model): LoraModel(
        (model): LlamaForCausalLM(
          (model): LlamaModel(
            (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
            (layers): ModuleList(
              (0-15): 16 x AssociativeLayerWrapper(
                (W_mq): Linear(in_features=2048, out_features=64, bias=False)
                (W_mk): Linear(in_features=2048, out_features=64, bias=False)
                (W_mv): Linear(in_features=2048, out_features=2048, bias=False)
                (W_mb): Linear(in_features=2048, out_features=1, bias=True)
                (layer): LlamaDecoderLayer(
                  (self_attn): LlamaAttention(
                    (q_proj): lora.Linear(
                      (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.1, inp

In [14]:
# merge lora
from safetensors.torch import save_model
merge_and_save = True
unmerge_and_save = False
if merge_and_save or unmerge_and_save:
    if merge_and_save:
        original_model.memory_cell.model.merge_and_unload()
    if unmerge_and_save:
        original_model.memory_cell.model.unload()
    original_model.memory_cell.model = original_model.memory_cell.model.base_model.model


In [15]:
armt_model = copy.deepcopy(original_model)
grouped_states = get_grouped_states(armt_model)
grouped_layer = make_grouped_layer_from_single_layer(
    copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)
# grouped_layer._grouped_execution = True
# grouped_layer._skip_associating = True
armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)


In [16]:
model_config = source_model.config

In [17]:
batcher = GroupedBatcher(
    armt_grouped_model, 
    n_layers=model_config.num_hidden_layers, 
    seg_size=segment_size, 
    hid_dim=model_config.hidden_size, 
    pos_embed_dim=model_config.hidden_size
)
executor = ArmtGroupedExecutor(armt_grouped_model, grouped_layer, batcher)


In [30]:
executor

ArmtGroupedExecutor(
  (armt_model): AssociativeRecurrentWrapper(
    (memory_cell): AssociativeMemoryCell(
      (model): LlamaForCausalLM(
        (model): LlamaModel(
          (embed_tokens): Embedding(128256, 2048, padding_idx=128004)
          (layers): ModuleList(
            (0): AssociativeLayerWrapper(
              (W_mq): Linear(in_features=2048, out_features=64, bias=False)
              (W_mk): Linear(in_features=2048, out_features=64, bias=False)
              (W_mv): Linear(in_features=2048, out_features=2048, bias=False)
              (W_mb): Linear(in_features=2048, out_features=1, bias=True)
              (layer): LlamaDecoderLayer(
                (self_attn): LlamaAttention(
                  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
                  (k_proj): Linear(in_features=2048, out_features=512, bias=False)
                  (v_proj): Linear(in_features=2048, out_features=512, bias=False)
                  (o_proj): Linear(in_feature

## Test 1 - check on random ids

In [18]:
torch.cuda.empty_cache()

In [19]:
num_segments = 10
input_ids = torch.randint(
    0, 5000, 
    (1, num_segments*(segment_size-mem_size)), 
    dtype=torch.long, 
    device="cuda"
)


In [20]:
original_model.memory_cell.zero_mem()
reference_output = original_model.forward(input_ids)

In [21]:
output = executor.forward(input_ids)

jit compile As: [torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048]), torch.Size([1024, 2048])] Bs: [torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64]), torch.Size([2048, 64])]

// Gemm operator cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8
using cutlass_tensorop_bf16_s16816gemm_grouped_bf16_256x128_64x3_tt_align8_base =
  typenam

In [22]:
output.logits

tensor([[[ 4.5625,  5.3125,  7.0312,  ..., -2.2969, -2.2969, -2.2969],
         [ 7.2500,  8.1250,  3.5938,  ..., -1.2109, -1.2109, -1.2109],
         [ 5.5625,  6.8750,  3.8594,  ..., -0.5156, -0.5156, -0.5156],
         ...,
         [-0.0525,  3.1719,  1.2109,  ...,  0.8320,  0.8320,  0.8320],
         [ 1.5547,  4.5625,  1.2422,  ..., -1.0547, -1.0547, -1.0547],
         [ 0.6406,  3.4531,  0.2910,  ..., -1.5625, -1.5625, -1.5625]]],
       device='cuda:0')

In [23]:
reference_output.logits

tensor([[[ 4.5625,  5.3125,  7.0312,  ..., -2.2969, -2.2969, -2.2969],
         [ 7.2500,  8.1250,  3.5938,  ..., -1.2109, -1.2109, -1.2109],
         [ 5.5625,  6.8750,  3.8594,  ..., -0.5156, -0.5156, -0.5156],
         ...,
         [-0.0942,  3.1406,  1.1641,  ...,  0.8945,  0.8945,  0.8945],
         [ 1.5312,  4.5625,  1.2031,  ..., -1.0078, -1.0078, -1.0078],
         [ 0.6680,  3.4531,  0.2715,  ..., -1.6094, -1.6094, -1.6094]]],
       device='cuda:0')

In [24]:
torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)

tensor(0.0187, device='cuda:0')

## Test 2 - check on some short text and gradually increase length

In [25]:
base_text = "The invention of the printing press by Johannes Gutenberg in the 15th century revolutionized the way information was shared. Before this, books were copied by hand, making them rare and expensive. The printing press allowed for faster and cheaper production of books, leading to a wider spread of knowledge. This innovation played a key role in the Renaissance, the Reformation, and the Scientific Revolution by making texts more accessible to the general public."
stacked_text = " ".join([base_text]*2000).strip()
messages = [
    {"role": "user", "content": stacked_text}
]
input_ids = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=False,
    return_tensors="pt"
)
input_ids = input_ids[..., :2*(segment_size-mem_size)]
input_ids.shape

Token indices sequence length is longer than the specified maximum sequence length for this model (172031 > 131072). Running this sequence through the model will result in indexing errors


torch.Size([1, 2016])

In [26]:
segments = [1, 2, 4, 8, 16, 32, 64, 128]
errors = []
for segm in segments:
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=False,
        return_tensors="pt"
    )
    input_ids = input_ids[..., :segm*(segment_size-mem_size)]
    original_model.memory_cell.zero_mem()
    reference_output = original_model.forward(input_ids)
    executor.armt_model.memory_cell.zero_mem()
    output = executor.forward(input_ids)
    diff = torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)
    print(f"Norm on text with {segm} segments: {diff}")
    torch.cuda.empty_cache()
    errors.append(diff)

Norm on text with 1 segments: 0.0
Norm on text with 2 segments: 0.010986328125
Norm on text with 4 segments: 0.014892578125
Norm on text with 8 segments: 0.0174560546875
Norm on text with 16 segments: 0.0189208984375
Norm on text with 32 segments: 0.0186767578125


OutOfMemoryError: CUDA out of memory. Tried to allocate 15.41 GiB. GPU 0 has a total capacity of 79.11 GiB of which 4.80 GiB is free. Including non-PyTorch memory, this process has 74.30 GiB memory in use. Of the allocated memory 69.29 GiB is allocated by PyTorch, and 4.36 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [31]:
import numpy as np
[np.round(el.float().cpu().numpy()*100, 2) for el in errors]

[np.float32(0.0),
 np.float32(1.1),
 np.float32(1.49),
 np.float32(1.75),
 np.float32(1.89),
 np.float32(1.87)]

In [29]:
for segm in range(1,17):
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=False,
        return_tensors="pt"
    )
    input_ids = input_ids[..., :segm*(segment_size-mem_size)]
    original_model.memory_cell.zero_mem()
    reference_output = original_model.forward(input_ids)
    executor.armt_model.memory_cell.zero_mem()
    output = executor.forward(input_ids)
    diff = torch.norm(output.logits-reference_output.logits)/torch.norm(reference_output.logits)
    print(f"Norm on text with {segm} segments: {diff}")

Norm on text with 1 segments: 0.0
Norm on text with 2 segments: 0.010498046875
Norm on text with 3 segments: 0.01336669921875
Norm on text with 4 segments: 0.01470947265625
Norm on text with 5 segments: 0.015625
Norm on text with 6 segments: 0.0164794921875
Norm on text with 7 segments: 0.0169677734375
Norm on text with 8 segments: 0.0174560546875
Norm on text with 9 segments: 0.017822265625
Norm on text with 10 segments: 0.01806640625
Norm on text with 11 segments: 0.018310546875
Norm on text with 12 segments: 0.0184326171875
Norm on text with 13 segments: 0.0185546875
Norm on text with 14 segments: 0.0186767578125
Norm on text with 15 segments: 0.018798828125
Norm on text with 16 segments: 0.018798828125


## Test 3 - check on BABILong qa1

In [58]:
armt_model = copy.deepcopy(original_model)
grouped_states = get_grouped_states(armt_model)
grouped_layer = make_grouped_layer_from_single_layer(
    copy.deepcopy(armt_model.memory_cell.model.model.layers[0]), *grouped_states)
# grouped_layer._grouped_execution = True
# grouped_layer._skip_associating = True
#grouped_layer.generate_mode = True
armt_grouped_model, source_model_layers = make_grouped_model_from_naive(armt_model, grouped_layer)


In [59]:
batcher = GroupedBatcher(
    armt_grouped_model, 
    n_layers=model_config.num_hidden_layers, 
    seg_size=segment_size, 
    hid_dim=model_config.hidden_size, 
    pos_embed_dim=model_config.hidden_size
)
executor = ArmtGroupedExecutor(armt_grouped_model, grouped_layer, batcher, original_model)


In [60]:
from babilong.prompts import DEFAULT_PROMPTS, DEFAULT_TEMPLATE, get_formatted_input
from tqdm import tqdm
import datasets
from pathlib import Path
import json
import pandas as pd

In [61]:
tasks = ["qa1", "qa2"]
split_names = ["2k", "4k", "8k"]
dataset_name = "RMT-team/babilong"
results_folder = "./test_res"
model_name = "unsloth/Llama-3.2-1B-Instruct"
use_instruction = True
use_examples = True
use_post_prompt = True
use_chat_template = True
api_url = False

In [62]:
model = executor
model_cpt = "bl_model_mem_code_fix_executor_mem_patch_armt-1b-it-v2"
model.name_or_path = "custom_rmt"
model.device = "cuda"

In [63]:
generate_kwargs = {
    'max_new_tokens': 20,
    'max_length': None,
    'num_beams': 1,
    'do_sample': False,
    'temperature': None,
    'top_p': None,
    'top_k': None,
    'pad_token_id': tokenizer.pad_token_id,
    'eos_token_id': tokenizer.eos_token_id,
    #'logits_processor': [NormLogitsWrapper()],
}

In [64]:
template_to_use = DEFAULT_TEMPLATE
print(f'prompt template:\n{template_to_use}')

for task in tqdm(tasks, desc='tasks'):
    # configure the prompt
    prompt_cfg = {
        'instruction': DEFAULT_PROMPTS[task]['instruction'] if use_instruction else '',
        'examples': DEFAULT_PROMPTS[task]['examples'] if use_examples else '',
        'post_prompt': DEFAULT_PROMPTS[task]['post_prompt'] if use_post_prompt else '',
        'template': template_to_use,
        'chat_template': use_chat_template,
    }
    prompt_name = [f'{k}_yes' if prompt_cfg[k] else f'{k}_no' for k in prompt_cfg if k != 'template']
    prompt_name = '_'.join(prompt_name)

    for split_name in tqdm(split_names, desc='lengths'):
        # load dataset
        data = datasets.load_dataset(dataset_name, split_name)
        task_data = data[task]#.select([1])

        # Prepare files with predictions, prompt, and generation configurations
        outfile = Path(f'{results_folder}/{model_name.replace("../", "")}/{model_cpt.replace("../", "")}/{task}_{split_name}_{prompt_name}.csv')
        outfile.parent.mkdir(parents=True, exist_ok=True)
        cfg_file = f'./{results_folder}/{model_name.replace("../", "")}/{model_cpt.replace("../", "")}/{task}_{split_name}_{prompt_name}.json'
        json.dump({'prompt': prompt_cfg, 'generate_kwargs': generate_kwargs}, open(cfg_file, 'w'), indent=4)

        df = pd.DataFrame({'target': [], 'output': [], 'question': []})

        for sample in tqdm(task_data, desc=f'task: {task} length: {split_name}'):
            target = sample['target']
            context = sample['input']
            question = sample['question']

            # format input text
            input_text = get_formatted_input(context, question, prompt_cfg['examples'],
                                             prompt_cfg['instruction'], prompt_cfg['post_prompt'],
                                             template=prompt_cfg['template'])
            if api_url:
                # model is running via llamacpp's serve command
                headers = {'Content-Type': 'application/json'}
                if generate_kwargs['temperature'] is None:
                    generate_kwargs['temperature'] = 0.0

                if use_chat_template:
                    input_text = [{'role': 'user', 'content': input_text}]
                    model_inputs = tokenizer.apply_chat_template(input_text, tokenize=True,
                                                                 add_generation_prompt=True)
                else:
                    model_inputs = tokenizer.encode(input_text, add_special_tokens=True)

                request_data = {'prompt': model_inputs, 'temperature': generate_kwargs['temperature']}
                response = requests.post(api_url, headers=headers, json=request_data).json()
                output = response['content'].strip()
            else:
                # generate output using local model
                if model.name_or_path in ['THUDM/chatglm3-6b-128k', 'THUDM/LongAlign-6B-64k-base', 'THUDM/LongAlign-6B-64k']:
                    # have to add special code to run chatglm as tokenizer.chat_template tokenization is not
                    # the same as in model.chat (recommended in https://huggingface.co/THUDM/chatglm3-6b-128k)
                    with torch.no_grad():
                        output, _ = model.chat(tokenizer, input_text, history=[], **generate_kwargs)
                else:
                    if use_chat_template:
                        input_text = [{'role': 'user', 'content': input_text}]
                        model_inputs = tokenizer.apply_chat_template(input_text, add_generation_prompt=True,
                                                                     return_tensors='pt', return_dict=model.name_or_path=="custom_rmt").to(model.device)
                        if model.name_or_path != "custom_rmt":
                            model_inputs = {'input_ids': model_inputs}
                    else:
                        model_inputs = tokenizer(input_text, return_tensors='pt',
                                                 add_special_tokens=True).to(model.device)

                    sample_length = model_inputs['input_ids'].shape[1]
                    with torch.no_grad():
                        #print(model_inputs["input_ids"].shape)
                        #print(model_inputs)
                        last_segm = model_inputs["input_ids"].shape[-1] // (1024 - 16) * (1024 - 16)
                        prev_ids = model_inputs["input_ids"][..., :last_segm]
                        #print(prev_ids.shape)
                        #print(prev_ids)
                        output = model.generate(**model_inputs, **generate_kwargs)
                        # we need to reset memory states between samples for activation-beacon models
                        if 'activation-beacon' in model.name_or_path and hasattr(model, 'memory'):
                            model.memory.reset()
                    if model.name_or_path != "custom_rmt":
                        output = output[0][sample_length:]
                    else:
                        output = output[0]
                    output = tokenizer.decode(output, skip_special_tokens=True).strip()

            df.loc[len(df)] = [target, output, question]
            # write results to csv file
            df.to_csv(outfile, escapechar='\\')

prompt template:
{instruction}

{examples}

{post_prompt}

<context>
{context}
</context>

Question: {question}


tasks:   0%|          | 0/2 [00:00<?, ?it/s]
lengths:   0%|          | 0/3 [00:00<?, ?it/s][A

task: qa1 length: 2k:   0%|          | 0/100 [00:00<?, ?it/s][A[A

task: qa1 length: 2k:   1%|          | 1/100 [00:00<00:30,  3.28it/s][A[A

task: qa1 length: 2k:   2%|▏         | 2/100 [00:00<00:29,  3.33it/s][A[A

task: qa1 length: 2k:   3%|▎         | 3/100 [00:00<00:28,  3.36it/s][A[A

task: qa1 length: 2k:   4%|▍         | 4/100 [00:01<00:28,  3.34it/s][A[A

task: qa1 length: 2k:   5%|▌         | 5/100 [00:01<00:28,  3.28it/s][A[A

task: qa1 length: 2k:   6%|▌         | 6/100 [00:01<00:27,  3.39it/s][A[A

task: qa1 length: 2k:   7%|▋         | 7/100 [00:02<00:26,  3.48it/s][A[A

task: qa1 length: 2k:   8%|▊         | 8/100 [00:02<00:26,  3.45it/s][A[A

task: qa1 length: 2k:   9%|▉         | 9/100 [00:02<00:25,  3.50it/s][A[A

task: qa1 length: 2k:  10%|█         | 10/100 [00:02<00:25,  3.47it/s][A[A

task: qa1 length: 2k:  11%|█         | 11/100 [00:03<00:27,  3.2

In [65]:
tasks = ["qa1", "qa2"]
split_names = ["2k", "4k", "8k"]
dataset_name = "RMT-team/babilong"
results_folder = "./test_res"
model_name = "unsloth/Llama-3.2-1B-Instruct"
use_instruction = True
use_examples = True
use_post_prompt = True
use_chat_template = True
api_url = False

In [66]:
model = original_model
model_cpt = "bl_model_vanilla_armt-1b-it-v2"
model.name_or_path = "custom_rmt"
model.device = "cuda"

In [67]:
generate_kwargs = {
    'max_new_tokens': 20,
    'max_length': None,
    'num_beams': 1,
    'do_sample': False,
    'temperature': None,
    'top_p': None,
    'top_k': None,
    'pad_token_id': tokenizer.pad_token_id,
    'eos_token_id': tokenizer.eos_token_id,
    #'logits_processor': [NormLogitsWrapper()],
}

In [68]:
template_to_use = DEFAULT_TEMPLATE
print(f'prompt template:\n{template_to_use}')

for task in tqdm(tasks, desc='tasks'):
    # configure the prompt
    prompt_cfg = {
        'instruction': DEFAULT_PROMPTS[task]['instruction'] if use_instruction else '',
        'examples': DEFAULT_PROMPTS[task]['examples'] if use_examples else '',
        'post_prompt': DEFAULT_PROMPTS[task]['post_prompt'] if use_post_prompt else '',
        'template': template_to_use,
        'chat_template': use_chat_template,
    }
    prompt_name = [f'{k}_yes' if prompt_cfg[k] else f'{k}_no' for k in prompt_cfg if k != 'template']
    prompt_name = '_'.join(prompt_name)

    for split_name in tqdm(split_names, desc='lengths'):
        # load dataset
        data = datasets.load_dataset(dataset_name, split_name)
        task_data = data[task]#.select([1])

        # Prepare files with predictions, prompt, and generation configurations
        outfile = Path(f'{results_folder}/{model_name.replace("../", "")}/{model_cpt.replace("../", "")}/{task}_{split_name}_{prompt_name}.csv')
        outfile.parent.mkdir(parents=True, exist_ok=True)
        cfg_file = f'./{results_folder}/{model_name.replace("../", "")}/{model_cpt.replace("../", "")}/{task}_{split_name}_{prompt_name}.json'
        json.dump({'prompt': prompt_cfg, 'generate_kwargs': generate_kwargs}, open(cfg_file, 'w'), indent=4)

        df = pd.DataFrame({'target': [], 'output': [], 'question': []})

        for sample in tqdm(task_data, desc=f'task: {task} length: {split_name}'):
            target = sample['target']
            context = sample['input']
            question = sample['question']

            # format input text
            input_text = get_formatted_input(context, question, prompt_cfg['examples'],
                                             prompt_cfg['instruction'], prompt_cfg['post_prompt'],
                                             template=prompt_cfg['template'])
            if api_url:
                # model is running via llamacpp's serve command
                headers = {'Content-Type': 'application/json'}
                if generate_kwargs['temperature'] is None:
                    generate_kwargs['temperature'] = 0.0

                if use_chat_template:
                    input_text = [{'role': 'user', 'content': input_text}]
                    model_inputs = tokenizer.apply_chat_template(input_text, tokenize=True,
                                                                 add_generation_prompt=True)
                else:
                    model_inputs = tokenizer.encode(input_text, add_special_tokens=True)

                request_data = {'prompt': model_inputs, 'temperature': generate_kwargs['temperature']}
                response = requests.post(api_url, headers=headers, json=request_data).json()
                output = response['content'].strip()
            else:
                # generate output using local model
                if model.name_or_path in ['THUDM/chatglm3-6b-128k', 'THUDM/LongAlign-6B-64k-base', 'THUDM/LongAlign-6B-64k']:
                    # have to add special code to run chatglm as tokenizer.chat_template tokenization is not
                    # the same as in model.chat (recommended in https://huggingface.co/THUDM/chatglm3-6b-128k)
                    with torch.no_grad():
                        output, _ = model.chat(tokenizer, input_text, history=[], **generate_kwargs)
                else:
                    if use_chat_template:
                        input_text = [{'role': 'user', 'content': input_text}]
                        model_inputs = tokenizer.apply_chat_template(input_text, add_generation_prompt=True,
                                                                     return_tensors='pt', return_dict=model.name_or_path=="custom_rmt").to(model.device)
                        if model.name_or_path != "custom_rmt":
                            model_inputs = {'input_ids': model_inputs}
                    else:
                        model_inputs = tokenizer(input_text, return_tensors='pt',
                                                 add_special_tokens=True).to(model.device)

                    sample_length = model_inputs['input_ids'].shape[1]
                    with torch.no_grad():
                        #print(model_inputs["input_ids"].shape)
                        #print(model_inputs)
                        last_segm = model_inputs["input_ids"].shape[-1] // (1024 - 16) * (1024 - 16)
                        prev_ids = model_inputs["input_ids"][..., :last_segm]
                        #print(prev_ids.shape)
                        #print(prev_ids)
                        output = model.generate(**model_inputs, **generate_kwargs)
                        # we need to reset memory states between samples for activation-beacon models
                        if 'activation-beacon' in model.name_or_path and hasattr(model, 'memory'):
                            model.memory.reset()
                    if model.name_or_path != "custom_rmt":
                        output = output[0][sample_length:]
                    else:
                        output = output[0]
                    output = tokenizer.decode(output, skip_special_tokens=True).strip()

            df.loc[len(df)] = [target, output, question]
            # write results to csv file
            df.to_csv(outfile, escapechar='\\')

prompt template:
{instruction}

{examples}

{post_prompt}

<context>
{context}
</context>

Question: {question}


tasks:   0%|          | 0/2 [00:00<?, ?it/s]
lengths:   0%|          | 0/3 [00:00<?, ?it/s][A

task: qa1 length: 2k:   0%|          | 0/100 [00:00<?, ?it/s][A[A

task: qa1 length: 2k:   1%|          | 1/100 [00:00<00:13,  7.58it/s][A[A

task: qa1 length: 2k:   2%|▏         | 2/100 [00:00<00:12,  7.56it/s][A[A

task: qa1 length: 2k:   3%|▎         | 3/100 [00:00<00:12,  7.54it/s][A[A

task: qa1 length: 2k:   4%|▍         | 4/100 [00:00<00:12,  7.62it/s][A[A

task: qa1 length: 2k:   5%|▌         | 5/100 [00:00<00:13,  7.21it/s][A[A

task: qa1 length: 2k:   6%|▌         | 6/100 [00:00<00:11,  7.91it/s][A[A

task: qa1 length: 2k:   7%|▋         | 7/100 [00:00<00:11,  7.91it/s][A[A

task: qa1 length: 2k:   8%|▊         | 8/100 [00:01<00:11,  7.85it/s][A[A

task: qa1 length: 2k:   9%|▉         | 9/100 [00:01<00:11,  7.79it/s][A[A

task: qa1 length: 2k:  10%|█         | 10/100 [00:01<00:11,  7.76it/s][A[A

task: qa1 length: 2k:  11%|█         | 11/100 [00:01<00:10,  8.2

In [69]:
import numpy as np
def compare_generated_ids(orig, opt, tokenizer):
    toks_orig = np.array(tokenizer(orig)["input_ids"])
    toks_opt = np.array(tokenizer(opt)["input_ids"])
    length = min([len(toks_opt), len(toks_orig)])
    #print(toks_opt, toks_orig, length)
    #print(np.where(toks_orig[:length] == toks_opt[:length]))
    match = np.mean(toks_orig[:length] == toks_opt[:length])
    return match

In [70]:
for split in split_names:
    path_orig = f"./test_res/unsloth/Llama-3.2-1B-Instruct/bl_model_vanilla_armt-1b-it-v2/qa1_{split}_instruction_yes_examples_yes_post_prompt_yes_chat_template_yes.csv"
    path_opt = f"./test_res/unsloth/Llama-3.2-1B-Instruct/bl_model_mem_code_fix_executor_mem_patch_armt-1b-it-v2/qa1_{split}_instruction_yes_examples_yes_post_prompt_yes_chat_template_yes.csv"
    orig_answers = pd.read_csv(path_orig)["output"]
    opt_answers = pd.read_csv(path_opt)["output"]
    scores = []
    for orig, opt in zip(orig_answers, opt_answers):
        scores.append(compare_generated_ids(orig, opt, tokenizer))
    print(split, f"{np.round(np.mean(scores)*100,2)}%")

2k 83.17%
4k 64.17%
8k 56.5%


# Now calc scores on Babilong

In [1]:
TASK_LABELS = {'qa1': ['bathroom', 'bedroom', 'garden', 'hallway', 'kitchen', 'office'], 
 'qa2': ['bathroom', 'bedroom', 'garden', 'hallway', 'kitchen', 'office'], 
 'qa3': ['bathroom', 'bedroom', 'garden', 'hallway', 'kitchen', 'office'], 
 'qa4': ['bathroom', 'bedroom', 'garden', 'hallway', 'kitchen', 'office'], 
 'qa5': ['Bill', 'Fred', 'Jeff', 'Mary', 'apple', 'football', 'milk'], 
 'qa6': ['no', 'yes'], 
 'qa7': ['none', 'one', 'three', 'two'], 
 'qa8': ['apple', 'football', 'milk', 'nothing'], 
 'qa9': ['no', 'yes'], 
 'qa10': ['maybe', 'no', 'yes'],
 'qa11': ['bathroom', 'bedroom', 'garden', 'hallway', 'kitchen', 'office'], 
 'qa12': ['bathroom', 'bedroom', 'garden', 'hallway', 'kitchen', 'office'], 
 'qa13': ['bathroom', 'bedroom', 'garden', 'hallway', 'kitchen', 'office'], 
 'qa14': ['bedroom', 'cinema', 'kitchen', 'office', 'park', 'school'], 
 'qa15': ['cat', 'mouse', 'sheep', 'wolf'], 
 'qa16': ['gray', 'green', 'white', 'yellow'], 
 'qa17': ['no', 'yes'], 
 'qa18': ['no', 'yes'], 
 'qa19': ['e,e', 'e,n', 'e,s', 'n,e', 'n,n', 'n,w', 's,e', 's,s', 's,w', 'w,n', 'w,s', 'w,w'], 
 'qa20': ['bedroom', 'bored', 'garden', 'hungry', 'kitchen', 'thirsty', 'tired']
}


def preprocess_output(output):
    output = output.lower()
    # take only the first sentence from output
    output = output.split('.')[0]
    # filter responses when model tries to generate examples
    output = output.split('<context>')[0]
    output = output.split('<example>')[0]
    output = output.split('Question')[0]
    return output


def preprocess_output_cot(output):
    output = output.lower()
    # take only the first sentence from output
    splitted_output = output.split('.')
    if len(splitted_output) == 1:
        output = splitted_output[-1]
    else:
        output = splitted_output[-2]
    # filter responses when model tries to generate examples
    output = output.split('<context>')[0]
    output = output.split('<example>')[0]
    output = output.split('Question')[0]
    return output


def compare_answers(target, output, question, task_labels, cot_answer=False):
    if cot_answer:
        output = preprocess_output_cot(output)
    else:
        output = preprocess_output(output)
    target = str(target).lower()
    task_labels = {str(label).lower() for label in task_labels}

    # extract labels that were mentioned in the question
    labels_in_question = {label for label in task_labels if label in str(question).lower()}
    # extract labels that were mentioned in the model output
    labels_in_output = {label for label in task_labels if label in output}
    # filter labels in the output to exclude mentioned in the question
    # mentions in questions are never targets
    labels_in_output = labels_in_output - labels_in_question

    # check if the target is the only prediction
    if ',' in target and len(target) > 3: 
        # if target contains multiple subtargets in qa8
        subtargets = target.split(',')
        num_subtargets = len(subtargets)
        if all([t in labels_in_output for t in subtargets]) and len(labels_in_output) == num_subtargets:
            return True
    else:
        if target in labels_in_output and len(labels_in_output) == 1:
            return True

    return False

In [2]:
SYSTEM_TEMPLATE = '{instruction}\n\n{examples}\n\n{post_prompt}'
USER_TEMPLATE = '<context>\n{context}\n</context>\n\nQuestion: {question}'
DEFAULT_TEMPLATE = f'{SYSTEM_TEMPLATE}\n\n{USER_TEMPLATE}'

CUSTOM_SYSTEM_PROMPTS = {
    # https://github.com/dvlab-research/LongLoRA/blob/2345c6d030f61ac3a031906386a103a5b05e0e6f/inference.py#L18
    'LONGLORA_LLAMA2':
        'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. '
        'Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
        'Please ensure that your responses are socially unbiased and positive in nature.\n\n'
        'If a question does not make any sense, or is not factually coherent, explain why instead of answering '
        'something not correct. If you don\'t know the answer to a question, please don\'t share false information.'
}


def get_formatted_input(context, question, examples, instruction, post_prompt, template=DEFAULT_TEMPLATE):
    # pre_prompt - general instruction
    # examples - in-context examples
    # post_prompt - any additional instructions after examples
    # context - text to use for qa
    # question - question to answer based on context
    formatted_input = template.format(instruction=instruction, examples=examples, post_prompt=post_prompt,
                                      context=context.strip(), question=question)
    return formatted_input.strip()


DEFAULT_PROMPTS = {
    'qa1': {
        'instruction':
            'I will give you context with the facts about positions of different persons hidden in some random text '
            'and a question. You need to answer the question based only on the information from the facts. '
            'If a person was in different locations, use the latest location to answer the question.',
        'examples':
            '<example>\n'
            'Charlie went to the hallway. Judith come back to the kitchen. Charlie travelled to balcony. '
            'Where is Charlie?\n'
            'Answer: The most recent location of Charlie is balcony.\n'
            '</example>\n\n'
            '<example>\n'
            'Alan moved to the garage. Charlie went to the beach. Alan went to the shop. Rouse '
            'travelled to balcony. Where is Alan?\n'
            'Answer: The most recent location of Alan is shop.\n'
            '</example>',
        'post_prompt':
            'Always return your answer in the following format: '
            'The most recent location of ’person’ is ’location’. Do not write anything else after that.'
    },
    'qa2': {
        'instruction':
            'I give you context with the facts about locations and actions of different persons '
            'hidden in some random text and a question.'
            'You need to answer the question based only on the information from the facts.\n'
            'If a person got an item in the first location and travelled to the second location '
            'the item is also in the second location. '
            'If a person dropped an item in the first location and moved to the second location '
            'the item remains in the first location.',
        'examples':
            '<example>\n'
            'Charlie went to the kitchen. Charlie got a bottle. Charlie moved to the balcony. '
            'Where is the bottle?\n'
            'Answer: The bottle is in the balcony.\n'
            '</example>\n'
            '<example>\n'
            'Alan moved to the garage. Alan got a screw driver. Alan moved to the kitchen. Where '
            'is the screw driver?\n'
            'Answer: The screw driver is in the kitchen.\n'
            '</example>',
        'post_prompt':
            'Always return your answer in the following format: The ’item’ is in ’location’. '
            'Do not write anything else after that.'
    },
    'qa3': {
        'instruction':
            'I give you context with the facts about locations and actions of different persons '
            'hidden in some random text and a question. '
            'You need to answer the question based only on the information from the facts.\n'
            'If a person got an item in the first location and travelled to the second location '
            'the item is also in the second location. '
            'If a person dropped an item in the first location and moved to the second location '
            'the item remains in the first location.',
        'examples':
            '<example>\n'
            'John journeyed to the bedroom. Mary grabbed the apple. Mary went back to the bathroom. '
            'Daniel journeyed to the bedroom. Daniel moved to the garden. Mary travelled to the kitchen. '
            'Where was the apple before the kitchen?\n'
            'Answer: Before the kitchen the apple was in the bathroom.\n'
            '</example>\n'
            '<example>\n'
            'John went back to the bedroom. John went back to the garden. John went back to the kitchen. '
            'Sandra took the football. Sandra travelled to the garden. Sandra journeyed to the bedroom. '
            'Where was the football before the bedroom?\n'
            'Answer: Before the bedroom the football was in the garden.\n'
            '</example>',
        'post_prompt':
            'Always return your answer in the following format: '
            'Before the $location_1$ the $item$ was in the $location_2$. Do not write anything else after that.'
    },
    'qa4': {
        'instruction':
            'I will give you context with the facts about different people, their location and actions, hidden in '
            'some random text and a question. '
            'You need to answer the question based only on the information from the facts.',
        'examples':
            '<example>\n'
            'The hallway is south of the kitchen. The bedroom is north of the kitchen. '
            'What is the kitchen south of?\n'
            'Answer: bedroom\n'
            '</example>\n'
            '<example>\n'
            'The garden is west of the bedroom. The bedroom is west of the kitchen. What is west of the bedroom?\n'
            'Answer: garden\n'
            '</example>',
        'post_prompt':
            'Your answer should contain only one word - location. Do not write anything else after that.'
    },
    'qa5': {
        'instruction':
            'I will give you context with the facts about locations and their relations hidden in some random text '
            'and a question. You need to answer the question based only on the information from the facts.',
        'examples':
            '<example>\n'
            'Mary picked up the apple there. Mary gave the apple to Fred. Mary moved to the bedroom. '
            'Bill took the milk there. Who did Mary give the apple to?\n'
            'Answer: Fred\n'
            '</example>\n'
            '<example>\n'
            'Jeff took the football there. Jeff passed the football to Fred. Jeff got the milk there. '
            'Bill travelled to the bedroom. Who gave the football?\n'
            'Answer: Jeff\n'
            '</example>\n'
            '<example>\n'
            'Fred picked up the apple there. Fred handed the apple to Bill. Bill journeyed to the bedroom. '
            'Jeff went back to the garden. What did Fred give to Bill?\n'
            'Answer: apple\n'
            '</example>',
        'post_prompt':
            'Your answer should contain only one word. Do not write anything else after that. '
            'Do not explain your answer.'
    },
    'qa6': {
        'instruction':
            'I will give you context with the facts about people and their locations hidden in some random text and a '
            'question. You need to answer the question based only on the information from the facts. '
            'If a person was in different locations, use the latest location the person was in to answer the question.',
        'examples':
            '<example>\n'
            'John travelled to the hallway. John travelled to the garden. Is John in the garden?\n'
            'Answer: yes\n'
            '</example>\n'
            '<example>\n'
            'Mary went to the office. Daniel journeyed to the hallway. Mary went to the bedroom. '
            'Sandra went to the garden. Is Mary in the office?\n'
            'Answer: no\n'
            '</example>\n',
        'post_prompt':
            'Your answer should contain only one word - $yes$ or $no$. Do not write anything else after that. '
            'Do not explain your answer.'
    },
    'qa7': {
        'instruction':
            'I will give you context with the facts about people and objects they carry, hidden in some random text '
            'and a question. You need to answer the question based only on the information from the facts.',
        'examples':
            '<example>\n'
            'Daniel went to the bedroom. Daniel got the apple there. How many objects is Daniel carrying?\n'
            'Answer: one\n'
            '</example>\n'
            '<example>\n'
            'Mary grabbed the apple there. Mary gave the apple to John. How many objects is Mary carrying?\n'
            'Answer: none\n'
            '</example>\n'
            '<example>\n'
            'Sandra travelled to the hallway. Sandra picked up the milk there. Sandra took the apple there. '
            'Mary travelled to the garden. How many objects is Sandra carrying?\n'
            'Answer: two\n'
            '</example>\n',
        'post_prompt':
            'Your answer should contain only one word - $none$ or $number_of_objects$. '
            'Do not write anything else after that. Do not explain your answer.',
    },
    'qa8': {
        'instruction':
            'I will give you context with the facts about people and objects they carry, hidden in some random text '
            'and a question. You need to answer the question based only on the information from the facts.',
        'examples':
            '<example>\n'
            'Sandra travelled to the garden. Mary grabbed the milk there. What is Mary carrying?\n'
            'Answer: milk\n'
            '</example>\n'
            '<example>\n'
            'Mary travelled to the kitchen. Sandra travelled to the office. John travelled to the office. '
            'Sandra discarded the milk there. What is Sandra carrying?\n'
            'Answer: nothing\n'
            '</example>\n'
            '<example>\n'
            'Daniel grabbed the apple there. Mary went to the office. Daniel moved to the garden. '
            'Daniel grabbed the milk there. Mary went to the kitchen. What is Daniel carrying?\n'
            "Answer: apple,milk\n"
            "</example>\n",
        'post_prompt':
            'Your answer should contain only one or two words: $nothing$ or $object$ or $object_1$, $object_2$. '
            'Do not write anything else. Do not explain your answer.'
    },
    'qa9': {
        'instruction':
            'I will give you context with the facts about people and their locations hidden in some random text and '
            'a question. You need to answer the question based only on the information from the facts. '
            'If a person was in different locations, use the latest location the person was in to answer the question.',
        'examples':
            '<example>\n'
            'John is not in the bathroom. Sandra is not in the bedroom. Is John in the bathroom?\n'
            'Answer: no\n'
            '</example>\n'
            '<example>\n'
            'Mary journeyed to the kitchen. John is in the bedroom. Sandra is not in the garden. '
            'Is Mary in the kitchen?\n'
            'Answer: yes\n'
            '</example>\n',
        'post_prompt':
            'Your answer should contain only one word - $yes$ or $no$. Do not write anything else. '
            'Do not explain your answer.'
    },
    'qa10': {
        'instruction':
            'I will give you context with the facts about people and their locations hidden in some random text and a '
            'question. You need to answer the question based only on the information from the facts. '
            'If a person was in different locations, use the latest location the person was in to answer the question.',
        'examples':
            '<example>\n'
            'Bill is in the kitchen. Julie is either in the school or the cinema. Is Bill in the bedroom?\n'
            'Answer: no\n'
            '</example>\n'
            '<example>\n'
            'Fred is in the bedroom. Mary is either in the school or the cinema. Is Mary in the school?\n'
            'Answer: maybe\n'
            '</example>\n'
            '<example>\n'
            'Fred is either in the kitchen or the park. Bill moved to the cinema. Is Bill in the cinema?\n'
            'Answer: yes\n'
            '</example>\n'
            '<context>\n',
        'post_prompt':
            'Your answer should contain only one word - $yes$ or $no$ or $maybe$. Do not write anything else. '
            'Do not explain your answer.'
    },
    'qa11': {
        'instruction':
            'I will give you context with the facts about people and their locations hidden in some random text and a '
            'question. You need to answer the question based only on the information from the facts. '
            'If a person was in different locations, use the latest location the person was in to answer the question.',
        'examples':
            '<example>\n'
            'Daniel journeyed to the hallway. After that he journeyed to the garden. Where is Daniel?\n'
            'Answer: garden\n'
            '</example>\n'
            '<example>\n'
            'Mary moved to the office. Afterwards she journeyed to the kitchen. Daniel went to the hallway. '
            'Then he journeyed to the garden. Where is Mary?\n'
            'Answer: kitchen\n'
            '</example>\n'
            '<example>\n'
            'Sandra moved to the kitchen. After that she went back to the hallway. Sandra moved to the bedroom. '
            'Then she went to the hallway. Mary moved to the bedroom. Afterwards she travelled to the bathroom. '
            'Where is Sandra?\n'
            'Answer: hallway\n'
            '</example>\n'
            '<context>\n',
        'post_prompt':
            'Your answer should contain only one word - location. Do not write anything else after that. '
            'Do not explain your answer.'
    },
    'qa12': {
        'instruction':
            'I will give you context with the facts about people and their locations hidden in some random text and a '
            'question. You need to answer the question based only on the information from the facts. '
            'If a person was in different locations, use the latest location the person was in to answer the question.',
        'examples':
            '<example>\n'
            'Mary and Daniel travelled to the bathroom. John and Daniel travelled to the office. Where is Daniel?\n'
            'Answer: office\n'
            '</example>\n'
            '<example>\n'
            'Sandra and Mary went back to the office. Daniel and Sandra went to the bedroom. Sandra and Mary travelled to the hallway. '
            'John and Mary went to the kitchen. Where is Mary?\n'
            'Answer: kitchen\n'
            '</example>\n'
            '<example>\n'
            'Daniel and Sandra went back to the hallway. Daniel and John moved to the office. Daniel and John moved to the garden. '
            'Daniel and Mary went back to the bathroom. Daniel and John went back to the kitchen. Daniel and Sandra went to the bathroom. '
            'Where is John?\n'
            'Answer: kitchen\n'
            '</example>\n'
            '<context>\n',
        'post_prompt':
            'Your answer should contain only one word - location. Do not write anything else after that. '
            'Do not explain your answer.'
    },
    'qa13': {
        'instruction':
            'I will give you context with the facts about people and their locations hidden in some random text and a '
            'question. You need to answer the question based only on the information from the facts. '
            'If a person was in different locations, use the latest location the person was in to answer the question.',
        'examples':
            '<example>\n'
            'Mary and Daniel travelled to the bathroom. Then they journeyed to the hallway. Where is Daniel?\n'
            'Answer: hallway\n'
            '</example>\n'
            '<example>\n'
            'Daniel and Sandra travelled to the kitchen. After that they journeyed to the hallway. Mary and Daniel travelled to the bedroom. '
            'After that they travelled to the hallway. Where is Sandra?\n'
            'Answer: hallway\n'
            '</example>\n'
            '<example>\n'
            'John and Mary moved to the bathroom. Then they travelled to the office. John and Mary went to the kitchen. '
            'Afterwards they went to the bedroom. John and Sandra moved to the bathroom. Following that they went back to the kitchen. '
            'Where is Mary?\n'
            'Answer: bedroom\n'
            '</example>\n'
            '<context>\n',
        'post_prompt':
            'Your answer should contain only one word - location. Do not write anything else after that. '
            'Do not explain your answer.'
    },
    'qa14': {
        'instruction':
            'I will give you context with the facts about people and their locations hidden in some random text and a '
            'question. You need to answer the question based only on the information from the facts. '
            'If a person was in different locations, use the latest location the person was in to answer the question.',
        'examples':
            '<example>\n'
            'Bill went back to the cinema yesterday. Julie went to the school this morning. Fred went to the park yesterday. '
            'Yesterday Julie went to the office. Where was Julie before the school?\n'
            'Answer: office\n'
            '</example>\n'
            '<example>\n'
            'This morning Fred went to the kitchen. Fred journeyed to the bedroom yesterday. Mary travelled to the bedroom this morning. '
            'Yesterday Mary went to the cinema. Where was Mary before the bedroom?\n'
            'Answer: cinema\n'
            '</example>\n'
            '<example>\n'
            'Yesterday Julie went back to the park. Julie went to the bedroom this morning. Bill journeyed to the cinema yesterday. '
            'This morning Bill went back to the park. This evening Julie went to the school. This afternoon Julie went back to the park. '
            'Where was Julie before the bedroom?\n'
            'Answer: park\n'
            '</example>\n'
            '<context>\n',
        'post_prompt':
            'Your answer should contain only one word - location. Do not write anything else after that. '
            'Do not explain your answer.'
    },
    'qa15': {
        'instruction':
            'I will give you context with the facts about animals, their names and relations. The facts and a question '
            'are hidden in some random text. You need to answer the question based only on the information from the facts.',
        'examples':
            '<example>\n'
            'Mice are afraid of wolves. Gertrude is a mouse. Cats are afraid of sheep. '
            'Winona is a mouse. Sheep are afraid of wolves. Emily is a mouse. Jessica is a wolf. '
            'What is gertrude afraid of?\n'
            'Answer: wolf\n'
            '</example>\n'
            '<example>\n'
            'Mice are afraid of wolves. Gertrude is a mouse. Cats are afraid of sheep. '
            'Winona is a mouse. Sheep are afraid of wolves. Emily is a mouse. Jessica is a wolf. '
            'What is jessica afraid of?\n'
            'Answer: cat\n'
            '</example>\n'
            '<example>\n'
            'Mice are afraid of cats. Wolves are afraid of sheep. Emily is a wolf. '
            'Cats are afraid of sheep. Gertrude is a wolf. Sheep are afraid of cats. Winona is a wolf. '
            'What is emily afraid of?\n'
            'Answer: sheep\n'
            '</example>\n'
            '<context>\n',
        'post_prompt':
            'Your answer should contain only one word - an animal species. Do not write anything else after that. '
            'Do not explain your answer.'
    },
    'qa16': {
        'instruction':
            'I will give you context with the facts about animals, their names and colors. The facts and a question '
            'are hidden in some random text. You need to answer the question based only on the information from the facts.',
        'examples':
            '<example>\n'
            'Lily is a frog. Bernhard is a frog. Bernhard is green. Brian is a lion. Brian is white. '
            'Julius is a swan. Julius is green. Lily is green. Greg is a swan. What color is Greg?\n'
            'Answer: green\n'
            '</example>\n'
            '<example>\n'
            'Julius is a lion. Lily is a rhino. Bernhard is a swan. Lily is white. Bernhard is green. '
            'Greg is a rhino. Greg is gray. Julius is white. Brian is a lion. What color is Brian?\n'
            'Answer: white\n'
            '</example>\n'
            '<example>\n'
            'Brian is a rhino. Julius is a lion. Bernhard is a lion. Greg is a swan. Brian is gray. '
            'Greg is white. Lily is a rhino. Bernhard is yellow. Lily is gray. What color is Julius?\n'
            'Answer: yellow\n'
            '</example>\n'
            '<context>\n',
        'post_prompt':
            'Your answer should contain only one word - a color. Do not write anything else after that. '
            'Do not explain your answer.'
    },
    'qa17': {
        'instruction':
            'I will give you context with the facts about different figures, their location and colors, hidden in '
            'some random text and a question. '
            'You need to answer the question based only on the information from the facts.',
        'examples':
            '<example>\n'
            'The triangle is above the pink rectangle. The blue square is to the left of the triangle. '
            'Is the pink rectangle to the right of the blue square?\n'
            'Answer: yes\n'
            '</example>\n'
            '<example>\n'
            'The red sphere is to the left of the yellow square. The red sphere is below the pink rectangle. '
            'Is the pink rectangle to the left of the yellow square?\n'
            'Answer: yes\n'
            '</example>'
            '<example>\n'
            'The red sphere is above the pink rectangle. The red sphere is to the right of the red square. '
            'Is the pink rectangle above the red square?\n'
            'Answer: no\n'
            '</example>',
        'post_prompt':
            'Your answer should contain only one word - $yes$ or $no$. Do not write anything else. '
            'Do not explain your answer.'
    },
    'qa18': {
        'instruction':
            'I will give you context with the facts about different objects and their sizes, hidden in '
            'some random text and a question. '
            'You need to answer the question based only on the information from the facts.',
        'examples':
            '<example>\n'
            'The box of chocolates fits inside the chest. The box is bigger than the chest. The box is bigger than the suitcase. '
            'The suitcase fits inside the box. The container is bigger than the box of chocolates. Does the box fit in the box of chocolates?\n'
            'Answer: no\n'
            '</example>\n'
            '<example>\n'
            'The suitcase is bigger than the container. The container fits inside the box. The chest is bigger than the chocolate.'
            'The suitcase fits inside the box. The chest fits inside the box. Does the chocolate fit in the box?\n'
            'Answer: yes\n'
            '</example>'
            '<example>\n'
            'The chocolate fits inside the box of chocolates. The suitcase fits inside the box. The chocolate fits inside the box. '
            'The box is bigger than the box of chocolates. The suitcase is bigger than the box of chocolates. Is the chocolate bigger than the box?\n'
            'Answer: no\n'
            '</example>',
        'post_prompt':
            'Your answer should contain only one word - $yes$ or $no$. Do not write anything else. '
            'Do not explain your answer.'
    },
    'qa19': {
        'instruction':
            'I will give you context with the facts about different places and their locations, hidden in '
            'some random text and a question. '
            'You need to answer the question based only on the information from the facts.',
        'examples':
            '<example>\n'
            'The office is east of the hallway. The kitchen is north of the office. The garden is west of the bedroom. '
            'The office is west of the garden. The bathroom is north of the garden. How do you go from the kitchen to the garden?\n'
            'Answer: s,e\n'
            '</example>\n'
            '<example>\n'
            'The bedroom is west of the hallway. The office is east of the garden. The garden is north of the kitchen. '
            'The kitchen is north of the bathroom. The hallway is west of the garden. How do you go from the kitchen to the hallway?\n'
            'Answer: n,w\n'
            '</example>\n'
            '<example>\n'
            'The bedroom is south of the hallway. The bathroom is east of the office. The kitchen is west of the garden. '
            'The garden is south of the office. The office is south of the bedroom. How do you go from the garden to the bedroom?\n'
            'Answer: n,n\n'
            '</example>\n',
        'post_prompt':
            'Your answer should contain only two letters, separated by a comma - ordinal directions. You can choose the letters from '
             '$n$, $s$, $e$ and $w$. Do not write anything else after that.'
    },
    'qa20': {
        'instruction':
            'I will give you context with the facts about people, their locations and condition hidden in some random text and a '
            'question. You need to answer the question based only on the information from the facts. '
            'If a person was in different locations, use the latest location the person was in to answer the question.',
        'examples':
            '<example>\n'
            'Sumit is tired. Where will sumit go?\n'
            'Answer: bedroom\n'
            '</example>\n'
            '<example>\n'
            'Yann is hungry. Yann journeyed to the kitchen. Why did yann go to the kitchen?\n'
            'Answer: hungry\n'
            '</example>\n'
            '<example>\n'
            'Antoine is thirsty. Yann is tired. Yann went back to the bedroom. Yann picked up the pajamas there.'
            'Jason is thirsty. Antoine went back to the kitchen. Why did antoine go to the kitchen?\n'
            'Answer: thirsty\n'
            '</example>\n'
            '<context>\n',
        'post_prompt':
            'Your answer should contain only one word - a person condition or a place. Do not write anything else after that. '
            'Do not explain your answer.'
    }
    

}

In [3]:
import os
import numpy as np
import pandas as pd

In [35]:
# extract results into table
results_folder = './test_res'


mod_names = [
    #"unsloth/Llama-3.2-1B-Instruct/bl_model_vanilla_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/bl_model_mem_code_fix_executor_mem_patch_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/DEBUG_fast_executor_v2_mem_patch_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/DEBUG_trained_fast_executor_v2_mem_patch_armt-1b-it-v2",

    "unsloth/Llama-3.2-1B-Instruct/orig_v2_model_paper_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/fast_executor_from_orig_paper_v2_mem_patch_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/fast_executor_from_orig_paper_wo_copy_v2_mem_patch_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/fast_executor_from_orig_paper_wo_copy_time_fix_v2_mem_patch_armt-1b-it-v2",
    "unsloth/Llama-3.2-1B-Instruct/fast_executor_from_orig_paper_time_fix_v2_mem_patch_armt-1b-it-v2",
    
]
disp_names = [
    #"Original-ARMT-Llama-3.2-1B-Instruct",
    #"Optimized-ARMT-Llama-3.2-1B-Instruct",
    #"Optimized-Fixed-ARMT-Llama-3.2-1B-Instruct",
    #"Optimized-HalfTrained-ARMT-Llama-3.2-1B-Instruct",

    "Original-ARMT-Llama-3.2-1B-Instruct - Time, sec.",
    #"Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec.",
    #"Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec. (without memory copy)",
    #"Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec. (without memory copy)",
    "Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec.",
]

all_dfs = []
for dat_idx, mod_name in enumerate(mod_names):
    #gkuzmin/rmt_it/babilong/babilong_evals/test_exps/gpt_2_lora_ift_flan_1024_seed_42/checkpoint-1000
    overall_results = {}
    #dname = ift_dataset_path.split("-")[0]
    disp_name = disp_names[dat_idx]
    if "-E" in disp_name:
        prompt_name = 'instruction_no_examples_yes_post_prompt_no_chat_template_no'
        all_names = [
            [f'./{mod_name}']]
    elif mod_name == "meta-llama/Llama-3.2-1B" or "data/pretrained_models" in mod_name:
        prompt_name = 'instruction_yes_examples_yes_post_prompt_yes_chat_template_no'
        all_names = [
            [f'./{mod_name}']]
    else:
        prompt_name = 'instruction_yes_examples_yes_post_prompt_yes_chat_template_yes'
    all_names = [
        [f'./{mod_name}']]
    names = all_names[0]
    for model_name in names:
        overall_results[disp_names[dat_idx]] = {}
        tasks = ["qa2"]# ['qa1', 'qa2', 'qa4', 'qa5']#, 'qa3', 'qa4', 'qa5']#, 'qa6', 'qa7', 'qa8', 'qa9', 'qa10']
        tasks = ['qa1', 'qa2']#, 'qa3', 'qa4', 'qa5']
        lengths = ['0k', '1k', '2k', '4k', '8k']#, '16k']#, '32k']#, '32k', '64k']
        lengths = ['0k', '1k', '2k', '4k', '8k', '16k', '32k', '64k']
        accuracy = np.zeros((len(tasks), len(lengths)))
        for j, task in enumerate(tasks):
            #overall_results[model_name][task] = {}
            for i, ctx_length in enumerate(lengths):
                if "lm_no_st" in model_name:
                    fname = f'./{results_folder_rmt}/{model_name}/{task}_{ctx_length}_{prompt_name}.csv'
                elif "mix" in model_name or "mergekit" in model_name:
                    fname = f'./{new_results_folder}/{model_name}/{task}_{ctx_length}_{prompt_name}.csv'
                else:
                    fname = f'./{results_folder}/{model_name}/{task}_{ctx_length}_{prompt_name}.csv'
                if not os.path.isfile(fname):
                    print(f'No such file: {fname}')
                    continue
        
                df = pd.read_csv(fname)
        
                if df['output'].dtype != object:
                    df['output'] = df['output'].astype(str)
                df['output'] = df['output'].fillna('')
        
                if ctx_length == "32k":
                    df.dropna(inplace=True)
                    print(len(df))
                df['correct'] = df.apply(lambda row: compare_answers(row['target'], row['output'],
                                                                     row['question'], TASK_LABELS[task]
                                                                     ), axis=1)
                score = df['correct'].sum()
                #if model_name == "Qwen/Qwen2.5-0.5B-Instruct/runs/lm/flan/Qwen/Qwen2.5-0.5B-Instruct/linear_adamw_wd1e-03_4086-128-1x4096_mem5_bs32_regular_bptt-4_lora_freeze_from_cpt_0-1/run_2/model_best/model.safetensors":
                #    overall_results["RMT4k-Qwen2.5-0.5B-Instruct"][task + "_" + ctx_length] = float(np.round(100 * score / len(df) if len(df) > 0 else 0, 2))
                #else:
                overall_results[disp_names[dat_idx]][task + "_" + ctx_length] = float(np.round(100 * score / len(df) if len(df) > 0 else 0, 2))
                accuracy[j, i] = 100 * score / len(df) if len(df) > 0 else 0
        #print(accuracy)
    overall_results = {key.split("/")[-1]:value for key,value in overall_results.items()}
    mod_df = pd.DataFrame.from_dict(overall_results)
    all_dfs.append(mod_df)
full_df = pd.concat(all_dfs, axis=1)
from matplotlib.colors import LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list('ryg', ["red", "yellow", "green"], N=256)
cmap = "BuGn"
# old cmap
#full_df.style.format(precision=2).background_gradient(axis=0, cmap='YlOrRd')
s = full_df.style
s.set_table_styles([
    {"selector": "tr", "props": "line-height: 15px;"},
    {"selector": "td,th", "props": "line-height: inherit; padding: 5px;"}
])
for l0 in ['qa1_64k', 'qa2_64k', 'qa3_16k', 'qa4_16k', 'qa5_16k']:
    s.set_table_styles({l0: [{'selector': '', 'props': 'border-bottom: 3px solid black;'}]}, overwrite=False, axis=1)
fres = s.format(precision=2).background_gradient(axis=None, cmap=cmap)
#fres = full_df.style.format(precision=2).background_gradient(axis=0, cmap=cmap)
from pandas import option_context

with option_context('display.max_colwidth', 20):
    display(fres)


100
100
100
100


Unnamed: 0,"Original-ARMT-Llama-3.2-1B-Instruct - Time, sec.","Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec."
qa1_0k,100.0,100.0
qa1_1k,100.0,100.0
qa1_2k,100.0,100.0
qa1_4k,100.0,100.0
qa1_8k,100.0,100.0
qa1_16k,100.0,100.0
qa1_32k,100.0,100.0
qa1_64k,70.0,69.0
qa2_0k,100.0,100.0
qa2_1k,100.0,100.0


In [37]:
# extract results into table
import json
results_folder = './test_res'


mod_names = [
    #"unsloth/Llama-3.2-1B-Instruct/bl_model_vanilla_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/bl_model_mem_code_fix_executor_mem_patch_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/DEBUG_fast_executor_v2_mem_patch_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/DEBUG_trained_fast_executor_v2_mem_patch_armt-1b-it-v2",

    "unsloth/Llama-3.2-1B-Instruct/orig_v2_model_paper_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/fast_executor_from_orig_paper_v2_mem_patch_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/fast_executor_from_orig_paper_wo_copy_v2_mem_patch_armt-1b-it-v2",
    #"unsloth/Llama-3.2-1B-Instruct/fast_executor_from_orig_paper_wo_copy_time_fix_v2_mem_patch_armt-1b-it-v2",
    "unsloth/Llama-3.2-1B-Instruct/fast_executor_from_orig_paper_time_fix_v2_mem_patch_armt-1b-it-v2",
    
]
disp_names = [
    #"Original-ARMT-Llama-3.2-1B-Instruct",
    #"Optimized-ARMT-Llama-3.2-1B-Instruct",
    #"Optimized-Fixed-ARMT-Llama-3.2-1B-Instruct",
    #"Optimized-HalfTrained-ARMT-Llama-3.2-1B-Instruct",

    "Original-ARMT-Llama-3.2-1B-Instruct - Time, sec.",
    #"Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec.",
    #"Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec. (without memory copy)",
    #"Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec. (without memory copy)",
    "Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec.",
]

all_dfs = []
for dat_idx, mod_name in enumerate(mod_names):
    #gkuzmin/rmt_it/babilong/babilong_evals/test_exps/gpt_2_lora_ift_flan_1024_seed_42/checkpoint-1000
    overall_results = {}
    #dname = ift_dataset_path.split("-")[0]
    disp_name = disp_names[dat_idx]
    if "-E" in disp_name:
        prompt_name = 'instruction_no_examples_yes_post_prompt_no_chat_template_no'
        all_names = [
            [f'./{mod_name}']]
    elif mod_name == "meta-llama/Llama-3.2-1B" or "data/pretrained_models" in mod_name:
        prompt_name = 'instruction_yes_examples_yes_post_prompt_yes_chat_template_no'
        all_names = [
            [f'./{mod_name}']]
    else:
        prompt_name = 'instruction_yes_examples_yes_post_prompt_yes_chat_template_yes'
    all_names = [
        [f'./{mod_name}']]
    names = all_names[0]
    for model_name in names:
        overall_results[disp_names[dat_idx]] = {}
        tasks = ["qa2"]# ['qa1', 'qa2', 'qa4', 'qa5']#, 'qa3', 'qa4', 'qa5']#, 'qa6', 'qa7', 'qa8', 'qa9', 'qa10']
        tasks = ['qa1', 'qa2']#, 'qa3', 'qa4', 'qa5']
        lengths = ['0k', '1k', '2k', '4k', '8k']#, '16k']#, '32k']#, '32k', '64k']
        lengths = ['2k', '4k', '8k', '16k', '32k', '64k']
        accuracy = np.zeros((len(tasks), len(lengths)))
        for j, task in enumerate(tasks):
            #overall_results[model_name][task] = {}
            for i, ctx_length in enumerate(lengths):
                if "lm_no_st" in model_name:
                    fname = f'./{results_folder_rmt}/{model_name}/{task}_{ctx_length}_{prompt_name}.csv'
                elif "mix" in model_name or "mergekit" in model_name:
                    fname = f'./{new_results_folder}/{model_name}/{task}_{ctx_length}_{prompt_name}.csv'
                else:
                    fname = f'./{results_folder}/{model_name}/time_{task}_{ctx_length}_{prompt_name}.json'
                if not os.path.isfile(fname):
                    print(f'No such file: {fname}')
                    continue
                with open(fname, "r") as f:
                    df = json.load(f)
                
                overall_results[disp_names[dat_idx]][task + "_" + ctx_length] = np.round(float(df[task][ctx_length]), 2)#float(np.round(100 * score / len(df) if len(df) > 0 else 0, 2))
                accuracy[j, i] = 100 * score / len(df) if len(df) > 0 else 0
        #print(accuracy)
    overall_results = {key.split("/")[-1]:value for key,value in overall_results.items()}
    mod_df = pd.DataFrame.from_dict(overall_results)
    all_dfs.append(mod_df)
full_df = pd.concat(all_dfs, axis=1)
full_df["Speedup"] = full_df["Original-ARMT-Llama-3.2-1B-Instruct - Time, sec."] / full_df["Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec."]
from matplotlib.colors import LinearSegmentedColormap
cmap = LinearSegmentedColormap.from_list('ryg', ["red", "yellow", "green"], N=256)
cmap = "BuGn"
# old cmap
#full_df.style.format(precision=2).background_gradient(axis=0, cmap='YlOrRd')
s = full_df.style
s.set_table_styles([
    {"selector": "tr", "props": "line-height: 15px;"},
    {"selector": "td,th", "props": "line-height: inherit; padding: 5px;"}
])
for l0 in ['qa1_64k', 'qa2_64k', 'qa3_16k', 'qa4_16k', 'qa5_16k']:
    s.set_table_styles({l0: [{'selector': '', 'props': 'border-bottom: 3px solid black;'}]}, overwrite=False, axis=1)
fres = s.format(precision=2)#.background_gradient(axis=None, cmap=cmap)
#fres = full_df.style.format(precision=2).background_gradient(axis=0, cmap=cmap)
from pandas import option_context

with option_context('display.max_colwidth', 20):
    display(fres)


Unnamed: 0,"Original-ARMT-Llama-3.2-1B-Instruct - Time, sec.","Optimized-ARMT-Llama-3.2-1B-Instruct - Time, sec.",Speedup
qa1_2k,13.43,15.06,0.89
qa1_4k,22.45,17.99,1.25
qa1_8k,41.41,22.49,1.84
qa1_16k,79.16,33.12,2.39
qa1_32k,153.68,54.2,2.84
qa1_64k,302.15,94.36,3.2
qa2_2k,13.08,14.93,0.88
qa2_4k,22.66,18.21,1.24
qa2_8k,41.66,22.7,1.84
qa2_16k,79.8,33.38,2.39
