Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add attention sinks #3515

Draft
wants to merge 90 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
7914879
temp
felixzhu555 Mar 15, 2024
5b672d9
wip
felixzhu555 Mar 18, 2024
b35d7ba
wip
felixzhu555 Mar 18, 2024
e90cb58
wip
felixzhu555 Mar 19, 2024
831f18b
wip
felixzhu555 Mar 19, 2024
c8d86e6
change q pos
felixzhu555 Mar 21, 2024
0bd7566
evict
felixzhu555 Mar 21, 2024
f0263a4
edit xformers
felixzhu555 Mar 31, 2024
15b68ca
wip
Mar 31, 2024
9fe1895
wip
Apr 1, 2024
595638d
wip
felixzhu555 Apr 1, 2024
217743d
wip
felixzhu555 Apr 4, 2024
fd83c78
wip
felixzhu555 Apr 10, 2024
12e0e97
pull from main
felixzhu555 Apr 13, 2024
a9b094c
wip
felixzhu555 Apr 14, 2024
25e599d
cuda illegal memory access
felixzhu555 Apr 14, 2024
d14b94e
wip
Apr 17, 2024
8bb1840
cache current prerope key inside llama instead of xformers
felixzhu555 Apr 18, 2024
339305b
early eos
felixzhu555 Apr 18, 2024
1157cf3
fix small bugs
felixzhu555 Apr 18, 2024
0f0a414
wip
felixzhu555 Apr 21, 2024
6f01606
fix prefill
felixzhu555 Apr 22, 2024
740cbdb
wip
felixzhu555 Apr 23, 2024
15d586a
starting to work!
felixzhu555 Apr 24, 2024
c4a50b4
blockwise speedup
felixzhu555 Apr 24, 2024
455c814
wip
felixzhu555 Apr 25, 2024
ee12294
start removing loop over blocks
felixzhu555 May 15, 2024
d29b559
wip
felixzhu555 May 15, 2024
94ebe4d
wip
felixzhu555 May 16, 2024
899a7b3
wip: after refactor, generation abruptly ends
felixzhu555 May 17, 2024
016a6c6
speedup to 4 tok/s done
felixzhu555 May 18, 2024
2186c13
cache phys_bnums, 3x speedup
felixzhu555 May 18, 2024
e7acfbe
move logic out of xformers.py
felixzhu555 May 19, 2024
18042c6
refactor into new layer
felixzhu555 May 19, 2024
d2af329
pull from main
felixzhu555 May 21, 2024
8fe15d4
wip
felixzhu555 May 21, 2024
3ae06f5
add use_attention_sinks args
felixzhu555 May 21, 2024
67c3bdf
investigating eviction issue
felixzhu555 May 23, 2024
e09296b
flash attn works
felixzhu555 May 24, 2024
8f152d5
remove eviction & some refactoring
felixzhu555 May 30, 2024
a766775
start mixtral
felixzhu555 Jun 1, 2024
1e44278
refactor, start alibi, try falcon/bloom
felixzhu555 Jun 2, 2024
7413279
tiny
felixzhu555 Jun 2, 2024
05d7aa9
add mpt
felixzhu555 Jun 3, 2024
19a90f6
alibi not working
felixzhu555 Jun 3, 2024
34df763
fix seq len bug -> eviction and alibi work
felixzhu555 Jun 4, 2024
afb754c
eviction moved to block manager -> rope works, alibi not yet
felixzhu555 Jun 5, 2024
d7db6e1
fix alibi bug
felixzhu555 Jun 5, 2024
3d0929c
beam search not supported
felixzhu555 Jun 5, 2024
13b48c4
small fix
felixzhu555 Jun 5, 2024
9475536
pull main
felixzhu555 Jun 5, 2024
88a77d3
refactor models
felixzhu555 Jun 5, 2024
b3cfffb
pull main
felixzhu555 Jun 6, 2024
3e229a0
small
felixzhu555 Jun 7, 2024
b834de8
tests wip
felixzhu555 Jun 8, 2024
56b448a
tests failing
felixzhu555 Jun 10, 2024
7d9723c
wip
felixzhu555 Jun 11, 2024
2f92168
test correctness done
felixzhu555 Jun 12, 2024
c8416a0
add attn backend to tests, add eviction test
felixzhu555 Jun 12, 2024
b31ae95
small
felixzhu555 Jun 12, 2024
f241532
start refactor
felixzhu555 Jun 19, 2024
5c7f802
add wrapper method
felixzhu555 Jun 19, 2024
143db31
wip
felixzhu555 Jun 19, 2024
0722ff0
pull main
felixzhu555 Jun 20, 2024
e0848e3
refactor wip
felixzhu555 Jun 20, 2024
7abb285
fix test
felixzhu555 Jun 21, 2024
5bf0d5c
small
felixzhu555 Jun 21, 2024
ae31b1d
chunked prefill wip
felixzhu555 Jun 21, 2024
779b2a3
wip
felixzhu555 Jun 22, 2024
d527920
cuda mem error
felixzhu555 Jun 22, 2024
87bd485
chunked prefill working
felixzhu555 Jun 23, 2024
0a1abf8
wip
felixzhu555 Jun 23, 2024
08fd48f
fix paxos paper
felixzhu555 Jun 25, 2024
65f5f6d
wip
felixzhu555 Jun 26, 2024
cb12d5f
Merge branch 'main' of https://github.com/vllm-project/vllm into add_…
felixzhu555 Jun 26, 2024
fdc1365
chunked prefill for alibi
felixzhu555 Jun 27, 2024
da75ff6
add some docstrings
felixzhu555 Jun 28, 2024
fa8a253
fix test
felixzhu555 Jun 28, 2024
1763a44
pull main
felixzhu555 Jun 29, 2024
ef65724
fix after removal of logical block table
felixzhu555 Jul 16, 2024
38bd15f
change pos arange
felixzhu555 Jul 17, 2024
b0b8d0b
pull main
felixzhu555 Jul 17, 2024
7de1a21
small
felixzhu555 Aug 4, 2024
1ecec38
small
felixzhu555 Aug 4, 2024
5f03373
pull main, breaking changes to be fixed
felixzhu555 Aug 4, 2024
2da86a8
fix updates from pull main
felixzhu555 Aug 4, 2024
71ca701
refactor forward: remove rem logic, move torch ops out of loop
felixzhu555 Aug 4, 2024
bce7902
fix flash_attn.py
felixzhu555 Aug 4, 2024
be779fb
fix tests
felixzhu555 Aug 4, 2024
9d97b8d
pull main
felixzhu555 Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions temp/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import json
import math
from typing import List, Tuple
import os
from transformers import AutoTokenizer
from huggingface_hub import login
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput


login(token=os.environ.get("HF_TOKEN"))

MAX_GEN_TOKENS = 500
beam_search_params = SamplingParams(
temperature=0,
min_tokens=100,
max_tokens=150,
use_beam_search=True,
n=3
)


def get_prompt(model, file_path="/workspace/vllm/tests/prompts/attn-sinks-prompts.txt", magic_word=True) -> List[Tuple[str, SamplingParams]]:
with open(file_path, "r") as f:
prompts = json.load(f)

prompt = prompts[model]
if magic_word:
prompt = (
"Remember: my favorite color is mint green. "
"Here is a Harry Potter excerpt: " + prompt +
" First, summarize this excerpt. Then, print my favorite color AFTER the summary."
)
return [(prompt, SamplingParams(
logprobs=1,
min_tokens=100,
max_tokens=500,
temperature=0.5
)) for _ in range(1)]


def get_short_prompt() -> List[Tuple[str, SamplingParams]]:
prompt = "Tell me the story of the boy who cried wolf."
return [(prompt, SamplingParams(
logprobs=1,
min_tokens=50,
max_tokens=200
)) for _ in range(1)]


def get_long_prompt(file_path="./paxos-paper.txt", count=1) -> Tuple[str, SamplingParams]:
# this file is 4060 tokens
with open(file_path, "r") as f:
prompt = f.read()

# prompt = "Remember: the magic word is apple. " + prompt + " Then, print the magic word given earlier."
return [(prompt, SamplingParams(
logprobs=1,
temperature=1,
min_tokens=100,
max_tokens=MAX_GEN_TOKENS,
)) for _ in range(count)]


def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, SamplingParams]],
do_print: bool):
"""Continuously process a list of prompts and handle the outputs."""
request_id = 0

while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1

request_outputs: List[RequestOutput] = engine.step()

for request_output in request_outputs:
if request_output.finished:
# print("\nPROMPT:")
# print(request_output.prompt)

out = request_output.outputs[0]
num_tokens = len(out.token_ids)
cum_logprob = out.cumulative_logprob
avg_logprob = cum_logprob / num_tokens

if do_print:
print("\n", "~" * 100)
print(f"Prompt length: {len(request_output.prompt_token_ids)} tokens")
print(f"OUTPUT: ({num_tokens} tokens)")
print(out.text, "\n")
print("Output stats:", cum_logprob, avg_logprob, out.finish_reason, f"isnan={math.isnan(cum_logprob)}")
print("~" * 100)


if __name__ == "__main__":
model = "meta-llama/Llama-2-13b-chat-hf"
model = "mistralai/Mixtral-8x7B-Instruct-v0.1" # TODO
model = "tiiuae/falcon-7b-instruct"
model = "bigscience/bloom-7b1"
model = "mistralai/Mistral-7B-Instruct-v0.2" # llama under the hood
model = "mosaicml/mpt-7b-chat"
model = "lmsys/vicuna-7b-v1.5"
model = "meta-llama/Meta-Llama-3-8B-Instruct"
args = EngineArgs(
model=model,
enforce_eager=True,
block_size=16,
dtype="bfloat16",
use_attention_sinks=True,
enable_chunked_prefill=False
)

engine = LLMEngine.from_engine_args(args)
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
prompts = get_prompt(model, magic_word=True)
# prompts = get_short_prompt()
# prompts = get_long_prompt()
process_requests(engine, prompts, do_print=False)
1 change: 1 addition & 0 deletions temp/paxos-paper.txt

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions temp/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

read -p "Paste HF token: " token
export HF_TOKEN="$token"
export MAX_JOBS=8
git config --global user.email "felixzhu555@gmail.com"
git config --global user.name "Felix Zhu"
pip install -e .
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,25 @@ def generate_w_logprobs(
outputs.append((output_ids, output_str, output_logprobs))
return outputs

def generate_w_cum_logprobs(
self,
prompts: List[str],
sampling_params: SamplingParams,
) -> List[Tuple[str, float]]:
assert sampling_params.logprobs is not None
req_outputs = self.model.generate(prompts,
sampling_params=sampling_params)
outputs: List[Tuple[str, float]] = []
for req_output in req_outputs:
assert len(req_output.outputs) == 1, \
"This method expects only one CompletionOutput per request."
compl_output = req_output.outputs[0]
output_str = compl_output.text
output_logprob = compl_output.cumulative_logprob
output_avg_logprob = output_logprob / len(compl_output.token_ids)
outputs.append((output_str, output_avg_logprob))
return outputs

def generate_greedy(
self,
prompts: List[str],
Expand Down
10 changes: 10 additions & 0 deletions tests/prompts/attn-sinks-prompts.txt

Large diffs are not rendered by default.

194 changes: 194 additions & 0 deletions tests/test_attention_sinks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""Test attention sinks correctness for large models (7B).

Run `pytest tests/test_attention_sinks.py`.
"""
from functools import lru_cache
from math import isnan
import os
import pytest

from vllm import SamplingParams, EngineArgs, LLMEngine
from vllm.attention.selector import get_attn_backend


_ATTN_SINKS_PROMPTS_FILEPATH = os.path.join(
os.path.dirname(__file__),
"prompts",
"attn-sinks-prompts.txt"
)

_RETRIEVAL_COLOR = "mint green"


@pytest.mark.parametrize(
"model, max_model_len, test_retrieval, min_tokens, max_tokens, enable_chunked_prefill",
[
# rope models
("meta-llama/Meta-Llama-3-8B-Instruct", 8192, True, 100, 400, True),
("mistralai/Mistral-7B-Instruct-v0.2", 32768, True, 100, 600, False),
# alibi models
("mosaicml/mpt-7b-chat", 2048, False, 500, 800, False),
("bigscience/bloom-7b1", 2048, False, 500, 800, False)
]
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"])
def test_correctness(
vllm_runner,
model: str,
max_model_len: int,
test_retrieval: bool,
min_tokens: int,
max_tokens: int,
dtype: str,
batch_size: int,
attn_backend: str,
enable_chunked_prefill: bool,
monkeypatch: pytest.MonkeyPatch
):
if model == "mosaicml/mpt-7b-chat" and attn_backend == "XFORMERS":
return # sinks performance is worse than just alibi here

prompt = _get_prompt(model, test_retrieval=test_retrieval)
prompts = [prompt] * batch_size
params = SamplingParams(
logprobs=1,
temperature=0.5,
min_tokens=min_tokens,
max_tokens=max_tokens
)

monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
with vllm_runner(
model,
max_model_len=max_model_len,
dtype=dtype,
enforce_eager=True,
enable_chunked_prefill=enable_chunked_prefill
) as normal_model:
# bypass context length cap for normal generation
# to compare w/ attention sinks, which generates past context length
monkeypatch.setattr(
normal_model.model.llm_engine.output_processor.stop_checker,
"use_attention_sinks",
True
)
normal_outputs = normal_model.generate_w_cum_logprobs(prompts, params)
monkeypatch.undo() # undo setattr so that cleanup runs correctly

monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
with vllm_runner(
model,
max_model_len=max_model_len,
dtype=dtype,
enforce_eager=True,
use_attention_sinks=True,
enable_chunked_prefill=enable_chunked_prefill
) as sink_model:
sink_outputs = sink_model.generate_w_cum_logprobs(prompts, params)

get_attn_backend.cache_clear()

if test_retrieval:
for output_str, _ in sink_outputs:
assert _RETRIEVAL_COLOR in output_str.lower()

sum_normal_avg_logprob_per_token = sum(
avg_logprob for _, avg_logprob in normal_outputs)
sum_sink_avg_logprob_per_token = sum(
avg_logprob for _, avg_logprob in sink_outputs)

# attn sinks should be lower perplexity (higher logprob per token)
# nan logprob means negative infinity
assert sum_sink_avg_logprob_per_token > sum_normal_avg_logprob_per_token \
or isnan(sum_normal_avg_logprob_per_token)


@pytest.mark.parametrize("model, max_model_len", [
("meta-llama/Meta-Llama-3-8B-Instruct", 8192),
("mosaicml/mpt-7b-chat", 2048)
])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("attn_backend, block_size", [
("XFORMERS", 8),
("FLASH_ATTN", 16),
("FLASH_ATTN", 32)
])
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
def test_eviction(
model: str,
max_model_len: int,
dtype: str,
batch_size: int,
attn_backend: str,
block_size: int,
enable_chunked_prefill: bool,
monkeypatch: pytest.MonkeyPatch
):
prompt = _get_prompt(model)
prompts = [prompt] * batch_size
sampling_params = SamplingParams(
logprobs=1,
min_tokens=200,
max_tokens=201
)

monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

engine_args = EngineArgs(
model,
max_model_len=max_model_len,
dtype=dtype,
block_size=block_size,
enforce_eager=True,
use_attention_sinks=True,
enable_chunked_prefill=enable_chunked_prefill
)
engine = LLMEngine.from_engine_args(engine_args)

total_blocks = engine.scheduler[0].block_manager.get_num_free_gpu_blocks()
max_blocks_needed = (max_model_len // block_size) * batch_size

request_id = 0
while prompts or engine.has_unfinished_requests():
if prompts:
prompt = prompts.pop()
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1

engine.step()
free_blocks = engine.scheduler[0].block_manager.get_num_free_gpu_blocks()
used_blocks = total_blocks - free_blocks
assert used_blocks <= max_blocks_needed, (
f"Number of used blocks ({used_blocks}) should be "
f"at most {max_blocks_needed}"
)

del engine
get_attn_backend.cache_clear()


@lru_cache
def _get_prompt(model_name: str, test_retrieval: bool = False) -> str:
prompts = _get_prompts_json()
prompt = prompts[model_name]
# prompt is (model's context length - 100) tokens long

if test_retrieval:
return (
f"Remember: my favorite color is {_RETRIEVAL_COLOR}. "
f"Here is a Harry Potter excerpt: {prompt} "
"First, summarize this excerpt. "
"Then, print my favorite color AFTER the summary."
)
else:
return prompt


@lru_cache
def _get_prompts_json():
import json
with open(_ATTN_SINKS_PROMPTS_FILEPATH, "r") as f:
return json.load(f)
4 changes: 3 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
self.model_config = input_builder.model_config

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
Expand Down Expand Up @@ -269,7 +270,8 @@ def _add_seq_group(
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
self.block_size, inter_data.block_tables,
self.model_config)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
self.model_config = input_builder.model_config

# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
Expand Down Expand Up @@ -294,7 +295,8 @@ def _add_seq_group(
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
self.block_size, inter_data.block_tables,
self.model_config)

# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
Expand Down
Loading
Loading