From f0ac61cf2a34fa0bfbeeb0989800fc6a5b8f2aef Mon Sep 17 00:00:00 2001 From: Rohan Joshi Date: Fri, 29 Aug 2025 17:55:00 -0700 Subject: [PATCH] SpinQuant rotate bias Summary: Added bias rotation. This is needed to apply SpinQuant R2 to models which have bias such as Qwen models. Differential Revision: D81352249 --- torchao/prototype/spinquant/hadamard_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torchao/prototype/spinquant/hadamard_utils.py b/torchao/prototype/spinquant/hadamard_utils.py index 0b276a0d03..f3ed3b4290 100644 --- a/torchao/prototype/spinquant/hadamard_utils.py +++ b/torchao/prototype/spinquant/hadamard_utils.py @@ -237,6 +237,10 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!" W = module.weight.data + if module.bias is not None: + B = module.bias.data + bias_dtype_orig = B.dtype + B = B.float() dtype_orig = W.dtype W = W.float() @@ -244,9 +248,13 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): if output: had_K, K = get_hadK(out_features) W = matmul_hadU(W.t(), had_K.to(W.device), K).t() + if module.bias is not None: + B = matmul_hadU(B, had_K.to(B.device), K) else: had_K, K = get_hadK(in_features) W = matmul_hadU(W, had_K.to(W.device), K) + if module.bias is not None: + B = matmul_hadU(B, had_K.to(B.device), K) else: if R2 is not None: hadK = R2.to(torch.float64) @@ -260,8 +268,15 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): temp = W.reshape(-1, shape[-1] // had_dim, had_dim) temp = temp.to(torch.float64) @ hadK W = temp.reshape(shape) + if module.bias is not None: + shape = B.shape + temp = B.reshape(-1, had_dim) + temp = temp.to(torch.float64) @ hadK + B = temp.reshape(shape) if output: W = W.t() module.weight.data = W.to(dtype=dtype_orig) + if module.bias is not None: + module.bias.data = B.to(dtype=bias_dtype_orig)