From 7376bb59230d615b2db0cc389aedbc9aec6b2e9f Mon Sep 17 00:00:00 2001 From: Douwe den Blanken Date: Sat, 25 Mar 2023 16:35:29 +0100 Subject: [PATCH 1/2] Make sure that `bias_correction.correct_bias` also works with 1D conv inputs Signed-off-by: Douwe den Blanken --- TrainingExtensions/torch/src/python/aimet_torch/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index 93e32138885..204832457c8 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -469,7 +469,7 @@ def get_input_shape_batch_size(data_loader): # finding shape of a batch input_shape = torch.Tensor.size(images_in_one_batch) - return input_shape[0], (1, input_shape[1], input_shape[2], input_shape[3]) + return input_shape[0], (1, *input_shape[1:]) def has_hooks(module: torch.nn.Module): From 7d544bbce9e1879e5baa2bbd25612056eb1faf34 Mon Sep 17 00:00:00 2001 From: Douwe den Blanken Date: Sat, 25 Mar 2023 16:36:55 +0100 Subject: [PATCH 2/2] Fix output_size not being instantiated for a Conv1d layer Signed-off-by: Douwe den Blanken --- .../torch/src/python/aimet_torch/bias_correction.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py b/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py index c60076a156a..35b7463b0c6 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py @@ -271,6 +271,8 @@ def pass_data_through_model(model, early_stopping_iterations=None, use_cuda=Fals if module.bias is None: if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): output_size = module.out_channels + elif isinstance(module, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): + output_size = module.out_channels elif isinstance(module, torch.nn.Linear): output_size = module.out_features module.bias = torch.nn.Parameter(torch.zeros(output_size)) @@ -396,4 +398,4 @@ def find_all_conv_bn_with_activation(model: torch.nn.Module, input_shape: Tuple) graph_searcher.find_all_patterns_in_graph_apply_actions() convs_bn_activation_dict = layer_select_handler.get_conv_linear_bn_info_dict() - return convs_bn_activation_dict \ No newline at end of file + return convs_bn_activation_dict