-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow in-place operations on views #3384
Conversation
colesbury
commented
Oct 30, 2017
•
edited
Loading
edited
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good for the most part, but I think there are a few things that should be changed before this is merged (unless I misunderstood some parts of the code)
return info['res'](name) | ||
return name + suffix | ||
|
||
formula = re.sub(regex.format(name), repl, formula) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
var.requires_grad() = flags.requires_grad; | ||
var.is_volatile() = flags.is_volatile; | ||
if (grad_fn) { | ||
var.output_nr() = grad_fn->num_inputs++; | ||
var.grad_fn() = std::move(grad_fn); | ||
if (inplace && var.is_view()) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
auto CopyBackwards::apply(const variable_list& grads) -> variable_list { | ||
check_input_variables("CopyBackwards", grads, 1); | ||
auto& grad = grads[0]; | ||
return variable_list{zeros_like(grad), grad}; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
check_input_variables("CopySlices", inputs, 1); | ||
auto& grad = inputs[0]; | ||
|
||
auto result = grad.type().tensor(base.sizes, base.strides); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
||
auto offset = view.storage_offset - base.storage_offset; | ||
auto grad_slice = result.as_strided(view.sizes, view.strides, offset); | ||
auto res = (*fn)({ grad_slice.clone() }); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.h
Outdated
|
||
Variable base; | ||
int expected_version; | ||
}; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.cpp
Outdated
} | ||
is_view = true; | ||
version_counter = base.version_counter(); | ||
expected_version = version_counter.current_version(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.cpp
Outdated
} | ||
|
||
std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() { | ||
std::lock_guard<std::mutex> lock(grad_accumulator_lock); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.cpp
Outdated
fn->self_geometry = TensorGeometry(base); | ||
fn->size = sizes(); | ||
fn->stride = strides(); | ||
fn->storage_offset = Variable(this, true).storage_offset(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.cpp
Outdated
fn->set_flags(Function::flags({ base })); | ||
fn->num_inputs = 1; | ||
_grad_fn = std::move(fn); | ||
expected_version = version_counter.current_version(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
BTW do we still need |
I've changed the logic in |
[EDIT]: I've removed live refs |
65f5416
to
95929e6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general looks good, but I have a few concerns about correctness. It's really changing the critical parts of autograd and I'm starting to get nervous that it's this large 😕
@@ -382,18 +364,19 @@ static void _wrap_outputs(THPFunction *self, t2var_type &t2var, | |||
const t2var_type &shared_pairs, | |||
PyObject *raw_output, PyObject *outputs, bool is_volatile) | |||
{ | |||
auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self); | |||
bool is_executable = self->cdata.is_executable && !is_volatile; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.h
Outdated
|
||
Variable base; | ||
int expected_version; | ||
}; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
if (!output_var) throw python_error(); | ||
// We already have the data tensor wrapped as a PyObject* | ||
Py_INCREF(output); | ||
Py_XDECREF(output_var->data); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
return var; | ||
// An input has been returned, but it wasn't modified. Return it as a view | ||
// so that we can attach a new grad_fn to the Variable. | ||
return make_variable_view(std::move(prev), std::move(data)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -21,6 +21,15 @@ namespace torch { namespace autograd { | |||
using at::Tensor; | |||
struct VariableImpl; | |||
|
|||
// TODO: fix name conflict with jit VariableFlags |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.h
Outdated
@@ -197,6 +211,16 @@ inline const std::shared_ptr<Function>& Variable::grad_fn() const { | |||
inline std::shared_ptr<Function>& Variable::grad_fn() { | |||
return get()->get_grad_fn(); | |||
}; | |||
inline void Variable::set_history(VarFlags flags, int output_nr, std::shared_ptr<Function> grad_fn) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.h
Outdated
@@ -134,6 +146,8 @@ struct VariableViewImpl : public VariableImpl { | |||
// re-create the grad_fn to express the up-to-date view relationship between | |||
// this and the base Variable. | |||
virtual std::shared_ptr<Function>& get_grad_fn() override; | |||
// Sets the grad_fn. If |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
set_base_fn(var, std::move(grad_fn)); | ||
} else { | ||
var.output_nr() = grad_fn->num_inputs++; | ||
var.grad_fn() = std::move(grad_fn); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.cpp
Outdated
// If we have a previous grad_fn then this is an in-place modification | ||
if (output_nr != 0 || grad_fn->num_inputs != 1) { | ||
throw std::runtime_error("Functions which modify views in-place must return a single Variable"); | ||
} |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.cpp
Outdated
@@ -121,6 +122,20 @@ std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() { | |||
return _grad_fn; | |||
} | |||
|
|||
void VariableViewImpl::set_grad_fn(std::shared_ptr<Function> grad_fn) { | |||
if (this->_grad_fn && grad_fn) { | |||
// If we have a previous grad_fn then this is an in-place modification |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
89434d8
to
bee1fe6
Compare
Rebased to get TORCH_ASSERT and added two more commits |
Here's the diff for the generated VariableType.cpp: https://gist.github.com/colesbury/bc6835aaef3f755dd063cb484e788701 |
9405868
to
6d83141
Compare
grad_inputs[i] = std::move(res[i]); | ||
for (size_t i = 0; i < res.size(); i++) { | ||
if (should_compute_output(i)) { | ||
TORCH_ASSERT(res[i].defined()); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_autograd.py
Outdated
@@ -1542,6 +1542,14 @@ def func(root, b): | |||
go = Variable(torch.randn(a.size()), requires_grad=True) | |||
gradgradcheck(func, (a, b), (go,)) | |||
|
|||
def test_inplace_view5(self): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
if is_view: | ||
base_call = 'auto ret = as_view(static_cast<const Variable&>(self), {})'.format(base_call) | ||
else: | ||
base_call = 'auto ret = as_variable({})'.format(base_call) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.h
Outdated
@@ -56,7 +56,10 @@ struct Variable : public at::Tensor { | |||
inline const std::shared_ptr<Function>& grad_fn() const; | |||
inline std::shared_ptr<Function>& grad_fn(); | |||
|
|||
// Sets the flags and grad_fn ("history") of a new Variable | |||
inline void set_history(VarFlags flags, int output_nr, std::shared_ptr<Function> grad_fn); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.h
Outdated
@@ -218,14 +220,17 @@ inline std::shared_ptr<Function>& Variable::grad_fn() { | |||
return get()->get_grad_fn(); | |||
}; | |||
inline void Variable::set_history(VarFlags flags, int output_nr, std::shared_ptr<Function> grad_fn) { | |||
assert(!get()->_grad_fn && "set_history can only be called on new Variables"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.cpp
Outdated
@@ -121,6 +122,20 @@ std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() { | |||
return _grad_fn; | |||
} | |||
|
|||
void VariableViewImpl::set_grad_fn(std::shared_ptr<Function> grad_fn) { | |||
if (this->_grad_fn && grad_fn) { | |||
// If we have a previous grad_fn then this is an in-place modification |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, but we should discuss the situation when requires_grad
of base
changes
@@ -53,6 +53,7 @@ struct VariableType : public at::Type { | |||
std::tuple<Variable, Variable> as_variable(std::tuple<Tensor, Tensor> tensor) const; | |||
std::tuple<Variable, Variable, Variable> as_variable(std::tuple<Tensor, Tensor, Tensor> tensor) const; | |||
std::vector<Variable> as_variable(TensorList tensor) const; | |||
Variable maybe_wrap(Tensor data, const Variable & self, bool inplace) const; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.cpp
Outdated
TORCH_ASSERTM(requires_grad, "Can't set grad_fn on view with requires_grad=False"); | ||
TORCH_ASSERT(output_nr == 0); | ||
if (grad_fn->num_inputs != 1) { | ||
throw std::runtime_error("Functions which modify views in-place must return a single Variable"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.h
Outdated
@@ -42,9 +51,16 @@ struct Variable : public at::Tensor { | |||
inline const Variable & grad() const; | |||
inline Variable & grad(); | |||
|
|||
inline bool is_leaf() const; | |||
|
|||
inline const std::shared_ptr<Function>& grad_fn() const; | |||
inline std::shared_ptr<Function>& grad_fn(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
{ | ||
auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self); | ||
bool is_executable = self->cdata.is_executable && !is_volatile; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last commits look good to me, but please add more comments to the tests to clearly state what you're checking. I need one more day to look go over VariableViewImpl
again and ensure that there are no more weird edge cases.
torch/csrc/autograd/variable.cpp
Outdated
throw std::runtime_error( | ||
"requires_grad is False and base.requires_grad is True. Cannot use " | ||
"this view in a differentiable operation. Re-create the view from the " | ||
"base Variable."); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@@ -391,6 +431,7 @@ static struct PyGetSetDef THPVariable_properties[] = { | |||
{"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, NULL, NULL}, | |||
{"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, NULL, NULL}, // only for legacy reasons | |||
{"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, NULL, NULL}, | |||
{"_base", (getter)THPVariable_get_base, NULL, NULL, NULL}, // only for legacy reasons |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I think I'm ok with merging this once you address these comments. Please please please please add the test descriptions as well, I have no idea what are they checking.
} | ||
is_view = true; | ||
version_counter = base.version_counter(); | ||
attr_version = version_counter.current_version(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
} | ||
|
||
std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() { | ||
std::lock_guard<std::mutex> lock(mutex); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.cpp
Outdated
} | ||
|
||
TORCH_ASSERTM(requires_grad, "Can't set grad_fn on view with requires_grad=False"); | ||
TORCH_ASSERT(output_nr == 0); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
auto copySlices = std::make_shared<CopySlices>(base, TensorGeometry(data), std::move(grad_fn)); | ||
base.output_nr() = 0; | ||
base.get()->_grad_fn = std::move(copySlices); | ||
base.requires_grad() = true; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.h
Outdated
}; | ||
inline void Variable::set_history(VarFlags flags, int output_nr, std::shared_ptr<Function> grad_fn) { | ||
assert(!get()->_grad_fn && "set_history can only be called on new Variables"); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch::createTensor(output), | ||
get_shared_base(output), | ||
is_modified); | ||
if (is_modified) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Adds VariableViewImpl, a subclass of VariableImpl which has a pointer to the base Variable on which it is a view. In-place operations on views change the grad_fn of the base. Note that in-place operations on views from Python-implemented autograd functions are not currently supported, but in-place operations on non-views still work. Fixes pytorch#3313
- Python functions can modify views if they return only the single modified variable - Move shared code to variable.cpp - Remove live_refs
is used in a differntiable operation.
Record base in SavedVariable. We were not saving the "base" in SavedVariable. This was probably a bug, but no detectable because we don't do in-place operations on saved variables.
The detach() function still shares version counters with the original Variable, but doesn't return a view.
- Rename and add descriptions to autograd tests - Move most of set_history to make_variable constructors - Additional assertions
I've removed the calls to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
torch/csrc/autograd/variable.cpp
Outdated
: VariableImpl(std::move(data_)) | ||
VariableViewImpl::VariableViewImpl(Variable base_, at::Tensor data_, VarFlags flags, | ||
int output_nr, std::shared_ptr<Function> grad_fn) | ||
: VariableImpl(std::move(data_), flags, output_nr, std::move(grad_fn)) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/autograd/variable.h
Outdated
|
||
inline Variable make_variable(at::Tensor data, bool requires_grad, bool is_volatile=false) { | ||
return Variable(new VariableImpl(std::move(data), requires_grad, is_volatile), false); | ||
return make_variable(std::move(data), VarFlags(requires_grad, is_volatile)); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Actually you haven't added a check for the case when we're rebasing a view that requires grad (we should check that base still requires grad as well) |
- Views may have output_nr != 0, but then we don't allow in-place modifications to them