diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index ba3d152c90..c3869f2761 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -70,7 +70,59 @@ quantize_(m, config) ## MX inference -Coming soon! +```python +import copy + +import torch +import torch.nn as nn +from torchao.quantization import quantize_ +from torchao.prototype.mx_formats.config import ( + MXGemmKernelChoice, +) +from torchao.prototype.mx_formats.inference_workflow import ( + MXFPInferenceConfig, + NVFP4InferenceConfig, + NVFP4MMConfig, +) + +m = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda") +x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) + +# mxfp8 + +m_mxfp8 = copy.deepcopy(m) +config = MXFPInferenceConfig( + activation_dtype=torch.float8_e4m3fn, + weight_dtype=torch.float8_e4m3fn, + gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, +) +quantize_(m_mxfp8, config=config) +m_mxfp8 = torch.compile(m_mxfp8, fullgraph=True) +y_mxfp8 = m_mxfp8(x) + +# mxfp4 + +m_mxfp4 = copy.deepcopy(m) +config = MXFPInferenceConfig( + activation_dtype=torch.float4_e2m1fn_x2, + weight_dtype=torch.float4_e2m1fn_x2, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, +) +quantize_(m_mxfp4, config=config) +m_mxfp4 = torch.compile(m_mxfp4, fullgraph=True) +y_mxfp4 = m_mxfp4(x) + +# nvfp4 + +m_nvfp4 = copy.deepcopy(m) +config = NVFP4InferenceConfig( + mm_config=NVFP4MMConfig.DYNAMIC, + use_dynamic_per_tensor_scale=True, +) +quantize_(m_nvfp4, config=config) +m_nvfp4 = torch.compile(m_nvfp4, fullgraph=True) +y_nvfp4 = m_nvfp4(x) +``` ## MXTensor