diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index 8eb6ddde11..379d5f7e76 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -16,30 +16,6 @@ _replace_with_custom_fn_if_matches_filter, quantize_, ) -from torchao.quantization.subclass import ( - Int4WeightOnlyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, -) - - -def _int8wo_api(mod, **kwargs): - quantize_(mod, Int8WeightOnlyConfig(**kwargs), set_inductor_config=False) - - -def _int8da_int8w_api(mod, **kwargs): - quantize_( - mod, - Int8DynamicActivationInt8WeightConfig(**kwargs), - set_inductor_config=False, - ) - - -def _int4wo_api(mod, **kwargs): - kwargs_copy = kwargs.copy() - if "groupsize" in kwargs_copy: - kwargs_copy["group_size"] = kwargs_copy["groupsize"] - del kwargs_copy["groupsize"] - quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy), set_inductor_config=False) class ToyLinearModel(torch.nn.Module): @@ -117,38 +93,18 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): return _ref_change_linear_weights_to_woqtensors -_ref_change_linear_weights_to_int8_woqtensors = ( - _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) -) -_ref_change_linear_weights_to_int4_woqtensors = ( - _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) -) - - torch._dynamo.config.cache_size_limit = 50000 @torch.no_grad -def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): - if kwargs is None: - kwargs = {} - +def _bench_quantized_tensor_subclass_perf(api, config, M, N, K): m = ToyLinearModel( M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda" ).eval() m_bf16 = copy.deepcopy(m) - m_ref = copy.deepcopy(m) example_inputs = m.example_inputs() - api(m, **kwargs) - - # reference - ref_api(m_ref, **kwargs) - - res = m(*example_inputs) - ref = m_ref(*example_inputs) - - assert torch.equal(res, ref) + api(m, config) # Pass both model and config # perf comparison from torchao.utils import benchmark_model @@ -158,22 +114,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): RUNS = 100 torch._dynamo.reset() - m_ref = torch.compile(m_ref, mode="max-autotune", fullgraph=True) - benchmark_model(m_ref, WARMUP, example_inputs) - ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs) + m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True) + benchmark_model(m_bf16, WARMUP, example_inputs) + bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) torch._dynamo.reset() m = torch.compile(m, mode="max-autotune", fullgraph=True) benchmark_model(m, WARMUP, example_inputs) elapsed_time = benchmark_model(m, RUNS, example_inputs) - torch._dynamo.reset() - m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True) - benchmark_model(m_bf16, WARMUP, example_inputs) - bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) - print( - f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}" + f"{(M, N, K)}: elapsed time: {elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}" ) @@ -182,24 +133,32 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): (20, 2048, 2048), ] - print("_int8da_int8w_api") - + print("Int8DynamicActivationInt8WeightConfig") for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( - _int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K + quantize_, + Int8DynamicActivationInt8WeightConfig(), + M, + N, + K, ) - print("_int8wo_api") - + print("Int8WeightOnlyConfig") for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( - _int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K + quantize_, + Int8WeightOnlyConfig(), + M, + N, + K, ) - print("_int4wo_api") - kwargs = {"groupsize": 32, "version": 1} - + print("Int4WeightOnlyConfig") for M, N, K in all_shapes: _bench_quantized_tensor_subclass_perf( - _int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs + quantize_, + Int4WeightOnlyConfig(group_size=32), + M, + N, + K, ) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index f99cf4a1b4..bcb6347f0c 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -48,7 +48,9 @@ ) from torchao.quantization.quant_primitives import ( MappingType, + choose_qparams_affine, dequantize_affine, + quantize_affine, ) from torchao.quantization.smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -56,20 +58,13 @@ smooth_fq_linear_to_inference, swap_linear_with_smooth_fq_linear, ) -from torchao.quantization.subclass import ( - Int4WeightOnlyQuantizedLinearWeight, - Int8DynamicallyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, -) from torchao.quantization.utils import ( LoggingTensorMode, _apply_logging_hook, _fqn_to_op_to_shape_to_count, - _quant_int8_dynamic_per_token_linear, _quantize_activation_per_token_absmax, compute_error, dequantize_per_channel, - dynamically_quantize_per_channel, ) from torchao.quantization.utils import ( compute_error as SQNR, @@ -476,9 +471,27 @@ def _test_dynamic_quant_per_channel_numerics_impl( # torch.aminmax support half on cpu x = torch.randn(16, 32, device=device, dtype=float_dtype) - y_vals, y_scale, y_zero_point = dynamically_quantize_per_channel( - x, qmin, qmax, int_dtype + + eps = torch.finfo(torch.float32).eps + block_size = (1, x.shape[1]) + zero_point_dtype = torch.int64 + + mapping_type = MappingType.SYMMETRIC + scale, zero_point = choose_qparams_affine( + x, + mapping_type, + block_size, + target_dtype=int_dtype, + quant_min=qmin, + quant_max=qmax, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + y_vals = quantize_affine( + x, block_size, scale, zero_point, int_dtype, qmin, qmax ) + y_scale = scale + y_zero_point = zero_point min_val, max_val = torch.aminmax(x, dim=1) @@ -561,30 +574,6 @@ def test_quantize_per_token_xpu(self): for dtype in (torch.float32, torch.float16, torch.bfloat16): self._test_quantize_per_token_impl("xpu", dtype) - def _test_per_token_linear_impl(self, device, dtype): - x = torch.randn(2, 16, 8, device=device, dtype=dtype) - w = torch.randn(16, 8, device=device, dtype=dtype) - wq, w_scales, _w_zp = dynamically_quantize_per_channel(w, -127, 127, torch.int8) - # Note: need to make the weight contiguous because we are - # testing in eager mode and cuBlas will not give correct results - # for a transposed weight - y = _quant_int8_dynamic_per_token_linear( - x, wq.t().contiguous(), w_scales, None, dtype - ) - y_ref = torch.matmul(x, w.t()) - sqnr = compute_error(y_ref, y) - self.assertTrue(sqnr >= 42.0) - - def test_per_token_linear_cpu(self): - for dtype in (torch.float32,): - self._test_per_token_linear_impl("cpu", dtype) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @skip_if_rocm("ROCm enablement in progress") - def test_per_token_linear_cuda(self): - for dtype in (torch.float32, torch.float16, torch.bfloat16): - self._test_per_token_linear_impl("cuda", dtype) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test__int_mm(self): # TODO(future): figure out what here needs to move to PT core, @@ -681,62 +670,6 @@ def _test_dequantize_impl( f"{lin.weight.__class__.__name__} failed transpose on dtype={test_dtype}", ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_dequantize_int8_dynamic_quant_subclass(self, device, dtype): - self._test_dequantize_impl( - Int8DynamicallyQuantizedLinearWeight.from_float, - device, - 35, - test_dtype=dtype, - ) - - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype): - self._test_dequantize_impl( - Int8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype - ) - - @parameterized.expand(COMMON_DEVICE_DTYPE) - @skip_if_rocm("ROCm enablement in progress") - def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype): - if device == "cpu": - self.skipTest(f"Temporarily skipping for {device}") - if dtype != torch.bfloat16: - self.skipTest("Currently only supports bfloat16.") - for test_shape in [(16, 1024, 16)] + ( - [(1, 1024, 8)] if device == "cuda" else [] - ): - self._test_dequantize_impl( - Int4WeightOnlyQuantizedLinearWeight.from_float, - device, - 15, - test_shape=test_shape, - test_dtype=dtype, - ) - - @parameterized.expand(COMMON_DEVICE_DTYPE) - @skip_if_rocm("ROCm enablement in progress") - def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype): - if device == "cpu": - self.skipTest(f"Temporarily skipping for {device}") - if dtype != torch.bfloat16: - self.skipTest("Currently only supports bfloat16.") - m_shapes = [16, 256] + ([1] if device == "cuda" else []) - n_shapes = [16] + ([8, 13] if device == "cuda" else []) - for groupsize in [256, 128]: - for inner_k_tiles in [8, 4, 2]: - for m in m_shapes: - for n in n_shapes: - self._test_dequantize_impl( - lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float( - w, groupsize, inner_k_tiles - ), - device, - 15, - test_shape=[m, 256, n], - test_dtype=dtype, - ) - @run_supported_device_dtype def _test_lin_weight_subclass_impl( self, @@ -771,22 +704,6 @@ def _test_lin_weight_subclass_impl( f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}", ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_int8_dynamic_quant_subclass(self, device, dtype): - self._test_lin_weight_subclass_impl( - Int8DynamicallyQuantizedLinearWeight.from_float, - device, - 35, - test_dtype=dtype, - ) - - @parameterized.expand(COMMON_DEVICE_DTYPE) - def test_int8_weight_only_quant_subclass(self, device, dtype): - undo_recommended_configs() - self._test_lin_weight_subclass_impl( - Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype - ) - @parameterized.expand(COMMON_DEVICE_DTYPE) def test_aq_int8_dynamic_quant_subclass(self, device, dtype): self._test_lin_weight_subclass_impl( @@ -891,46 +808,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype test_dtype=dtype, ) - @parameterized.expand(COMMON_DEVICE_DTYPE) - @skip_if_rocm("ROCm enablement in progress") - def test_int4_weight_only_quant_subclass(self, device, dtype): - if device == "cpu": - self.skipTest(f"Temporarily skipping for {device}") - if dtype != torch.bfloat16: - self.skipTest(f"Fails for {dtype}") - for test_shape in [(16, 1024, 16)] + ( - [(1, 1024, 8)] if device == "cuda" else [] - ): - self._test_lin_weight_subclass_impl( - Int4WeightOnlyQuantizedLinearWeight.from_float, - device, - 10, - test_shape=test_shape, - test_dtype=dtype, - ) - - @parameterized.expand(COMMON_DEVICE_DTYPE) - @skip_if_rocm("ROCm enablement in progress") - @unittest.skip("Skip to fix CI until we deprecate these APIs long term") - def test_int4_weight_only_quant_subclass_grouped(self, device, dtype): - if dtype != torch.bfloat16: - self.skipTest(f"Fails for {dtype}") - m_shapes = [16, 256] + ([1] if device == "cuda" else []) - n_shapes = [16] + ([8, 13] if device == "cuda" else []) - for groupsize in [128, 64]: - for inner_k_tiles in [8, 4, 2]: - for m in m_shapes: - for n in n_shapes: - self._test_lin_weight_subclass_impl( - lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float( - w, groupsize, inner_k_tiles - ), - device, - 10, - test_shape=[m, 256, n], - test_dtype=dtype, - ) - @torch.no_grad() @run_supported_device_dtype def _test_lin_weight_subclass_api_impl( @@ -1120,7 +997,6 @@ def test_dynamic_quant(self): sqnr = compute_error(y_ref, y_test) self.assertGreater(sqnr, 40.0) - # self.assertTrue(isinstance(m[0], DynamicallyPerAxisQuantizedLinear)) class TestWeightOnlyInt8Quant(unittest.TestCase): @@ -1324,30 +1200,6 @@ def test_save_load_int4woqtensors(self, device, dtype): self._test_handle_save_load_meta_impl(_int4wo_api, device, 20, test_dtype=dtype) -class TorchCompileUnitTest(unittest.TestCase): - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_fullgraph(self): - lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) - lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( - lin_fp16, alpha=0.25 - ) - - x0 = torch.randn(17, 1, 32, device="cuda", dtype=torch.float16) - - # calibrate - _ = lin_smooth(x0) - - # inference - lin_smooth.to_inference() - - # torch.compile - lin_smooth_opt = torch.compile(lin_smooth, fullgraph=True) - # print(lin_smooth_opt) - - lin_smooth_opt(x0) - # print(y) - - class UtilsUnitTest(unittest.TestCase): def test_shape_logger(self): x = torch.randn(4, 4) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index d7d2b4a5b4..577ca6789a 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -59,10 +59,6 @@ _replace_with_custom_fn_if_matches_filter, ) from torchao.quantization.quant_primitives import MappingType -from torchao.quantization.subclass import ( - Int4WeightOnlyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, -) from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( @@ -167,14 +163,6 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): return _ref_change_linear_weights_to_woqtensors -_ref_change_linear_weights_to_int8_woqtensors = ( - _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight) -) -_ref_change_linear_weights_to_int4_woqtensors = ( - _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) -) - - class TestQuantFlow(TestCase): GPU_DEVICES = (["cuda"] if torch.cuda.is_available() else []) + ( ["xpu"] if torch.xpu.is_available() else [] @@ -446,54 +434,6 @@ def test_quantized_tensor_subclass_8da4w(self, mapping_type): ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(len(GPU_DEVICES) == 0, "Need GPU available") - def test_quantized_tensor_subclass_int4(self): - for device in self.GPU_DEVICES: - # use 1024 so that we don't need padding - m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to(device) - m_copy = copy.deepcopy(m) - example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device) - - group_size = 32 - if device == "xpu": - quantize_( - m, - Int4WeightOnlyConfig( - group_size=group_size, layout=Int4XPULayout(), version=1 - ), - ) - else: - quantize_(m, Int4WeightOnlyConfig(group_size=group_size, version=1)) - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - assert isinstance(m.linear2.weight, AffineQuantizedTensor) - - # reference - _ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size) - - res = m(*example_inputs) - ref = m_copy(*example_inputs) - - self.assertTrue(torch.equal(res, ref)) - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_quantized_tensor_subclass_int8_wo(self): - m = ToyLinearModel().eval().to(torch.bfloat16) - m_copy = copy.deepcopy(m) - example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - - quantize_(m, Int8WeightOnlyConfig()) - - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - assert isinstance(m.linear2.weight, AffineQuantizedTensor) - - # reference - _ref_change_linear_weights_to_int8_woqtensors(m_copy) - - res = m(*example_inputs) - ref = m_copy(*example_inputs) - - self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_save_load(self): m = ToyLinearModel().eval().to(torch.bfloat16) diff --git a/torchao/prototype/quantization/autoquant_v2.py b/torchao/prototype/quantization/autoquant_v2.py index 1240bbacd0..1851b29208 100644 --- a/torchao/prototype/quantization/autoquant_v2.py +++ b/torchao/prototype/quantization/autoquant_v2.py @@ -40,11 +40,6 @@ MappingType, ZeroPointDomain, ) -from torchao.quantization.subclass import ( # noqa - Int8DynamicallyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, - QuantizedLinearWeightBase, -) from torchao.quantization.utils import _quantize_activation_per_token_absmax from torchao.utils import ( TorchAOBaseTensor, @@ -80,7 +75,6 @@ def _is_linear(mod, *args): return ( isinstance(mod, torch.nn.Linear) and hasattr(mod, "weight") - and not isinstance(mod.weight, QuantizedLinearWeightBase) and not isinstance(mod.weight, AutoQuantizableLinearWeightV1) and not isinstance(mod.weight, AffineQuantizedTensor) and not isinstance(mod.weight, LinearActivationQuantizedTensor) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index c8774e9426..bdb1d90c04 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -100,21 +100,11 @@ IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) -from .smoothquant import ( - SmoothFakeDynamicallyQuantizedLinear, - SmoothFakeDynQuantMixin, - get_scale, - set_smooth_fq_attribute, - smooth_fq_linear_to_inference, - swap_linear_with_smooth_fq_linear, -) -from .subclass import * # noqa: F403 from .transform_module import register_quantize_module_handler from .unified import Quantizer, TwoStepQuantizer from .utils import ( compute_error, ) -from .weight_only import WeightOnlyInt8QuantLinear # TODO: remove after migration of APIs are done AOPerModuleConfig = ModuleFqnToConfig @@ -172,13 +162,6 @@ "Int4TilePackedTo4dTensor", "Float8Tensor", "Int4OpaqueTensor", - # smooth quant - subject to change - "get_scale", - "SmoothFakeDynQuantMixin", - "SmoothFakeDynamicallyQuantizedLinear", - "swap_linear_with_smooth_fq_linear", - "smooth_fq_linear_to_inference", - "set_smooth_fq_attribute", "compute_error", # building blocks "to_linear_activation_quantized", @@ -210,7 +193,6 @@ "Int4WeightOnlyQuantizer", "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightLinear", - "WeightOnlyInt8QuantLinear", "TwoStepQuantizer", "Quantizer", # Layouts for quant_api diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index eb19a00923..be9d546c66 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -41,11 +41,6 @@ PerRow, PerTensor, ) -from .subclass import ( # noqa - Int8DynamicallyQuantizedLinearWeight, - Int8WeightOnlyQuantizedLinearWeight, - QuantizedLinearWeightBase, -) __all__ = [ "AutoQuantizableLinearWeight", diff --git a/torchao/quantization/dynamic_quant.py b/torchao/quantization/dynamic_quant.py deleted file mode 100644 index 5c6ee9c8f9..0000000000 --- a/torchao/quantization/dynamic_quant.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn - -from .utils import ( - _quant_int8_dynamic_per_token_linear, - dynamically_quantize_per_channel, -) - -__all__ = ["DynamicallyPerAxisQuantizedLinear"] - - -class DynamicallyPerAxisQuantizedLinear(torch.nn.Linear): - """ - This class is a replacement for `torch.nn.Linear`. It implements a - quantized matmul using int8 dynamic symmetric per-token activation, - and int8 symmetric per-channel weight quantization - """ - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - ) -> None: - super().__init__(in_features, out_features, bias) - - def forward(self, X: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """ - Performs the forward pass of the quantized linear layer which consists - of int8 dynamic symmetric per-token activation and int8 symmetric per-channel weight - quantization - - Args: - X (torch.Tensor): The input floating point tensor to the quantized linear layer. - - Returns: - torch.Tensor: The output floating point tensor after the quantized matmul and rescale. - - """ - - Y = _quant_int8_dynamic_per_token_linear( - X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype - ) - return Y - - @classmethod - def from_float(cls, mod: torch.nn.Linear) -> "DynamicallyPerAxisQuantizedLinear": - """ - Converts a `mod` of class `torch.nn.Linear` to the - `DynamicallyPerAxisQuantizedLinear` class - - Args: - mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert. - - Returns: - DynamicallyPerAxisQuantizedLinear: The converted quantized linear module. - - """ - - # create the new module with a toy size to ensure initialization is fast - fake_in_features, fake_out_features = 8, 8 - new_mod = cls( - fake_in_features, - fake_out_features, - bias=mod.bias is not None, - ) - new_mod.in_features = mod.in_features - new_mod.out_features = mod.out_features - W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel( - mod.weight, -128, 127, torch.int8 - ) - new_mod.register_buffer("W_int_repr_t", W_int_repr.contiguous().t()) - new_mod.W_scales = nn.Parameter(W_scales) - new_mod.bias = mod.bias - del new_mod.weight - - device_to_use = next(mod.parameters()).device - new_mod.to(device_to_use) - return new_mod diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3bda8f91ab..643153e78d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -130,9 +130,6 @@ ZeroPointDomain, quantize_affine, ) -from .subclass import ( - QuantizedLinearWeightBase, -) from .unified import Quantizer, TwoStepQuantizer from .utils import _get_per_token_block_size @@ -289,7 +286,6 @@ def _is_linear(mod, *args): return ( isinstance(mod, torch.nn.Linear) and hasattr(mod, "weight") - and not isinstance(mod.weight, QuantizedLinearWeightBase) and not isinstance(mod.weight, AutoQuantizableLinearWeight) and not isinstance(mod.weight, AffineQuantizedTensor) and not isinstance(mod.weight, LinearActivationQuantizedTensor) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py deleted file mode 100644 index caffef7b58..0000000000 --- a/torchao/quantization/subclass.py +++ /dev/null @@ -1,702 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -import torch -from torch.utils._python_dispatch import return_and_correct_aliasing - -from torchao.quantization.utils import ( - _quant_int8_dynamic_per_token_linear, - dequantize_per_channel, - dynamically_quantize_per_channel, - groupwise_affine_quantize_tensor, - unpack_tinygemm_scales_and_zeros, -) -from torchao.utils import ( - check_cpu_version, - check_xpu_version, - find_multiple, -) - -from .quant_primitives import ( - ZeroPointDomain, -) - -__all__ = [ - "Int8DynamicallyQuantizedLinearWeight", - "Int8WeightOnlyQuantizedLinearWeight", - "Int4WeightOnlyQuantizedLinearWeight", -] - - -aten = torch.ops.aten - - -class QuantizedLinearWeightBase(torch.Tensor): - """ - Base quantized tensor subclass for quantized linear weights. When the from_float method is used, - to create an instance of any QuantizedLinearWeightBase, we assume the input - weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels. - - The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, - regardless of the internal representation's type or orientation. - """ - - @staticmethod - def __new__(cls, int_data, transposed, shape, *args, **kwargs): - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - assert "dtype" in kwargs - assert not kwargs.get("requires_grad", False) - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, int_data, transposed, *args, **kwargs): - self.int_data = int_data - - self.transposed = transposed - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - pass - - def __repr__(self): - return ( - f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " - f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" - ) - - def dequantize(self): - pass - - def int_repr(self): - pass - - def q_params(self): - pass - - def half(self): - return self.to(torch.float16) - - def _get_to_kwargs(self, *args, **kwargs): - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - memory_format = ( - memory_format if memory_format is not None else torch.preserve_format - ) - kwargs = { - "device": device, - "dtype": dtype, - "memory_format": memory_format, - } - return kwargs - - def _apply_fn_to_data(self, fn): - pass - - def _change_shape(self): - pass - - def __tensor_flatten__(self): - pass - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - pass - - @classmethod - def from_float(cls, input_float): - pass - - # __torch_function__ = torch._C._disabled_torch_function_impl - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - kwargs = {} if kwargs is None else kwargs - - if func is torch.nn.functional.linear: - mat1, w_qtensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - assert not w_qtensor.transposed - return cls._quantized_op(mat1, w_qtensor, bias) - - try: - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - except Exception: - print(f"ERR: subclass doesn't implement {func}") - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - # two scenarios where we currently fall back to vanilla mm: - # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation - # for consistency and to allow people to test - # 2 - we're given non-floats - quantizing long to int8 is crazy - if ( - func in [aten.mm.default, aten.addmm.default] - and args[0].is_floating_point() - and args[0].is_cuda - ): - if func == aten.addmm.default: - assert args[1].shape[-1] == args[2].shape[0], ( - f"need mat1 shape: {args[1].shape} final" - f"dim to match mat2 shape: {args[2].shape} first dim " - ) - mat1, w_qtensor, bias = ( - args[1], - args[2], - args[0], - ) - else: - assert args[0].shape[-1] == args[1].shape[0], ( - f"need mat1 shape: {args[0].shape} final dim" - f"to match mat2 shape: {args[1].shape} first dim" - ) - mat1, w_qtensor, bias = ( - args[0], - args[1], - None if len(args) == 2 else args[2], - ) - # call the quantized op for the specific type - # of quantized tensor subclass - return cls._quantized_op(mat1, w_qtensor, bias) - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - if func is aten.t.default: - args[0].transposed = not args[0].transposed - new = args[0]._change_shape(args[0].shape[::-1]) - return return_and_correct_aliasing(func, args, kwargs, new) - - if func is aten._to_copy.default: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), - ) - - -class ConstructTensorSubclass(torch.nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - self.args = args - self.kwargs = kwargs - - def forward(self, x): - pass - - def right_inverse(self, tensor_subclass_instance): - fields, _ = tensor_subclass_instance.__tensor_flatten__() - return [getattr(tensor_subclass_instance, field) for field in fields] - - -@torch._dynamo.allow_in_graph -def from_qtensor_components_int8dyn(*args, **kwargs): - return Int8DynamicallyQuantizedLinearWeight(*args, **kwargs) - - -class ConstructTensorSubclassInt8Dyn(ConstructTensorSubclass): - def forward(self, int_data, q_scales): - return from_qtensor_components_int8dyn( - int_data, q_scales, *self.args, **self.kwargs - ) - - -class Int8DynamicallyQuantizedLinearWeight(QuantizedLinearWeightBase): - """ - A Tensor subclass that when applied to a weight used in a linear op/module, changes the - linear op to a dynamically quantized linear op with symmetric per-token and per-channel - quantization on the activation and weight respectively. - """ - - subclass_constructor = ConstructTensorSubclassInt8Dyn - - @staticmethod - def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs): - if dtype is None: - dtype = q_scales.dtype - kwargs["dtype"] = dtype - return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] - - def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs): - self.q_scales = q_scales - super().__init__(int_data, transposed) - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - return _quant_int8_dynamic_per_token_linear( - act_mat, w_qtensor.int_data, w_qtensor.q_scales, bias, act_mat.dtype - ) - - def dequantize(self, dtype=None): - """ - Obtain the dequantized version of the quantized tensor subclass - """ - zero_points = torch.zeros( - self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype - ) - # zero_points = 0 - # TODO: fix dtype here? `to(self.dtype)` is not overwritten by `dtype` arg? - dq_t = dequantize_per_channel( - self.int_data.t(), - self.q_scales, - zero_points, - self.dtype if dtype is None else dtype, - ).to(self.dtype) - # data was transposed to dequantize so make sure shape is correct - return dq_t if not self.transposed else dq_t.t() - - def int_repr(self): - """ - Get the internal integer representation of the quantized tensor - """ - return self.int_data if self.transposed else self.int_data.t() - - def q_params(self): - """ - Get the quantization scales for the quantized tensor - """ - return {"q_scales": self.q_scales} - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.q_scales.to(kwargs["device"]), - self.transposed, - self.shape, - **kwargs, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.q_scales), - self.transposed, - self.shape, - dtype=self.dtype, - ) - - # `QuantizedLinearWeightBase` inconsistently. - - def _change_shape(self, shape): - return self.__class__( - self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype - ) - - def __tensor_flatten__(self): - # note: the order of args must match the order of args in __init__ - return ["int_data", "q_scales"], [self.transposed, self.shape, self.dtype] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None - ): - int_data, q_scales = tensor_data_dict["int_data"], tensor_data_dict["q_scales"] - transposed, shape, dtype = tensor_attributes - return cls( - int_data, - q_scales, - transposed, - shape if outer_size is None else outer_size, - dtype=dtype, - strides=outer_stride, - ) - - @classmethod - def from_float(cls, input_float, qmin=-128, qmax=127, dtype=None): - """ - Method used to convert a linear weight tensor to an instance of the - Int8DynamicallyQuantizedLinearWeight subclass. - - Example usage:: - - model.lin_mod.weight = ( - Int8DynamicallyQuantizedLinearWeight.from_float(model.lin_mod.weight) - ) - """ - if dtype is None: - dtype = input_float.dtype - - # because we call transpose in dequantization - w_int_repr, w_scales, _ = dynamically_quantize_per_channel( - input_float, qmin, qmax, torch.int8 - ) - # the desired representation shape for fast quantized matmul is - # transposed compared to how it's stored as a linear weight, - # i.e. we want in_channels as dim=0 and out_channels (and quantized axis) as dim=1 - # however the external representation of our tensor will maintain the correct - # shape attribute which needs to be tracked directly. - int_data = w_int_repr.contiguous().t() - if not issubclass(cls, Int8DynamicallyQuantizedLinearWeight): - int_data = int_data.contiguous() - return cls( - int_data, - w_scales, - False, - input_float.shape, - dtype=dtype, - ) - - -@torch._dynamo.allow_in_graph -def from_qtensor_components_int8wo(*args, **kwargs): - return Int8WeightOnlyQuantizedLinearWeight(*args, **kwargs) - - -class ConstructTensorSubclassInt8wo(ConstructTensorSubclass): - def forward(self, int_data, q_scales): - return from_qtensor_components_int8wo( - int_data, q_scales, *self.args, **self.kwargs - ) - - -class Int8WeightOnlyQuantizedLinearWeight(Int8DynamicallyQuantizedLinearWeight): - """ - A Tensor subclass that when applied to a weight used in a linear op/module, - changes the linear op to a weight-only quantized linear op with symmetric - per-channel quantization on the weight. - """ - - subclass_constructor = ConstructTensorSubclassInt8wo - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - orig_dtype = act_mat.dtype - y = ( - torch.mm( - act_mat.reshape(-1, act_mat.shape[-1]), - w_qtensor.int_data.to(act_mat.dtype), - ) - * w_qtensor.q_scales - ) - y = y.reshape(*act_mat.shape[:-1], y.shape[-1]) - if bias is not None: - y += bias - return y.to(orig_dtype) - - -@torch._dynamo.allow_in_graph -def from_qtensor_components_int4wo(*args, **kwargs): - return Int4WeightOnlyQuantizedLinearWeight(*args, **kwargs) - - -class ConstructTensorSubclassInt4wo(ConstructTensorSubclass): - def forward(self, int_data, scales_and_zeros): - return from_qtensor_components_int4wo( - int_data, scales_and_zeros, *self.args, **self.kwargs - ) - - -class Int4WeightOnlyQuantizedLinearWeight(QuantizedLinearWeightBase): - """ - A Tensor subclass that when applied to a weight used in a linear op/module, - changes that linear op to a weight-only int4 quantized linear op with groupwise - affine quantization on the weight. - """ - - subclass_constructor = ConstructTensorSubclassInt4wo - - @staticmethod - def __new__( - cls, - int_data, - scales_and_zeros, - transposed, - shape, - groupsize=128, - inner_k_tiles=8, - zero_point_domain=ZeroPointDomain.FLOAT, - preserve_zero=False, - dtype=None, - **kwargs, - ): - if dtype is None: - dtype = scales_and_zeros.dtype - kwargs["dtype"] = dtype - return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data, - scales_and_zeros, - transposed, - shape, - groupsize, - inner_k_tiles, - zero_point_domain, - preserve_zero, - dtype, - **kwargs, - ): - # the transposed flag tracks whether the tensor subclass has been transposed relative - # to how a weight is normally stored in a linear i.e. [out_features, in_features]. - # tracking both transposed and shape is slightly redundant but corner cases like - # square matrices can cause issues otherwise - - self.scales_and_zeros = scales_and_zeros - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.zero_point_domain = zero_point_domain - self.preserve_zero = preserve_zero - super().__init__(int_data, transposed) - - @staticmethod - def _quantized_op(act_mat, w_qtensor, bias): - orig_act_size = act_mat.size() - orig_dtype = act_mat.dtype - - # reshape and pad activation - act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) - pad_size = find_multiple(act_mat.shape[-1], 1024) - act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) - - # matmul - if check_cpu_version(act_mat.device): - y = aten._weight_int4pack_mm_for_cpu( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros, - ) - elif check_xpu_version(act_mat.device): - if not w_qtensor.zero_point_domain == ZeroPointDomain.INT: - y = aten._weight_int4pack_mm( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros, - ) - else: - y = aten._weight_int4pack_mm_with_scales_and_zeros( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros[0], - w_qtensor.scales_and_zeros[1], - ) - else: - y = aten._weight_int4pack_mm( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros, - ) - - # remove out_feature padding - orig_out_features = ( - w_qtensor.shape[-1] if w_qtensor.transposed else w_qtensor.shape[-2] - ) - y = y[:, :orig_out_features] - - y = y.reshape(*orig_act_size[:-1], orig_out_features) - if bias is not None: - y += bias - return y.to(orig_dtype) - - def dequantize(self): - eye_shape = self.shape[1] if not self.transposed else self.shape[0] - w_dq = self._quantized_op( - torch.eye(eye_shape, device=self.device, dtype=self.dtype), self, None - ) - # we dequantized using linear with the identity matrix, output has shape [in_channels, out_channels] - # so we need to transpose back to get the original shape unless self.transposed is set. - w_dq = w_dq if self.transposed else w_dq.t() - return w_dq.to(self.dtype) - - def int_repr(self): - return self.int_data - - def q_params(self): - scales, zero_points = unpack_tinygemm_scales_and_zeros( - self.scales_and_zeros, - ) - return {"q_scales": scales, "q_zero_points": zero_points} - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.scales_and_zeros.to(kwargs["device"]), - self.transposed, - self.shape, - self.groupsize, - self.inner_k_tiles, - self.zero_point_domain, - self.preserve_zero, - **kwargs, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.scales_and_zeros), - self.transposed, - self.shape, - self.groupsize, - self.inner_k_tiles, - self.zero_point_domain, - self.preserve_zero, - dtype=self.dtype, - ) - - # `QuantizedLinearWeightBase` inconsistently. - - def _change_shape(self, shape): - return self.__class__( - self.int_data, - self.scales_and_zeros, - self.transposed, - shape, - self.groupsize, - self.inner_k_tiles, - self.zero_point_domain, - self.preserve_zero, - dtype=self.dtype, - ) - - def __tensor_flatten__(self): - return ["int_data", "scales_and_zeros"], ( - self.transposed, - self.shape, - self.groupsize, - self.inner_k_tiles, - self.zero_point_domain, - self.preserve_zero, - self.dtype, - ) - - @classmethod - - # `QuantizedLinearWeightBase` inconsistently. - - def __tensor_unflatten__( - cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None - ): - int_data, scales_and_zeros = ( - tensor_data_dict["int_data"], - tensor_data_dict["scales_and_zeros"], - ) - ( - transposed, - shape, - groupsize, - inner_k_tiles, - zero_point_domain, - preserve_zero, - dtype, - ) = attributes - return cls( - int_data, - scales_and_zeros, - transposed, - shape if outer_size is None else outer_size, - groupsize, - inner_k_tiles, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - dtype=dtype, - strides=outer_stride, - ) - - @classmethod - def from_float( - cls, - input_float, - groupsize=128, - inner_k_tiles=8, - zero_point_domain=ZeroPointDomain.FLOAT, - preserve_zero=False, - dtype=None, - ): - """ - Method used to convert a linear weight tensor to an instance of the - Int4WeightOnlyQuantizedLinearWeight subclass. - - Example usage:: - - model.lin_mod.weight = ( - Int4WeightOnlyQuantizedLinearWeight.from_float(model.lin_mod.weight) - ) - """ - if dtype is None: - dtype = input_float.dtype - - int_data, scales_and_zeros, transposed, groupsize, inner_k_tils = ( - cls.to_qtensor_components( - input_float, - groupsize, - inner_k_tiles, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - ) - ) - return cls( - int_data, - scales_and_zeros, - transposed, - input_float.shape, - groupsize, - inner_k_tiles, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - dtype=dtype, - ) - - @classmethod - def to_qtensor_components( - cls, - input_float, - groupsize=128, - inner_k_tiles=8, - zero_point_domain=ZeroPointDomain.FLOAT, - preserve_zero=False, - ): - assert groupsize in [256, 128, 64, 32] - assert inner_k_tiles in [8, 4, 2] - orig_out_features, orig_in_features = input_float.shape - - # padding - in_features = find_multiple(orig_in_features, 1024) - out_features = find_multiple(orig_out_features, 8) - input_float = torch.nn.functional.pad( - input_float, - (0, in_features - orig_in_features, 0, out_features - orig_out_features), - ) - - # quantization and packing - input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor( - input_float, - 4, - groupsize, - dtype=input_float.dtype, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - ) - if check_cpu_version(input_float.device): - int_data = aten._convert_weight_to_int4pack_for_cpu( - input_int4x8, inner_k_tiles - ) - else: - int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) - return int_data, scales_and_zeros, False, groupsize, inner_k_tiles diff --git a/torchao/quantization/weight_only.py b/torchao/quantization/weight_only.py deleted file mode 100644 index fb30c14936..0000000000 --- a/torchao/quantization/weight_only.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch - -from .utils import dynamically_quantize_per_channel - -__all__ = ["WeightOnlyInt8QuantLinear"] - - -class WeightOnlyInt8QuantLinear(torch.nn.Linear): - """ - This class is a replacement for `torch.nn.Linear`. It implements a - mixed dtype matrix multiplication using int8 symmetric per-channel weight quantization. - - The primary goal of this class is to leverage int8 quantization for weights to reduce the - memory footprint and computational requirements while performing linear transformations. - This can be particularly beneficial for deploying models in low latency environments - - Attributes: - w_int8 (torch.Tensor): The quantized weights in int8 format. - scales (torch.Tensor): The scaling factors for each channel to convert the quantized - weights back to floating point format during the forward pass. - """ - - def __init__(self, *args, **kwargs): - """ - Initializes the WeightOnlyInt8QuantLinear module. - - Args: - *args: Variable length argument list for `torch.nn.Linear`. - **kwargs: Arbitrary keyword arguments. - Must include 'w_int8' (int8 quantized weights) and 'scales' (scaling factors). - """ - w_int8 = kwargs.pop("w_int8") - scales = kwargs.pop("scales") - super().__init__(*args, **kwargs) - - self.register_buffer("w_int8", w_int8) - self.register_buffer("scales", scales) - - def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """ - Performs the forward pass of the quantized linear layer, which consists of - mixed dtype matrix multiplication using int8 symmetric per-channel weight quantization. - - Args: - x (torch.Tensor): The input floating point tensor to the quantized linear layer. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - torch.Tensor: The output floating point tensor after the quantized matrix multiplication - and rescale. - """ - x_view = x.view(-1, x.shape[-1]) - y = torch.mm(x_view, self.w_int8.to(x.dtype)) * self.scales - y = y.reshape(*x.shape[:-1], -1) - if self.bias is not None: - y += self.bias - return y - - @classmethod - def from_float(cls, mod: torch.nn.Linear): - """ - Converts a `torch.nn.Linear` module to a `WeightOnlyInt8QuantLinear` module. - - This method performs the conversion by dynamically quantizing the weights of the original - floating point linear layer to int8 format and creating a new `WeightOnlyInt8QuantLinear` - instance with these quantized weights and the corresponding scaling factors. - - Args: - mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert. - - Returns: - WeightOnlyInt8QuantLinear: The converted quantized linear module with int8 weights. - """ - w_fp32 = mod.weight - w_int8, scales, _zp = dynamically_quantize_per_channel( - w_fp32, -128, 127, torch.int8 - ) - # Create the new module with a toy size to ensure initialization is fast - fake_in_features, fake_out_features = 8, 8 - new_mod = cls( - fake_in_features, - fake_out_features, - bias=mod.bias is not None, - w_int8=w_int8.t().contiguous(), - scales=scales, - ) - new_mod.in_features = mod.in_features - new_mod.out_features = mod.out_features - del new_mod.weight - new_mod.bias = mod.bias - device_to_use = next(mod.parameters()).device - new_mod.to(device_to_use) - return new_mod