Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix batch norm for empty inputs #30035

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions aten/src/ATen/native/Normalization.cpp
Expand Up @@ -393,8 +393,8 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(const Tensor
if (train) {
// when in training mode
// Q(X) = X - E[x] ; i.e. input centered to zero mean
// Y = Q(X) / σ ; i.e. BN output before weight and bias
// dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / σ * w
// Y = Q(X) / sigma ; i.e. BN output before weight and bias
// dL/dX = (Q(dL/dY) - dot(Y, dL/dY) * Y) / sigma * w

// projection of gradOutput on to output scaled by std
scalar_t k = (scalar_t) dotp * invstd * invstd / n;
Expand Down Expand Up @@ -531,6 +531,9 @@ Tensor batch_norm(
const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
if (input.numel()==0){
return input; //return input instead of new empty tensor, because new empty tensor breaks the gradient chain
}
return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var,
training, momentum, eps, cudnn_enabled));
}
Expand Down
28 changes: 23 additions & 5 deletions test/test_nn.py
Expand Up @@ -8821,6 +8821,16 @@ def _test_GroupNorm_cuda_half(self):
output.sum().backward()
self.assertEqual(output.type(), input.type())

def _test_module_empty_input(self, module, inp):
inp.requires_grad_(True)
out = module(inp)
gO = torch.rand_like(out)
out.backward(gO)
self.assertEqual(out.size(), inp.size())
for p in module.parameters():
if p.requires_grad and p.grad is not None:
self.assertEqual(p.grad, torch.zeros_like(p.grad))

def test_Dropout(self, device):
input = torch.Tensor(1000)
self._test_dropout(nn.Dropout, device, input)
Expand Down Expand Up @@ -8850,7 +8860,7 @@ def test_InstanceNorm1d_general(self, device):
input = torch.rand(b, c, d)
self._test_InstanceNorm_general(nn.InstanceNorm1d, input, device)

if device == 'cuda':
if self.device_type == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input, device)

def test_InstanceNorm2d_general(self, device):
Expand All @@ -8862,7 +8872,7 @@ def test_InstanceNorm2d_general(self, device):
input = torch.rand(b, c, h, w)
self._test_InstanceNorm_general(nn.InstanceNorm2d, input, device)

if 'cuda' in device:
if self.device_type == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input, device)

def test_InstanceNorm3d_general(self, device):
Expand All @@ -8875,21 +8885,29 @@ def test_InstanceNorm3d_general(self, device):
input = torch.rand(b, c, h, w, d)
self._test_InstanceNorm_general(nn.InstanceNorm3d, input, device)

if 'cuda' in device:
if self.device_type == 'cuda':
self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input, device)

def test_LayerNorm_general(self, device):
self._test_LayerNorm_general(device)

if 'cuda' in device:
if self.device_type == 'cuda':
self._test_LayerNorm_cuda_half(device)

def test_GroupNorm_general(self, device):
self._test_GroupNorm_general(device)

if 'cuda' in device:
if self.device_type == 'cuda':
self._test_GroupNorm_cuda_half()

def test_BatchNorm_empty(self, device):
mod = torch.nn.BatchNorm2d(3).to(device)
inp = torch.randn(0, 3, 2, 2, device=device)
self._test_module_empty_input(mod, inp)
if self.device_type == 'cuda' and self.has_cudnn():
with torch.backends.cudnn.flags(enabled=False):
self._test_module_empty_input(mod, inp)

def test_one_hot(self, device):
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
Expand Down