From fabb11d4f9fe4d867473bd4dcd6c7c09dcd78e4b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 9 Oct 2025 17:17:45 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_inference_workflow.py | 10 ++++++++++ torchao/prototype/mx_formats/inference_workflow.py | 1 + 2 files changed, 11 insertions(+) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 7cc29c7ee1..4fbe44bd2c 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -191,6 +191,16 @@ def test_inference_workflow_nvfp4( f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}" ) + # serialization + with tempfile.NamedTemporaryFile() as f: + torch.save(m_mx.state_dict(), f) + f.seek(0) + + # temporary workaround for https://github.com/pytorch/ao/issues/3077 + torch.serialization.add_safe_globals([getattr]) + + _ = torch.load(f, weights_only=True) + class VLLMIntegrationTestCase(TorchAOIntegrationTestCase): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index d0f1b04119..febaa1750f 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -211,6 +211,7 @@ def _nvfp4_inference_linear_transform( NVFP4MMConfig, MXGemmKernelChoice, QuantizeTensorToMXKwargs, + QuantizeTensorToNVFP4Kwargs, ScaleCalculationMode, ] )