Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variable/Tensor Merge Proposal #13638

Open
yf225 opened this issue Nov 6, 2018 · 14 comments

Comments

Projects
None yet
5 participants
@yf225
Copy link
Contributor

commented Nov 6, 2018

馃殌 High-level changes:

  1. IMPORTANT: Both Variable and Variable::Impl are removed, and at::Tensor is always the tensor that's passed around in PyTorch, and it can record autograd history when its autograd metadata (AutogradMeta) is not null.
  2. IMPORTANT: Autograd-related function implementations in Variable will be moved to VariableType.
  3. Autograd metadata now lives in an AutogradMeta struct that TensorImpl has a pointer to, and the AutogradMeta is only populated when the at::Tensor requires gradient.
  4. We decide whether to dispatch to VariableType / non-VariableType functions using the at::AutoNonVariableTypeMode in appropriate places internally. (We only dispatch to VariableType functions if we need profiling/JIT-tracing/autograd)
  5. Common Tensor functions (e.g. numel()聽/聽sizes()聽/聽dim()) are de-virtualized in TensorImpl and have their runtime reduced by 43%-86%.
  6. tensor.is_variable() and options.is_variable() always return true, because every at::Tensor is a variable (and can record autograd history when its AutogradMeta is not null). (We keep options.is_variable(...) for backward compatibility, and raise warning if it's set to false.)
  7. API behavior change: changing shape/storage on tensor.data in Python or tensor.data() in C++ will no longer update tensor.

Pitch

Currently, the distinction between at::Tensor and Variable (subclass of at::Tensor that contains autograd metadata and functions) creates unnecessary cognitive overhead for PyTorch core development. We want to remove this distinction and make it possible to use at::Tensor everywhere in PyTorch. After merging Variable into at::Tensor, here are the common end-user APIs:

  • When C++ user wants to create a non-history-recording at::Tensor from another at::Tensor:
    Current API (unchanged):
auto t = torch::ones({2, 2}, torch::requires_grad()); // t is recording history
auto t_detached = t.detach() // t_detached is the non-history-recording version of t

When the user calls t.detach(), we do the following under the hood:

  1. We do the shallow copy of t's TensorImpl, which copies the storage pointer and all other TensorImpl fields (e.g. size / stride).
    • Note that subclasses of TensorImpl (e.g. SparseTensorImpl) need to know how to make a shallow copy of themselves, and we dispatch this operation to each TensorImpl subclass' own shallow_copy_and_detach() function (by making the shallow_copy_and_detach() function virtual in TensorImpl and overriding it in TensorImpl subclasses).
  2. We set the AutogradMeta pointer to NULL, to indicate that it doesn't need to record history.
  3. We return an at::Tensor that wraps the new TensorImpl.

  • When C++ user wants to enable/disable history-recording for an at::Tensor:
    Proposed API:
auto t = torch::ones({2, 2});  // t is not recording history (this already works)
t.requires_grad_(true);  // t is recording history now (new API)
t.requires_grad_(false); // t is not recording history anymore (new API)

When the user calls t.requires_grad_(true), we do the following under the hood:

  1. We initialize a struct called AutogradMeta, which stores autograd-specific fields (such as grad_/grad_fn_/grad_accumulator_).
  2. We assign the struct to the AutogradMeta pointer in t's TensorImpl.

When the user calls t.requires_grad_(false), we do the following under the hood:

  1. We set the AutogradMeta pointer in t's TensorImpl to NULL.

  • When C++ user wants to call non-Variable operations on an at::Tensor when dispatching through type()
    Proposed API:
{
  auto t_type = t.type();  // `t_type` is a Variable type if `t` contains AutogradMeta
}
{
  at::AutoNonVariableTypeMode grad_mode(false);  // thread-local guard (new API)
  auto non_var_type = t.type();  // "non_var_type" is a non-Variable type
}
{
  at::AutoNonVariableTypeMode grad_mode(true);  // thread-local guard (new API)
  auto var_type = t.type();  // "var_type" is a Variable type
}

Under the hood, type() checks whether the at::AutoNonVariableTypeMode thread-local guard is enabled when determining the type of the variable.


  • When C++ user wants to change content of an at::Tensor that has AutogradMeta, without affecting the tensor's grad_fn or version_counter_
    Proposed behavior:
auto t = torch::ones({2, 2});
t.requires_grad_(true);
AT_ASSERT(t.current_version() == 0);
t.data().add_(1);  // This is consistent with Python `.data` behavior: changing `.data` of a tensor in Python doesn't affect the tensor's `grad_fn` or `version_counter_`
AT_ASSERT(t.current_version() == 0);

Motivation

  • Overly Complex OOP design: Currently the distinction between Variable and Tensor is hard to grasp: Variable::Impl is a subclass of TensorImpl, but it also has an at::Tensor data member which internally wraps another TensorImpl. This co-existence of "is-a" and "has-a" relationship makes the code complicated and adds cognitive overhead. In particular, it's difficult to track which functions we have overridden in Variable::Impl, and which functions are applicable to Tensor vs. Variable (e.g. is_wrapped_number() is only valid on Tensor, not Variable) (for more context, also see note: We regret making Variable hold a Tensor). Ideally, we want to use the same tensor type everywhere in PyTorch code.

  • Unused data members in Variable::Impl take up cache/memory space: Since Variable::Impl is a subclass of TensorImpl, it contains all of the data members that a normal TensorImpl would have (such as sizes_ / strides_ / etc.). However, the Variable::Impl functions always call into the underlying at::Tensor and ignores the rest of the fields, which causes a lot of wasted cache/memory space.

  • Virtual functions are slow: We care about how much time it takes to execute common Tensor functions such as numel() / sizes() / dim(). Currently, these functions are virtual in TensorImpl, so that Variable::Impl (a subclass of TensorImpl) can override them and dispatch those calls to the Variable::Impl's underlying at::Tensor. Virtual function calls are slow because they involve an extra vtable lookup. Specifically, we did the following comparison on the most common Tensor functions (all timings are in ns):

Benchmark Time (no flush) Time (flush L1) Time (flush L1+L2) Time (flush L1+L2+L3)
Tensor.dim() - non-virtual 1.3 3.33 7.6 58
Variable.dim() - virtual 4.5 24.4 52 173.67
Runtime Savings -71.11111% -86.35246% -85.38462% -66.60333%
Tensor.numel() - non-virtual 22.6 63.89 109.22 294.5
Variable.numel() - virtual 80.33 133.1 192 810.9
Runtime Savings -71.86605% -51.9985% -43.11458% -63.68233%
Tensor.size(0) - non-virtual 30.4 60.1 100.44 384.3
Variable.size(0) - virtual 75.4 127.67 203.8 875.9
Runtime Savings -59.6817% -52.92551% -50.71639% -56.12513%
Tensor.sizes() - non-virtual 2 4.25 13.25 67.6
Variable.sizes() - virtual 5.2 28.44 62.1 254.78
Runtime Savings -61.53846% -85.05626% -78.66345% -73.46731%
Tensor.resize_({0}) no-op - non-virtual 23.11 86.44 105.44 332.33
Variable.resize_({0}) no-op - virtual 168.4 254.22 348.56 890.9
Runtime Savings -86.27672% -65.99795% -69.74983% -62.69727%
Tensor.resize_({64, 2048}) no-op - non-virtual 33.4 102.56 129.56 407.22
Variable.resize_({64, 2048}) no-op - virtual 193 278.1 364.9 936.6
Runtime Savings -82.6943% -63.12118% -64.49438% -56.52146%

Benchmarked commit: f000101
Benchmark script: https://github.com/yf225/benchmark/blob/tensor_functions/timing/cpp2/benchmarks/aten_overheads.cpp
Non-virtual code: master...yf225:nonvirtual_tensorimpl
Virtual code: master...yf225:virtual_tensorimpl

Based on our current implementation, the runtime difference for dim(), numel(), size(), sizes(), and no-op resize() comes from the virtual function call overhead and the at::Tensor data member indirection in Variable::Impl. If we de-virtualize those functions, we would be able to cut the runtime by 43%-86% on the most common Tensor functions.

Breaking changes

Note that this change will break the current API in the following way:

In the old world, whenever we want to create a Variable that shares the same data with another Variable, we simply do auto var_new = make_variable(var.data()) or auto var_new = var.detach(), and any shape / data / storage pointer changes to var_new will be reflected in var automatically, because internally they share the same underlying at::Tensor.

However, in the new world, there is no concept of the "underlying at::Tensor" of a Variable, since the Variable itself is the Tensor. When we want to create an at::Tensor that shares the same data with another at::Tensor, we can still call auto t_new = t.detach(), but in this case, only the tensor storage data is shared (via ref-counted pointer) between t_new and t, but not the tensor size/stride information (they are copied by value). In other words, changing anything (e.g. size / stride / storage_ptr ) in the detached Tensor (t_new) that are not bits inside tensor storage won't update the original Tensor (t), and we should no longer expect those data to be shared.

This has implications for Python call sites that do

tensor.data.in_place_operation_()

or

tensor_detached = tensor.detach()
tensor_detached.in_place_operation_()

If in_place_operation_() only updates the data inside the tensor (such as zeros_()), such operation will still work properly; if the in-place operation changes the size, stride or the storage pointer inside the TensorImpl (e.g. resize_ / resize_as_ / set_ / transpose_), such operation on tensor.data or tensor_detached will no longer update the tensor. We will address this inconsistency in the following ways:

  1. Add an allow_tensor_metadata_change_ flag to TensorImpl to disallow size/stride/storage_ptr changes from in-place operations such as resize_ / resize_as_ / set_ / transpose_, and set this flag to true when people call tensor.data in Python.
  2. Write text in the docs to actively discourage changing the shape or storage of tensor_detached and expecting tensor to also be updated.

Upcoming PRs

  1. Add a flag to TensorImpl to disallow size/stride/storage_ptr changes from in-place operations such as resize_ / resize_as_ / set_ / transpose_, and set this flag to true when people call tensor.data in Python.
  2. Write text in the docs to actively discourage changing the shape or storage of tensor_detached and expecting tensor to also be updated.
  3. Move Variable::Impl data members into TensorImpl as AutogradMeta struct
  4. Change Variable::Impl functions to use data members in AutogradMeta struct
  5. Add shallow_copy() function to each subclass of TensorImpl
  6. Do shallow copy when the user calls make_variable(tensor) / variable.detach() (Reason: now that autograd metadata lives in TensorImpl, in order to create a new history for for the Variable returned from variable.detach() we not only need to create a new AutogradMeta struct, but we also need to create a new TensorImpl object that stores pointer to the new AutogradMeta struct (which we obtain by shallow-copying the original TensorImpl). Otherwise, changing history of the detached Variable will also change the history of the original Variable, which is not the correct behavior.)
  7. Add AutogradMetaInterface class, and make AutogradMeta a subclass of it, so that we can make autograd_meta_ a unique_ptr in TensorImpl
  1. Move set_requires_grad() / requires_grad() / grad() from Variable::Impl to AutogradMeta
  2. Move Variable::Impl functions such as backward() / rebase_history() / grad_accumulator() / grad_fn() out of Variable::Impl and into AutogradMeta.
  3. Note: we need to make these changes so that we can remove Variable::Impl class in the next PR.
  1. Add thread-local guard (at::AutoNonVariableTypeMode) to make sure that in VariableType.cpp the operations on baseType still dispatch to non-Variable type, even if the parameters are now Variables
  1. Make gesv_out return the original input tensor instead of a new tensor (currently by copying the result tensor into the original input tensor, because a true in-place gesv is more difficult to implement. NOTE: also open an issue for this).
  2. In VariableType.cpp, after each in-place function on the "unpacked" tensor, check pointer address equality for storage in the original input variable's TensorImpl (check this for all arguments in unpacked_args)
  1. Remove .type() calls as much as possible, to reduce the need of using the at::AutoNonVariableTypeMode guard
  1. Make JIT attributes t_ and ts_ store Variable instead of Tensor (and in t_ and ts_ use sites, don't wrap the tensor into Variable again) (global search make_variable( in jit/ to find places where we are doing double-wrapping for t_ and ts_ attributes)
  1. tril_ and triu_ should not change the input tensor's TensorImpl pointer
  1. Move pyobj_ to TensorImpl itself, because we always need to be able to convert to and from the Python representation.
  1. Move version_counter_ to storage or TensorImpl, because we may capture non-requires-grad variables inside an autograd function, and we need a working version counter in these cases.
  2. We should not share version counter in shallow_copy_and_detach(), because a pure Tensor doesn't have concept of version counter, and it's managed by autograd instead.
  3. We should preserve the API semantics of tensor.data in Python, and allow it as an escape route for in-place operations without bumping version counter.
  1. tensor.is_variable() should check whether the TensorImpl has AutogradMeta. is_variable_ should be removed.
  • PR: Fix version counter sharing in Variable.set_data(...) #20391

  • PR: Move at::NonVariableTypeMode to TensorImpl, and check it in TensorImpl is_variable() #20392

  • PR: Require passing version_counter and allow_tensor_metadata_change to shallow_copy_and_detach(): #20496

  • PR: Shallow-copy indices and values in sparse tensor constructor #20330

  • PR: Remove Variable::Impl (#17072)

  1. Remove the at::Tensor data member (data_) from Variable::Impl
  2. In Variable construction and in Variable.set_data(), copy all data from data.impl to the variable's TensorImpl.
  3. Make Variable.data() the same semantics as tensor.data in Python. Notice breakage in any Variable.data() call sites
  4. Remove the Variable::Impl class and the DifferentiableViewImpl class
  5. Remove mentions of Variable::Impl and DifferentiableViewImpl
  6. Fix comments in [Tensor versus Variable in C++], [We regret making Variable hold a Tensor], [ Autograd View Variables ]. Go through all comments in variable.h and variable.cpp and fix any inconsistency.
  7. NOTE: we don't need to add SparseVariableImpl that handles how to copy SparseTensorImpl, because SparseTensorImpl already implements the shallow_copy_and_detach() function that Variable factory functions can call.
  8. In places where we need to ensure the tensor is not requiring gradient, we should check !requires_grad() || at::NonVariableTypeMode::is_enabled(), instead of !requires_grad() || !at::GradMode::is_enabled(), because we don't want to move at::GradMode to ATen.
  • PR: (NOTE: blocked on asking XLA team to de-virtualize XLATensorImpl)
  1. Remove the virtual attribute from common Tensor functions in TensorImpl
  2. Move the sparse-specific implementation of common Tensor functions such as dim() / sizes() from SparseTensorImpl to TensorImpl, by branching based on is_sparse() in TensorImpl
  • PR: Remove tensor_data() call sites as much as possible, since this is the old semantics. Try to use the variable_data() semantics (and manually set requires_grad and edge if needed). And then remove tensor_data() API.

  • PR: [We might not need this, because this might just work: yf225:create_autograd_meta_function_pointer] Move set_requires_grad() to VariableType, make sure its implementation correctly reflect the implementations in AutogradMeta and DifferentiableViewMeta. (branch: yf225:move_set_requires_grad_to_VariableType) (PR: #21362)

  • PR:

  1. Can we initialize AutogradMeta for Python tensors that need grad_fn? So that we don't need to move grad_fn_ and output_nr_ from AutogradMeta to TensorImpl?
  2. Move grad_fn_ and output_nr_ from AutogradMeta to TensorImpl (in yf225:move_grad_fn_and_output_nr_to_tensorimpl branch) (Test PR: #19964)
  3. Why we need this: SavedVariable (
    SavedVariable::SavedVariable(const Variable& variable, bool is_output) {
    if (variable.defined()) {
    was_default_constructed_ = false;
    output_nr_ = variable.output_nr();
    requires_grad_ = variable.requires_grad();
    has_grad_fn_ = !variable.is_leaf();
    // These copies are all shared_ptr copies, so slightly more expensive.
    // Do them here instead of in the init list in case data is undefined.
    data_ = variable.data();
    if (variable.is_leaf()) {
    grad_accumulator_ = variable.grad_accumulator();
    } else if (!is_output) {
    grad_fn_ = variable.grad_fn();
    }
    version_counter_ = variable.version_counter();
    saved_version_ = version_counter_.current_version();
    }
    }
    ) can save variables that don't require grad, and it expects grad_fn_ and output_nr_ to exist for those non-requires-grad variables. So we need to move grad_fn_ and output_nr_ from AutogradMeta to TensorImpl to have them always exist. UPDATE: if a Tensor doesn't have AutogradMeta, it should not have grad_fn or output_nr_ because it's never an output of a history-tracking operation and we should never call backward() on it.
  4. The second reason why we need this:
import torch
from torch.autograd import Variable, Function

class CollectOnDelete(Function):
    def __del__(self):
        gc.collect()

a = Variable(torch.randn(10, 10), _grad_fn=CollectOnDelete())
a.requires_grad  # prints True because `a`'s grad_fn_ is not null, even though `a` doesn't have autograd_meta
  • PR: Don't set autograd_meta when requires_grad = false (we need to benchmark this) (After this PR, all tensors are variables)
  1. NOTE: A view on Variables that doesn't have autograd metadata should behave the same as view on non-requires-grad Variables.
  2. We only create tensor with autograd_meta populated when options.requires_grad() == true. In all other cases, we create tensor with autograd_meta = null, to optimize for memory usage. To make this work properly, in VariableType*.cpp we need to assert that tensors have autograd metadata before performing operations on them.
  3. Remove make_variable(), add attach_autograd_meta()? Think about this.
  4. Remove _tensor_data_deprecated() API
  5. tensor.requires_grad() behavior:
  • When tensor.requires_grad() == false, this tensor does not have autograd_meta.
  • When tensor.requires_grad() == true, if this tensor is a view, it is guaranteed to have autograd_meta; if this tensor is not a view, it either has autograd_meta (aka. a leaf variable), or its grad_fn_ is not null.
    Example:
import torch
from torch.autograd import Variable, Function

class CollectOnDelete(Function):
    def __del__(self):
        gc.collect()

a = Variable(torch.randn(10, 10), _grad_fn=CollectOnDelete())
a.requires_grad  # prints True because `a`'s grad_fn_ is not null, even though `a` doesn't have autograd_meta
  • When set_requires_grad(true) is called on a non-view tensor that doesn't have grad_fn_ (aka. a leaf variable), its autograd_meta will be populated.
  • When set_requires_grad(false) is called on a non-view tensor that doesn't have grad_fn_ (aka. a leaf variable), its autograd_meta will be set to null.
  1. Move requires_grad_ field from AutogradMeta to DifferentialViewMeta.
    DifferentialViewMeta can have requires_grad_ equals to true or false. Example cases:
  • requires_grad_ needs to be true:
    def randn_like(x):
    y = torch.testing.randn_like(x if x.is_floating_point() else x.double())
    if gen_non_contig_grad_outputs:
    y = torch.testing.make_non_contiguous(y)
    return y.requires_grad_()

    In this example, y = torch.testing.make_non_contiguous(y) returns a view variable with requires_grad_ = false, and y.requires_grad_() is used to set the view variable's requires_grad_ to true.
  • requires_grad_ needs to be false:
a = torch.Tensor(2, 2)
b = torch.zeros_like(a)
c = b.view(-1)
c._is_view()  # prints True
c.is_leaf  # prints True
c.zero_() # c's requires_grad_ needs to be false for this to work, otherwise this throws "RuntimeError: a leaf Variable that requires grad has been used in an in-place operation."

Conclusion: since we need both cases to work, we can't just remove requires_grad_ from DifferentialViewMeta, as it is an important state field in DifferentialViewMeta.
4. Move implementation of requires_grad() from AutogradMeta and DiffViewMeta to TensorImpl.
3. Remove make_variable(), since every Tensor is a Variable now, and conversion shouldn't be needed (look at places where a shallow-copy with new autograd history is expected, and create a function like shallow_copy_with_new_autograd_meta() for it).
5. Address https://github.com/pytorch/pytorch/pull/17072/files#r276326234.

  • PR:

  • Tensors passed into Caffe2 shouldn't require grad. Enforce this.

  • Remove the unwrap logic created in #21620.

  • PR:

  • tensorimpl.is_variable() should be removed and tensor.is_variable() should be deprecated. Check at::NonVariableTypeMode::is_enabled() in all original check sites of !is_variable() in ATen instead.

  • as_variable() in torch/csrc/autograd/VariableTypeUtils.h is not needed anymore.

  • PR:

  1. options.is_variable() should always return true. (We keep options.is_variable(...) for backward compatibility, and raise warning if it's set to false.) Fix factory functions to make this happen, if necessary.
  2. For getType(TensorOptions), we only check at::NonVariableTypeMode::is_enabled() to decide whether to choose Variable path. And we add at::AutoNonVariableTypeMode to all proper places to ensure we are dispatching to non-Variable path when appropriate.
  • PR: Merge Variable into at::Tensor
  1. Move autograd-related functions from Variable to VariableType or free functions in torch::autograd::. (See https://fb.quip.com/M48WApjXT2aj)
  2. Address https://github.com/pytorch/pytorch/pull/17072/files#r260882208 (clean up set_requires_grad() signature)
  3. Address "TODO: These factory functions don't need to be friends anymore."
  4. Remove Variable and use at::Tensor everywhere.
  • PR: remove mentions of "Variable and Tensor are merged"

  • PR:

  1. Remove unpack() in VariableType*.cpp.
  2. Clean up the unpack_args logic in gen_variable_type.py, since we are not doing unpack anymore.
  3. Fix comments for use_derived in gen_variable_type.py
  • PR:
  1. Remove https://pytorch.org/cppdocs/#aten section, and replace all at::Tensor with torch::Tensor, and remove/fix all mentions of ATen in cpp docs and tutorials, when Variable and Tensor are merged (since now ATen becomes just an implementation detail).
  2. Can this https://pytorch.org/tutorials/advanced/cpp_extension.html#backward-pass generate backward pass automatically now?
  • PR
  1. Improve "NOTE: After the Variable/Tensor merge" comment based on #18223 (comment)
  • PR: Remove requires_tensor: True in native_functions.yaml. Figure out how to fix _dimV, _dimS case (torch.randn(2, 3)._dimV() shouldn't hit that error)

  • PR: address comments in #20496 (comment).

  • PR: address comments in
    https://github.com/pytorch/pytorch/pull/17072/files#r281784475

  • PR: Make sure in-place update on the original value tensor also updates the version counter in the sparse tensor's values_ tensor, and it should throw version mismatch error in backward() when the original value tensor is changed (this requires saving the values tensor for backward in the sparse constructor). Example:

ind=torch.tensor([[0],[1]]).add_(1).sub_(1)
values = torch.tensor([1.]).add_(1).add_(1).sub_(1).sub_(1)
c = torch.sparse_coo_tensor(ind, values).requires_grad_()
values.add_(1)
c.sum().backward()  # maybe not this exact syntax, but you get the idea.
  • PR (not strictly related to the merge, but for improving consistency with Python API):
  1. Add Tensor.requires_grad_(bool) C++ API (which internally calls Tensor.set_requires_grad(bool)), to be consistent with Python API

Non-goals:

  • [TODO: put this note somewhere in code] We intentionally don't merge at::GradMode and at::NonVariableTypeMode, with the following reasoning:
    Semantically, at::GradMode and at::NonVariableTypeMode actually mean different things: at::GradMode controls whether a tensor should accumulate gradients, and at::NonVariableTypeMode controls whether a Variable should be treated as a non-Variable tensor in type dispatches. There are places whether we don't want the tensor to accumulate gradients, but still want the Variable to be treated as a Variable. Here is one example:
#  torch/tensor.py
with torch.no_grad():
   ...
   new_tensor = self.new()    # `at::GradMode` is false at this point
   ...
// tools/autograd/templates/python_variable_methods.cpp
static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwargs)
{
  ...
  // if we merge `at::GradMode` and `at::NonVariableTypeMode`, since `at::GradMode` is false and `self_.type()` checks `at::GradMode` to decide whether to return non-Variable type, it will return a non-Variable type here, which is not what we want (and throws a "Tensor that was converted to Variable was not actually a Variable" error)
  return THPVariable_Wrap(torch::utils::legacy_tensor_new(self_.type(), args, kwargs));
  ...
}

For the above reason, we cannot merge at::GradMode and at::NonVariableTypeMode, as they have different purposes.

@zou3519 zou3519 referenced this issue Nov 9, 2018

Open

[perf] Reduce tensor & aten overhead #13049

14 of 21 tasks complete

facebook-github-bot added a commit that referenced this issue Dec 27, 2018

Move autograd metadata from VariableImpl to TensorImpl (#13827)
Summary:
Changes originally in this PR:
1. Move Variable::Impl data members into TensorImpl as `AutogradMeta` struct
2. Change Variable::Impl functions to use data members in `AutogradMeta` struct
3. Add `shallow_copy_and_detach()` function to each subclass of TensorImpl
4. Do shallow copy when the user calls `make_variable(tensor)` / `make_variable_view(tensor)` / `variable.set_data(tensor)` / `variable.detach()`

Changes moved from #13645:
1. Add a flag to Variable to disallow size/stride/storage_ptr changes from in-place operations such as `resize_` / `resize_as_` / `set_` / `transpose_`, and set this flag to true when people call `tensor.data` in Python.
2. Write text in the docs to actively discourage changing the shape or storage of `tensor_detached` and expecting `tensor` to also be updated.

This is the 1st+2nd PR mentioned in #13638.
Pull Request resolved: #13827

Differential Revision: D13507173

Pulled By: yf225

fbshipit-source-id: b177b08438d534a8197e34e1ad4a837e2db0ed6a

zdevito pushed a commit to zdevito/ATen that referenced this issue Dec 27, 2018

Move autograd metadata from VariableImpl to TensorImpl (#13827)
Summary:
Changes originally in this PR:
1. Move Variable::Impl data members into TensorImpl as `AutogradMeta` struct
2. Change Variable::Impl functions to use data members in `AutogradMeta` struct
3. Add `shallow_copy_and_detach()` function to each subclass of TensorImpl
4. Do shallow copy when the user calls `make_variable(tensor)` / `make_variable_view(tensor)` / `variable.set_data(tensor)` / `variable.detach()`

Changes moved from pytorch/pytorch#13645:
1. Add a flag to Variable to disallow size/stride/storage_ptr changes from in-place operations such as `resize_` / `resize_as_` / `set_` / `transpose_`, and set this flag to true when people call `tensor.data` in Python.
2. Write text in the docs to actively discourage changing the shape or storage of `tensor_detached` and expecting `tensor` to also be updated.

This is the 1st+2nd PR mentioned in pytorch/pytorch#13638.
Pull Request resolved: pytorch/pytorch#13827

Differential Revision: D13507173

Pulled By: yf225

fbshipit-source-id: b177b08438d534a8197e34e1ad4a837e2db0ed6a

facebook-github-bot added a commit that referenced this issue Dec 28, 2018

Move VariableImpl functions to AutogradMeta and Variable (#15487)
Summary:
In this PR, we are moving all functions away from `Variable::Impl`, in order to get rid of `Variable::Impl` (and the `data_` Tensor in it) in the next PR. Some of the functions (such as `set_requires_grad` / `requires_grad` / `grad`) will be living in `AutogradMeta` class, while others (such as `backward()` / `rebase_history()` / `grad_accumulator()` / `grad_fn()`) will be living in `Variable` class.

This is the 2nd PR mentioned in #13638.
Pull Request resolved: #15487

Differential Revision: D13553173

Pulled By: yf225

fbshipit-source-id: 691f9432d0cd0640af380c757f3e3a2f64f8851c
@apaszke

This comment has been minimized.

Copy link
Member

commented Jan 5, 2019

also see note: We regret making Variable hold a Tensor

The link gives a 404 馃槥

@yf225

This comment has been minimized.

Copy link
Contributor Author

commented Jan 5, 2019

@apaszke Fixed with a permalink to the correct file =)

@ezyang

This comment has been minimized.

Copy link
Contributor

commented Jan 16, 2019

@yf225: @zdevito, @smessmer and I were talking about at::AutoGradMode, and we think this is actually exactly the same thing as the existing at::NoGradGuard guard in C++. So might not even need a separate thread local for this one :)

@yf225

This comment has been minimized.

Copy link
Contributor Author

commented Jan 17, 2019

@ezyang Currently torch::NoGradGuard only exists in libtorch, but for the base type dispatch to work we need to check the guard in ATen. I am planning to work on merging them as one of the last few PRs of the Variable/Tensor merge.

facebook-github-bot added a commit that referenced this issue Jan 24, 2019

Add thread-local guard: at::AutoNonVariableTypeMode (#15939)
Summary:
This PR adds thread-local guard (`at::AutoNonVariableTypeMode`) to make sure that in VariableType.cpp the operations on baseType still dispatch to non-Variable type, even if the parameters will become Variables after the Tensor/Variable merge. We achieve this by making `legacyTensorType()` and `getType()` check the `at::AutoNonVariableTypeMode` guard to decide whether to return non-Variable type for a variable.

This is part of the VariableImpl/TensorImpl merge work: #13638.
Pull Request resolved: #15939

Reviewed By: ezyang

Differential Revision: D13640980

Pulled By: yf225

fbshipit-source-id: d12c2543822958558d7d70d36c50999a5eb8783f

zdevito pushed a commit to zdevito/ATen that referenced this issue Jan 24, 2019

Add thread-local guard: at::AutoNonVariableTypeMode (#15939)
Summary:
This PR adds thread-local guard (`at::AutoNonVariableTypeMode`) to make sure that in VariableType.cpp the operations on baseType still dispatch to non-Variable type, even if the parameters will become Variables after the Tensor/Variable merge. We achieve this by making `legacyTensorType()` and `getType()` check the `at::AutoNonVariableTypeMode` guard to decide whether to return non-Variable type for a variable.

This is part of the VariableImpl/TensorImpl merge work: pytorch/pytorch#13638.
Pull Request resolved: pytorch/pytorch#15939

Reviewed By: ezyang

Differential Revision: D13640980

Pulled By: yf225

fbshipit-source-id: d12c2543822958558d7d70d36c50999a5eb8783f

@gqchen gqchen assigned gqchen and unassigned gqchen Jan 30, 2019

@gqchen gqchen pinned this issue Feb 4, 2019

@ezyang ezyang unpinned this issue Feb 4, 2019

facebook-github-bot added a commit that referenced this issue Feb 7, 2019

Make JIT attributes t_ and ts_ store Variable instead of Tensor (#16596)
Summary:
Discussed with zdevito and we want to use Variable (with `set_requires_grad(false)`) instead of Tensor in all parts of JIT, to eliminate the distinction and the conceptual overhead when trying to figure out which one to use.

This also helps with the Variable/Tensor merge work tracked at #13638, which will make common functions (such as `numel()` / `sizes()` / `dim()`) on Variable much faster when finished.
Pull Request resolved: #16596

Differential Revision: D13979971

Pulled By: yf225

fbshipit-source-id: c69119deec5bce0c22809081115f1012fdbb7d5a

facebook-github-bot added a commit that referenced this issue Feb 11, 2019

Enforce same input tensor storage in VariableType functions (#16305)
Summary:
In VariableType.cpp, when a function modifies its input tensors, it should only change the input tensors' storage data in-place, and should never change the input tensors' storage pointers. This PR adds checks for this, and also fixes functions that fail this test.

This is part of the Variable/Tensor merge work (#13638).
Pull Request resolved: #16305

Differential Revision: D13897855

Pulled By: yf225

fbshipit-source-id: 0c4fc7eb530d30db88037b1f0981f6f8454d3b79

zdevito pushed a commit to zdevito/ATen that referenced this issue Feb 11, 2019

Enforce same input tensor storage in VariableType functions (#16305)
Summary:
In VariableType.cpp, when a function modifies its input tensors, it should only change the input tensors' storage data in-place, and should never change the input tensors' storage pointers. This PR adds checks for this, and also fixes functions that fail this test.

This is part of the Variable/Tensor merge work (pytorch/pytorch#13638).
Pull Request resolved: pytorch/pytorch#16305

Differential Revision: D13897855

Pulled By: yf225

fbshipit-source-id: 0c4fc7eb530d30db88037b1f0981f6f8454d3b79

pearu pushed a commit to Quansight/pytorch that referenced this issue Feb 12, 2019

Enforce same input tensor storage in VariableType functions (pytorch#鈥
鈥16305)

Summary:
In VariableType.cpp, when a function modifies its input tensors, it should only change the input tensors' storage data in-place, and should never change the input tensors' storage pointers. This PR adds checks for this, and also fixes functions that fail this test.

This is part of the Variable/Tensor merge work (pytorch#13638).
Pull Request resolved: pytorch#16305

Differential Revision: D13897855

Pulled By: yf225

fbshipit-source-id: 0c4fc7eb530d30db88037b1f0981f6f8454d3b79
@yf225

This comment has been minimized.

Copy link
Contributor Author

commented Feb 27, 2019

Semantics for new Variable design

  1. All tensors are Variables. tensor.is_variable() and options.is_variable() are redundant and should be removed. (We might want to keep options.is_variable(...) for backward compatibility, and throw error if it's set to false.)
  2. When we need to decide whether to dispatch to gradient-recording vs. non-gradient-recording operators, we do the following:
    1. For getType(TensorImpl), we check autograd_meta() && GradMode::is_enabled() to decide whether to choose Variable path. And we add NoGradGuard to all proper places to ensure we are dispatching to non-Variable path when appropriate.
    2. For getType(TensorOptions), we check GradMode::is_enabled() to decide whether to choose Variable path. And we add NoGradGuard to all proper places to ensure we are dispatching to non-Variable path when appropriate.
  3. We only create tensor with autograd_meta populated when options.requires_grad() == true. In all other cases, we create tensor with autograd_meta = null, to optimize for memory usage.
  4. tensor.requires_grad() behavior:
    1. When tensor.requires_grad() == false, this tensor may or may not have autograd_meta
    2. When tensor.requires_grad() == true, this tensor is guaranteed to have autograd_meta
    3. When set_requires_grad(true) is called on a tensor that doesn't have autograd_meta, its autograd_meta will be populated and requires_grad will be set to true
    4. We always need to cast an at::Tensor to Variable in order to call set_requires_grad(true), because at::Tensor has no concept of autograd.
  5. autograd_meta_ == null means that:
    1. grad_ is undefined
    2. name_ is unset
    3. grad_fn_ and grad_accumulator_ are null
    4. There are no hooks
    5. Does not require grad (requires_grad() returns false)
    6. Calling autograd functions (e.g. rebase_history()) on this tensor will result in error
  6. version_counter_ is moved to storage, because we may capture non-requires-grad variables inside an autograd function, and we need a working version counter in these cases.
  7. pyobj_ is moved to TensorImpl itself, because we always need to be able to convert to and from the Python representation.
@gchanan

This comment has been minimized.

Copy link
Contributor

commented Mar 1, 2019

2. For getType(TensorOptions), we check GradMode::is_enabled() to decide whether to choose Variable path. And we add NoGradGuard to all proper places to ensure we are dispatching to non-Variable path when appropriate.

Why do we need getType(TensorOptions)? From what I can tell, is_variable() today is always false in this case, so we can just always return the "non-variable" version. AFAIK this is only called in construction and there is no difference in constructing a Tensor vs a Variable. Is that correct?

@gchanan

This comment has been minimized.

Copy link
Contributor

commented Mar 1, 2019

4. We always need to cast an at::Tensor to Variable in order to call set_requires_grad(true), because at::Tensor has no concept of autograd

But then you can't get rid of the Tensor / Variable distinction in user code, right? I.e. if I create a tensor, I need to manually cast it to a Variable in order to get autodiff.

@gchanan

This comment has been minimized.

Copy link
Contributor

commented Mar 1, 2019

6. version_counter_ is moved to storage, because we may capture non-requires-grad variables inside an autograd function, and we need a working version counter in these cases.

This seems useful -- should it be done in a separate PR? It's not dependent on any of the rest of this workstream right?

@gchanan

This comment has been minimized.

Copy link
Contributor

commented Mar 1, 2019

7. pyobj_ is moved to TensorImpl itself, because we always need to be able to convert to and from the Python representation.

I know @ezyang has some plans to reduce the space required by TensorImpl, but from what I understand, this isn't considered a pre-req. Just flagging this to keep in mind.

@gchanan

This comment has been minimized.

Copy link
Contributor

commented Mar 1, 2019

I'd also like to do more thinking about the factory function case. E.g. if I pass in requires_grad=True in a !GradMode::enabled context, what happens? (I assume you can get the answer from python today). How will that work?

@apaszke

This comment has been minimized.

Copy link
Member

commented Mar 2, 2019

@gchanan for the last question I think it will create a tensor that requires grad, but no operators performed before the grad mode is enabled back again will be differentiated. Once this block ends it will look as any tensor that requires grad.

@gchanan

This comment has been minimized.

Copy link
Contributor

commented Mar 6, 2019

@apaszke that sounds correct to me, thanks!

@yf225

This comment has been minimized.

Copy link
Contributor Author

commented Mar 13, 2019

  1. For getType(TensorOptions), we check GradMode::is_enabled() to decide whether to choose Variable path. And we add NoGradGuard to all proper places to ensure we are dispatching to non-Variable path when appropriate.

Why do we need getType(TensorOptions)? From what I can tell, is_variable() today is always false in this case, so we can just always return the "non-variable" version. AFAIK this is only called in construction and there is no difference in constructing a Tensor vs a Variable. Is that correct?

is_variable() in getType(TensorOptions) is actually not always false. For example:

// At build/aten/src/ATen/Functions.h:3590
static inline Tensor empty(IntArrayRef size, const TensorOptions & options) {
    return at::getType(options).empty(size, options);
}

This function can be called with options coming from a Variable, and we expect it to dispatch to VariableType::empty(...). If we let getType(TensorOptions) always return the "non-variable" version, it will dispatch to the wrong path:

Before:

#1  0x00007fffe341a30d in at::native::empty_cpu (size=..., options=...) at ../aten/src/ATen/native/TensorFactories.cpp:93
#2  0x00007fffe3602d8a in at::CPULongType::empty (this=0x555556257130, size=..., options=...) at aten/src/ATen/CPULongType.cpp:2070
#3  0x00007fffe0f18904 in torch::autograd::VariableType::<lambda()>::operator()(void) const (__closure=0x7fffffffb900)
    at ../torch/csrc/autograd/generated/VariableType_4.cpp:9375
#4  0x00007fffe0f18ba2 in torch::autograd::VariableType::empty (this=0x55555624f0c0, size=..., options=...)
    at ../torch/csrc/autograd/generated/VariableType_4.cpp:9376

/\ 
|| at::getType(options) returns variable version, which causes dispatch to torch::autograd::VariableType::empty(...), and it's correct

#5  0x00007fffe36b8c2a in at::empty (size=..., options=...) at aten/src/ATen/Functions.h:3590
#6  0x00007fffe36b8fd1 in at::TypeDefault::copy (this=0x55555624f0c0, src=..., non_blocking=false, to_device=...) at aten/src/ATen/TypeDefault.cpp:37
#7  0x00007fffe3416b50 in at::native::to_impl (self=..., options=..., non_blocking=false) at ../aten/src/ATen/native/TensorConversions.cpp:24
#8  0x00007fffe34173ee in at::native::to (self=..., dtype=c10::ScalarType::Long, non_blocking=false, copy=false)
    at ../aten/src/ATen/native/TensorConversions.cpp:69
#9  0x00007fffe37025b5 in at::TypeDefault::to (this=0x55555624ee70, self=..., dtype=c10::ScalarType::Long, non_blocking=false, copy=false)
    at aten/src/ATen/TypeDefault.cpp:4358
#10 0x00007fffe0d56979 in torch::autograd::VariableType::to (this=0x55555624ee70, self=..., dtype=c10::ScalarType::Long, non_blocking=false, copy=false)
    at ../torch/csrc/autograd/generated/VariableType_2.cpp:15946
#11 0x00007fffe83ce9f9 in at::Tensor::to (this=0x7fffdd9d7490, dtype=c10::ScalarType::Long, non_blocking=false, copy=false)
    at ../aten/src/ATen/core/TensorMethods.h:788
#12 0x00007fffe8337e03 in torch::autograd::dispatch_to (self=..., dtype=c10::ScalarType::Long, non_blocking=false, copy=false)
    at ../torch/csrc/autograd/generated/python_variable_methods.cpp:275
#13 0x00007fffe8338e4f in torch::autograd::THPVariable_to_type (self=0x7fffdd9d7480, scalarType=c10::ScalarType::Long)
    at ../torch/csrc/autograd/generated/python_variable_methods.cpp:311
#14 0x00007fffe833941f in torch::autograd::THPVariable_long (self=0x7fffdd9d7480, args=0x0)
    at ../torch/csrc/autograd/generated/python_variable_methods.cpp:339

If we let getType(TensorOptions) always return the "non-variable" version:

#0  __cxxabiv1::__cxa_throw (obj=0x5555569489a0, tinfo=0x7fffe8ce6d10 <typeinfo for c10::Error>, dest=0x7fffe31c824a <c10::Error::~Error()>)
    at /opt/conda/conda-bld/compilers_linux-64_1534514838838/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/eh_throw.cc:80
#1  0x00007fffe341a2d2 in at::native::empty_cpu (size=..., options=...) at ../aten/src/ATen/native/TensorFactories.cpp:93
#2  0x00007fffe3602cc2 in at::CPULongType::empty (this=0x555556257130, size=..., options=...) at aten/src/ATen/CPULongType.cpp:2070

/\ 
|| at::getType(options) returns non-variable version, which causes dispatch to at::CPULongType::empty, and it's incorrect

#3  0x00007fffe36b8b62 in at::empty (size=..., options=...) at aten/src/ATen/Functions.h:3590
#4  0x00007fffe36b8efc in at::TypeDefault::copy (this=0x55555624f0c0, src=..., non_blocking=false, to_device=...) at aten/src/ATen/TypeDefault.cpp:36
#5  0x00007fffe3416a8e in at::native::to_impl (self=..., options=..., non_blocking=false) at ../aten/src/ATen/native/TensorConversions.cpp:24
#6  0x00007fffe341732c in at::native::to (self=..., dtype=c10::ScalarType::Long, non_blocking=false, copy=false)
    at ../aten/src/ATen/native/TensorConversions.cpp:69
#7  0x00007fffe37024e1 in at::TypeDefault::to (this=0x55555624ee70, self=..., dtype=c10::ScalarType::Long, non_blocking=false, copy=false)
    at aten/src/ATen/TypeDefault.cpp:4357
#8  0x00007fffe0d56979 in torch::autograd::VariableType::to (this=0x55555624ee70, self=..., dtype=c10::ScalarType::Long, non_blocking=false, copy=false)
    at ../torch/csrc/autograd/generated/VariableType_2.cpp:15946
#9  0x00007fffe83ce9f9 in at::Tensor::to (this=0x7fffdd7209e8, dtype=c10::ScalarType::Long, non_blocking=false, copy=false)
    at ../aten/src/ATen/core/TensorMethods.h:788
#10 0x00007fffe8337e03 in torch::autograd::dispatch_to (self=..., dtype=c10::ScalarType::Long, non_blocking=false, copy=false)
    at ../torch/csrc/autograd/generated/python_variable_methods.cpp:275
#11 0x00007fffe8338e4f in torch::autograd::THPVariable_to_type (self=0x7fffdd7209d8, scalarType=c10::ScalarType::Long)
    at ../torch/csrc/autograd/generated/python_variable_methods.cpp:311
#12 0x00007fffe833941f in torch::autograd::THPVariable_long (self=0x7fffdd7209d8, args=0x0)
    at ../torch/csrc/autograd/generated/python_variable_methods.cpp:339

I think dispatching based on "variable" vs. "non-variable" version of TensorOptions still matters after we merge the internals of Variable and Tensor and remove options.is_variable(), since the dispatch actually controls whether we should use the autograd history recording path or not. If we check at::NonVariableTypeMode::is_enabled() in getType(TensorOptions), we can control this dispatch by using at::AutoNonVariableTypeMode in the right places, which I think results in pretty clear semantics, and it's the same reason why we need to check !at::NonVariableTypeMode::is_enabled() in https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Context.cpp#L118.

@yf225

This comment has been minimized.

Copy link
Contributor Author

commented Mar 14, 2019

Update: Based on the offline discussion, we will want to check !at::NonVariableTypeMode::is_enabled() in getType(TensorOptions) for doing the correct dispatch based on variable/non-variable type, similar to our implementation of getType(TensorImpl).

@yf225 yf225 changed the title VariableImpl/TensorImpl Merge Proposal Variable/Tensor Merge Proposal Apr 10, 2019

facebook-github-bot added a commit that referenced this issue Apr 11, 2019

Move version_counter_ to TensorImpl (#18223)
Summary:
According to #13638 (comment), after the Variable/Tensor merge, we may capture variables without autograd metadata inside an autograd function, and we need a working version counter in these cases. This PR makes it possible by moving `version_counter_` out of autograd metadata and into TensorImpl, so that variables without autograd metadata still have version counters.
Pull Request resolved: #18223

Differential Revision: D14735123

Pulled By: yf225

fbshipit-source-id: 15f690311393ffd5a53522a226da82f5abb6c65b

zdevito pushed a commit to zdevito/ATen that referenced this issue Apr 11, 2019

Move version_counter_ to TensorImpl (#18223)
Summary:
According to pytorch/pytorch#13638 (comment), after the Variable/Tensor merge, we may capture variables without autograd metadata inside an autograd function, and we need a working version counter in these cases. This PR makes it possible by moving `version_counter_` out of autograd metadata and into TensorImpl, so that variables without autograd metadata still have version counters.
Pull Request resolved: pytorch/pytorch#18223

Differential Revision: D14735123

Pulled By: yf225

fbshipit-source-id: 15f690311393ffd5a53522a226da82f5abb6c65b

zhangguanheng66 pushed a commit to zhangguanheng66/pytorch that referenced this issue May 6, 2019

Move version_counter_ to TensorImpl (pytorch#18223)
Summary:
According to pytorch#13638 (comment), after the Variable/Tensor merge, we may capture variables without autograd metadata inside an autograd function, and we need a working version counter in these cases. This PR makes it possible by moving `version_counter_` out of autograd metadata and into TensorImpl, so that variables without autograd metadata still have version counters.
Pull Request resolved: pytorch#18223

Differential Revision: D14735123

Pulled By: yf225

fbshipit-source-id: 15f690311393ffd5a53522a226da82f5abb6c65b

pull bot pushed a commit to rahulunair/pytorch that referenced this issue May 24, 2019

Remove Variable::Impl and DifferentiableViewImpl (pytorch#17072)
Summary:
As part of the Variable/Tensor merge work: pytorch#13638, we make the following changes in this PR:
1. Remove the `Variable::Impl` class and the `DifferentiableViewImpl` class
2. Change all `Variable.data()` call sites to either use `Variable` directly, or use `Variable.tensor_data()`
3. Remove `Variable.data()` API
3. Add `Variable.variable_data()` that matches `tensor.data` in Python API, which creates a new `Variable` that shares the same storage and tensor metadata with the original `Variable`, but with a completely new autograd history.

After this PR, Variable doesn't wrap a Tensor internally anymore, and both Variable and Tensor use the same TensorImpl class as its `impl_`. The only difference is that Variable always has AutogradMeta in its TensorImpl, but Tensor doesn't.

**Note that this PR is BC-breaking in the following use cases:**

**Use Case 1:**
Previously, `x.data = y` works even if `x` and `y` are of different TensorImpl type (e.g. `x` is a CPU dense tensor whose impl is of type TensorImpl, while `y` is a CPU sparse tensor whose impl is of type SparseTensorImpl). However, after this PR, `x.data = y` doesn't work anymore if `x` and `y` are of different TensorImpl type, because the underlying implementation `variable.set_data(tensor)` no longer works if `variable` and `tensor` have different TensorImpl type.

**Use Case 2:**
If a tensor `x`'s `grad` is sparse, accumulating dense gradients to `x` will change the tensor that `x.grad` is pointing to. This is better illustrated with the following example:
```python
params = torch.tensor([1.5, 1.5]).requires_grad_()
with torch.no_grad():
    # Change gradient to a sparse tensor
    params.grad = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]))

grad_saved = params.grad
params.backward(torch.tensor([1.5, 1.5]))
assert id(grad_saved) == id(params.grad)  # This will fail after this PR
```
The assertion in the last line will fail after this PR, because adding dense gradients to sparse gradients will change the `params.grad` tensor reference.
Pull Request resolved: pytorch#17072

Differential Revision: D14075257

Pulled By: yf225

fbshipit-source-id: 0e681df641270dea586042dd26db59f2e76b5957

zdevito pushed a commit to zdevito/ATen that referenced this issue May 24, 2019

Remove Variable::Impl and DifferentiableViewImpl (#17072)
Summary:
As part of the Variable/Tensor merge work: pytorch/pytorch#13638, we make the following changes in this PR:
1. Remove the `Variable::Impl` class and the `DifferentiableViewImpl` class
2. Change all `Variable.data()` call sites to either use `Variable` directly, or use `Variable.tensor_data()`
3. Remove `Variable.data()` API
3. Add `Variable.variable_data()` that matches `tensor.data` in Python API, which creates a new `Variable` that shares the same storage and tensor metadata with the original `Variable`, but with a completely new autograd history.

After this PR, Variable doesn't wrap a Tensor internally anymore, and both Variable and Tensor use the same TensorImpl class as its `impl_`. The only difference is that Variable always has AutogradMeta in its TensorImpl, but Tensor doesn't.

**Note that this PR is BC-breaking in the following use cases:**

**Use Case 1:**
Previously, `x.data = y` works even if `x` and `y` are of different TensorImpl type (e.g. `x` is a CPU dense tensor whose impl is of type TensorImpl, while `y` is a CPU sparse tensor whose impl is of type SparseTensorImpl). However, after this PR, `x.data = y` doesn't work anymore if `x` and `y` are of different TensorImpl type, because the underlying implementation `variable.set_data(tensor)` no longer works if `variable` and `tensor` have different TensorImpl type.

**Use Case 2:**
If a tensor `x`'s `grad` is sparse, accumulating dense gradients to `x` will change the tensor that `x.grad` is pointing to. This is better illustrated with the following example:
```python
params = torch.tensor([1.5, 1.5]).requires_grad_()
with torch.no_grad():
    # Change gradient to a sparse tensor
    params.grad = torch.sparse_coo_tensor(torch.tensor([[1, 1]]).long(), torch.tensor([1., 1.]))

grad_saved = params.grad
params.backward(torch.tensor([1.5, 1.5]))
assert id(grad_saved) == id(params.grad)  # This will fail after this PR
```
The assertion in the last line will fail after this PR, because adding dense gradients to sparse gradients will change the `params.grad` tensor reference.
Pull Request resolved: pytorch/pytorch#17072

Differential Revision: D14075257

Pulled By: yf225

fbshipit-source-id: 0e681df641270dea586042dd26db59f2e76b5957
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can鈥檛 perform that action at this time.