diff --git a/test/test_onnx.py b/test/test_onnx.py index 5bb8eba2530..975cea7a58f 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -482,6 +482,17 @@ def test_keypoint_rcnn(self): dynamic_axes={"images_tensors": [0, 1, 2]}, tolerate_small_mismatch=True) + def test_shufflenet_v2_dynamic_axes(self): + model = models.shufflenet_v2_x0_5(pretrained=True) + dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True) + test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0) + + self.run_model(model, [(dummy_input,), (test_inputs,)], + input_names=["input_images"], + output_names=["output"], + dynamic_axes={"input_images": {0: 'batch_size'}, "output": {0: 'batch_size'}}, + tolerate_small_mismatch=True) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 9ba090ad09b..9a4333eb10b 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -19,7 +19,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor: - batchsize, num_channels, height, width = x.data.size() + batchsize, num_channels, height, width = x.size() channels_per_group = num_channels // groups # reshape