Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,59 @@ quantize_(m, config)

## MX inference

Coming soon!
```python
import copy

import torch
import torch.nn as nn
from torchao.quantization import quantize_
from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
)
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)

m = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda")
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)

# mxfp8

m_mxfp8 = copy.deepcopy(m)
config = MXFPInferenceConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
gemm_kernel_choice=MXGemmKernelChoice.CUBLAS,
)
quantize_(m_mxfp8, config=config)
m_mxfp8 = torch.compile(m_mxfp8, fullgraph=True)
y_mxfp8 = m_mxfp8(x)

# mxfp4

m_mxfp4 = copy.deepcopy(m)
config = MXFPInferenceConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
)
quantize_(m_mxfp4, config=config)
m_mxfp4 = torch.compile(m_mxfp4, fullgraph=True)
y_mxfp4 = m_mxfp4(x)

# nvfp4

m_nvfp4 = copy.deepcopy(m)
config = NVFP4InferenceConfig(
mm_config=NVFP4MMConfig.DYNAMIC,
use_dynamic_per_tensor_scale=True,
)
quantize_(m_nvfp4, config=config)
m_nvfp4 = torch.compile(m_nvfp4, fullgraph=True)
y_nvfp4 = m_nvfp4(x)
```

## MXTensor

Expand Down
Loading