In [1]:
import transformers
import torch
import time
import shutil
from tqdm import tqdm, trange
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# Load the model
ckpt = "gradientai/Llama-3-8B-Instruct-Gradient-1048k"
tokenizer = AutoTokenizer.from_pretrained(ckpt, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2",
).to("cuda")

generation_config = GenerationConfig.from_pretrained(ckpt)
eos_token_ids = generation_config.eos_token_id
if not isinstance(eos_token_ids, list):
    eos_token_ids = [eos_token_ids]

# add some tokens like "</user>" and </s> to eos ids
eos_token_ids += tokenizer.encode("</user>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</s>", add_special_tokens=False)
eos_token_ids += tokenizer.encode("</", add_special_tokens=False)

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
from duo_attn.utils import load_attn_pattern, sparsify_attention_heads
from duo_attn.patch import enable_duo_attention_eval

# Load the attention pattern
attn_heads, sink_size, recent_size = load_attn_pattern(
    "attn_patterns/Llama-3-8B-Instruct-Gradient-1048k/lr=0.02-reg=0.05-ctx=1000_32000-multi_passkey10"
)

print(attn_heads.shape)
print(sink_size)
print(recent_size)

# Sparsify attention heads
attn_heads, sparsity = sparsify_attention_heads(attn_heads, sparsity=0.5)

print(attn_heads, sparsity)


(32, 8)
128
256
[[0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 1. 0.]
 [0. 0. 0. 0. 0. 1. 1. 0.]
 [0. 0. 0. 0. 1. 1. 0. 0.]
 [0. 1. 1. 0. 1. 0. 0. 1.]
 [0. 0. 0. 0. 0. 1. 1. 0.]
 [1. 0. 0. 1. 0. 1. 1. 0.]
 [1. 1. 1. 0. 1. 0. 1. 1.]
 [0. 0. 0. 1. 0. 1. 1. 1.]
 [1. 0. 0. 1. 1. 0. 1. 1.]
 [1. 1. 0. 0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0. 1. 0. 0.]
 [1. 1. 1. 0. 1. 0. 1. 1.]
 [0. 1. 0. 0. 1. 1. 1. 1.]
 [1. 0. 1. 0. 1. 0. 1. 1.]
 [1. 1. 1. 0. 0. 1. 1. 0.]
 [1. 1. 0. 1. 0. 1. 1. 1.]
 [0. 1. 0. 1. 0. 1. 0. 0.]
 [1. 0. 1. 1. 0. 1. 0. 1.]
 [1. 0. 1. 1. 1. 1. 1. 0.]
 [1. 0. 1. 1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 1. 0. 0. 1.]
 [1. 1. 0. 1. 0. 1. 1. 1.]
 [1. 0. 0. 0. 1. 1. 1. 0.]
 [1. 1. 0. 1. 0. 1. 0. 0.]
 [1. 0. 0. 1. 0. 0. 0. 1.]
 [0. 1. 0. 0. 0. 1. 1. 1.]
 [1. 1. 0. 1. 1. 1. 1. 0.]
 [0. 0. 1. 0. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 0. 0. 0. 0. 1. 0. 1.]] 0.5


In [3]:
from transformers.models.llama.modeling_llama import (
    logger,
    apply_rotary_pos_emb,
    repeat_kv,
    LlamaSdpaAttention,
    LlamaFlashAttention2,
)
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.cache_utils import Cache, DynamicCache, StaticCache, OffloadedCache, OffloadedStaticCache
import types
from transformers.modeling_flash_attention_utils  import _flash_attention_forward

# def LlamaAttention_fast_forward(
#     self,
#     hidden_states: torch.Tensor,
#     attention_mask: Optional[torch.Tensor] = None,
#     position_ids: Optional[torch.LongTensor] = None,
#     past_key_value: Optional[Cache] = None,
#     output_attentions: bool = False,
#     use_cache: bool = False,
#     cache_position: Optional[torch.LongTensor] = None,
#     position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
#     **kwargs,
# ) :
#     # LlamaFlashAttention2 attention does not support output_attentions
#     output_attentions = False

#     bsz, q_len, _ = hidden_states.size()

#     query_states = self.q_proj(hidden_states)
#     key_states = self.k_proj(hidden_states)
#     value_states = self.v_proj(hidden_states)

#     # Flash attention requires the input to have the shape
#     # batch_size x seq_length x head_dim x hidden_dim
#     # therefore we just need to keep the original shape
#     query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
#     key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
#     value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

#     cos, sin = self.rotary_emb(value_states, position_ids)
    
#     query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

#     if past_key_value is not None:
#         # sin and cos are specific to RoPE models; cache_position needed for the static cache
#         cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
#         key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

#     # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
#     # to be able to avoid many of these transpose/reshape/view.
#     query_states = query_states.transpose(1, 2)
#     key_states = key_states.transpose(1, 2)
#     value_states = value_states.transpose(1, 2)

#     dropout_rate = self.attention_dropout if self.training else 0.0

#     # In PEFT, usually we cast the layer norms in float32 for training stability reasons
#     # therefore the input hidden states gets silently casted in float32. Hence, we need
#     # cast them back in the correct dtype just to be sure everything works as expected.
#     # This might slowdown training & inference so it is recommended to not cast the LayerNorms
#     # in fp32. (LlamaRMSNorm handles it correctly)
#     input_dtype = query_states.dtype
#     if input_dtype == torch.float32:
#         query_states = query_states.to(target_dtype)
#         key_states = key_states.to(target_dtype)
#         value_states = value_states.to(target_dtype)

#     attn_output = _flash_attention_forward(
#         query_states,
#         key_states,
#         value_states,
#         attention_mask,
#         q_len,
#         position_ids=position_ids,
#         dropout=dropout_rate,
#         sliding_window=getattr(self, "sliding_window", None),
#         use_top_left_mask=self._flash_attn_uses_top_left_mask,
#         is_causal=self.is_causal,
#     )

#     attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
#     attn_output = self.o_proj(attn_output)

#     if not output_attentions:
#         attn_weights = None

#     return attn_output, attn_weights, past_key_value

# def enable_duo_attention_eval(model, full_attention_heads):
    
#     device = next(model.parameters()).device
#     dtype = next(model.parameters()).dtype
#     for idx, layer in enumerate(model.model.layers):
#         module = layer.self_attn
#         layer_full_attention_heads = torch.tensor(
#             full_attention_heads[idx], device=device, dtype=dtype
#         )
#         module.full_attn_head_mask = layer_full_attention_heads > 0.5
#         module.num_full_attn_head = module.full_attn_head_mask.sum().item()
#         module.num_streaming_attn_head = (module.num_key_value_heads - module.num_full_attn_head)
#         print(layer_full_attention_heads)
#         module.forward = types.MethodType(
#             LlamaAttention_fast_forward, module
#         )

# enable_duo_attention_eval(model, attn_heads)

In [4]:

def LlamaAttention_fast_forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    cache_position: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
    **kwargs,
) :
    output_attentions = False

    bsz, q_len, hd = hidden_states.size()
    chunk_size = hd // self.num_key_value_heads
    num_heads = self.num_heads // self.num_key_value_heads

    if not hasattr(self, 'q_proj_list'):
        self.q_proj_list = list((self.q_proj.weight.split(self.head_dim * num_heads, dim=0)))
        # self.q_proj.weight.data.storage().resize_(0)
    if not hasattr(self, 'k_proj_list'):
        self.k_proj_list = list((self.k_proj.weight.split(self.head_dim, dim=0)))
        # self.k_proj.weight.data.storage().resize_(0)
    if not hasattr(self, 'v_proj_list'):
        self.v_proj_list = list((self.v_proj.weight.split(self.head_dim, dim=0)))
        # self.v_proj.weight.data.storage().resize_(0)


    attn_output_list = [None for _ in range((self.num_key_value_heads))]
    
    for i in range(self.num_key_value_heads):
        bsz, q_len, hd = hidden_states.size()

        self.q_proj.weight.data = self.q_proj_list[i].data
        self.k_proj.weight.data = self.k_proj_list[i].data
        self.v_proj.weight.data = self.v_proj_list[i].data

        # print(hidden_states.shape, self.q_proj.weight.shape)
        
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Flash attention requires the input to have the shape
        # batch_size x seq_length x head_dim x hidden_dim
        # therefore we just need to keep the original shape
        query_states = query_states.view(bsz, q_len, num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, 1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, 1, self.head_dim).transpose(1, 2)
    
        if position_embeddings is None:
            logger.warning_once(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
    
        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, 'full_head': self.full_attn_head_mask[i]}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx + i, cache_kwargs)
    
        # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
        # to be able to avoid many of these transpose/reshape/view.
        query_states = query_states.transpose(1, 2)
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)
    
        dropout_rate = self.attention_dropout if self.training else 0.0
    
        # In PEFT, usually we cast the layer norms in float32 for training stability reasons
        # therefore the input hidden states gets silently casted in float32. Hence, we need
        # cast them back in the correct dtype just to be sure everything works as expected.
        # This might slowdown training & inference so it is recommended to not cast the LayerNorms
        # in fp32. (LlamaRMSNorm handles it correctly)
    
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            # Handle the case where the model is quantized
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype
    
            logger.warning_once(
                f"The input hidden states seems to be silently casted in float32, this might be related to"
                f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
                f" {target_dtype}."
            )
    
            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)
    
        attn_output = _flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            position_ids=position_ids,
            dropout=dropout_rate,
            sliding_window=getattr(self, "sliding_window", None),
            use_top_left_mask=self._flash_attn_uses_top_left_mask,
            is_causal=self.is_causal,
        )
    
        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
        attn_output_list[i] = attn_output
        
    attn_output = torch.cat(attn_output_list, dim=-1)
    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

full_attention_heads = attn_heads
LlamaFlashAttention2.forward = LlamaAttention_fast_forward
layer_idx = 0
for idx, layer in enumerate(model.model.layers):
    device = next(model.parameters()).device
    dtype = next(model.parameters()).dtype
    module = layer.self_attn
    module.layer_idx = layer_idx
    layer_idx += module.num_key_value_heads

    # print(full_attention_heads)
    # layer_full_attention_heads = torch.tensor(
    #     full_attention_heads[idx], device=device, dtype=dtype
    # )
    module.full_attn_head_mask = full_attention_heads[idx] > 0.5
    module.num_full_attn_head = module.full_attn_head_mask.sum().item()
    module.num_streaming_attn_head = (module.num_key_value_heads - module.num_full_attn_head)
    print(module.full_attn_head_mask)


[False False  True False False False False False]
[False False False  True False False False False]
[False False False False False  True  True False]
[False False False False False  True  True False]
[False False False False  True  True False False]
[False  True  True False  True False False  True]
[False False False False False  True  True False]
[ True False False  True False  True  True False]
[ True  True  True False  True False  True  True]
[False False False  True False  True  True  True]
[ True False False  True  True False  True  True]
[ True  True False False False False False  True]
[ True False False False False  True False False]
[ True  True  True False  True False  True  True]
[False  True False False  True  True  True  True]
[ True False  True False  True False  True  True]
[ True  True  True False False  True  True False]
[ True  True False  True False  True  True  True]
[False  True False  True False  True False False]
[ True False  True  True False  True False  True]


In [5]:

class OffloadedCache(DynamicCache):
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.

    Example:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

        >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
        >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

        >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> past_key_values = DynamicCache()
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        DynamicCache()
        ```
    """

    def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
        if not torch.cuda.is_available():
            raise RuntimeError("OffloadedCache can only be used with a GPU")
        super().__init__()
        self.original_device = []
        self.prefetch_stream = torch.cuda.Stream()
        self.beam_idx = None  # used to delay beam search operations

        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        
        self.offload_key_cache: List[torch.Tensor] = []
        self.offload_value_cache: List[torch.Tensor] = []

        self.id_type_list = []
        self.real_id_dict = {}

    def prefetch_layer(self, layer_idx: int):
        "Starts prefetching the next layer cache"
        if layer_idx < len(self.offload_key_cache):
            with torch.cuda.stream(self.prefetch_stream):
                # Prefetch next layer tensors to GPU
                device = self.original_device[layer_idx]
                self.offload_key_cache[layer_idx] = self.offload_key_cache[layer_idx].to(device, non_blocking=True)
                self.offload_value_cache[layer_idx] = self.offload_value_cache[layer_idx].to(device, non_blocking=True)
    
    def evict_previous_layer(self, layer_idx: int):
        "Moves the previous layer cache to the CPU"
        if len(self.offload_key_cache) > 2:
            # We do it on the default stream so it occurs after all earlier computations on these tensors are done
            prev_layer_idx = (layer_idx - 1) % len(self.offload_key_cache)
            self.offload_key_cache[prev_layer_idx] = self.offload_key_cache[prev_layer_idx].to("cpu", non_blocking=True)
            self.offload_value_cache[prev_layer_idx] = self.offload_value_cache[prev_layer_idx].to("cpu", non_blocking=True)
            
    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        """
        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
        sequence length.
        """
        if layer_idx < len(self.id_type_list):
            if self.id_type_list[layer_idx] == True:
                layer_idx = self.real_id_dict[layer_idx]
                return (self.key_cache[layer_idx], self.value_cache[layer_idx])
            else:
                layer_idx = self.real_id_dict[layer_idx]
                # Evict the previous layer if necessary
                torch.cuda.current_stream().synchronize()
                self.evict_previous_layer(layer_idx)
                # Load current layer cache to its original device if not already there
                original_device = self.original_device[layer_idx]
                self.prefetch_stream.synchronize()
                key_tensor = self.offload_key_cache[layer_idx]
                value_tensor = self.offload_value_cache[layer_idx]
                # Now deal with beam search ops which were delayed
                if self.beam_idx is not None:
                    self.beam_idx = self.beam_idx.to(original_device)
                    key_tensor = key_tensor.index_select(0, self.beam_idx)
                    value_tensor = value_tensor.index_select(0, self.beam_idx)
                # Prefetch the next layer
                self.prefetch_layer((layer_idx + 1) % len(self))
                return (key_tensor, value_tensor)
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def reorder_cache(self, beam_idx: torch.LongTensor):
        """Saves the beam indices and reorders the cache when the tensor is back to its device."""
        # We delay this operation until the tensors are back to their original
        # device because performing torch.index_select on the CPU is very slow
        del self.beam_idx
        self.beam_idx = beam_idx.clone()

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        is_stream_head = (cache_kwargs['full_head'] == False)
        if len(self.id_type_list) <= layer_idx:
            self.id_type_list.append(is_stream_head)

        original_layer_layer_idx = layer_idx

        # print(layer_idx, cache_kwargs['full_head'])

        if is_stream_head:
            if layer_idx in self.real_id_dict:
                layer_idx = self.real_id_dict[layer_idx]
            else:
                self.real_id_dict[layer_idx] = len(self.key_cache)
                layer_idx = self.real_id_dict[layer_idx]
            if len(self.key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(self.key_cache), layer_idx):
                    self.key_cache.append([])
                    self.value_cache.append([])
                self.key_cache.append(key_states)
                self.value_cache.append(value_states)
            elif len(self.key_cache[layer_idx]) == 0:  # fills previously skipped layers; checking for tensor causes errors
                self.key_cache[layer_idx] = key_states
                self.value_cache[layer_idx] = value_states
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
                
            key_tensor, value_tensor = self.key_cache[layer_idx], self.value_cache[layer_idx]
            incoming_kv_seq_len = self.key_cache[layer_idx].shape[2]
            if incoming_kv_seq_len > self.sink_size + self.recent_size:
                sink_key_states = self.key_cache[layer_idx][:, :, : self.sink_size, :]
                recent_key_states = self.key_cache[layer_idx][
                    :, :, incoming_kv_seq_len - self.recent_size : incoming_kv_seq_len, :
                ]
                self.key_cache[layer_idx] = torch.cat([sink_key_states, recent_key_states], dim=-2)

                sink_value_states = self.value_cache[layer_idx][:, :, : self.sink_size, :]
                recent_value_states = self.value_cache[layer_idx][
                    :, :, incoming_kv_seq_len - self.recent_size : incoming_kv_seq_len, :
                ]
                self.value_cache[layer_idx] = torch.cat([sink_value_states, recent_value_states], dim=-2)
        else:
            if layer_idx in self.real_id_dict:
                layer_idx = self.real_id_dict[layer_idx]
            else:
                self.real_id_dict[layer_idx] = len(self.offload_key_cache)
                layer_idx = self.real_id_dict[layer_idx]
                
            if len(self.offload_key_cache) < layer_idx:
                raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
            elif len(self.offload_key_cache) == layer_idx:
                self.offload_key_cache.append(key_states)
                self.offload_value_cache.append(value_states)
                self.original_device.append(key_states.device)
                self.evict_previous_layer(layer_idx)
                key_tensor, value_tensor = key_states, value_states
            else:
                key_tensor, value_tensor = self[original_layer_layer_idx]
                self.offload_key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
                self.offload_value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
                key_tensor = self.offload_key_cache[layer_idx]
                value_tensor = self.offload_value_cache[layer_idx]

        return key_tensor, value_tensor
        

In [6]:

# Initialize past_key_values to None
past_key_values = OffloadedCache()
past_key_values.sink_size = 64
past_key_values.recent_size = 256
# Manually perform inference using KV cache

inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)
max_new_tokens = 23
generated_tokens = []
input_ids = inputs["input_ids"]

for _ in range(max_new_tokens):
    with torch.no_grad():
        outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
        
        # Extract the logits and past_key_values (the cache)
        next_token_logits = outputs.logits[:, -1, :]  # Logits of the last token
        past_key_values = outputs.past_key_values  # KV cache to be reused in the next step

        # Greedy decoding: get the token with the highest probability
        next_token = torch.argmax(next_token_logits, dim=-1)
        generated_tokens.append(next_token.item())

        # Only pass the new token for the next iteration
        input_ids = next_token.unsqueeze(-1)

# Convert generated token ids to text
output_text = "Fun fact: The shortest" + tokenizer.decode(generated_tokens, skip_special_tokens=True)
print(output_text)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Fun fact: The shortest war in history was between Zanzibar and Great Britain on August 27, 1896. Zanzibar
