In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import _crop_past_key_values
import torch

In [2]:
# Load the model and evaluate on LongBench
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

model_name = "deepseek-ai/deepseek-coder-1.3b-base"
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)


ds = load_dataset("tianyang/repobench_python_v1.1", split="cross_file_first")
ds = ds.select(list(range(1)))

Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}


In [3]:
from transformers.cache_utils import DynamicCache

def get_key_values_with_indices(model, past_key_values, indices):
    new_past = []
    if isinstance(past_key_values, DynamicCache):
        past_key_values.batch_select_indices(indices)
    elif past_key_values is not None:
        for idx in range(len(past_key_values)):
            new_past.append(
                (
                    past_key_values[idx][0][:, :, indices, :],
                    past_key_values[idx][1][:, :, indices, :],
                )
            )
        past_key_values = tuple(new_past)
    return past_key_values

In [4]:
def get_punctuation_tokens(tokenizer):
    punctuation_list = ['.', ',', ':', "\n", "\t", "!"]
    punctuation_tokens = [idx for (idx, token) in enumerate(tokenizer.get_vocab().items()) if any(punctuation in token for punctuation in punctuation_list)]
    return punctuation_tokens

In [7]:
def efficient_generate(model, tokenizer, input_ids, max_tokens=200, sink_length=8, end_length=64, window_stride=8, print_stream=False):
    # prefill stage: load all of the inputs (since prefilling is fast, this shouldn't take too much time)
    current_tokens = input_ids[0]
    outputs = model(
        input_ids,
        use_cache=True,
    )
    # print(outputs)
    
    # TODO: attach the generated token to the end of the current tokens
    past_key_values = outputs.past_key_values

    # build the mask
    mask = torch.zeros(current_tokens.shape[-1], dtype=torch.bool)
    mask[:sink_length] = 1
    mask[-end_length:] = 1
    mask[::window_stride] = 1

    punctuation_tokens = get_punctuation_tokens(tokenizer)
    punctuation_token_indices = [index for (index, token_id) in enumerate(current_tokens.tolist()) if token_id in punctuation_tokens]
    mask[punctuation_token_indices] = 1
    
    # save the relevant ones
    # print(mask)
    context_positions = torch.arange(len(current_tokens))[mask]
    context_tokens = current_tokens[mask]
    # print("context position: ", context_positions, "mask shape: ", mask.shape)
    # print("past key values: ", past_key_values[0][0].shape, past_key_values[0][1].shape)
    # past_key_values = key_values_indices(model, past_key_values, context_positions)
    past_key_values = get_key_values_with_indices(model, past_key_values, context_positions)

    new_logits = outputs.logits[:, -1:]
    selected_tokens = new_logits.argmax(dim=-1)[0]
    # print(current_tokens, selected_tokens)
    current_tokens = torch.cat((current_tokens, selected_tokens), dim=-1)

    # generate subsequent tokens
    while True:
        # check if the length is long enough, and if so, confirm whether or not the last token of the sliding window left will be staying on
            # if so, keep the KV cache the same
            # otherwise, prune the KV cache


        print("Shape before trim: ", "tokens: ", context_tokens.shape, " positions: ", context_positions.shape)
        if len(context_tokens) >= end_length:
            if not(
                (len(context_tokens) - end_length - 1) % window_stride == 0
                or
                (context_tokens[-(end_length + 1)] in punctuation_token_indices)
                or
                (len(context_tokens) - end_length - 1) < sink_length
            ):
                del_idx = len(context_tokens) - end_length - 1
                indices_without_del = torch.cat((torch.arange(0, del_idx), torch.arange(del_idx + 1, len(context_tokens))), dim=-1)
                past_key_values = get_key_values_with_indices(model, past_key_values, indices_without_del)
                context_tokens = torch.cat((context_tokens[:-del_idx], context_tokens[-(del_idx - 1):]))
                context_positions = torch.cat((context_positions[:-del_idx], context_positions[-(del_idx - 1):]))
        print("Shape after trim: ", "tokens: ", context_tokens.shape, " positions: ", context_positions.shape)
        
        # print(context_positions)

        print("Context tokens vs context positions shape: ", context_tokens.unsqueeze(0).shape, context_positions.unsqueeze(0).shape)
        
        outputs = model(
            context_tokens.unsqueeze(0),
            past_key_values=past_key_values,
            position_ids=context_positions.unsqueeze(0)
        )
            
        new_logits = outputs.logits[:, -1:][0]
        selected_tokens = new_logits.argmax(dim=-1)

        print("Selected tokens: ", selected_tokens)
        
        current_tokens = torch.cat((current_tokens, selected_tokens), dim=-1)
        context_tokens = torch.cat((context_tokens, selected_tokens), dim=-1)
        context_positions = torch.cat((context_positions, torch.Tensor([current_tokens.shape[-1]])), dim=-1)

        if print_stream:
            print(tokenizer.decode(current_tokens[-1]), end="")
        
        past_key_values = outputs.past_key_values
        
        # calculate the next tokens  
        if current_tokens.shape[-1] >= max_tokens:
            break
        if current_tokens[-1] == tokenizer.eos_token:
            break

    return tokenizer.decode(current_tokens)

In [8]:
for row in ds:
    encoded_inputs = tokenizer.encode(row['all_code'], return_tensors="pt").to(model.device)
    output = efficient_generate(
        model,
        tokenizer,
        encoded_inputs,
        max_tokens=encoded_inputs.shape[-1] + 50,
        print_stream=True
    )
    print(output)

Shape before trim:  tokens:  torch.Size([74])  positions:  torch.Size([74])
Shape after trim:  tokens:  torch.Size([73])  positions:  torch.Size([73])
Context tokens vs context positions shape:  torch.Size([1, 73]) torch.Size([1, 73])
Selected tokens:  tensor([185])

Shape before trim:  tokens:  torch.Size([74])  positions:  torch.Size([74])
Shape after trim:  tokens:  torch.Size([73])  positions:  torch.Size([73])
Context tokens vs context positions shape:  torch.Size([1, 73]) torch.Size([1, 73])
Selected tokens:  tensor([185])

Shape before trim:  tokens:  torch.Size([74])  positions:  torch.Size([74])
Shape after trim:  tokens:  torch.Size([73])  positions:  torch.Size([73])
Context tokens vs context positions shape:  torch.Size([1, 73]) torch.Size([1, 73])
Selected tokens:  tensor([185])

Shape before trim:  tokens:  torch.Size([74])  positions:  torch.Size([74])
Shape after trim:  tokens:  torch.Size([73])  positions:  torch.Size([73])
Context tokens vs context positions shape:  t

KeyboardInterrupt: 