Skip to content

Qualcomm AI Engine Direct - Support attention sink for long context usecase#16574

Merged
cccclai merged 1 commit intopytorch:mainfrom
CodeLinaro:dev1/hutton/enable_attention_sink
Feb 3, 2026
Merged

Qualcomm AI Engine Direct - Support attention sink for long context usecase#16574
cccclai merged 1 commit intopytorch:mainfrom
CodeLinaro:dev1/hutton/enable_attention_sink

Conversation

@shewu-quic
Copy link
Copy Markdown
Collaborator

@shewu-quic shewu-quic commented Jan 14, 2026

Summary

  • Support narrow operation
  • Support attention sink for static llama
    • Include the --max_context_len option to set the maximum length for the model's memory, and use max_seq_len to define the maximum sequence length for evaluation.
    • Specified --use_attention_sink <sink_size>,<eviction_batch_size> to enable attention sink feature in llama.py
    • Behavior matrix in llama.py
      • Given that --compile_only:
        • Specify --use_attention_sink -> Compile the LLM and the attention sink model
        • Otherwise, -> Compile the LLM only
      • Given that --pre_get_pte:
        • Specify --use_attention_sink -> If the criteria below are not met, compile the attention sink model before running inference. And then inference with attention sink
          • Check if the attention sink model exists
          • Verify sink_size and eviction batch size are identical
        • Otherwise, -> Inference LLM without attention sink
      • Neither --compile_only nor --pre_get_pte:
        • Specify --use_attention_sink -> Compile the LLM and the attention sink model and inference with attention sink
        • Otherwise, -> Compile the LLM and inference LLM without attention sink

Test plan

  • Test for narrow op:
    • python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNFloatingPointOperator.test_qnn_backend_narrow --model SM8750 --device $DEVICE --build_folder build-android
    • python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedOperator.test_qnn_backend_narrow --model SM8750 --device $DEVICE --build_folder build-android
  • Test for attention sink in llama.py
    • python backends/qualcomm/tests/test_qnn_delegate.py TestExampleLLMScript.test_attention_sink --model SM8750 --device $DEVICE -b build-android -a unit_test

Results

  • Test multi-conversation using attention sink in llama.py
    • python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s $DEVICE -m SM8750 --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-1b_instruct --model_mode kv --max_seq_len 4096 --max_context_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" "Could you give a more difficult example in python?" "Could you add a GUI for this game?" "Could you tell me more about tkinter?" "Is it possible to deploy on a website?" --tasks wikitext --limit 1 --use_attention_sink 4,32
Set llama 3.2 1b instruct to a max context length of 1024, and activate the attention sink feature with a sink_size of 4 and eviction_batch_size of 32. Then, run a multi-turn conversation with a sequence length of 4096 using five prompts:
"I would like to learn python, could you teach me with a simple example?",
"Could you give a more difficult example in python?",
"Could you add a GUI for this game?",
"Could you tell me more about tkinter?",
"Is it possible to deploy on a website?"
<|start_header_id|>user<|end_header_id|>

I would like to learn python, could you teach me with a simple example?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I'd be happy to help you learn Python. Here's a simple example to get you started:

**Example: Guessing Game**

In this example, we'll create a simple game where the computer thinks of a number between 1 and 100, and you try to guess it. After each guess, the computer will tell you if your guess is higher or lower than the number it's thinking of.

**Code:**
```python
import random

# The number the computer is thinking of
number = random.randint(1, 100)

# Keep track of your guesses
guesses = 0

# Keep asking for guesses until you guess the number
while True:
    # Ask for a guess
    guess = input("Guess a number: ")

    # Check if the guess is valid
    try:
        # Convert the guess to an integer
        guess = int(guess)

        # Check if the guess is within the range
        if guess < 1 or guess > 100:
            print("Please enter a number between 1 and 100.")
        else:
            # Increment the number of guesses
            guesses += 1

            # Check if you've guessed the number
            if guess == number:
                print(f"Congratulations! You guessed the number in {guesses} guesses.")
                break
            else:
                # Tell the computer how close your guess is
                print(f"Your guess is {guess} away from the number.")
    except ValueError:
        print("Invalid input. Please enter a number.")
```
**How to Run:**

1. Save this code in a file with a `.py` extension (e.g., `guessing_game.py`).
2. Open a terminal or command prompt and navigate to the directory where you saved the file.
3. Type `python guessing_game.py` to run the game.

**What to Do:**

1. Run the game by typing `python guessing_game.py`.
2. Follow the prompts to guess a number.
3. Keep guessing until you guess the number correctly.

I hope this example helps you get started with Python! Let me know if you have any questions or need further guidance.<|eot_id|><|start_header_id|>user<|end_header_id|>

Could you give a more difficult example in python?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Here's a more challenging example in Python:

**Example: Rock, Paper, Scissors**

In this example, we'll create a game of Rock, Paper, Scissors where the computer thinks of a random move and you try to guess it. After each guess, the computer will tell you if your move is Rock, Paper, or Scissors.

**Code:**
```python
import random

# Define the possible moves
moves = ["rock", "paper", "scissors"]

# The computer's move
computer_move = random.choice(moves)

# Keep track of your guesses
guesses = 0

# Keep asking for guesses until you guess the computer's move
while True:
    # Ask for a guess
    guess = input("Guess a move (rock, paper, or scissors): ")

    # Check if the guess is valid
    if guess.lower() in moves:
        # Convert the guess to a move
        move = guess.lower()
        if move == computer_move:
            print(f"Congratulations! You guessed the computer's move correctly!")
            break
        else:
            # Increment the number of guesses
            guesses += 1
            # Tell the computer how close your guess is
            print(f"Your guess is {guess} away from the computer's move.")
    else:
        print("Invalid input. Please enter a valid move (rock, paper, or scissors).")
```
**How to Run:**

1. Save this code in a file with a `.py` extension (e.g., `rock_paper_scissors.py`).
2. Open a terminal or command prompt and navigate to the directory where you saved the file.
3. Type `python rock_paper_scissors.py` to run the game.

**What to Do:**

1. Run the game by typing `python rock_paper_scissors.py`.
2. Follow the prompts to guess a move.
3. Keep guessing until you guess the computer's move correctly.

This example is more challenging because it:

* Requires you to keep track of your guesses and the computer's move.
* Requires you to convert your guess to a specific move (rock, paper, or scissors).
* Requires you to tell the computer how close your guess is to the computer's move.

I hope this example helps you get started with more advanced Python concepts! Let me know if you have any questions or need further guidance.<|eot_id|><|start_header_id|>user<|end_header_id|>

Could you add a GUI for this game?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Here's an example of how you can create a GUI for the Rock, Paper, Scissors game using Python's Tkinter library:

**Code:**
```python
import tkinter as tk
from random import randint

class RockPaperScissors:
    def __init__(self):
        self.window = tk.Tk()
        self.window.title("Rock, Paper, Scissors")
        self.window.geometry("300x200")

        self.move = tk.StringVar()
        self.move.set("rock")

        self.rock_button = tk.Button(self.window, text="Rock", command=self.rock)
        self.rock_button.pack()

        self.paper_button = tk.Button(self.window, text="Paper", command=self.paper)
        self.paper_button.pack()

        self.scissors_button = tk.Button(self.window, text="Scissors", command=self.scissors)
        self.scissors_button.pack()

        self.result_label = tk.Label(self.window, text="", font=('Arial', 24))
        self.result_label.pack()

    def rock(self):
        computer_move = randint(0, 2)
        if computer_move == 0:
            self.move.set("rock")
            self.result_label.config(text=f"Computer chose {self.move.get()}.")
        elif computer_move == 1:
            self.move.set("paper")
            self.result_label.config(text=f"Computer chose {self.move.get()}.")
        else:
            self.move.set("scissors")
            self.result_label.config(text=f"Computer chose {self.move.get()}.")

    def paper(self):
        computer_move = randint(0, 2)
        if computer_move == 0:
            self.move.set("paper")
            self.result_label.config(text=f"Computer chose {self.move.get()}.")
        elif computer_move == 1:
            self.move.set("rock")
            self.result_label.config(text=f"Computer chose {self.move.get()}.")
        else:
            self.move.set("scissors")
            self.result_label.config(text=f"Computer chose {self.move.get()}.")

    def scissors(self):
        computer_move = randint(0, 2)
        if computer_move == 0:
            self.move.set("scissors")
            self.result_label.config(text=f"Computer chose {self.move.get()}.")
        elif computer_move == 1:
            self.move.set("paper")
            self.result_label.config(text=f"Computer chose {self.move.get()}.")
        else:
            self.move.set("rock")
            self.result_label.config(text=f"Computer chose {self.move.get()}.")

    def run(self):
        self.window.mainloop()

if __name__ == "__main__":
    game = RockPaperScissors()
    game.run()
```
**How to Run:**

1. Save this code in a file with a `.py` extension (e.g., `rock_paper_scissors.py`).
2. Open a terminal or command prompt and navigate to the directory where you saved the file.
3. Type `python rock_paper_scissors.py` to run the game.
4. A GUI window will appear with three buttons: Rock, Paper, and Scissors.
5. Click on the buttons to play the game.

This GUI implementation provides a simple and intuitive way to play the Rock, Paper, Scissors game. The game logic is implemented in the `rock`, `paper`, and `scissors` methods, which are called when the corresponding buttons are clicked. The game displays the result in the label below the buttons.

Note that this is a basic implementation, and you can customize it to suit your needs. For example, you could add additional features like scoring, high scores, or a more complex game logic.<|eot_id|><|start_header_id|>user<|end_header_id|>

Could you tell me more about tkinter?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Tkinter is a Python library that allows you to create simple graphical user interfaces (GUIs) for your Python programs. It's a great tool for beginners and experienced Python developers alike.

Here are some key features and concepts related to Tkinter:

**Key Features:**

1. **Widgets:** Tkinter provides a variety of widgets, such as buttons, labels, text boxes, and more, that you can use to create your GUI.
2. **Layout Management:** Tkinter allows you to create a layout for your widgets using the `pack`, `grid`, or `place` methods.
3. **Event Handling:** Tkinter provides a way to handle events, such as button clicks, key presses, and mouse movements, which can trigger actions in your GUI.
4. **Graphics:** Tkinter can display images, charts, and other graphical elements in your GUI.

**Basic Concepts:**

1. **Widgets:** A widget is a graphical element that you can use to create your GUI. Examples of widgets include buttons, labels, and text boxes.
2. **Layout:** A layout is a way of arranging widgets in your GUI. You can use the `pack`, `grid`, or `place` methods to create a layout.
3. **Event Handling:** Event handling is the process of responding to events, such as button clicks or key presses, in your GUI.
4. **Widgets:** You can create a GUI by creating a `Frame` widget, which is a container that holds other widgets.

**Example Code:**

Here's an example of a simple GUI created using Tkinter:
```python
import tkinter as tk

class MyGUI:
    def __init__(self):
        self.window = tk.Tk()
        self.window.title("My GUI")
        self.label = tk.Label(self.window, text="Hello, World!")
        self.label.pack()

        self.button = tk.Button(self.window, text="Click me!", command=self.button_click)
        self.button.pack()

    def button_click(self):
        print("Button clicked!")

    def run(self):
        self.window.mainloop()

if __name__ == "__main__":
    gui = MyGUI()
    gui.run()
```
In this example, we create a `MyGUI` class that creates a window with a label and a button. When the button is clicked, the `button_click` method is called, which prints a message to the console.

**Tips and Tricks:**

1. **Use the `pack` or `grid` layout manager:** These are the most common layout managers in Tkinter.
2. **Use the `mainloop` method:** This is the main method that starts the event loop of your GUI.
3. **Use the `bind` method:** This is used to bind a function to a specific event, such as button click or key press.
4. **Use the `config` method:** This is used to change the properties of a widget, such as its text or color.

I hope this helps! Let me know if you have any specific questions or if you'd like more information on how to use Tkinter.<|eot_id|><|start_header_id|>user<|end_header_id|>

Is it possible to deploy on a website?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Yes, it is possible to deploy a Tkinter application on a website. However, it requires some additional steps and considerations.

Here are some options:

1. **Run the application as a web server:** You can use a web server like Apache or Nginx to run your Tkinter application on a website. This will allow users to access your application through a web browser.
2. **Use a web framework:** You can use a web framework like Flask or Django to create a web application that uses Tkinter as a backend. This will allow you to create a web application that can be accessed through a web browser.
3. **Use a Python web server:** You can use a Python web server like `http.server` or `webserver` to run your Tkinter application on a website. This will allow you to access your application through a web browser.

Here's an example of how you can deploy a Tkinter application on a website using `http.server`:
```python
import tkinter as tk

class MyGUI:
    def __init__(self):
        self.window = tk.Tk()
        self.window.title("My GUI")
        self.label = tk.Label(self.window, text="Hello, World!")
        self.label.pack()

    def run(self):
        self.window.mainloop()

if __name__ == "__main__":
    gui = MyGUI()
    gui.run()
```
To run this code on a website, you'll need to install `http.server` or `webserver` and add it to your `PATH` environment variable. Here's an example of how to do this:
```bash
# Install http.server
pip install http.server

# Add http.server to your PATH environment variable
export PATH=/usr/local/bin:/usr/local/bin
```
Once you've installed `http.server`, you can run the following code to deploy your Tkinter application on a website:
```python
import http.server

class MyGUI:
    def __init__(self):
        self.window = http.server.HTTPServer(("", 8000), MyGUI)

    def run(self):
        self.window.run()

if __name__ == "__main__":
    gui = MyGUI()
    gui.run()
```
This will start a web server on port 8000 and allow users to access your Tkinter application through a web browser.

Note that deploying a Tkinter application on a website requires some technical knowledge and setup. If you're not comfortable with this, you may want to consider using a more modern web development framework or a Python web framework like Flask or Django.<|eot_id|>

The performance for each run is nearly similar.

I 00:00:00.528911 executorch:prompt_processor.cpp:267] Prompt Processor: total 26 prompt tokens (AR-1 * 26 iters)
I 00:00:00.934570 executorch:runner.cpp:459] RSS after prompt prefill: 1326.957031 MiB (0 if unsupported)
I 00:00:08.314366 executorch:token_generator.cpp:345] 
Reached to the end of generation
I 00:00:08.314416 executorch:runner.cpp:474] RSS after finishing text generation: 1326.957031 MiB (0 if unsupported)
I 00:00:08.314457 executorch:stats.h:143] 	Prompt Tokens: 26    Generated Tokens: 446
I 00:00:08.314473 executorch:stats.h:149] 	Model Load Time:		0.526000 (seconds)
I 00:00:08.314489 executorch:stats.h:159] 	Total inference time:		7.786000 (seconds)		 Rate: 	57.282302 (tokens/second)
I 00:00:08.314506 executorch:stats.h:167] 		Prompt evaluation:	0.407000 (seconds)		 Rate: 	63.882064 (tokens/second)
I 00:00:08.314527 executorch:stats.h:178] 		Generated 446 tokens:	7.379000 (seconds)		 Rate: 	60.441794 (tokens/second)
I 00:00:08.314544 executorch:stats.h:186] 	Time to first generated token:	0.407000 (seconds)
I 00:00:08.314547 executorch:stats.h:193] 	Sampling time over 472 tokens:	0.797000 (seconds)

cc: @haowhsu-quic

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 14, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16574

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit b75fe29 with merge base 47b8d1d (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 14, 2026
@shewu-quic
Copy link
Copy Markdown
Collaborator Author

@pytorchbot label "release notes: qualcomm"

@pytorch-bot pytorch-bot bot added the release notes: qualcomm Changes to the Qualcomm backend delegate label Jan 14, 2026
@shewu-quic
Copy link
Copy Markdown
Collaborator Author

Hi @cccclai ,
This PR is to enable attention sink in static llama.
Could you please take a look?
Thanks

@cccclai
Copy link
Copy Markdown
Contributor

cccclai commented Jan 16, 2026

Hi sorry a bit late on this. I'm currently out and will take a look next week

Copy link
Copy Markdown
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for enabling attention sink, this is great!


def get_8a4w_qnn_ptq_config(
act_symmetric: bool = True,
act_symmetric: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason we try to make act_symmetric default to True

Copy link
Copy Markdown
Collaborator Author

@shewu-quic shewu-quic Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to the original design?
This configuration is actually for the 8-bit kv cache. In the case of QK @ V, according to the QNN documentation, the second input V as weight is expected to be signed and symmetrically quantized. For some models, we try to annotate value projection (wv) with 8a4w to improve performance. And we use symmetric to avoid convert op which converts asymmetric to symmetric quantization.

                                                            QK (16 bits) ──┬─> matmul op (16 bits)
                  past v (8 bits symmetric) ┬─> cat op (8 bits symmetric) ─┘
value projection (new v) (8 bits symmetric) ┘
    

Example:
```bash
# Compile llama pte file and attention sink rope pte file with sink_size = 4 and batch_eviction_size = 64
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-1b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 4096 --max_context_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 --use_attention_sink 4,64 --compile_only
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have definition somewhere to explain the difference between max_seq_len and max_context_len?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out. I just try to follow the executorch's llm naming. Let me add some descriptions for it.

python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-1b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 4096 --max_context_len 1024 --prompt "I would like to learn python, could you teach me with a simple example?" --tasks wikitext --limit 1 --use_attention_sink 4,64 --compile_only
```

After running this, the `attention_sink_evictor.pte` file will be generated in the artifacts directory. This file is necessary for using the attention sink feature, as it enables remove batch_eviction_size of the key and value cache and re-rotates the key cache at runtime.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove batch_eviction_size of the key and value cache and re-rotates the key cache at runtime.

what does it mean?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching!

python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2-1b_instruct --model_mode hybrid --prefill_ar_len 128 --max_seq_len 4096 --prompt "I would like to learn python, could you teach me with a simple example?" "Could you give more difficult example in python?" "Could you add a GUI for this game?" "Could you tell me more about tkinter?" "Is possible to deploy on website?" ---pre_gen_pte ${PATH_TO_ARTIFACT_IN_1ST_RUN} --use_attention_sink 4,64
```

If you want to modify `sink_size` or `batch_eviction_size`, or if you have a pre-compiled llm pte file and wish to use the attention sink feature, you can recompile the `attention_sink_evictor.pte` with different attention sink config.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also elaborate batch_eviction_size? I guess this doc is good to explain how to use it, but some of the concept is coupled with attention sink itself

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Let me try to elaborate more about attention sink mechanism.

As far as I know, attention sink is a way to evict cache when maximum context length be reached.
There are two mainly concept for attention sink.

  1. Maintain Attention Sinks: Always include several initial tokens as attention sinks in the KV cache.
  2. Redefine Positional Context: Use positions relative to the cache instead of absolute positions from the original text, enhancing relevance and coherence in generated responses.
image

When the cache reaches capacity, follow these three steps for attention sink:

  1. Keep the first sink_size tokens in the KV cache.
  2. Remove eviction_batch_size tokens from the KV cache.
  3. Rotate the remaining KV cache to maintain its relationship.

Afterward, you can continue generating tokens until the cache reaches capacity again, then repeat the process.

Comment thread examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

parser.add_argument(
"--max_seq_len",
help="The maximum length of sequence to evaluate.",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still not super clear the difference between max_context_len and max_seq_len...

Copy link
Copy Markdown
Collaborator Author

@shewu-quic shewu-quic Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the following option work, or do you have any recommendations?

  1. max_seq_len: Maximum sequence length the model can handle
  2. max_context_len: Maximum length of the model's memory/cache

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_seq_len: Maximum sequence length the model can handle

just to confirm, does it mean max_seq_len = max_context_len + {max decode length}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, the kv cache has the shape [batch, num_head, max_context_len, head_dim].
Previously, we could only generate tokens up to max_context_len - num_of_prompt_token.
With attention sink enabled, it's possible to generate more tokens than max_context_len - num_of_prompt_token.
The max_seq_len parameter determines the maximum number of tokens that can be generated.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it's clearer to say

max_seq_len: Maximum sequence length the model can generate
max_context_len: Maximum length of the model's memory/cache, including both prompt tokens and generated tokens

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

default=8,
type=int,
)
parser.add_argument(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work for all llms?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it should be.

sin(delta), cos(delta))
where delta = new_position * theta - original_position * theta

Based on https://github.com/huggingface/transformers/blame/main/src/transformers/cache_utils.py#L961
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is helpful, though main might change, can we use a link from a commit instead?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Thanks

int32_t seq_len,
std::function<void(const std::string&)> token_callback,
bool dump_logits) {
bool dump_logits,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does attention sink work with lookahead?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the attention sink feature is a way to clear the cache when the number of generated tokens reaches max_context_len. The lookahead method is used to guess more tokens at once to enhance performance.

QuantDtype.use_8a4w,
False,
act_observer=MinMaxObserver,
act_symmetric=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super clear when to use act_symmetric and when not..

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically, the weight should be symmetric. So, if the subsequent operation involves weight, you can annotate the path as symmetric to prevent the need for a convert operation.

@shewu-quic shewu-quic force-pushed the dev1/hutton/enable_attention_sink branch from cc7cea3 to a9ab580 Compare January 27, 2026 07:20
@shewu-quic
Copy link
Copy Markdown
Collaborator Author

Hi @cccclai,

I’ve rebased the PR.
When you have a moment, could you please review it?
Thank you.

@cccclai
Copy link
Copy Markdown
Contributor

cccclai commented Jan 28, 2026

Hi can you rebase again?

@shewu-quic shewu-quic force-pushed the dev1/hutton/enable_attention_sink branch from a9ab580 to f833c5c Compare January 29, 2026 00:32
@shewu-quic
Copy link
Copy Markdown
Collaborator Author

Hi can you rebase again?

Sure, I have rebased and refactored the statement. Thank you.

@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync bot commented Jan 29, 2026

@cccclai has imported this pull request. If you are a Meta employee, you can view this in D91796792.

@cccclai
Copy link
Copy Markdown
Contributor

cccclai commented Jan 30, 2026

There are some internal failure..

File not found: `executorch/examples/qualcomm/oss_scripts/llama/wrappers.py`.
     Included in `executorch/examples/qualcomm/oss_scripts/llama/TARGETS` but does not exist\nFile not found: `fbcode//executorch/examples/qualcomm/oss_scripts/llama/wrappers.py`.
     Included in `executorch/examples/qualcomm/oss_scripts/llama/TARGETS` but does not exist

and another error

@shewu-quic
Copy link
Copy Markdown
Collaborator Author

There are some internal failure..

File not found: `executorch/examples/qualcomm/oss_scripts/llama/wrappers.py`.
     Included in `executorch/examples/qualcomm/oss_scripts/llama/TARGETS` but does not exist\nFile not found: `fbcode//executorch/examples/qualcomm/oss_scripts/llama/wrappers.py`.
     Included in `executorch/examples/qualcomm/oss_scripts/llama/TARGETS` but does not exist

and another error

Resolved the naming in TARGETS file

@cccclai
Copy link
Copy Markdown
Contributor

cccclai commented Feb 1, 2026

It conflicts because the other qcom PR lands...can you rebase again? Sorry for the back and forth

conversation

- Support narrow operation
- Support attention sink for static llama
- Include the --max_context_len option to set the maximum length for the
model's memory, and use max_seq_len to define the maximum sequence
length for evaluation.
- Specified --use_attention_sink <sink_size>,<eviction_batch_size> to
enable attention sink feature in llama.py
- Add more descriptions for attention sink feature and related parameters
- Behavior matrix in `llama.py`
  - Given that `--compile_only`:
    - Specify `--use_attention_sink` -> Compile the LLM and the attention sink model
    - Otherwise, -> Compile the LLM only
  - Given that `--pre_get_pte`:
    - Specify `--use_attention_sink` -> If the criteria below are not met, compile the attention sink model before running inference.
And then inference with attention sink
      - Check if the attention sink model exists
      - Verify sink_size and eviction batch size are identical
    - Otherwise, -> Inference LLM without attention sink
  - Neither `--compile_only` nor `--pre_get_pte`:
    - Specify `--use_attention_sink` -> Compile the LLM and the
	attention sink model and inference with attention sink
    - Otherwise, -> Compile the LLM and inference LLM without
	attention sink
- Test for narrow op:
  - python backends/qualcomm/tests/test_qnn_delegate.py -k
TestQNNFloatingPointOperator.test_qnn_backend_narrow --model SM8750
--device $DEVICE --build_folder build-android
  - python backends/qualcomm/tests/test_qnn_delegate.py -k
TestQNNQuantizedOperator.test_qnn_backend_narrow --model SM8750 --device
$DEVICE --build_folder build-android
- Test for attention sink in llama.py
  - python backends/qualcomm/tests/test_qnn_delegate.py
TestExampleLLMScript.test_attention_sink --model SM8750 --device $DEVICE
-b build-android -a unit_test
@shewu-quic shewu-quic force-pushed the dev1/hutton/enable_attention_sink branch from 2d3fd14 to b75fe29 Compare February 2, 2026 04:57
@shewu-quic
Copy link
Copy Markdown
Collaborator Author

It conflicts because the other qcom PR lands...can you rebase again? Sorry for the back and forth

No problem. I have rebased.
Thanks :)

@cccclai
Copy link
Copy Markdown
Contributor

cccclai commented Feb 2, 2026

I've forward fix the internal error, should be able to merge this after import finish

@cccclai cccclai merged commit 3b3c9d4 into pytorch:main Feb 3, 2026
147 of 150 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: qualcomm Changes to the Qualcomm backend delegate

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants