Skip to content
31 changes: 31 additions & 0 deletions examples/models/llama2/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,44 @@ runtime.python_library(
],
)

runtime.python_library(
name = "sdpa",
srcs = [
"source_transformation/sdpa.py",
],
_is_external_target = True,
visibility = ["//executorch/..."],
deps = [
"//caffe2:torch",
],
)

runtime.python_test(
name = "quantized_kv_cache_test",
srcs = [
"source_transformation/test_quantized_kv_cache.py",
],
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
],
deps = [
":quantized_kv_cache",
"//caffe2:torch",
"//executorch/examples/models/llama2:llama_transformer",
],
)

runtime.python_test(
name = "quantized_sdpa_with_kv_cache_test",
srcs = [
"source_transformation/test_sdpa_with_quantized_kv_cache.py",
],
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
],
deps = [
":quantized_kv_cache",
":sdpa",
"//caffe2:torch",
"//executorch/examples/models/llama2:llama_transformer",
],
Expand Down
4 changes: 1 addition & 3 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,9 +890,7 @@ def _get_source_transforms( # noqa
transforms.append(replace_sdpa_with_custom_op)

if args.quantize_kv_cache:
assert (
args.use_kv_cache and not args.use_sdpa_with_kv_cache
), "quantize_kv_cache requires use_kv_cache=True and use_sdpa_with_kv_cache=False"
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"
transforms.append(replace_kv_cache_with_quantized_kv_cache)

if args.use_kv_cache:
Expand Down
102 changes: 65 additions & 37 deletions examples/models/llama2/source_transformation/quantized_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
raise ValueError(
f"Only affine symmetric and asymmetric cache types are supported: got {cache_type}"
)

# For now supporting int8 only
self.quantized_cache_dtype = torch.int8
self.cache_fp_type = torch.float32
Expand All @@ -65,10 +66,10 @@ def __init__(
"v_cache", torch.zeros(cache_shape, dtype=self.quantized_cache_dtype)
)
self.register_buffer(
"k_cache_scales", torch.ones(scale_shape, dtype=torch.double)
"k_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
)
self.register_buffer(
"v_cache_scales", torch.ones(scale_shape, dtype=torch.double)
"v_cache_scales", torch.ones(scale_shape, dtype=torch.float64)
)
if cache_type == QuantizedCacheType.AffineAsymmetric:
self.register_buffer(
Expand Down Expand Up @@ -100,47 +101,74 @@ def update(self, input_pos, k_val, v_val):

quantized_v_val, v_scales, v_zero_points = self._quantize(v_val)

if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
dim_to_slice = 2 if self.is_transposed else 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k_scales = self.k_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k_zp = self.k_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k.copy_(quantized_k_val)
narrowed_k_scales.copy_(k_scales)
narrowed_k_zp.copy_(k_zero_points)
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v_scales = self.v_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v_zp = self.v_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v.copy_(quantized_v_val)
narrowed_v_scales.copy_(v_scales)
narrowed_v_zp.copy_(v_zero_points)
else:
if self.is_transposed:
if self.is_transposed:
# We cannot use update_cache op at the moment
# if the cache is transposed
# Also note that we shold not need separate paths
# for dynamic shape vs !
# Only reason it is done this way is to accommodate
# for lowering pains of backends that work better
# with index_put op.
if self.enable_dynamic_shape:
start_pos = input_pos[0].item()
torch._check_is_size(start_pos)
dim_to_slice = 2 if self.is_transposed else 1
torch._check(start_pos < self.k_cache.size(dim_to_slice))
seq_length = k_val.size(dim_to_slice)
narrowed_k = self.k_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_k_scales = self.k_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k_zp = self.k_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_k.copy_(quantized_k_val)
narrowed_k_scales.copy_(k_scales)
narrowed_k_zp.copy_(k_zero_points)
narrowed_v = self.v_cache.narrow(dim_to_slice, start_pos, seq_length)
narrowed_v_scales = self.v_cache_scales.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v_zp = self.v_cache_zero_points.narrow(
dim_to_slice, start_pos, seq_length
)
narrowed_v.copy_(quantized_v_val)
narrowed_v_scales.copy_(v_scales)
narrowed_v_zp.copy_(v_zero_points)
else:
self.k_cache[:, :, input_pos] = quantized_k_val
self.k_cache_scales[:, :, input_pos] = k_scales
self.k_cache_zero_points[:, :, input_pos] = k_zero_points
self.v_cache[:, :, input_pos] = quantized_v_val
self.v_cache_scales[:, :, input_pos] = v_scales
self.v_cache_zero_points[:, :, input_pos] = v_zero_points
else:
self.k_cache[:, input_pos] = quantized_k_val
self.k_cache_scales[:, input_pos] = k_scales
self.k_cache_zero_points[:, input_pos] = k_zero_points
self.v_cache[:, input_pos] = quantized_v_val
self.v_cache_scales[:, input_pos] = v_scales
self.v_cache_zero_points[:, input_pos] = v_zero_points
else:
# Right now using custom ops on this path.
# In future we can update custom op to handle transposed cache
# as well.
# Note that we may have to revert this change if other ET
# backends such as QNN want to use quantized cache, with dynamic shape,
# instead of quantizing on their own.
# But until this opting for code simplicity
start_pos = input_pos[0].item()
_ = torch.ops.llama.update_quantized_cache(
quantized_k_val, self.k_cache, start_pos
)
_ = torch.ops.llama.update_quantized_cache(
k_scales, self.k_cache_scales, start_pos
)
_ = torch.ops.llama.update_quantized_cache(
k_zero_points, self.k_cache_zero_points, start_pos
)
_ = torch.ops.llama.update_quantized_cache(
quantized_v_val, self.v_cache, start_pos
)
_ = torch.ops.llama.update_quantized_cache(
v_scales, self.v_cache_scales, start_pos
)
_ = torch.ops.llama.update_quantized_cache(
v_zero_points, self.v_cache_zero_points, start_pos
)

k_out = torch.ops.quantized_decomposed.dequantize_per_token(
self.k_cache,
Expand Down
34 changes: 29 additions & 5 deletions examples/models/llama2/source_transformation/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,32 @@
# Example script for exporting Llama2 to flatbuffer

import math
from typing import Tuple
from typing import Tuple, Union

import torch

from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA
from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
QuantizedKVCache,
)


class SDPACustom(torch.nn.Module):
def __init__(
self,
kv_cache: KVCache,
kv_cache: Union[KVCache, QuantizedKVCache],
dim: int,
):
super().__init__()
# Custom op only supports float32 currently. Converting to/from float32 is
# faster than not having the op.
self.kv_cache = kv_cache.to(torch.float)
self.kv_cache = kv_cache
if not isinstance(kv_cache, QuantizedKVCache):
self.kv_cache = kv_cache.to(torch.float)
else:
assert (
kv_cache.cache_fp_type == torch.float32
), "Only float32 is supported for custom SDPA"
self.dim = dim

def forward(
Expand All @@ -44,12 +53,27 @@ def forward(
q = q.to(dtype=torch.float)
k = k.to(dtype=torch.float)
v = v.to(dtype=torch.float)

k_cache = self.kv_cache.k_cache
v_cache = self.kv_cache.v_cache
if isinstance(self.kv_cache, QuantizedKVCache):
# updated quantize cache, scale and zero points
# returns dequantized kv cache
# Not most optimal. Optimizations to follow next
k_cache, v_cache = self.kv_cache.update(input_pos, k, v)
# Note that this path will still inplace mutate the k_cache, v_cache.
# WHen we are not using quantized kv cache, this will just mutate
# the original kv cache.
# When we aer using quantized kv cache, this will mutate
# k_cache, v_cache that is returned from cache update operation.
# This operation just dequantized thee cache and returns that.
# Future diffs will optimize this
output = torch.ops.llama.sdpa_with_kv_cache(
q,
k,
v,
self.kv_cache.k_cache,
self.kv_cache.v_cache,
k_cache,
v_cache,
input_pos[-1].item(),
seqlen,
None, # Attention mask
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from executorch.examples.models.llama2.llama_transformer import KVCache

from executorch.examples.models.llama2.source_transformation.quantized_kv_cache import (
QuantizedCacheType,
QuantizedKVCache,
)

from executorch.examples.models.llama2.source_transformation.sdpa import SDPACustom


class SDPAWithQuantizedKVCacheTest(unittest.TestCase):

def _init_cache(self):
self.kv_cache = KVCache(
self.max_batch_size,
self.max_seq_len,
self.n_kv_heads,
self.head_dim,
False,
self.enable_dynamic_shape,
dtype=self.dtype,
)
self.quantized_kv_cache = QuantizedKVCache.from_float(
self.kv_cache, QuantizedCacheType.AffineAsymmetric
)

def _init_kv(self):
kv_shape = (1, self.seq_len, self.n_kv_heads, self.head_dim)
q_shape = (1, self.seq_len, self.n_heads, self.head_dim)
q = torch.rand(q_shape, dtype=self.dtype)
k = torch.rand(kv_shape, dtype=self.dtype)
v = torch.rand(kv_shape, dtype=self.dtype)
return q, k, v

def setUp(self):
torch.manual_seed(42)
self.max_batch_size = 1
self.max_seq_len = 5
self.n_kv_heads = 4
self.n_heads = 8
self.head_dim = 17
self.dim = self.n_heads * self.head_dim
self.enable_dynamic_shape = False
self.dtype = torch.float32

def test_simple(self, is_dynamic_shape=False):
self.enable_dynamic_shape = is_dynamic_shape
input_pos = torch.tensor([0], dtype=torch.int64)
self.seq_len = 3
self._init_cache()
q, k, v = self._init_kv()
self.float_sdpa = SDPACustom(self.kv_cache, self.dim)
self.quantized_sdpa = SDPACustom(self.quantized_kv_cache, self.dim)
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
torch.testing.assert_close(
float_out,
quantized_out,
)

input_pos = torch.tensor([3], dtype=torch.int64)
self.seq_len = 1
q, k, v = self._init_kv()
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
quantized_out = self.quantized_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
torch.testing.assert_close(
float_out,
quantized_out,
rtol=1e-03,
atol=1e-03,
)
Loading