Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ class Float8Tensor(TorchAOBaseTensor):
sharing the same set of quantization parameters (scale), have the same rank as qdata or
is an empty list (representing per tensor quantization)
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
hp_value_lb (Optional[float]): the lower bound for high precision floating point value for calculating scale
hp_value_ub (Optional[float]): the upper bound for high precision floating point value for calculating scale
act_quant_kwargs (QuantizeTensorToFloat8Kwargs): the kwargs for Float8Tensor.from_hp
kernel_preference (KernelPreference): the preference for quantize, mm etc. kernel to use,
by default, this will be chosen for user based on hardware, library availabilities etc.
Expand All @@ -98,8 +96,6 @@ class Float8Tensor(TorchAOBaseTensor):
optional_tensor_attribute_names = [
"block_size",
"mm_config",
"hp_value_lb",
"hp_value_ub",
"act_quant_kwargs",
"kernel_preference",
"dtype",
Expand All @@ -111,8 +107,6 @@ def __new__(
scale: torch.Tensor,
block_size: Optional[List[int]] = None,
mm_config: Optional[Float8MMConfig] = None,
hp_value_lb: Optional[float] = None,
hp_value_ub: Optional[float] = None,
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
kernel_preference: KernelPreference = KernelPreference.AUTO,
dtype: Optional[torch.dtype] = None,
Expand All @@ -130,8 +124,6 @@ def __init__(
scale: torch.Tensor,
block_size: Optional[List[int]] = None,
mm_config: Optional[Float8MMConfig] = None,
hp_value_lb: Optional[float] = None,
hp_value_ub: Optional[float] = None,
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
kernel_preference: KernelPreference = KernelPreference.AUTO,
dtype: Optional[torch.dtype] = None,
Expand All @@ -141,8 +133,6 @@ def __init__(
self.scale = scale
self.block_size = block_size
self.mm_config = mm_config
self.hp_value_lb = hp_value_lb
self.hp_value_ub = hp_value_ub
self.act_quant_kwargs = act_quant_kwargs
self.kernel_preference = kernel_preference

Expand Down Expand Up @@ -248,8 +238,6 @@ def from_hp(
scale,
block_size=block_size,
mm_config=mm_config,
hp_value_lb=hp_value_lb,
hp_value_ub=hp_value_ub,
act_quant_kwargs=act_quant_kwargs,
kernel_preference=kernel_preference,
dtype=hp_dtype,
Expand Down Expand Up @@ -472,8 +460,6 @@ def _(func, types, args, kwargs):
sliced_scale,
block_size,
self.mm_config,
self.hp_value_lb,
self.hp_value_ub,
self.act_quant_kwargs,
self.kernel_preference,
dtype=self.dtype,
Expand Down Expand Up @@ -503,8 +489,6 @@ def _(func, types, args, kwargs):
assert tensor_0.scale.ndim == tensors[i].scale.ndim
assert tensor_0.block_size == tensors[i].block_size
assert tensor_0.mm_config == tensors[i].mm_config
assert tensor_0.hp_value_lb == tensors[i].hp_value_lb
assert tensor_0.hp_value_ub == tensors[i].hp_value_ub
assert tensor_0.act_quant_kwargs == tensors[i].act_quant_kwargs
assert tensor_0.kernel_preference == tensors[i].kernel_preference

Expand All @@ -528,8 +512,6 @@ def _(func, types, args, kwargs):
cat_scale,
block_size,
tensor_0.mm_config,
tensor_0.hp_value_lb,
tensor_0.hp_value_ub,
tensor_0.act_quant_kwargs,
tensor_0.kernel_preference,
tensor_0.dtype,
Expand All @@ -551,8 +533,6 @@ def _(func, types, args, kwargs):
scale,
block_size,
self.mm_config,
self.hp_value_lb,
self.hp_value_ub,
self.act_quant_kwargs,
self.kernel_preference,
self.dtype,
Expand Down Expand Up @@ -603,8 +583,6 @@ def _(func, types, args, kwargs):
scale,
block_size,
self.mm_config,
self.hp_value_lb,
self.hp_value_ub,
self.act_quant_kwargs,
self.kernel_preference,
self.dtype,
Expand All @@ -627,8 +605,6 @@ def _(func, types, args, kwargs):
scale,
block_size,
self.mm_config,
self.hp_value_lb,
self.hp_value_ub,
self.act_quant_kwargs,
self.kernel_preference,
self.dtype,
Expand Down
Loading