Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch nightly fails calling backward() #57900

Closed
EnricoMi opened this issue May 8, 2021 · 2 comments
Closed

Torch nightly fails calling backward() #57900

EnricoMi opened this issue May 8, 2021 · 2 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: norms and normalization triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@EnricoMi
Copy link

EnricoMi commented May 8, 2021

馃悰 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Running Horovod test test_horovod_sync_batch_norm fails with
            # backward pass for gradient calculation
>           grad_input = torch.batch_norm_backward_elemt(
                grad_output,
                saved_input,
                mean,
                invstd,
                weight,
                mean_dy,
                mean_dy_xmu
            )
E           TypeError: batch_norm_backward_elemt() missing 1 required positional arguments: "count"

Expected behavior

This unit test used to work with 1.9.0.dev20210303+cu111 and before.

Environment

Collecting environment information...
PyTorch version: 1.9.0.dev20210507+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.10.2

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 11.2.152
GPU models and configuration: GPU 0: GeForce MX150
Nvidia driver version: 460.39
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] pytorch-lightning==1.2.9
[pip3] torch==1.9.0.dev20210507+cu111
[pip3] torchmetrics==0.2.0
[pip3] torchvision==0.10.0.dev20210507+cu111
[conda] Could not collect
  • PyTorch Version (e.g., 1.0): 1.9.0.dev20210503+cu111
  • OS (e.g., Linux): nvidia/cuda:11.2.2-devel-ubuntu18.04
  • Build command you used (if compiling from source): docker-compose -f docker-compose.test.yml build test-gpu-gloo-py3_8-tfhead-keras_none-torchhead-mxnethead-pyspark_3_1_1 using docker-compose.test.yml

Additional context

The full error:

    def test_horovod_sync_batch_norm(self):
        """Tests Horovod version of SyncBatchNorm."""
        if not torch.cuda.is_available():
            self.skipTest("No GPUs available")
    
        hvd.init()
    
        ts_list = [
            torch.stack([
                torch.tensor([
                    [r, r + 1],
                    [r * 2, r * 2 + 1],
                    [r * 3, r * 3 + 1],
                    [r * 4, r * 4 + 1]
                ])
                for r in range(hvd.size())
            ]),
            torch.stack([
                torch.tensor([
                    [r + 1],
                    [r * 2 + 1],
                    [r * 3 + 1],
                    [r * 4 + 1]
                ])
                for r in range(hvd.size())
            ]),
        ]
    
        for ts in ts_list:
            sync_bn = hvd.SyncBatchNorm(num_features=4)
            sync_bn.cuda(hvd.local_rank())
    
            bn = torch.nn.BatchNorm1d(num_features=4)
            bn.cuda(hvd.local_rank())
    
            ts = ts.cuda(hvd.local_rank()).float()
            ts1 = ts.clone().requires_grad_()
            ts2 = ts.clone().requires_grad_()
    
            # Training
            sync_bn_out = sync_bn(ts1[hvd.rank()].unsqueeze(0))
            bn_out = bn(ts2)
            assert torch.allclose(sync_bn_out, bn_out[hvd.rank()].unsqueeze(0), 1e-6)
            assert torch.allclose(sync_bn.running_mean, bn.running_mean, 1e-6)
            assert torch.allclose(sync_bn.running_var, bn.running_var, 1e-6)
    
            # Gradients
>           sync_bn_out.sum().backward()

test_torch.py:2295: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.8/dist-packages/torch/_tensor.py:255: in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py:147: in backward
    Variable._execution_engine.run_backward(
/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py:89: in apply
    return self._forward_cls.backward(self, *args)  # type: ignore[attr-defined]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch.autograd.function._SyncBatchNormBackward object at 0x7f45a8345ba0>
grad_output = tensor([[[1., 1.],
         [1., 1.],
         [1., 1.],
         [1., 1.]]], device='cuda:1')

    @staticmethod
    def backward(self, grad_output):
        grad_output = grad_output.contiguous()
        saved_input, weight, mean, invstd, count_all = self.saved_tensors
        need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[0:3]
    
        # calculate local stats as well as grad_weight / grad_bias
        sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
            grad_output,
            saved_input,
            mean,
            invstd,
            weight,
            need_input_grad,
            need_weight_grad,
            need_bias_grad
        )
    
        if need_input_grad:
            # synchronizing stats used to calculate input gradient.
            sum_dy_handle = allreduce_async(sum_dy, op=Sum, name='sync_batch_norm.sum_dy')
            sum_dy_xmu_handle = allreduce_async(sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu')
    
            # wait on the async communication to finish
            sum_dy = synchronize(sum_dy_handle)
            sum_dy_xmu = synchronize(sum_dy_xmu_handle)
    
            if _SYNC_BN_V2 or _SYNC_BN_V3:
                count_all_sum = count_all.sum()
                mean_dy = sum_dy / count_all_sum
                mean_dy_xmu = sum_dy_xmu / count_all_sum
            else:
                # before 1.5.0, sum_dy was sum of means from every worker, so we just
                # need to divide it by number of workers
                mean_dy = sum_dy / size()
                mean_dy_xmu = sum_dy_xmu / size()
    
            # backward pass for gradient calculation
>           grad_input = torch.batch_norm_backward_elemt(
                grad_output,
                saved_input,
                mean,
                invstd,
                weight,
                mean_dy,
                mean_dy_xmu
            )
E           TypeError: batch_norm_backward_elemt() missing 1 required positional arguments: "count"

/usr/local/lib/python3.8/dist-packages/horovod/torch/sync_batch_norm.py:179: TypeError

cc @ngimel

@ngimel
Copy link
Collaborator

ngimel commented May 8, 2021

batch_norm_backward_elemt API was changed in #46906. batch_norm_backward_elemt is a non-public undocumented function, and as such is not subject to backward compatibilitu requirements. Please update the code using it.

@ngimel ngimel added module: norms and normalization module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 8, 2021
@EnricoMi
Copy link
Author

EnricoMi commented May 9, 2021

That's unfortunate. I will update our code then. Thanks for the quick pointer.

@EnricoMi EnricoMi closed this as completed May 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: norms and normalization triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants