diff --git a/torchvision/ops/deform_conv.py b/torchvision/ops/deform_conv.py index aa5e42c4a6e..ec58aedecf4 100644 --- a/torchvision/ops/deform_conv.py +++ b/torchvision/ops/deform_conv.py @@ -38,17 +38,17 @@ def deform_conv2d( Examples:: - >>> input = torch.rand(1, 3, 10, 10) + >>> input = torch.rand(4, 3, 10, 10) >>> kh, kw = 3, 3 >>> weight = torch.rand(5, 3, kh, kw) >>> # offset should have the same spatial size as the output >>> # of the convolution. In this case, for an input of 10, stride of 1 >>> # and kernel size of 3, without padding, the output size is 8 - >>> offset = torch.rand(5, 2 * kh * kw, 8, 8) + >>> offset = torch.rand(4, 2 * kh * kw, 8, 8) >>> out = deform_conv2d(input, offset, weight) >>> print(out.shape) >>> # returns - >>> torch.Size([1, 5, 8, 8]) + >>> torch.Size([4, 5, 8, 8]) """ out_channels = weight.shape[0]