Skip to content

Commit

Permalink
Tests for verifying behaviour of BatchNorm using 0-dim batch sizes. (p…
Browse files Browse the repository at this point in the history
…ytorch#32384)

Summary:
The `BatchNorm*` part of the issue (see pytorchgh-12013) seems to have been fixed in the master branch and these tests would make it concrete.

However I would appreciate comments on pytorch#12013 (comment) on whether the current behaviour is satisfactory.
Pull Request resolved: pytorch#32384

Differential Revision: D19704154

Pulled By: ngimel

fbshipit-source-id: 1bbbbf1ae1215a460b22cf26e6b263e518ecf60b
  • Loading branch information
v0dro authored and ttumiel committed Mar 4, 2020
1 parent fe0461a commit e3c4bdc
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9124,6 +9124,11 @@ def test_BatchNorm_empty(self, device):
with torch.backends.cudnn.flags(enabled=False):
self._test_module_empty_input(mod, inp)

self.assertEqual(mod.running_mean, torch.tensor([0., 0, 0], device=device))
self.assertEqual(mod.running_var, torch.tensor([1., 1, 1], device=device))
self.assertEqual(mod.weight.grad, torch.tensor([0., 0, 0], device=device))
self.assertEqual(mod.bias.grad, torch.tensor([0., 0, 0], device=device))

def test_group_conv_empty(self, device):
mod = torch.nn.Conv2d(4, 4, stride=2, kernel_size=3, padding=1, groups=4).to(device)
inp = torch.randn(0, 4, 4, 4, device=device)
Expand Down
24 changes: 24 additions & 0 deletions torch/testing/_internal/common_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,14 @@ def fractional_max_pool3d_test(test_case):
check_eval=True,
desc='3d_input_not_affine',
),
dict(
module_name='BatchNorm1d',
constructor_args=(5, 1e-3, 0.3, False),
input_size=(0, 5, 9),
cudnn=True,
check_eval=True,
desc='zero_batch',
),
dict(
module_name='BatchNorm2d',
constructor_args=(3,),
Expand Down Expand Up @@ -1044,6 +1052,14 @@ def fractional_max_pool3d_test(test_case):
check_eval=True,
desc='not_tracking_stats',
),
dict(
module_name='BatchNorm2d',
constructor_args=(5, 1e-3, 0.3, False),
input_size=(0, 5, 2, 2),
cudnn=True,
check_eval=True,
desc='zero_batch',
),
dict(
module_name='BatchNorm3d',
constructor_args=(3,),
Expand Down Expand Up @@ -1083,6 +1099,14 @@ def fractional_max_pool3d_test(test_case):
check_eval=True,
desc='not_tracking_stats',
),
dict(
module_name='BatchNorm3d',
constructor_args=(5, 1e-3, 0.3, False),
input_size=(0, 5, 2, 2, 2),
cudnn=True,
check_eval=True,
desc='zero_batch',
),
dict(
module_name='InstanceNorm1d',
constructor_args=(3, 1e-3, 0.3),
Expand Down

0 comments on commit e3c4bdc

Please sign in to comment.