New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
ctx.save_for_backward doesn't save torch.Tensor subclasses fully #47117
Comments
This is nasty one. We actually do some very strange stuff to tensors (unpack them into "saved tensors") when you pass them into For non-metadata carrying subclasses we could probably make this work by just saving the subclass and using cc @albanD |
Hi, I just wanted to confirm that Tensor-like objects are off the table here right? As they wouldn't work with custom Function at all. For Tensor subclass, I am curious how they are implemented in details? |
From the design side of things, I think the usual |
But then the user would have to re-implement the |
Subclassing would then work OOTB (without anything extra on the user side), if they wanted to add metadata, they could use |
Sorry I am not super familiar with how this would work. |
Something like this:
|
I am not sure what that would mean for the user in terms of types? Does that mean that the |
This means, essentially, that PyTorch only defines the behavior in terms of It's up to the subclass to decide what to do next, adding metadata, etc. For other consumers of |
Thanks for all the details @hameerabbasi ! So I'm happy with saying that the |
I'm a little skeptical. If I'm using One question is whether or not direct
Of course, this particular modification only really works because LoggingTensor doesn't really have any data with it. There are a few more orthogonal problems which might also need to be solved to support mlamarre's use case:
@mlamarre it would be helpful to know what your expectation on this other matters is too. |
My personal use case requires adding metadata to a tensor. It's an external resource descriptor that I update and use in the custom
So I prefer a documentation fix, than the About the questions: What if you pass in a grad_output that is a tensor subclass? In the example, it's already a subclass (LoggingTensor) because What if you return a tensor subclass from a custom function? Also the case in the example, What is the subclass of x.grad? If you mean Should it be possible to override operator calls that are internal to autograd? I don't understand this one. |
Thanks. At least doc fix is easy for us to do :) |
I don't understand exactly what needs to be done here, can someone explain in more detail what example needs to be added, and hopefully I can do the explaining. |
General idea of the change:
Somewhere in this section there must be a warning that |
BTW @hameerabbasi, there's probably a bunch more functions like this, esp. the autograd API functions (like |
I think a more accurate statement would be the following.
|
The |
I've created a PR to fix the docs (#51031) |
What is the status of this issue with regards to subclasses with metadata? Why would it be a problem to have "reference cycles"? Python's garbage collector (I've heard) can handle reference cycles. Moreover, in many cases the subclass will behave exactly the same as a base Tensor with regards to autograd: if it just carries around some metadata (and maybe modifies it based on the operations it encounters); or if, say, a LoggingTensor wants to also log the operations in the backward pass. Currently, for intermediate gradient calculations, |
First off, we accidentally closed this; the doc fix isn't a feature fix and we haven't done that.
It matters because the tensors are fairly large, so you care quite a bit about prompt deallocation (lest you OOM). Python GC can detect reference cycles but not promptly, and so you might end up OOMing before GC gets around to running. (btw #50186 is supposed to make this better and I really need to get around to reviewing it.) A reduce style method of being saved would probably work, and is probably worth doing. |
I tried adding a
output:
|
master branch (commit: 73f1e2d) seems fixed the issue
|
However, as I tried to add some torch functions in between, the functions return torch.Tensor in backward.
output:
|
@ken012git Yes, @albanD landed a change to unconditionally save the original Tensor (without shallow copy and detaching it first) and so this issue is technically fixed. However...
To do this, you need an even fancier mechanism that we are currently alpha testing internally. It's on master, so you can try it on master; however, we are probably going to make some BC-breaking API changes in the future (related to kwarg handling). Check #59760 and the quoted issue. |
I believe that unconditionally is not technically true here? pytorch/torch/csrc/autograd/saved_variable.cpp Lines 69 to 73 in 4a390a5
Edit: after discussion with @albanD, here is a case where the original tensor is not saved: import torch
import logging
class LoggingTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# following line was changed from the documentation to avoid an infinite recursion because of __repr__
logging.info(f"func: {func.__name__}")
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
class IdentityFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
print('forward x type',type(x),'x data_ptr',x.data_ptr())
y = x.clone()
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, grad_output):
y, = ctx.saved_tensors
print('backward y type',type(y),'y data_ptr',y.data_ptr())
print('backward grad_output type',type(grad_output),'grad_output data_ptr',grad_output.data_ptr())
return grad_output
lt = LoggingTensor(torch.rand((3,3)))
lt.requires_grad_(True)
y = IdentityFunction.apply(lt)
y.sum().backward() output:
In general we're not necessarily saving the original tensor, for instance: a = torch.randn(5, requires_grad=True)
y = torch.exp(a)
print(y.grad_fn._saved_result.equal(y)) # True
print(y.grad_fn._saved_result is y) # False This stack was an attempt to fix it: #60399, but was abandoned. I think an alternative way is to expose a post-unpacking hook, which would be called line 187, after we create a new tensor and give it a grad metadata. pytorch/torch/csrc/autograd/saved_variable.cpp Lines 181 to 186 in 4a390a5
Edit: a better way is to do #63485 |
馃悰 Bug
Saving a torch.Tensor subclass with
ctx.save_for_backward
only saves the base Tensor. The subclass type and additional data is removed (object slicing in C++ terminology).To Reproduce
Following the Extending PyTorch doc.
LoggingTensor
is copy-pasted from there.Expected behavior
I would expect the subclass type to be preserved.
Expected:
Current result:
Environment
cc @hameerabbasi @rgommers @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @jlin27 @mruberry
The text was updated successfully, but these errors were encountered: