Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP8-E5M2 KV Cache #2279

Merged
merged 43 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
0ac4ba1
test_cache.py passed
Dec 22, 2023
63ec85b
format
Dec 27, 2023
f98c816
fix compiling warning
Dec 27, 2023
9881221
fix fp8x4 -> float4
Dec 28, 2023
850137b
test_attention pass
Dec 29, 2023
a852f54
fix typo
Dec 29, 2023
7be2ed4
fix typo
Dec 29, 2023
4bcf15c
add latency & throughput benchmark
Dec 29, 2023
1a13c5a
fix benchmark_latency.py
Dec 29, 2023
5cdb619
Merge branch 'main' into fp8_cache
Jan 3, 2024
82516df
fix copy_blocks
Jan 3, 2024
556e5b2
use VLLM_DISPATCH_CASE_FLOATING_BYTE_TYPES
Jan 3, 2024
c67277b
add default behavior in description
Jan 12, 2024
c3760f8
grace code regarding to comments
Jan 13, 2024
7bae850
add namespace fp8_e5m2_unscaled
Jan 13, 2024
525003b
change interface
Jan 15, 2024
537b5a7
solve conflict
Jan 15, 2024
7e837dd
fix tp
Jan 15, 2024
fe5f053
print log.info
Jan 15, 2024
58d9817
add log and raise error on amd gpu
Jan 16, 2024
dddd6eb
fix none error
Jan 18, 2024
a61d828
do not use VLLM_LDG
Jan 18, 2024
1cb7af6
fix mirror typo
Jan 19, 2024
589297a
fix compiler error on lower cc
Jan 19, 2024
6223984
fix unittest error
Jan 19, 2024
4f85f9b
fix utest
Jan 19, 2024
0ff1d14
Merge branch 'main' into fp8_cache
Jan 20, 2024
b4db831
fix tp error
Jan 21, 2024
3072560
more clear, tp error
Jan 24, 2024
d837bbb
fix conflict
Jan 24, 2024
b493300
avoid compile fp8 when cuda version is lower than 11.8
Jan 24, 2024
5461bd6
fix ut
Jan 25, 2024
f66fb4e
fix typo
Jan 25, 2024
7e5d61b
loose NUM_BLOCKS in test_attention
Jan 25, 2024
b4aedf5
update test_cache
Jan 25, 2024
eac2720
fix yapf
Jan 25, 2024
bb6cc13
loose num_blocks
Jan 26, 2024
455d0b5
fix stuck when tp>1
Jan 26, 2024
dbd464c
update
Jan 27, 2024
11411e1
rename: create_kv_caches -> create_kv_caches_with_random
Jan 27, 2024
4945577
remove cache_dtype_str
Jan 28, 2024
b52e702
add co-author
Jan 28, 2024
fee9a13
Merge branch 'main' into fp8_cache
zhuohan123 Jan 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main(args: argparse.Namespace):
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
)

sampling_params = SamplingParams(
Expand Down Expand Up @@ -115,6 +116,11 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument('--enforce-eager',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument('--kv-cache-dtype',
type=str,
choices=['fp8', None],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, please be explicit in this PR that the fp8 means fp8_e5m2. Given that there are multiple ways to implement fp8, this will make things more clear.

Suggested change
choices=['fp8', None],
choices=['fp8_e5m2', None],

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

default=None,
help='Data type for kv cache storage.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please specify the default behavior here. And one question, why 'fp16' is not a valid option here?

Copy link
Collaborator

@zhuohan123 zhuohan123 Jan 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, can we call fp8 fp8_e5m2 across this PR? Because there are different ways to implement fp8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default kv cache data type is same with the model dtype. So it may be not suitable to make fp16 as default, as the model dtype will be float/bfloat16/half/...

parser.add_argument(
'--profile',
action='store_true',
Expand Down
12 changes: 11 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def run_vllm(
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
Expand All @@ -83,6 +84,7 @@ def run_vllm(
dtype=dtype,
max_model_len=max_model_len,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
)

# Add the requests to the engine.
Expand Down Expand Up @@ -206,7 +208,8 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager)
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
Expand Down Expand Up @@ -284,6 +287,13 @@ def main(args: argparse.Namespace):
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['fp8', None],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
choices=['fp8', None],
choices=['fp8_e5m2', None],

default=None,
help=
'Data type for kv cache storage. If None, will use model data type.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
Expand Down
38 changes: 32 additions & 6 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from vllm._C import ops
from vllm._C import ops, cache_ops

NUM_BLOCKS = 1024
PARTITION_SIZE = 512
Expand All @@ -21,6 +21,7 @@ def main(
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
use_fp8_kv_cache: bool,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use a string option instead of bool.

Suggested change
use_fp8_kv_cache: bool,
kv_cache_dtype: Optional[str] = None,

seed: int,
do_profile: bool,
) -> None:
Expand Down Expand Up @@ -59,15 +60,36 @@ def main(
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")

# Create the KV cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
cache_dtype = dtype if not use_fp8_kv_cache else torch.uint8
x = 16 // torch.tensor([], dtype=cache_dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
key_cache.uniform_(-scale, scale)
key_cache = torch.empty(size=key_cache_shape,
dtype=cache_dtype,
device="cuda")
if not use_fp8_kv_cache:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not use_fp8_kv_cache:
if kv_cache_dtype == None:

key_cache.uniform_(-scale, scale)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
elif kv_cache_dtype == 'fp8_e5m2':

# NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type,
# it may occur Inf or NaN if we directly use torch.randint
# to generate random data for fp8 cache.
# For example, s.11111.00 in fp8e5m2 format repesents Inf.
# | E4M3 | E5M2
#-----|-------------|-------------------
# Inf | N/A | s.11111.00
# NaN | s.1111.111 | s.11111.{01,10,11}
key_cache_tmp = torch.empty_like(key_cache, dtype=dtype)
key_cache_tmp.uniform_(-scale, scale)
cache_ops.convert_fp8(key_cache_tmp, key_cache)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cache_ops.convert_fp8(key_cache_tmp, key_cache)
cache_ops.convert_fp8_e5m2(key_cache_tmp, key_cache)

value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
dtype=cache_dtype,
device="cuda")
value_cache.uniform_(-scale, scale)
if not use_fp8_kv_cache:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

value_cache.uniform_(-scale, scale)
else:
value_cache_tmp = torch.empty_like(value_cache, dtype=dtype)
value_cache_tmp.uniform_(-scale, scale)
cache_ops.convert_fp8(value_cache_tmp, value_cache)

# Prepare for the paged attention kernel.
output = torch.empty_like(query)
Expand Down Expand Up @@ -106,6 +128,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_context_len,
alibi_slopes,
use_fp8_kv_cache,
)
elif version == "v2":
ops.paged_attention_v2(
Expand All @@ -123,6 +146,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
block_size,
max_context_len,
alibi_slopes,
use_fp8_kv_cache,
)
else:
raise ValueError(f"Invalid version: {version}")
Expand Down Expand Up @@ -166,6 +190,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
type=str,
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument("--use-fp8-kv-cache", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
args = parser.parse_args()
Expand All @@ -188,6 +213,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float:
block_size=args.block_size,
use_alibi=args.use_alibi,
dtype=dtype_to_torch_dtype[args.dtype],
use_fp8_kv_cache=args.use_fp8_kv_cache,
seed=args.seed,
do_profile=args.profile,
)
1 change: 1 addition & 0 deletions csrc/attention/attention_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
#include "dtype_fp8.cuh"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include "dtype_fp8.cuh"
#include "dtype_fp8_e5m2.cuh"

Loading
Loading