In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import _crop_past_key_values
import torch

In [2]:
# Load the model and evaluate on LongBench
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

model_name = "deepseek-ai/deepseek-coder-1.3b-base"
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)


ds = load_dataset("tianyang/repobench_python_v1.1", split="cross_file_first")
ds = ds.select(list(range(1)))

Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}


In [3]:
from transformers.cache_utils import DynamicCache

def get_key_values_with_indices(model, past_key_values, indices):
    new_past = []
    if isinstance(past_key_values, DynamicCache):
        past_key_values.batch_select_indices(indices)
    elif past_key_values is not None:
        for idx in range(len(past_key_values)):
            new_past.append(
                (
                    past_key_values[idx][0][:, :, indices, :],
                    past_key_values[idx][1][:, :, indices, :],
                )
            )
        past_key_values = tuple(new_past)
    return past_key_values

In [4]:
def get_punctuation_tokens(tokenizer):
    punctuation_list = ['.', ',', ':', "\n", "\t", "!"]
    punctuation_tokens = [idx for (idx, token) in enumerate(tokenizer.get_vocab().items()) if any(punctuation in token for punctuation in punctuation_list)]
    return punctuation_tokens

In [69]:
from transformers.generation.utils import _crop_past_key_values

def efficient_generate(model, tokenizer, input_ids, max_tokens=200, sink_length=8, end_length=32, window_stride=8, print_stream=False):
    model.eval()
    current_tokens = input_ids
    position_ids = torch.arange(input_ids.shape[-1]).unsqueeze(0)
    
    first_output = model(
        input_ids,
        position_ids=position_ids,
        use_cache=True
    )

    past_key_values = first_output.past_key_values
    selected_tokens = first_output.logits.argmax(dim=-1)
    print("Selected tokens: ", selected_tokens)

    current_tokens = torch.cat((current_tokens, selected_tokens[:, -1:]), dim=-1)

    position_ids_add = torch.Tensor(position_ids[0, -1] + 1).unsqueeze(0).unsqueeze(0)
    position_ids = torch.cat((position_ids, position_ids_add), dim=-1)

    mask = torch.zeros(current_tokens.shape[-1], dtype=torch.bool)
    mask[:sink_length] = 1
    mask[-end_length:] = 1
    mask[::window_stride] = 1

    print("Mask: ", mask)

    print("Position ids shape before: ", position_ids.shape)
    position_ids = position_ids[:, mask]
    print("Position ids shape: ", position_ids.shape)
    
    past_key_values = get_key_values_with_indices(model, past_key_values, position_ids[0, :-1])
    print("Context tokens shape before: ", current_tokens.shape)
    context_tokens = current_tokens[:, mask]
    print("Context tokens shape after: ", context_tokens.shape)
    # context_tokens = current_tokens
    
    while True:
        print("Current tokens shape: ", context_tokens.shape, " Position ids: ", position_ids.shape, " Key value shapes: ", past_key_values[0][0].shape)

        if position_ids.shape[-1] > end_length + sink_length and not(position_ids[0, -end_length] % window_stride == 0):
            # remove this ending value
            current_ending_index = position_ids.shape[-1] - end_length - 1
            new_position_indices = torch.tensor([i for i in range(position_ids.shape[-1]) if not(i == current_ending_index)], dtype=torch.int)

            # past_key_values = get_key_values_with_indices(model, past_key_values, new_position_indices[:-1])
            # past_key_values = _crop_past_key_values(model, past_key_values, 0)
            context_tokens = context_tokens[:, new_position_indices]
            position_ids = position_ids[:, new_position_indices]
        
        past_key_values = _crop_past_key_values(model, past_key_values, past_key_values[0][0].shape[2] - 1)
        print("Past key values shape before passing into output: ", past_key_values[0][0].shape)
        
        output = model(
            # current_tokens,
            context_tokens,
            position_ids=position_ids,
            # past_key_values=past_key_values,
            # use_cache=True
        )
        
        past_key_values = output.past_key_values
        print("Generated past key values shape: ", output.past_key_values[0][0].shape)
        selected_tokens = output.logits.argmax(dim=-1)
        context_tokens = torch.cat((context_tokens, selected_tokens[:, -1:]), dim=-1)
        current_tokens = torch.cat((current_tokens, selected_tokens[:, -1:]), dim=-1)

        position_ids_add = torch.Tensor(position_ids[0, -1] + 1).unsqueeze(0).unsqueeze(0)
        position_ids = torch.cat((position_ids, position_ids_add), dim=-1)

        print(tokenizer.batch_decode(current_tokens))
        
        if current_tokens.shape[-1] > max_tokens or current_tokens[0, -1] == tokenizer.eos_token:
            break

In [70]:
for row in ds:
    encoded_inputs = tokenizer.encode(row['all_code'], return_tensors="pt").to(model.device)

    # output = model.generate(encoded_inputs, max_new_tokens=50, use_cache=True)
    # print(tokenizer.batch_decode(output))

    output = efficient_generate(
        model,
        tokenizer,
        encoded_inputs,
        max_tokens=encoded_inputs.shape[-1] + 50,
        print_stream=True
    )
    print(output)

Selected tokens:  tensor([[ 6422,   185,   185,   334,    66,     8,   207,    17,    15,    16,
            17,    11,  6580,  6971,    13,   685,    11,  2412,    13,   185,
          2418,  6724, 14663,    13,   185, 13244,    35,    55,    12, 31841,
            12, 21171,    25,   380,  6593,    12,    18,    12,  1982,  1029,
           185,  1487,  2192, 11732,  2422,    11,  1016,   254,   412,  2530,
         22451,  1753, 13063,    38,   262,  1753,   279,   254, 31124,  4330,
           409,  6486,  1615, 25510,  1079,    13,  2156,    14,   807,  7814,
            14,    33,  6593,    12,    18,    12,  1982,  1029,   185, 23984,
           185,   185,  1892,  1892,  1892]])
Mask:  tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True, False,
        False, False, False, False, False, False,  True, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False,  True, False, False, False, False, False,