Skip to content

Commit

Permalink
fix batch norm for empty inputs (#30035)
Browse files Browse the repository at this point in the history
Summary:
Fix for #29578
Shape check is moved up as much as possible, because backends by and large don't correctly handle empty inputs, so check needs to be done before backend selection. That also automatically takes care of backward, because forward for empty input is automatically differentiable, so no backend-specific backward routines are ever called.
Pull Request resolved: #30035

Test Plan: tests for empty inputs are added.

Differential Revision: D18584427

Pulled By: ngimel

fbshipit-source-id: a42918f50eb1f6995921aafa92879cd42dd5e9e1
  • Loading branch information
ngimel authored and facebook-github-bot committed Nov 19, 2019
1 parent c272758 commit a9ad2e2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
7 changes: 5 additions & 2 deletions aten/src/ATen/native/Normalization.cpp
Expand Up @@ -393,8 +393,8 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(const Tensor
if (train) {
// when in training mode
// Q(X) = X - E[x] ; i.e. input centered to zero mean
// Y = Q(X) / σ ; i.e. BN output before weight and bias
// dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w
// Y = Q(X) / sigma ; i.e. BN output before weight and bias
// dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / sigma * w

// projection of gradOutput on to output scaled by std
scalar_t k = (scalar_t) dotp * invstd * invstd / n;
Expand Down Expand Up @@ -531,6 +531,9 @@ Tensor batch_norm(
const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
if (input.numel()==0){
return input; //return input instead of new empty tensor, because new empty tensor breaks the gradient chain
}
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
training, momentum, eps, cudnn_enabled));
}
Expand Down
28 changes: 23 additions & 5 deletions test/test_nn.py
Expand Up @@ -8821,6 +8821,16 @@ def _test_GroupNorm_cuda_half(self):
output.sum().backward()
self.assertEqual(output.type(), input.type())

def _test_module_empty_input(self, module, inp):
inp.requires_grad_(True)
out = module(inp)
gO = torch.rand_like(out)
out.backward(gO)
self.assertEqual(out.size(), inp.size())
for p in module.parameters():
if p.requires_grad and p.grad is not None:
self.assertEqual(p.grad, torch.zeros_like(p.grad))

def test_Dropout(self, device):
input = torch.Tensor(1000)
self._test_dropout(nn.Dropout, device, input)
Expand Down Expand Up @@ -8850,7 +8860,7 @@ def test_InstanceNorm1d_general(self, device):
input = torch.rand(b, c, d)
self._test_InstanceNorm_general(nn.InstanceNorm1d, input, device)

if device == 'cuda':
if self.device_type == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input, device)

def test_InstanceNorm2d_general(self, device):
Expand All @@ -8862,7 +8872,7 @@ def test_InstanceNorm2d_general(self, device):
input = torch.rand(b, c, h, w)
self._test_InstanceNorm_general(nn.InstanceNorm2d, input, device)

if 'cuda' in device:
if self.device_type == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input, device)

def test_InstanceNorm3d_general(self, device):
Expand All @@ -8875,21 +8885,29 @@ def test_InstanceNorm3d_general(self, device):
input = torch.rand(b, c, h, w, d)
self._test_InstanceNorm_general(nn.InstanceNorm3d, input, device)

if 'cuda' in device:
if self.device_type == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input, device)

def test_LayerNorm_general(self, device):
self._test_LayerNorm_general(device)

if 'cuda' in device:
if self.device_type == 'cuda':
self._test_LayerNorm_cuda_half(device)

def test_GroupNorm_general(self, device):
self._test_GroupNorm_general(device)

if 'cuda' in device:
if self.device_type == 'cuda':
self._test_GroupNorm_cuda_half()

def test_BatchNorm_empty(self, device):
mod = torch.nn.BatchNorm2d(3).to(device)
inp = torch.randn(0, 3, 2, 2, device=device)
self._test_module_empty_input(mod, inp)
if self.device_type == 'cuda' and self.has_cudnn():
with torch.backends.cudnn.flags(enabled=False):
self._test_module_empty_input(mod, inp)

def test_one_hot(self, device):
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
Expand Down

0 comments on commit a9ad2e2

Please sign in to comment.