Skip to content

Commit

Permalink
Propagate CreationMeta when chaining views (#51061)
Browse files Browse the repository at this point in the history
Summary:
Fixes #49824

## Background

When creating a view of a view, there was a possibility that the new view would be less restrictive than the previous view, incorrectly sidestepping the error that should be thrown when using in-place operations on the new view.

The fix addresses this by propagating `CreationMeta` from the previous view to the new view. Currently, the old view's `creation_meta` is only propagated when the new view's `creation_meta == CreationMeta::DEFAULT`. This ensures that the new view is not less restrictive than the previous view wrt. allowing in-place operations.

Pull Request resolved: #51061

Test Plan:
```
python test/test_autograd.py TestAutogradDeviceTypeCPU.test_inplace_view_of_multiple_output_view_cpu
python test/test_autograd.py TestAutogradDeviceTypeCUDA.test_inplace_view_of_multiple_output_view_cuda
python test/test_autograd.py TestAutogradDeviceTypeCPU.test_inplace_multiple_output_view_of_view_cpu
python test/test_autograd.py TestAutogradDeviceTypeCUDA.test_inplace_multiple_output_view_of_view_cuda
```

Reviewed By: heitorschueroff

Differential Revision: D26076434

Pulled By: jbschlosser

fbshipit-source-id: c47f0ddcef9b8449427b671aff9ad08edca70fcd
  • Loading branch information
jbschlosser authored and facebook-github-bot committed Jan 27, 2021
1 parent 5ec2e26 commit 0b5303e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7377,6 +7377,20 @@ def test_inplace_view_multiple_outputs(self, device):
with self.assertRaises(RuntimeError):
v1[0].mul_(2)

def test_inplace_view_of_multiple_output_view(self, device):
a = torch.rand(10, device=device, requires_grad=True).clone()
b = a.unbind(0)
c = b[0].view_as(b[0])
with self.assertRaises(RuntimeError):
c.mul_(2)

def test_inplace_multiple_output_view_of_view(self, device):
a = torch.rand(10, device=device, requires_grad=True).clone()
b = a.view_as(a)
c = b.unbind(0)
with self.assertRaises(RuntimeError):
c[0].mul_(2)

def test_inplace_view_makes_base_require_grad(self, device):
# in-place modification to view makes base require grad
a = torch.randn(4, 4, device=device, requires_grad=False)
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/autograd/VariableTypeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_dif
if (base.is_view()) {
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
const auto& base_bw_info = diff_view_meta->get_backward_view();
creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta);
return make_variable_differentiable_view(tensor, base_bw_info.chain(base, tensor, view_func),
c10::nullopt, creation_meta, allow_tensor_metadata_change);
} else {
Expand Down Expand Up @@ -188,6 +189,10 @@ inline Tensor as_view(const Tensor & base, const Tensor & tensor, bool is_bw_dif
}

if (is_fw_differentiable || is_bw_differentiable) {
if (base.is_view()) {
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta);
}
return make_variable_differentiable_view(tensor, std::move(new_bw_info), std::move(new_fw_info),
creation_meta, allow_tensor_metadata_change);
} else {
Expand Down Expand Up @@ -234,6 +239,11 @@ inline std::vector<Tensor> as_view(const Tensor & base, std::vector<Tensor>& ten
}
}

if ((is_fw_differentiable || is_bw_differentiable) && base.is_view()) {
auto diff_view_meta = static_cast<DifferentiableViewMeta*>(torch::autograd::impl::get_autograd_meta(base));
creation_meta = propagate_creation_meta(diff_view_meta->get_creation_meta(), creation_meta);
}

for(Tensor &tensor : tensors) {
if (is_fw_differentiable || is_bw_differentiable) {
tensor = make_variable_differentiable_view(tensor, new_bw_info, new_fw_info, creation_meta);
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/autograd/variable.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,15 @@ struct TORCH_API ViewInfo {
enum class CreationMeta: uint8_t { DEFAULT, IN_CUSTOM_FUNCTION, MULTI_OUTPUT_NODE,
NO_GRAD_MODE, MULTI_OUTPUT_SAFE };

/// Handles correctly propagating CreationMeta when a new view is created from a previous view.
/// In general, we don't want the new view to be _less_ restrictive than the previous view
/// (it's okay to be _more_ restrictive). A CreationMeta value of DEFAULT is currently the least
/// restrictive, as the behavior for all other CreationMeta values is to error out for in-place ops.
/// If this changes, the logic here will need to be updated to properly handle the new semantics.
inline CreationMeta propagate_creation_meta(CreationMeta prev_view_creation_meta, CreationMeta new_view_creation_meta) {
return (new_view_creation_meta == CreationMeta::DEFAULT) ? prev_view_creation_meta : new_view_creation_meta;
}

/// Unified function to handle error checking when rebase happens
/// indirect=true means that the caller is not doing the inplace, but the inplace happened
/// somewhere else.
Expand Down

0 comments on commit 0b5303e

Please sign in to comment.