Skip to content

Support in-place operations on Variable views #3313

@colesbury

Description

@colesbury

Currently, we don't allow in-place operations on views if there are any other live Variables sharing the same data. We should allow in-place operations and compute the gradient correctly. There are a few changes necessary to support this.

Viewing operations include view, narrow, select, expand, [un]squeeze, and unfold. By "a view" I mean a Variable that is a view on another Variable.

  1. Add a field "base" to Variable. Every view has a pointer to a single base Variable. (The base is never a view)
  2. In-place operations on views change the grad_fn of the base, not of the view.
  3. The grad_fn on a view may become stale. So views also store an expected_version

In-place operations on views

Every in-place operation on a view inserts three Functions to base.grad_fn. For example:

x = base[2]
x.mul_(y)

becomes:

Stride(MulBackward(Unstride(<old base.grad_fn>), <y.grad_fn>))

Assume, for now, that base is contiguous. The Stride op takes the grad_output, makes it contiguous, and applies the stride, sizes, and storageOffset of the view (x). This performs the identical viewing operation on grad_output that was performed on base to get x. For the above example, this is equivalent to:

grad_input = grad_output[2]

Unstride performs the reverse operation. It creates a grad_input with the same shape as base, applies the striding of the view x to grad_input, and copies grad_output into that view. For the above example, this is equivalent to:

grad_input = zeros_like(base)
grad_input[2] = grad_output

Stride and Unstride are both implemented in terms of a differentiable restride operation that generalizes all of our viewing operations:

Tensor restride(Tensor self, IntList sizes, IntList strides, int offset);

Stale grad_fn on views

An in-place operation on a view makes the grad_fn of all other views stale. The compute_next_functions needs to be updated to handle views: it should re-create the Function for a view if it's stale before returning it.

Non-contiguous base Variables

If base is not contiguous, then we also need the strides/sizes of the base in restride:

Tensor restride(Tensor self, IntList sizes, IntList strides, IntList base_sizes, IntList base_strides, int offset) {
  Tensor copy = self.type().tensor(base_sizes, base_strides);
  copy.copy_(self);
  return restride(copy, sizes, strides, offset);
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions