In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import _crop_past_key_values
import torch

In [None]:
def key_values_indices(model, past_key_values, indices: torch.tensor):
    """Crops the past key values up to a certain maximum length."""
    new_past = []
    if model.config.is_encoder_decoder:
        for idx in range(len(past_key_values)):
            new_past.append(
                (
                    past_key_values[idx][0][:, :, indices, :],
                    past_key_values[idx][1][:, :, indices, :],
                    past_key_values[idx][2],
                    past_key_values[idx][3],
                )
            )
        past_key_values = tuple(new_past)
    # bloom is special
    elif "bloom" in model.__class__.__name__.lower() or (
        model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
    ):
        for idx in range(len(past_key_values)):
            new_past.append(
                (
                    past_key_values[idx][0][:, :, indices],
                    past_key_values[idx][1][:, indices, :],
                )
            )
        past_key_values = tuple(new_past)
    # gptbigcode is too
    elif "gptbigcode" in model.__class__.__name__.lower() or (
        model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
    ):
        if model.config.multi_query:
            for idx in range(len(past_key_values)):
                past_key_values[idx] = past_key_values[idx][:, indices, :]
        else:
            for idx in range(len(past_key_values)):
                past_key_values[idx] = past_key_values[idx][:, :, indices, :]
    elif isinstance(past_key_values, DynamicCache):
        past_key_values.crop(max_length)
    elif past_key_values is not None:
        for idx in range(len(past_key_values)):
            new_past.append(
                (
                    past_key_values[idx][0][:, :, indices, :],
                    past_key_values[idx][1][:, :, indices, :],
                )
            )
        past_key_values = tuple(new_past)
    return past_key_values

In [None]:
def efficient_generate(model, tokenizer, input_ids, max_tokens=200, sink_length=8, end_length=64, window_stride=8):
    # prefill stage: load all of the inputs (since prefilling is fast, this shouldn't take too much time)
    current_tokens = input_ids
    outputs = model(
        input_ids
    )
    
    # TODO: attach the generated token to the end of the current tokens
    past_key_values = outputs.past_key_values

    # build the mask
    mask = torch.zeros(current_tokens.shape[-1])
    mask[:sink_length] = 1
    mask[-end_length:] = 1
    mask[::window_stride] = 1

    input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    punctuation_chars = set(".!?;:") + set("\n")
    punctuation_token_idx = [input_id for input_id, token in zip(input_ids, tokens) if any(char in token in char in punctuation_chars)]
    punctuation_token_indices = [idx for idx, (input_id, token) in enumerate(zip(input_ids, tokens)) if any(char in token in char in punctuation_chars)]
    mask[punctuation_token_indices] = 1
    
    # save the relevant ones
    context_tokens = current_tokens[mask]
    past_key_values = key_values_indices(model, past_key_values, torch.nonzero(mask == 1).squeeze())

    # generate subsequent tokens
    while True:
        # check if the length is long enough, and if so, confirm whether or not the last token of the sliding window left will be staying on
            # if so, keep the KV cache the same
            # otherwise, prune the KV cache
        if len(context_tokens) >= end_length:
            if not(
                (len(context_tokens) - end_length - 1) % window_stride == 0
                or
                (context_tokens[-(end_length + 1)] in punctuation_token_idx)
                or
                (len(context_tokens) - end_length - 1) < sink_length
            ):
                del_idx = len(context_tokens) - end_length - 1
                past_key_values = key_values_indices(model, past_key_values, torch.Tensor([j for j in range(len(context_tokens)) if j != del_idx]))
                context_tokens = torch.cat((context_tokens[:-del_idx], context_tokens[-(del_idx - 1):]))
        outputs = model(
            context_tokens,
            past_key_values=past_key_values
        )
        
        
        # calculate the next tokens
        
        if current_tokens.shape[-1] >= max_tokens:
            break
        if current_tokens[-1] == tokenizer.eos_token:
            break

    return tokenizer.decode()

In [None]:
def generate_attention_mask():
    ...