From 48dc7cea55265e3044d179fc9f09fae5d3793c51 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 20 Nov 2024 15:54:36 -0800 Subject: [PATCH] [Executorch] Add quantized kv cache to oss ci Fixes to make sure quantized kv cache works in oss Differential Revision: [D66269487](https://our.internmc.facebook.com/intern/diff/D66269487/) [ghstack-poisoned] --- .../source_transformation/quantized_kv_cache.py | 2 ++ examples/models/llama/source_transformation/sdpa.py | 2 +- exir/passes/_quant_patterns_and_replacements.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/source_transformation/quantized_kv_cache.py b/examples/models/llama/source_transformation/quantized_kv_cache.py index 26567f3d52c..fae2f124c92 100644 --- a/examples/models/llama/source_transformation/quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/quantized_kv_cache.py @@ -7,6 +7,8 @@ import logging from enum import Enum +import executorch.extension.llm.custom_ops # noqa: F401 + import torch import torch.nn as nn from executorch.examples.models.llama.llama_transformer import KVCache diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index 44541f6eaac..53b3a8f62ba 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -56,7 +56,7 @@ def forward( k_cache = self.kv_cache.k_cache v_cache = self.kv_cache.v_cache - if isinstance(self.kv_cache, QuantizedKVCache): + if hasattr(self.kv_cache, "quantized_cache_dtype"): # updated quantize cache, scale and zero points # returns dequantized kv cache # Not most optimal. Optimizations to follow next diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index 529b22b1d06..f718af3c7dd 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -192,6 +192,19 @@ def embedding_byte_dtype_out_meta( "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", ) +# TODO: move these registrations to pytorch core +quantized_decomposed_lib.define( + "quantize_per_token.out(Tensor input, Tensor scales, Tensor zero_points, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", +) + +quantized_decomposed_lib.define( + "dequantize_per_token.out(Tensor input, Tensor scales, Tensor zero_points, int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype, *, Tensor(a!) out) -> Tensor(a!)", +) + +quantized_decomposed_lib.define( + "choose_qparams_per_token_asymmetric.out(Tensor input, ScalarType dtype, *, Tensor(a!) scale_out, Tensor(b!) zero_point_out) -> (Tensor(a!), Tensor(b!))", +) + @impl(quantized_decomposed_lib, "embedding_2bit", "CompositeExplicitAutograd") def embedding_2bit(