Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with Progressive Generation Using inputs_embeds and past_key_values #35707

Closed
2 of 4 tasks
Superbooming opened this issue Jan 15, 2025 · 17 comments
Closed
2 of 4 tasks

Comments

@Superbooming
Copy link

Superbooming commented Jan 15, 2025

System Info

  • transformers version: 4.46.3
  • Platform: Linux-6.8.0-48-generic-x86_64-with-glibc2.17
  • Python version: 3.8.20
  • Huggingface_hub version: 0.26.1
  • Safetensors version: 0.4.5
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA RTX A6000

Who can help?

@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I am currently rewriting the generate_progressively function for my custom model class. My goal is to enable the model to generate results progressively by concatenating the initial input_ids with each element of the compress_outputs sequence in turn. Specifically:

  1. In the first iteration, the model generates results by concatenating input_ids with the first element of compress_outputs.
  2. In the second iteration, it concatenates input_ids with the first and second elements of compress_outputs (the first two elements) to generate results.
  3. This process continues until the last element of the compress_outputs sequence is included.

To improve efficiency, I want to leverage caching, as the majority of the concatenated input in each iteration has already been used to compute past_key_values. Below is the code snippet for the function I implemented. In this context, self.model refers to mistral-7b-chat-v0.2.

@torch.no_grad()
    def generate_progressively(
            self,
            input_ids,
            attention_mask,
            compress_outputs,
            **kwargs,
    ):
        results = []
        compress_output_count = compress_outputs.size(1)
        batch_size = input_ids.size(0)

        inputs_embs = self.base.model.embed_tokens(input_ids)
        prompt_cache = DynamicCache()
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            use_cache=True,
            past_key_values=prompt_cache,
        )
        prompt_cache = outputs.past_key_values

        for compress_ind in range(compress_output_count):
            current_compress_outputs = compress_outputs[:, compress_ind: compress_ind+1, :].type_as(input_ids)
            outputs = self.model(
                input_ids=None,
                inputs_embeds=current_compress_outputs,
                use_cache=True,
                past_key_values=prompt_cache,
            )
            prompt_cache = outputs.past_key_values

            inputs_embs = torch.cat([inputs_embs, current_compress_outputs], dim=1)
            attention_mask = torch.cat([attention_mask, torch.ones(batch_size, 1, device=input_ids.device)], dim=1)

            generated_outputs = self.base.generate(
                inputs_embeds=inputs_embs,
                attention_mask=attention_mask,
                use_cache=True,
                past_key_values=prompt_cache,
                return_dict_in_generate=True,
                **kwargs,
            )
            results.append(generated_outputs.sequences)
        return results

When I execute this code, the program throws an error during execution. The error occurs at line 393 in transformers/generation/utils.py, specifically in the prepare_inputs_for_generation function.
The problematic line of code is:

if inputs_embeds is not None and cache_position[0] == 0:

The error message is: IndexError: index 0 is out of bounds for dimension 0 with size 0.

I track the excution of the code and here’s a detailed breakdown of the issue:
The error occurs in transformers/generation/utils.py. Initially, the program enters the self._sample function and then proceeds to the self._get_initial_cache_position function.
Within this function, the following line:

if not is_torchdynamo_compiling():
    cache_position = cache_position[past_length:]

causes the correct cache_position slice to become empty, resulting in an IndexError in subsequent steps.

Even if I manage to fix the issue with cache_position, another problem arises later in the self.prepare_inputs_for_generation function.
The relevant code is as follows:

if not self.config.is_encoder_decoder:
    if inputs_embeds is not None and cache_position[0] == 0:
        model_inputs[input_ids_key] = None
        model_inputs["inputs_embeds"] = inputs_embeds
    else:
        model_inputs[input_ids_key] = input_ids.clone(memory_format=torch.contiguous_format)
        model_inputs["inputs_embeds"] = None

In my case, I provide only inputs_embeds and past_key_values, and since cache_position[0] is not 0, the code attempts to set model_inputs[input_ids_key] using input_ids. However, since input_ids is None, this results in further issues.

Under the current implementation of the generate function in transformers, is it possible to use only inputs_embeds and past_key_values for generation? How can I modify my implementation to achieve progressive generation with caching as intended? Are there specific guidelines for correctly managing cache_position and ensuring compatibility with inputs_embeds?

Expected behavior

My primary objective is to progressively generate outputs by leveraging caching (past_key_values) to improve efficiency.

@zucchini-nlp
Copy link
Member

Seems to be same as #34678 and someone is working on it, as per the last comment

@Superbooming
Copy link
Author

Yes, it seems to be the same. I'll keep track on it. Thanks!

@haixuanTao
Copy link

I get the same error: IndexError: index 0 is out of bounds for dimension 0 with size 0 with Qwen/Qwen2.5-VL-3B-Instruct and #35890 currently does not solve this issue.

@zucchini-nlp
Copy link
Member

@haixuanTao yes, I believe for qwen2-vl and some more models the proposed change has to be applied in overriden code part as well

@yaswanth19
Copy link
Contributor

@haixuanTao Try it now - It should work!! 🤗

@haixuanTao
Copy link

haixuanTao commented Feb 5, 2025

Hey thanks for the quick update!

Using your current branch, I get:

  File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/bin/dora-qwen2-5-vl", line 8, in <module>
    sys.exit(main())
  File "/Users/xaviertao/Documents/work/dora/node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py", line 233, in main
    response, history, past_key_values = generate(
  File "/Users/xaviertao/Documents/work/dora/node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py", line 109, in generate
    outputs = model.generate(
  File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2228, in generate
    result = self._sample(
  File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 3202, in _sample
    model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
  File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1882, in prepare_inputs_for_generation
    or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1])  # Exception 3
IndexError: index -1 is out of bounds for dimension 0 with size 0

Maybe I'm missing something in my dev env?

@qinggangwu
Copy link

Hey thanks for the quick update!

Using your current branch, I get:

File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/bin/dora-qwen2-5-vl", line 8, in
sys.exit(main())
File "/Users/xaviertao/Documents/work/dora/node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py", line 233, in main
response, history, past_key_values = generate(
File "/Users/xaviertao/Documents/work/dora/node-hub/dora-qwen2-5-vl/dora_qwen2_5_vl/main.py", line 109, in generate
outputs = model.generate(
File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2228, in generate
result = self._sample(
File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 3202, in _sample
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1882, in prepare_inputs_for_generation
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
IndexError: index -1 is out of bounds for dimension 0 with size 0
Maybe I'm missing something in my dev env?

I encountered the same problem. The root cause is that the chat_template of Qwen2-VL and Qwen2-VL-Instruct is different, which causes the input_ids to be empty before inference.
So the solution is to replace the chat_template.json of Qwen2-VL with the chat_template.json of Instruct. It solved my problem.

@haixuanTao
Copy link

haixuanTao commented Feb 6, 2025

So agreed that this seems unrelated to the current issue, but I could not find how to change chat_template.

I tried:

from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, StaticCache

DEFAULT_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"


model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    DEFAULT_PATH, torch_dtype="auto", device_map="auto"
)

frames = ["https://teachlikeachampion.org/wp-content/uploads/good2.jpg"]
processor = AutoProcessor.from_pretrained(DEFAULT_PATH)
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": image,
            }
            for image in frames
        ]
        + [
            {"type": "text", "text": "abc"},
        ],
    },
]
tmp_history = messages
# Preparation for inference
chat_template = "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
text = processor.apply_chat_template(
    tmp_history,
    tokenize=False,
    add_generation_prompt=True,
    chat_template=chat_template,
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to(model.device)

past_key_values = None

outputs = model.generate(
    **inputs,
    max_new_tokens=128,
    past_key_values=past_key_values,
    return_dict_in_generate=True,
)

past_key_values = outputs.past_key_values

## Error should happen below
outputs = model.generate(
    **inputs,
    max_new_tokens=128,
    past_key_values=past_key_values,
    return_dict_in_generate=True,
)

## Error
#   File "/Users/xaviertao/Documents/work/dora/examples/vlm/.venv/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1882, in prepare_inputs_for_generation
#    or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1])  # Exception 3
# IndexError: index -1 is out of bounds for dimension 0 with size 0

Without much luck

@ArthurZucker
Copy link
Collaborator

I think you need to pass cache_position to indicate where you are in the generation:

## Error should happen below
outputs = model.generate(
    **inputs,
    cache_position=torch.tensor([past_key_values.get_seq_len()], dtype=torch.long, device=model.device),
    max_new_tokens=128,
    past_key_values=past_key_values,
    return_dict_in_generate=True,
)

I admit that it's really not optimal 😿

@haixuanTao
Copy link

haixuanTao commented Feb 10, 2025

I tried:

## Error should happen below
outputs = model.generate(
    **inputs,
    cache_position=torch.tensor([past_key_values.get_seq_length()], dtype=torch.long, device=model.device),
    max_new_tokens=128,
    past_key_values=past_key_values,
    return_dict_in_generate=True,
)

and it didn't work.

I tried replacing the cache_position within the problematic function and I got the following result:

So it seems that cache_position within prepare_inputs_for_generation is sometimes equal to empty list which is different from torch.tensor([past_key_values.get_seq_len()], dtype=torch.long, device=model.device):

  • At the very beginning which seems to be something like:
cache_position= tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19], device='mps:0')
cache_position_from_arthur= tensor([0], device='mps:0')
  • at the very end where:
cache_position tensor([], device='mps:0', dtype=torch.int64)
cache_position_from_arthur tensor([46], device='mps:0')
cache_position tensor([], device='mps:0', dtype=torch.int64)
cache_position_from_arthur tensor([47], device='mps:0')
cache_position tensor([], device='mps:0', dtype=torch.int64) 
cache_position_from_arthur tensor([48], device='mps:0')

Empty list generating the initial error.

def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        pixel_values=None,
        pixel_values_videos=None,
        image_grid_thw=None,
        video_grid_thw=None,
        second_per_grid_ts=None,
        **kwargs,
    ):
        # Overwritten -- in specific circumstances we don't want to forward image inputs to the model

        # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
        # Exception 1: when passing input_embeds, input_ids may be missing entries
        # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
        # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
        #              (we can't check exception 3 while compiling)
        # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
        # generate the first token for each sequence. Later use the generated Input ids for continuation.
        cache_position_from_arthur=torch.tensor([past_key_values.get_seq_length()], dtype=torch.long, device=self.device)
        cache_position = cache_position if cache_position is not None and len(cache_position) > 0 else cache_position_from_arthur

Using your Arthur I couldn't figure out how to fix this issue sadly and the above code I couldn't find a combination that worked for me.

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Feb 10, 2025

@haixuanTao the issue was fixed in the linked PR. The code you provided is using the same inputs for continuing generation, while the new input ids should be a concatenation of "initial prompt + the generated text"

Adding this at the end will work

outputs = model.generate(
    outputs.sequences, # This line uses the whole input sequence
    max_new_tokens=128,
    past_key_values=past_key_values,
    return_dict_in_generate=True,
)

@ArthurZucker
Copy link
Collaborator

we might need to update our api to be a bit more friendly no? 🤗

@zucchini-nlp
Copy link
Member

@ArthurZucker yeah, we have it documented tbh. This question was raised several times and stems from the fact that model.forward() and model.generate() expect different input shapes when caching (due to legacy reasons). With @gante we decided to not open that can yet, but passing only unprocessed tokens makes the most sense

@haixuanTao
Copy link

haixuanTao commented Feb 19, 2025

So, I got the inference to work using the code you provided but the inference time seems to be longer when I try to redo the same prompt multiple times with the same KV cache when I would expect it to be the same time.

I haven't really checked every knuck and buckles of why but here's my code:

import copy
import time

import torch
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, StaticCache

DEFAULT_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"


model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    DEFAULT_PATH, torch_dtype="auto", device_map="auto"
)

frames = ["https://teachlikeachampion.org/wp-content/uploads/good2.jpg"]
processor = AutoProcessor.from_pretrained(DEFAULT_PATH)
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": image,
            }
            for image in frames
        ]
        + [
            {"type": "text", "text": "abc"},
        ],
    },
]
tmp_history = messages
# Preparation for inference
text = processor.apply_chat_template(
    tmp_history,
    tokenize=False,
    add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to(model.device)

past_key_values = None

now = time.time()
outputs = model.generate(
    **inputs,
    max_new_tokens=128,
    past_key_values=past_key_values,
    return_dict_in_generate=True,
)
init_outputs = outputs
sequence = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
print(sequence)
print(time.time() - now)

now = time.time()

past_key_values = outputs.past_key_values


def redo_inference():
    frames = [
        "https://www.hellyhansen.com/media/catalog/product/5/3/53851_787-1-onbody13.jpg?quality=90&bg-color=255,255,255&fit=bounds&height=&width="
    ]

    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                }
                for image in frames
            ]
            + [
                {"type": "text", "text": "Is there a hoodie?"},
            ],
        },
    ]
    tmp_history = messages
    # Preparation for inference
    text = processor.apply_chat_template(
        tmp_history,
        tokenize=False,
        add_generation_prompt=True,
    )

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(model.device)

    kv = copy.deepcopy(past_key_values)
    tmp = copy.deepcopy(outputs.sequences)
    moutputs = model.generate(
        torch.concatenate(
            [tmp, inputs["input_ids"]], dim=1
        ),  # This line uses the whole input sequence
        max_new_tokens=128,
        past_key_values=kv,
        return_dict_in_generate=True,
    )

    sequence = processor.batch_decode(moutputs.sequences, skip_special_tokens=True)[0]
    print(sequence)
    print(time.time() - now)


redo_inference()
redo_inference()
redo_inference()

Outputs:

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.85s/it]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
system
You are a helpful assistant.
user
abc
assistant
The image you provided is an emoji with a thumbs-up gesture, which typically represents approval or agreement. The emoji has a cheerful expression with wide eyes and a big smile. This type of emoji is often used to convey positive feedback or to indicate that something is good or acceptable.
6.706953048706055
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
418 3051 36
system
You are a helpful assistant.
user
abc
assistant
The image you provided is an emoji with a thumbs-up gesture, which typically represents approval or agreement. The emoji has a cheerful expression with wide eyes and a big smile. This type of emoji is often used to convey positive feedback or to indicate that something is good or acceptable.system
You are a helpful assistant.
user
Is there a hoodie?
assistant
No, there is no hoodie in the image. The image shows a cartoon character giving a thumbs-up gesture.
8.745131969451904
418 3051 36
system
You are a helpful assistant.
user
abc
assistant
The image you provided is an emoji with a thumbs-up gesture, which typically represents approval or agreement. The emoji has a cheerful expression with wide eyes and a big smile. This type of emoji is often used to convey positive feedback or to indicate that something is good or acceptable.system
You are a helpful assistant.
user
Is there a hoodie?
assistant
No, there is no hoodie in the image. The image shows a cartoon character giving a thumbs-up gesture.
16.76447105407715
418 3051 36
system
You are a helpful assistant.
user
abc
assistant
The image you provided is an emoji with a thumbs-up gesture, which typically represents approval or agreement. The emoji has a cheerful expression with wide eyes and a big smile. This type of emoji is often used to convey positive feedback or to indicate that something is good or acceptable.system
You are a helpful assistant.
user
Is there a hoodie?
assistant
No, there is no hoodie in the image. The image shows a cartoon character giving a thumbs-up gesture.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.85s/it]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
system
You are a helpful assistant.
user
abc
assistant
The image you provided is an emoji with a thumbs-up gesture, which typically represents approval or agreement. The emoji has a cheerful expression with wide eyes and a big smile. This type of emoji is often used to convey positive feedback or to indicate that something is good or acceptable.
6.706953048706055 # <---------------------- Inference time without KV cache
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
418 3051 36
system
You are a helpful assistant.
user
abc
assistant
The image you provided is an emoji with a thumbs-up gesture, which typically represents approval or agreement. The emoji has a cheerful expression with wide eyes and a big smile. This type of emoji is often used to convey positive feedback or to indicate that something is good or acceptable.system
You are a helpful assistant.
user
Is there a hoodie?
assistant
No, there is no hoodie in the image. The image shows a cartoon character giving a thumbs-up gesture.
8.745131969451904 # <---------------------- Inference time 1 time KV Cache
418 3051 36
system
You are a helpful assistant.
user
abc
assistant
The image you provided is an emoji with a thumbs-up gesture, which typically represents approval or agreement. The emoji has a cheerful expression with wide eyes and a big smile. This type of emoji is often used to convey positive feedback or to indicate that something is good or acceptable.system
You are a helpful assistant.
user
Is there a hoodie?
assistant
No, there is no hoodie in the image. The image shows a cartoon character giving a thumbs-up gesture.
16.76 # <---------------------- Inference time 2 time KV Cache
418 3051 36
system
You are a helpful assistant.
user
abc
assistant
The image you provided is an emoji with a thumbs-up gesture, which typically represents approval or agreement. The emoji has a cheerful expression with wide eyes and a big smile. This type of emoji is often used to convey positive feedback or to indicate that something is good or acceptable.system
You are a helpful assistant.
user
Is there a hoodie?
assistant
No, there is no hoodie in the image. The image shows a cartoon character giving a thumbs-up gesture.
24.8 # <---------------------- Inference time 3 time KV Cache

@zucchini-nlp
Copy link
Member

@haixuanTao The script you provided re-uses an initial prompt 3 times with the same input, and each time the generation takes 8 sec. I think it is expected since we're not changing anything and each new generation has inputs["input_ids"] new tokens to process.

To see if there's speed up, try to generate with no cache and compare with cached initial prompt. Make sure that in both cases the total input text processed is identical in length

@haixuanTao
Copy link

Ok indeed! Sorry dumb mistake.

I can confirm that this now works on my Mac M3, shaving one second compared to without kv cache!

Thanks!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants