Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exception raised in backward() loses backtrace information #42560

Closed
froody opened this issue Aug 4, 2020 · 4 comments
Closed

Exception raised in backward() loses backtrace information #42560

froody opened this issue Aug 4, 2020 · 4 comments
Assignees
Labels
better-engineering Relatively self-contained tasks for better engineering contributors high priority module: autograd Related to torch.autograd, and the autograd engine in general module: logging Features which make it easier to tell what PyTorch is doing under the hood module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@froody
Copy link
Contributor

froody commented Aug 4, 2020

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. Run the code below
import torch

class Foo(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input

    @staticmethod
    def backward(ctx, *grad):
        raise ValueError("something")

t = torch.Tensor(20)
t.requires_grad_()
output = Foo.apply(t)

loss = torch.nn.MSELoss()
loss(output, torch.Tensor(20)).backward()

Expected behavior

Stacktrace at the point of raise, showing the error to be from Foo.backward

Actual output

Traceback (most recent call last):
  File "/private/home/tbirch/bug.py", line 17, in <module>
    loss(output, torch.Tensor(20)).backward()
  File "/private/home/tbirch/.conda/envs/py38/lib/python3.8/site-packages/torch/tensor.py", line 198, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/private/home/tbirch/.conda/envs/py38/lib/python3.8/site-packages/torch/autograd/__init__.py", line 98, in backward
    Variable._execution_engine.run_backward(
RuntimeError: something

Environment

(same issue in 1.5.1 and 1.6.0)

Collecting environment information...
PyTorch version: N/A
Is debug build: N/A
CUDA used to build PyTorch: N/A

OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2

Python version: 3.7
Is CUDA available: N/A
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 418.116.00
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.5.1
[pip3] torchtext==0.7.0
[pip3] torchvision==0.6.0a0+35d732a
[pip3] torchviz==0.0.1
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.1.243 h6bb024c_0
[conda] mkl 2019.4 243
[conda] mkl-service 2.3.0 py38h516909a_0 conda-forge
[conda] mkl_fft 1.1.0 py38hc1659b7_1 conda-forge
[conda] mkl_random 1.1.0 py38h962f231_0
[conda] numpy 1.18.5 py38ha1c710e_0
[conda] numpy-base 1.18.5 py38hde5b4d6_0
[conda] pytorch 1.5.1 py3.8_cuda10.1.243_cudnn7.6.3_0 pytorch
[conda] torchtext 0.7.0 pypi_0 pypi
[conda] torchvision 0.6.1 py38_cu101 pytorch
[conda] torchviz 0.0.1 pypi_0 pypi

cc @ezyang @gchanan @zou3519 @ssnl @albanD @gqchen

@glaringlee glaringlee added better-engineering Relatively self-contained tasks for better engineering contributors module: logging Features which make it easier to tell what PyTorch is doing under the hood triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 5, 2020
@gchanan
Copy link
Contributor

gchanan commented Aug 5, 2020

@zou3519 zou3519 added the module: autograd Related to torch.autograd, and the autograd engine in general label Aug 5, 2020
@froody
Copy link
Contributor Author

froody commented Aug 6, 2020

See https://pytorch.org/docs/stable/autograd.html#torch.autograd.detect_anomaly.

Thanks, running this with torch-1.6.0 I see Warning: Error detected in FooBackward which does make it slightly easier to see what's going on, but my original question still stands: why isn't the default behavior to preserve the stacktrace from the original raise statement if there is one?

/private/home/tbirch/bug.py:17: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with torch.autograd.detect_anomaly():
[W python_anomaly_mode.cpp:42] Warning: Error detected in FooBackward. No forward pass information available. Enable detect anomaly during forward pass for more information. (function print_stack)
Traceback (most recent call last):
  File "/private/home/tbirch/bug.py", line 18, in <module>
    loss(output, torch.Tensor(20)).backward()
  File "/private/home/tbirch/.conda/envs/torch160/lib/python3.7/site-packages/torch/tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/private/home/tbirch/.conda/envs/torch160/lib/python3.7/site-packages/torch/autograd/__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: something

@albanD
Copy link
Collaborator

albanD commented Aug 10, 2020

Hi,

Thanks for opening an issue for this (it was mentioned in #41659 but better to have an issue for it).
This comes from the fact that the Future API eats up the original error and only throws a basic std::exception with the original message.
We should change that as our custom error types contain much more info (python error type, cpp stack traces, etc).

cc @pritamdamania87

@pritamdamania87 pritamdamania87 self-assigned this Aug 14, 2020
@albanD albanD added the module: regression It used to work, and now it doesn't label Aug 17, 2020
@izdeby izdeby removed module: regression It used to work, and now it doesn't triage review labels Aug 17, 2020
@albanD albanD added the module: regression It used to work, and now it doesn't label Aug 17, 2020
pritamdamania87 pushed a commit that referenced this issue Aug 26, 2020
This PR attempts to address
#42560. We add a "Python Traceback"
component to the error message recording the original location from where the
exception was thrown.

This approach isn't ideal since we're extending the exception message itself.
The alternative I considered was extending our Future to record this
information in addition to just the message. The tricky part about this was
avoiding python dependency. Ideally, I'd like to avoid a python dependency in
our Future class. Even if we do avoid the python depedency in our base class by
creating a subclass for python, we would still end up having a python
dependency in things like autograd engine's GraphTask.

For the example in #42560, the
exception trace would now look like:

```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
> RuntimeError: something
> Python Traceback (most recent call last):
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6909, in backward
>     raise ValueError("something")
```

Differential Revision: [D23337371](https://our.internmc.facebook.com/intern/diff/D23337371/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Aug 26, 2020
This PR attempts to address
#42560. We add a "Python Traceback"
component to the error message recording the original location from where the
exception was thrown.

This approach isn't ideal since we're extending the exception message itself.
The alternative I considered was extending our Future to record this
information in addition to just the message. The tricky part about this was
avoiding python dependency. Ideally, I'd like to avoid a python dependency in
our Future class. Even if we do avoid the python depedency in our base class by
creating a subclass for python, we would still end up having a python
dependency in things like autograd engine's GraphTask.

For the example in #42560, the
exception trace would now look like:

```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
> RuntimeError: something
> Python Traceback (most recent call last):
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6909, in backward
>     raise ValueError("something")
```

Differential Revision: [D23337371](https://our.internmc.facebook.com/intern/diff/D23337371/)

ghstack-source-id: 110722744
Pull Request resolved: #43608
pritamdamania87 pushed a commit that referenced this issue Aug 27, 2020
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Aug 27, 2020
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

ghstack-source-id: 110820151
Pull Request resolved: #43684
pritamdamania87 pushed a commit that referenced this issue Aug 28, 2020
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Aug 28, 2020
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Aug 28, 2020
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Aug 28, 2020
Pull Request resolved: #43684

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```
ghstack-source-id: 110920637

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)
pritamdamania87 pushed a commit that referenced this issue Aug 29, 2020
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Aug 29, 2020
Pull Request resolved: #43684

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```
ghstack-source-id: 111002998

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)
pritamdamania87 pushed a commit that referenced this issue Aug 31, 2020
…ne errors."

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Aug 31, 2020
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Aug 31, 2020
Pull Request resolved: #43684

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```
ghstack-source-id: 111080082

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)
pritamdamania87 pushed a commit that referenced this issue Sep 1, 2020
…ne errors."

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Sep 1, 2020
This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)

[ghstack-poisoned]
pritamdamania87 pushed a commit that referenced this issue Sep 1, 2020
Pull Request resolved: #43684

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:


```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```
ghstack-source-id: 111109637

Differential Revision: [D23365408](https://our.internmc.facebook.com/intern/diff/D23365408/)
facebook-github-bot pushed a commit that referenced this issue Sep 1, 2020
Summary:
Pull Request resolved: #43684

This PR attempts to address #42560 by capturing the appropriate
exception_ptr in the autograd engine and passing it over to the Future.

As part of this change, there is a significant change the Future API where we
now only accept an exception_ptr as part of setError.

For the example in #42560, the exception trace would now look like:

```
> Traceback (most recent call last):
>   File "test_autograd.py", line 6914, in test_preserve_backtrace
>     Foo.apply(t).sum().backward()
>   File "torch/tensor.py", line 214, in backward
>     torch.autograd.backward(self, gradient, retain_graph, create_graph)
>   File "torch/autograd/__init__.py", line 127, in backward
>     allow_unreachable=True)  # allow_unreachable flag
>   File "torch/autograd/function.py", line 87, in apply
>     return self._forward_cls.backward(self, *args)
>   File "test_autograd.py", line 6910, in backward
>     raise ValueError("something")
> ValueError: something
```
ghstack-source-id: 111109637

Test Plan: waitforbuildbot

Reviewed By: albanD

Differential Revision: D23365408

fbshipit-source-id: 1470c4776ec8053ea92a6ee1663460a3bae6edc5
@pritamdamania87
Copy link
Contributor

Resolved in #43684

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better-engineering Relatively self-contained tasks for better engineering contributors high priority module: autograd Related to torch.autograd, and the autograd engine in general module: logging Features which make it easier to tell what PyTorch is doing under the hood module: regression It used to work, and now it doesn't triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants