diff --git a/torchao/prototype/spinquant/hadamard_utils.py b/torchao/prototype/spinquant/hadamard_utils.py index f3ed3b4290..1a88664c79 100644 --- a/torchao/prototype/spinquant/hadamard_utils.py +++ b/torchao/prototype/spinquant/hadamard_utils.py @@ -237,7 +237,7 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!" W = module.weight.data - if module.bias is not None: + if output and module.bias is not None: B = module.bias.data bias_dtype_orig = B.dtype B = B.float() @@ -248,12 +248,12 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): if output: had_K, K = get_hadK(out_features) W = matmul_hadU(W.t(), had_K.to(W.device), K).t() - if module.bias is not None: + if output and module.bias is not None: B = matmul_hadU(B, had_K.to(B.device), K) else: had_K, K = get_hadK(in_features) W = matmul_hadU(W, had_K.to(W.device), K) - if module.bias is not None: + if output and module.bias is not None: B = matmul_hadU(B, had_K.to(B.device), K) else: if R2 is not None: @@ -268,7 +268,7 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): temp = W.reshape(-1, shape[-1] // had_dim, had_dim) temp = temp.to(torch.float64) @ hadK W = temp.reshape(shape) - if module.bias is not None: + if output and module.bias is not None: shape = B.shape temp = B.reshape(-1, had_dim) temp = temp.to(torch.float64) @ hadK @@ -278,5 +278,5 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None): W = W.t() module.weight.data = W.to(dtype=dtype_orig) - if module.bias is not None: + if output and module.bias is not None: module.bias.data = B.to(dtype=bias_dtype_orig) diff --git a/torchao/prototype/tests/test_spinquant.py b/torchao/prototype/tests/test_spinquant.py new file mode 100644 index 0000000000..f9dce4d9d6 --- /dev/null +++ b/torchao/prototype/tests/test_spinquant.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn as nn + +from torchao.prototype.spinquant.hadamard_utils import apply_exact_had_to_linear + + +class TestSpinQuant(unittest.TestCase): + def test_rotate_in_and_out(self): + """Perform rotation to output of linear layer and inverse rotation to input of next layer; test that the output is the same.""" + with torch.no_grad(): + layer1 = nn.Linear(256, 256, bias=True) + layer2 = nn.Linear(256, 256, bias=True) + model = nn.Sequential(layer1, layer2) + input = torch.rand(256) + output = model(input) + apply_exact_had_to_linear(layer1, output=True) + apply_exact_had_to_linear(layer2, output=False) + new_output = model(input) + torch.testing.assert_allclose(output, new_output)