-
Notifications
You must be signed in to change notification settings - Fork 348
Open
Description
Use case: MoE modules usually have E experts, each with 2 fused linears. MoE kernels today assume all E*2 linear operations are quantized the same way. We need an easy way to do this check in torchao, to know whether the model is quantized with all the experts using the same recipe, to know if we can safely map to a fused MoE kernel.
Today, adding a config to a set does not work because a config is not hashable:
>>> ct
Float8DynamicActivationFloat8WeightConfig(activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, granularity=[PerTensor(), PerTensor()], mm_config=Float8MMConfig(emulate=False, use_fast_accum=True, pad_inner_dim=False), activation_value_lb=None, activation_value_ub=None, kernel_preference=<KernelPreference.AUTO: 'auto'>, set_inductor_config=True, version=2)
>>> s = set()
>>> s.add(cr)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'Float8DynamicActivationFloat8WeightConfig'
We should either make configs hashable (ideal), or provide utilities to check equality of N configs in a simpler way than the user checking it N times themselves.
Metadata
Metadata
Assignees
Labels
No labels