diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 1b9c369be5..cc6e34a708 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -24,6 +24,7 @@ QuantizeTensorToNVFP4Kwargs, per_tensor_amax_to_scale, ) +from torchao.quantization.quant_api import _quantization_type from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -89,7 +90,7 @@ def __post_init__(self): def _linear_extra_repr(self): - return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={repr(self.weight)}" + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" @register_quantize_module_handler(MXFPInferenceConfig) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 8532337477..0235df1b5a 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -544,6 +544,9 @@ def __repr__(self): # TODO better elem dtype print for fp4 return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self._scale_e8m0}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}" # noqa: E501 + def _quantization_type(self): + return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # avoid circular dependency