diff --git a/examples/models/llama/source_transformation/pre_quantization.py b/examples/models/llama/source_transformation/pre_quantization.py index b6540b7f3ee..d284512e712 100644 --- a/examples/models/llama/source_transformation/pre_quantization.py +++ b/examples/models/llama/source_transformation/pre_quantization.py @@ -44,7 +44,7 @@ def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: # pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module, # Tensor]`. child.out_features, - bias=False, + bias=child.bias is not None, device=child.weight.device, groupsize=group_size, precision=precision,