diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 46d4ca5426..33a3532683 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -99,6 +99,7 @@ class Float8Tensor(TorchAOBaseTensor): "act_quant_kwargs", "kernel_preference", "dtype", + "test_only_new_attr", ] def __new__( @@ -110,6 +111,7 @@ def __new__( act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, kernel_preference: KernelPreference = KernelPreference.AUTO, dtype: Optional[torch.dtype] = None, + test_only_new_attr: Optional[int] = None, ): shape = qdata.shape kwargs = {} @@ -127,6 +129,7 @@ def __init__( act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, kernel_preference: KernelPreference = KernelPreference.AUTO, dtype: Optional[torch.dtype] = None, + test_only_new_attr: Optional[int] = None, ): super().__init__() self.qdata = qdata @@ -135,6 +138,7 @@ def __init__( self.mm_config = mm_config self.act_quant_kwargs = act_quant_kwargs self.kernel_preference = kernel_preference + self.test_only_new_attr = test_only_new_attr def __repr__(self): return (