-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
Currently, we don't allow in-place operations on views if there are any other live Variables sharing the same data. We should allow in-place operations and compute the gradient correctly. There are a few changes necessary to support this.
Viewing operations include view
, narrow
, select
, expand
, [un]squeeze
, and unfold
. By "a view" I mean a Variable that is a view on another Variable.
- Add a field "base" to Variable. Every view has a pointer to a single base Variable. (The base is never a view)
- In-place operations on views change the
grad_fn
of the base, not of the view. - The
grad_fn
on a view may become stale. So views also store anexpected_version
In-place operations on views
Every in-place operation on a view inserts three Functions
to base.grad_fn
. For example:
x = base[2]
x.mul_(y)
becomes:
Stride(MulBackward(Unstride(<old base.grad_fn>), <y.grad_fn>))
Assume, for now, that base
is contiguous. The Stride op takes the grad_output
, makes it contiguous, and applies the stride
, sizes
, and storageOffset
of the view (x
). This performs the identical viewing operation on grad_output
that was performed on base
to get x
. For the above example, this is equivalent to:
grad_input = grad_output[2]
Unstride performs the reverse operation. It creates a grad_input
with the same shape as base
, applies the striding of the view x
to grad_input
, and copies grad_output
into that view. For the above example, this is equivalent to:
grad_input = zeros_like(base)
grad_input[2] = grad_output
Stride
and Unstride
are both implemented in terms of a differentiable restride operation that generalizes all of our viewing operations:
Tensor restride(Tensor self, IntList sizes, IntList strides, int offset);
Stale grad_fn on views
An in-place operation on a view makes the grad_fn
of all other views stale. The compute_next_functions
needs to be updated to handle views: it should re-create the Function
for a view if it's stale before returning it.
Non-contiguous base Variables
If base is not contiguous, then we also need the strides/sizes of the base in restride
:
Tensor restride(Tensor self, IntList sizes, IntList strides, IntList base_sizes, IntList base_strides, int offset) {
Tensor copy = self.type().tensor(base_sizes, base_strides);
copy.copy_(self);
return restride(copy, sizes, strides, offset);
}