Skip to content

there should be an easy way to check that N torchao configs are equivalent #3062

@vkuzo

Description

@vkuzo

Use case: MoE modules usually have E experts, each with 2 fused linears. MoE kernels today assume all E*2 linear operations are quantized the same way. We need an easy way to do this check in torchao, to know whether the model is quantized with all the experts using the same recipe, to know if we can safely map to a fused MoE kernel.

Today, adding a config to a set does not work because a config is not hashable:

>>> ct
Float8DynamicActivationFloat8WeightConfig(activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, granularity=[PerTensor(), PerTensor()], mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), activation_value_lb=None, activation_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>, set_inductor_config=True, version=2)
>>> s = set()
>>> s.add(cr)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'Float8DynamicActivationFloat8WeightConfig'

We should either make configs hashable (ideal), or provide utilities to check equality of N configs in a simpler way than the user checking it N times themselves.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions