diff --git a/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py b/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py index c60076a156..35b7463b0c 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/bias_correction.py @@ -271,6 +271,8 @@ def pass_data_through_model(model, early_stopping_iterations=None, use_cuda=Fals if module.bias is None: if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)): output_size = module.out_channels + elif isinstance(module, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)): + output_size = module.out_channels elif isinstance(module, torch.nn.Linear): output_size = module.out_features module.bias = torch.nn.Parameter(torch.zeros(output_size)) @@ -396,4 +398,4 @@ def find_all_conv_bn_with_activation(model: torch.nn.Module, input_shape: Tuple) graph_searcher.find_all_patterns_in_graph_apply_actions() convs_bn_activation_dict = layer_select_handler.get_conv_linear_bn_info_dict() - return convs_bn_activation_dict \ No newline at end of file + return convs_bn_activation_dict diff --git a/TrainingExtensions/torch/src/python/aimet_torch/utils.py b/TrainingExtensions/torch/src/python/aimet_torch/utils.py index 93e3213888..204832457c 100644 --- a/TrainingExtensions/torch/src/python/aimet_torch/utils.py +++ b/TrainingExtensions/torch/src/python/aimet_torch/utils.py @@ -469,7 +469,7 @@ def get_input_shape_batch_size(data_loader): # finding shape of a batch input_shape = torch.Tensor.size(images_in_one_batch) - return input_shape[0], (1, input_shape[1], input_shape[2], input_shape[3]) + return input_shape[0], (1, *input_shape[1:]) def has_hooks(module: torch.nn.Module):