Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

ctx.save_for_backward doesn't save torch.Tensor subclasses fully #47117

Open
mlamarre opened this issue Oct 30, 2020 · 26 comments
Open

ctx.save_for_backward doesn't save torch.Tensor subclasses fully #47117

mlamarre opened this issue Oct 30, 2020 · 26 comments
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: autograd Related to torch.autograd, and the autograd engine in general module: docs Related to our documentation, both in docs/ and docblocks module: __torch_function__ needs design triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mlamarre
Copy link

mlamarre commented Oct 30, 2020

馃悰 Bug

Saving a torch.Tensor subclass with ctx.save_for_backward only saves the base Tensor. The subclass type and additional data is removed (object slicing in C++ terminology).

To Reproduce

Following the Extending PyTorch doc. LoggingTensor is copy-pasted from there.

import torch
import logging
class LoggingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # following line was changed from the documentation to avoid an infinite recursion because of __repr__
        logging.info(f"func: {func.__name__}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

class SquareFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        print('forward x type',type(x),'x data_ptr',x.data_ptr())
        ctx.save_for_backward(x)
        y = torch.mul(x,x)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        print('backward x type',type(x),'x data_ptr',x.data_ptr())
        return 2*x*grad_output

lt = LoggingTensor(torch.rand((3,3)))
lt.requires_grad_(True)
y = SquareFunction.apply(lt)
y.backward(torch.ones_like(y))
assert(lt.grad is not None) # that works

Expected behavior

I would expect the subclass type to be preserved.

Expected:

forward x type <class '__main__.LoggingTensor'> x data_ptr 1715819930816
backward x type <class '__main__.LoggingTensor'> x data_ptr 1715819930816

Current result:

forward x type <class '__main__.LoggingTensor'> x data_ptr 1715819930816
backward x type <class 'torch.Tensor'> x data_ptr 1715819930816

Environment

Collecting environment information...
PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Enterprise
CMake version: version 3.18.0

Python version: 3.7 (64-bit runtime)
Is CUDA available: True

Versions of relevant libraries:
[pip] numpy==1.18.1
[pip] torch==1.7.0

[conda] blas                      1.0                         mkl
[conda] cudatoolkit               11.0.221             h74a9793_0
[conda] mkl                       2020.1                      216
[conda] mkl-service               2.3.0            py37hb782905_0
[conda] mkl_fft                   1.0.15           py37h14836fe_0
[conda] mkl_random                1.1.1            py37h47e9c7a_0
[conda] numpy                     1.18.1           py37h93ca92e_0
[conda] numpy-base                1.18.1           py37hc3f5095_1
[conda] pytorch                   1.7.0           py3.7_cuda110_cudnn8_0    pytorch

cc @hameerabbasi @rgommers @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @jlin27 @mruberry

@ezyang
Copy link
Contributor

ezyang commented Oct 30, 2020

This is nasty one. We actually do some very strange stuff to tensors (unpack them into "saved tensors") when you pass them into save_for_backward so it is not surprising that we lose all this information. Note that we can't just directly save the Python object, because that is likely to result in reference cycles (as the saved tensors themselves contain references to the autograd graph which retain them). But if we don't save the Python object, then we can't save any of the extra metadata you might be interested in.

For non-metadata carrying subclasses we could probably make this work by just saving the subclass and using as_subclass on their way out. Metadata will be trickier.

cc @albanD

@albanD albanD added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: __torch_function__ module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module needs design labels Oct 30, 2020
@albanD
Copy link
Collaborator

albanD commented Oct 30, 2020

Hi,

I just wanted to confirm that Tensor-like objects are off the table here right? As they wouldn't work with custom Function at all.

For Tensor subclass, I am curious how they are implemented in details?
In particular, do they contain, just like regular Tensors a cdata field in c that contains the corresponding c++ Tensor? And how do they work with other c++ implemented functions?

@hameerabbasi
Copy link
Collaborator

From the design side of things, I think the usual mytensorlike.__torch_function__(FunctionClass.apply, types, args, kwargs) might work to wrap it back into a subclass. For the extra metadata, one could do the processing in __torch_function__ as usual.

@albanD
Copy link
Collaborator

albanD commented Nov 2, 2020

But then the user would have to re-implement the apply()? Or just do the wrapping and call into the original apply with only true Tensor objects?

@hameerabbasi
Copy link
Collaborator

But then the user would have to re-implement the apply()? Or just do the wrapping and call into the original apply with only true Tensor objects?

Subclassing would then work OOTB (without anything extra on the user side), if they wanted to add metadata, they could use __torch_function__ to do that.

@albanD
Copy link
Collaborator

albanD commented Nov 2, 2020

Sorry I am not super familiar with how this would work.
Can you show how the example above should be updated if we add the proper call to __torch_function__ in the apply?

@hameerabbasi
Copy link
Collaborator

Something like this:

def SubTensor(Tensor):
    def __init__(self, ...):
        super().__init__(...)
        self._init_wo_super(...)
    def _init_wo_super(self, ...):
        ...
    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        ret = super().__torch_function__(self, func, types, args, kwargs)
        ret._init_wo_super(...)
        return ret

@albanD
Copy link
Collaborator

albanD commented Nov 2, 2020

I am not sure what that would mean for the user in terms of types? Does that mean that the autograd.Function only ever works with plain Tensors and the user has to do all the proper unpacking around it?
Also it means that the backward only ever runs with plain Tensors (as they were unpacked before the forward)?

@hameerabbasi
Copy link
Collaborator

I am not sure what that would mean for the user in terms of types?

This means, essentially, that PyTorch only defines the behavior in terms of Tensors. We basically change the type to whatever subclass was provided, but little else.

It's up to the subclass to decide what to do next, adding metadata, etc. For other consumers of __torch_function__ (I assume this is what you mean by user types), it'll be up to them to define what the combination of the type and the particular Function instance means. They may decide to, for example, perform the operation on some wrapped tensor a-la NestedTensor. Or they may decide to say it's not supported initially/permanently.

@albanD
Copy link
Collaborator

albanD commented Nov 3, 2020

Thanks for all the details @hameerabbasi !

So I'm happy with saying that the .apply() function should call into __torch_function__ and that the custom Function only ever deals with plain Tensors.
It is the user's responsibility to rewrap the Tensors on exit.
@ezyang does that sound like a good solution for you?

@ezyang
Copy link
Contributor

ezyang commented Nov 5, 2020

I'm a little skeptical. If I'm using __torch_function__ just to implement something like LoggingTensor subclass, then shunting all calls to CustomFunc.apply straight to __torch_function__ seems unwanted and unnecessary. I'm only a little skeptical, though, because keeping autograd a Tensor only shop certainly does make things a lot simpler on the implementation side.

One question is whether or not direct .apply() into __torch_function__ is whether or not a user can actually implement the functionality in question by hand. In particular, @mlamarre wants to run backward() with a logging tensor for the saved backward tensors. Let's suppose he is now implementing a LoggingTensor specific version of the custom function (ugh). To get custom autograd, he still needs to write a autograd function. So it would have to look something like this:

class SquareFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        print('forward x type',type(x),'x data_ptr',x.data_ptr())
        ctx.save_for_backward(x)
        y = torch.mul(x,x)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        x = LoggingTensor(x)  # [NEW]
        print('backward x type',type(x),'x data_ptr',x.data_ptr())
        return 2*x*grad_output

Of course, this particular modification only really works because LoggingTensor doesn't really have any data with it.

There are a few more orthogonal problems which might also need to be solved to support mlamarre's use case:

  • What if you pass in a grad_output that is a tensor subclass?
  • What if you return a tensor subclass from a custom function?
  • What is the subclass of x.grad?
  • Should it be possible to override operator calls that are internal to autograd?

@mlamarre it would be helpful to know what your expectation on this other matters is too.

@mlamarre
Copy link
Author

mlamarre commented Nov 6, 2020

My personal use case requires adding metadata to a tensor. It's an external resource descriptor that I update and use in the custom forward method. For backward I only need to a set dirty flag on the metadata. Before 1.7, I had to pass this metadata as an extra argument to forward and set the corresponding output to None. Subclassing already makes the forward part more compact (2x less args for this type) and less error prone. Setting the dirty flag on the metadata when the grad is not null in backward would be nice but I have workarounds. I reported this issue mostly because ctx.save_for_backward behavior is different from all other functions.

.apply() calling into __torch_function__ would be ugly in my particular case, I would need to put the custom code that uses the metadata in a callback that's registered in the subclass and called from __torch_function__; Making my custom Function an empty shell which only purpose is to be a string that can be parsed in __torch_function__.

So I prefer a documentation fix, than the .apply() calling into __torch_function__ solution.

About the questions:

What if you pass in a grad_output that is a tensor subclass?

In the example, it's already a subclass (LoggingTensor) because torch.ones_like goes through __torch_function__

What if you return a tensor subclass from a custom function?

Also the case in the example, y is a subclass, the backward method also goes to __torch_function__.

What is the subclass of x.grad?

If you mean lt.grad again it's a LoggingTensor.

Should it be possible to override operator calls that are internal to autograd?

I don't understand this one.

@ezyang ezyang added the module: docs Related to our documentation, both in docs/ and docblocks label Nov 9, 2020
@ezyang
Copy link
Contributor

ezyang commented Nov 9, 2020

Thanks. At least doc fix is easy for us to do :)

@hameerabbasi
Copy link
Collaborator

I don't understand exactly what needs to be done here, can someone explain in more detail what example needs to be added, and hopefully I can do the explaining.

@mlamarre
Copy link
Author

mlamarre commented Nov 9, 2020

General idea of the change:
In Extending PyTorch in the subsection Subclassing torch.Tensor :

As of version 1.7.0, methods and functions applied on torch.Tensor subclasses will return subclass instances instead of torch.Tensor instances:

Somewhere in this section there must be a warning that ctx.save_for_backward doesn't save the subclass type nor any metadata, hence ctx.saved_tensors only contains plain torch.Tensor. A single sentence in this section would be sufficient.

@ezyang
Copy link
Contributor

ezyang commented Nov 10, 2020

BTW @hameerabbasi, there's probably a bunch more functions like this, esp. the autograd API functions (like backward() and grad()`); unless you actually added wrappers for these (I don't remember if you did, but thinking of gradcheck now...)

@hameerabbasi
Copy link
Collaborator

hameerabbasi commented Nov 10, 2020

I think a more accurate statement would be the following.

All Tensor methods and torch, torch.* namespace functions, excluding methods of types other than Tensor.

@mlamarre
Copy link
Author

The backward() method goes through __torch_function__ and the grad attributes has the subclass type in the example.

@hameerabbasi
Copy link
Collaborator

I've created a PR to fix the docs (#51031)

@kosiokarchev
Copy link

What is the status of this issue with regards to subclasses with metadata? Why would it be a problem to have "reference cycles"? Python's garbage collector (I've heard) can handle reference cycles. Moreover, in many cases the subclass will behave exactly the same as a base Tensor with regards to autograd: if it just carries around some metadata (and maybe modifies it based on the operations it encounters); or if, say, a LoggingTensor wants to also log the operations in the backward pass. Currently, for intermediate gradient calculations, __torch_function__ / as_subclass don't seem to be invoked. At least there could be a switch on the subclass to indicate that it wants to be remembered as a Python object or to provide some reduce-style method of being safely saved in a computational graph.

@ezyang ezyang reopened this Mar 31, 2021
@ezyang
Copy link
Contributor

ezyang commented Mar 31, 2021

First off, we accidentally closed this; the doc fix isn't a feature fix and we haven't done that.

Why would it be a problem to have "reference cycles"? Python's garbage collector (I've heard) can handle reference cycles.

It matters because the tensors are fairly large, so you care quite a bit about prompt deallocation (lest you OOM). Python GC can detect reference cycles but not promptly, and so you might end up OOMing before GC gets around to running. (btw #50186 is supposed to make this better and I really need to get around to reviewing it.)

A reduce style method of being saved would probably work, and is probably worth doing.

@ken012git
Copy link

ken012git commented Jul 28, 2021

I'm a little skeptical. If I'm using __torch_function__ just to implement something like LoggingTensor subclass, then shunting all calls to CustomFunc.apply straight to __torch_function__ seems unwanted and unnecessary. I'm only a little skeptical, though, because keeping autograd a Tensor only shop certainly does make things a lot simpler on the implementation side.

One question is whether or not direct .apply() into __torch_function__ is whether or not a user can actually implement the functionality in question by hand. In particular, @mlamarre wants to run backward() with a logging tensor for the saved backward tensors. Let's suppose he is now implementing a LoggingTensor specific version of the custom function (ugh). To get custom autograd, he still needs to write a autograd function. So it would have to look something like this:

class SquareFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        print('forward x type',type(x),'x data_ptr',x.data_ptr())
        ctx.save_for_backward(x)
        y = torch.mul(x,x)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        x = LoggingTensor(x)  # [NEW]
        print('backward x type',type(x),'x data_ptr',x.data_ptr())
        return 2*x*grad_output

I tried adding a LoggingTensorBackward that converts Tensor to LoggingTensor in backward, but the layer seems still receiving Tensor instead of LoggingTensor type.

import torch
import logging

class LoggingTensor(torch.Tensor):

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # following line was changed from the documentation to avoid an infinite recursion because of __repr__
        logging.info(f"func: {func.__name__}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

class LoggingTensorBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = LoggingTensor(grad_output)
        print('backward grad_input type',type(grad_input),'grad_input data_ptr',grad_input.data_ptr())
        return grad_input

class SquareFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        print('forward x type',type(x),'x data_ptr',x.data_ptr())
        ctx.save_for_backward(x)
        y = torch.mul(x,x)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        x, = ctx.saved_tensors
        print('backward x type',type(x),'x data_ptr',x.data_ptr())
        print('backward grad_output type',type(grad_output),'grad_output data_ptr',grad_output.data_ptr())
        return 2*x*grad_output

lt = LoggingTensor(torch.rand((3,3)))
lt.requires_grad_(True)
y = SquareFunction.apply(lt)
y = LoggingTensorBackward.apply(y)
y.backward(torch.ones_like(y))
assert(lt.grad is not None) # that works

output:

forward x type <class '__main__.LoggingTensor'> x data_ptr 94886885498048
backward grad_input type <class '__main__.LoggingTensor'> grad_input data_ptr 94886934698944
backward x type <class 'torch.Tensor'> x data_ptr 94886885498048
backward grad_output type <class 'torch.Tensor'> grad_output data_ptr 94886934698944  <- expected to be LoggingTensor ?

@ken012git
Copy link

ken012git commented Jul 28, 2021

master branch (commit: 73f1e2d) seems fixed the issue

forward x type <class '__main__.LoggingTensor'> x data_ptr 94490178958336
backward grad_input type <class '__main__.LoggingTensor'> grad_input data_ptr 94490178887296
backward x type <class '__main__.LoggingTensor'> x data_ptr 94490178958336
backward grad_output type <class '__main__.LoggingTensor'> grad_output data_ptr 94490178887296

@ken012git
Copy link

ken012git commented Jul 29, 2021

However, as I tried to add some torch functions in between, the functions return torch.Tensor in backward.
How to propagate subclass tensors such as LoggingTensor through the backward computation graph?

lt = LoggingTensor(torch.rand((3,3)))
lt.requires_grad_(True)
y = SquareFunction.apply(lt)
# y = torch.clamp(y, -100, 100)  # <- gives 'torch.Tensor'
# y = torch.div(y, -100)  # <- gives 'torch.Tensor'
y = y / -100  # <- gives 'torch.Tensor'
y = LoggingTensorBackward.apply(y)
y.backward(torch.ones_like(y))
assert(lt.grad is not None) 

output:

forward x type <class '__main__.LoggingTensor'> x data_ptr 94712420593024
backward grad_input type <class '__main__.LoggingTensor'> grad_input data_ptr 94712420499648
backward x type <class '__main__.LoggingTensor'> x data_ptr 94712420593024
backward grad_output type <class 'torch.Tensor'> grad_output data_ptr 94712420659520

@ezyang
Copy link
Contributor

ezyang commented Jul 30, 2021

master branch (commit: 73f1e2d) seems fixed the issue

@ken012git Yes, @albanD landed a change to unconditionally save the original Tensor (without shallow copy and detaching it first) and so this issue is technically fixed. However...

How to propagate subclass tensors such as LoggingTensor through the backward computation graph?

To do this, you need an even fancier mechanism that we are currently alpha testing internally. It's on master, so you can try it on master; however, we are probably going to make some BC-breaking API changes in the future (related to kwarg handling). Check #59760 and the quoted issue.

@Varal7
Copy link
Contributor

Varal7 commented Aug 18, 2021

@ken012git Yes, @albanD landed a change to unconditionally save the original Tensor (without shallow copy and detaching it first) and so this issue is technically fixed. However...

I believe that unconditionally is not technically true here?
The check is whether the tensor to save is a leaf or not an output (meaning not the output of the Node that saves it, here the custom function)

if (!is_output || is_leaf_) {
saved_original_ = true;
data_ = variable;
return;
}

But with custom functions, that can never be the case (I think??), so this issue is technically fixed.

Edit: after discussion with @albanD, here is a case where the original tensor is not saved:

import torch
import logging

class LoggingTensor(torch.Tensor):

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        # following line was changed from the documentation to avoid an infinite recursion because of __repr__
        logging.info(f"func: {func.__name__}")
        if kwargs is None:
            kwargs = {}
        return super().__torch_function__(func, types, args, kwargs)

class IdentityFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        print('forward x type',type(x),'x data_ptr',x.data_ptr())
        y = x.clone()
        ctx.save_for_backward(y)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        y, = ctx.saved_tensors
        print('backward y type',type(y),'y data_ptr',y.data_ptr())
        print('backward grad_output type',type(grad_output),'grad_output data_ptr',grad_output.data_ptr())
        return grad_output
    

lt = LoggingTensor(torch.rand((3,3)))
lt.requires_grad_(True)
y = IdentityFunction.apply(lt)
y.sum().backward()

output:

forward x type <class '__main__.LoggingTensor'> x data_ptr 93982988949312
backward y type <class 'torch.Tensor'> y data_ptr 93982988489664
backward grad_output type <class 'torch.Tensor'> grad_output data_ptr 93982988936000

In general we're not necessarily saving the original tensor, for instance:

a = torch.randn(5, requires_grad=True)
y = torch.exp(a)
print(y.grad_fn._saved_result.equal(y))  # True
print(y.grad_fn._saved_result is y) # False

This stack was an attempt to fix it: #60399, but was abandoned.

I think an alternative way is to expose a post-unpacking hook, which would be called line 187, after we create a new tensor and give it a grad metadata.

Variable var;
if (grad_fn) {
var = make_variable(data, Edge(std::move(grad_fn), output_nr_));
} else {
var = make_variable(data, requires_grad_);
}

Edit: a better way is to do #63485

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: autograd Related to torch.autograd, and the autograd engine in general module: docs Related to our documentation, both in docs/ and docblocks module: __torch_function__ needs design triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants