-
Notifications
You must be signed in to change notification settings - Fork 375
Open
Labels
bugSomething isn't workingSomething isn't working
Description
I can't infer the Llama-3.2-1B in mx_fp4 using torchao. The error log as below:
Traceback (most recent call last):
File "/home/pt-gpu/xw/infer.py", line 69, in <module>
y = model.generate(input_ids, generation_config, **generate_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/wengshiy/conda/envs/xw/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data/zhaoqion/transformers/src/transformers/generation/utils.py", line 2547, in generate
result = decoding_method(
^^^^^^^^^^^^^^^^
File "/data/zhaoqion/transformers/src/transformers/generation/utils.py", line 2774, in _sample
outputs = self(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/wengshiy/conda/envs/xw/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/wengshiy/conda/envs/xw/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/zhaoqion/transformers/src/transformers/utils/generic.py", line 940, in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/zhaoqion/transformers/src/transformers/models/llama/modeling_llama.py", line 473, in forward
logits = self.lm_head(hidden_states[:, slice_indices, :])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/wengshiy/conda/envs/xw/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/wengshiy/conda/envs/xw/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/wengshiy/conda/envs/xw/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/wengshiy/conda/envs/xw/lib/python3.12/site-packages/torchao/utils.py", line 649, in _dispatch__torch_dispatch__
raise NotImplementedError(
NotImplementedError: MXTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.expand', overload='default')>, types=(<class 'torchao.prototype.mx_formats.mx_tensor.MXTensor'>,), arg_types=(<class 'torchao.prototype.mx_formats.mx_tensor.MXTensor'>, <class 'list'>), kwarg_types={}
And, I create a single reproducer.
import copy
import torch
import torch.nn as nn
from torchao.quantization import quantize_
import torchao.prototype.mx_formats
from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
)
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = torch.nn.Linear(2048, 128256, bias=False)
def forward(self, x):
o = self.m(x)
return o
model = ToyModel()
model = model.eval().to(torch.bfloat16).to("cuda")
config = MXFPInferenceConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
)
quantize_(model, config=config)
x = torch.randn([1, 4, 2048], device="cuda", dtype=torch.bfloat16)[:, slice(-1, None, None), :]
y = model(x)
The envs:
torch 2.10.0.dev20250910+cu128 pypi_0 pypi
torchao 0.15.0.dev20251114+cu128 pypi_0 pypi
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working