diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 7cc29c7ee1..4fbe44bd2c 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -191,6 +191,16 @@ def test_inference_workflow_nvfp4( f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}" ) + # serialization + with tempfile.NamedTemporaryFile() as f: + torch.save(m_mx.state_dict(), f) + f.seek(0) + + # temporary workaround for https://github.com/pytorch/ao/issues/3077 + torch.serialization.add_safe_globals([getattr]) + + _ = torch.load(f, weights_only=True) + class VLLMIntegrationTestCase(TorchAOIntegrationTestCase): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index d0f1b04119..febaa1750f 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -211,6 +211,7 @@ def _nvfp4_inference_linear_transform( NVFP4MMConfig, MXGemmKernelChoice, QuantizeTensorToMXKwargs, + QuantizeTensorToNVFP4Kwargs, ScaleCalculationMode, ] )