Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow in-place operations on views #3384

Merged
merged 14 commits into from
Nov 6, 2017
Merged

Conversation

colesbury
Copy link
Member

@colesbury colesbury commented Oct 30, 2017

Allow in-place operations on views

Adds VariableViewImpl, a subclass of VariableImpl which has a pointer to
the base Variable on which it is a view. In-place operations on views
change the grad_fn of the base.

Fixes #3313

@pytorch pytorch deleted a comment from pytorchbot Oct 30, 2017
@pytorch pytorch deleted a comment from soumith Oct 30, 2017
@pytorch pytorch deleted a comment from soumith Oct 30, 2017
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good for the most part, but I think there are a few things that should be changed before this is merged (unless I misunderstood some parts of the code)

return info['res'](name)
return name + suffix

formula = re.sub(regex.format(name), repl, formula)

This comment was marked as off-topic.

var.requires_grad() = flags.requires_grad;
var.is_volatile() = flags.is_volatile;
if (grad_fn) {
var.output_nr() = grad_fn->num_inputs++;
var.grad_fn() = std::move(grad_fn);
if (inplace && var.is_view()) {

This comment was marked as off-topic.

This comment was marked as off-topic.

auto CopyBackwards::apply(const variable_list& grads) -> variable_list {
check_input_variables("CopyBackwards", grads, 1);
auto& grad = grads[0];
return variable_list{zeros_like(grad), grad};

This comment was marked as off-topic.

This comment was marked as off-topic.

check_input_variables("CopySlices", inputs, 1);
auto& grad = inputs[0];

auto result = grad.type().tensor(base.sizes, base.strides);

This comment was marked as off-topic.

This comment was marked as off-topic.


auto offset = view.storage_offset - base.storage_offset;
auto grad_slice = result.as_strided(view.sizes, view.strides, offset);
auto res = (*fn)({ grad_slice.clone() });

This comment was marked as off-topic.

This comment was marked as off-topic.


Variable base;
int expected_version;
};

This comment was marked as off-topic.

This comment was marked as off-topic.

}
is_view = true;
version_counter = base.version_counter();
expected_version = version_counter.current_version();

This comment was marked as off-topic.

}

std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() {
std::lock_guard<std::mutex> lock(grad_accumulator_lock);

This comment was marked as off-topic.

fn->self_geometry = TensorGeometry(base);
fn->size = sizes();
fn->stride = strides();
fn->storage_offset = Variable(this, true).storage_offset();

This comment was marked as off-topic.

This comment was marked as off-topic.

fn->set_flags(Function::flags({ base }));
fn->num_inputs = 1;
_grad_fn = std::move(fn);
expected_version = version_counter.current_version();

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Oct 31, 2017

BTW do we still need live_refs in VariableVersion?

@colesbury
Copy link
Member Author

I've changed the logic in _wrap_outputs some more. It's worth looking over carefully.

@colesbury
Copy link
Member Author

colesbury commented Oct 31, 2017

[EDIT]: I've removed live refs

@colesbury colesbury force-pushed the inplace branch 4 times, most recently from 65f5416 to 95929e6 Compare November 2, 2017 02:17
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general looks good, but I have a few concerns about correctness. It's really changing the critical parts of autograd and I'm starting to get nervous that it's this large 😕

@@ -382,18 +364,19 @@ static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
const t2var_type &shared_pairs,
PyObject *raw_output, PyObject *outputs, bool is_volatile)
{
auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self);
bool is_executable = self->cdata.is_executable && !is_volatile;

This comment was marked as off-topic.

This comment was marked as off-topic.


Variable base;
int expected_version;
};

This comment was marked as off-topic.

if (!output_var) throw python_error();
// We already have the data tensor wrapped as a PyObject*
Py_INCREF(output);
Py_XDECREF(output_var->data);

This comment was marked as off-topic.

This comment was marked as off-topic.

return var;
// An input has been returned, but it wasn't modified. Return it as a view
// so that we can attach a new grad_fn to the Variable.
return make_variable_view(std::move(prev), std::move(data));

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -21,6 +21,15 @@ namespace torch { namespace autograd {
using at::Tensor;
struct VariableImpl;

// TODO: fix name conflict with jit VariableFlags

This comment was marked as off-topic.

@@ -197,6 +211,16 @@ inline const std::shared_ptr<Function>& Variable::grad_fn() const {
inline std::shared_ptr<Function>& Variable::grad_fn() {
return get()->get_grad_fn();
};
inline void Variable::set_history(VarFlags flags, int output_nr, std::shared_ptr<Function> grad_fn) {

This comment was marked as off-topic.

@@ -134,6 +146,8 @@ struct VariableViewImpl : public VariableImpl {
// re-create the grad_fn to express the up-to-date view relationship between
// this and the base Variable.
virtual std::shared_ptr<Function>& get_grad_fn() override;
// Sets the grad_fn. If

This comment was marked as off-topic.

set_base_fn(var, std::move(grad_fn));
} else {
var.output_nr() = grad_fn->num_inputs++;
var.grad_fn() = std::move(grad_fn);

This comment was marked as off-topic.

// If we have a previous grad_fn then this is an in-place modification
if (output_nr != 0 || grad_fn->num_inputs != 1) {
throw std::runtime_error("Functions which modify views in-place must return a single Variable");
}

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -121,6 +122,20 @@ std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() {
return _grad_fn;
}

void VariableViewImpl::set_grad_fn(std::shared_ptr<Function> grad_fn) {
if (this->_grad_fn && grad_fn) {
// If we have a previous grad_fn then this is an in-place modification

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@colesbury colesbury force-pushed the inplace branch 2 times, most recently from 89434d8 to bee1fe6 Compare November 2, 2017 18:45
@colesbury
Copy link
Member Author

Rebased to get TORCH_ASSERT and added two more commits

@colesbury
Copy link
Member Author

Here's the diff for the generated VariableType.cpp:

https://gist.github.com/colesbury/bc6835aaef3f755dd063cb484e788701

@colesbury colesbury force-pushed the inplace branch 4 times, most recently from 9405868 to 6d83141 Compare November 3, 2017 02:03
grad_inputs[i] = std::move(res[i]);
for (size_t i = 0; i < res.size(); i++) {
if (should_compute_output(i)) {
TORCH_ASSERT(res[i].defined());

This comment was marked as off-topic.

@@ -1542,6 +1542,14 @@ def func(root, b):
go = Variable(torch.randn(a.size()), requires_grad=True)
gradgradcheck(func, (a, b), (go,))

def test_inplace_view5(self):

This comment was marked as off-topic.

if is_view:
base_call = 'auto ret = as_view(static_cast<const Variable&>(self), {})'.format(base_call)
else:
base_call = 'auto ret = as_variable({})'.format(base_call)

This comment was marked as off-topic.

@@ -56,7 +56,10 @@ struct Variable : public at::Tensor {
inline const std::shared_ptr<Function>& grad_fn() const;
inline std::shared_ptr<Function>& grad_fn();

// Sets the flags and grad_fn ("history") of a new Variable
inline void set_history(VarFlags flags, int output_nr, std::shared_ptr<Function> grad_fn);

This comment was marked as off-topic.

@@ -218,14 +220,17 @@ inline std::shared_ptr<Function>& Variable::grad_fn() {
return get()->get_grad_fn();
};
inline void Variable::set_history(VarFlags flags, int output_nr, std::shared_ptr<Function> grad_fn) {
assert(!get()->_grad_fn && "set_history can only be called on new Variables");

This comment was marked as off-topic.

@@ -121,6 +122,20 @@ std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() {
return _grad_fn;
}

void VariableViewImpl::set_grad_fn(std::shared_ptr<Function> grad_fn) {
if (this->_grad_fn && grad_fn) {
// If we have a previous grad_fn then this is an in-place modification

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, but we should discuss the situation when requires_grad of base changes

@@ -53,6 +53,7 @@ struct VariableType : public at::Type {
std::tuple<Variable, Variable> as_variable(std::tuple<Tensor, Tensor> tensor) const;
std::tuple<Variable, Variable, Variable> as_variable(std::tuple<Tensor, Tensor, Tensor> tensor) const;
std::vector<Variable> as_variable(TensorList tensor) const;
Variable maybe_wrap(Tensor data, const Variable & self, bool inplace) const;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

TORCH_ASSERTM(requires_grad, "Can't set grad_fn on view with requires_grad=False");
TORCH_ASSERT(output_nr == 0);
if (grad_fn->num_inputs != 1) {
throw std::runtime_error("Functions which modify views in-place must return a single Variable");

This comment was marked as off-topic.

This comment was marked as off-topic.

@@ -42,9 +51,16 @@ struct Variable : public at::Tensor {
inline const Variable & grad() const;
inline Variable & grad();

inline bool is_leaf() const;

inline const std::shared_ptr<Function>& grad_fn() const;
inline std::shared_ptr<Function>& grad_fn();

This comment was marked as off-topic.

{
auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self);
bool is_executable = self->cdata.is_executable && !is_volatile;

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last commits look good to me, but please add more comments to the tests to clearly state what you're checking. I need one more day to look go over VariableViewImpl again and ensure that there are no more weird edge cases.

throw std::runtime_error(
"requires_grad is False and base.requires_grad is True. Cannot use "
"this view in a differentiable operation. Re-create the view from the "
"base Variable.");

This comment was marked as off-topic.

@@ -391,6 +431,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"data", (getter)THPVariable_get_data, (setter)THPVariable_set_data, NULL, NULL},
{"_grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, NULL, NULL}, // only for legacy reasons
{"grad", (getter)THPVariable_get_grad, (setter)THPVariable_set_grad, NULL, NULL},
{"_base", (getter)THPVariable_get_base, NULL, NULL, NULL}, // only for legacy reasons

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I think I'm ok with merging this once you address these comments. Please please please please add the test descriptions as well, I have no idea what are they checking.

}
is_view = true;
version_counter = base.version_counter();
attr_version = version_counter.current_version();

This comment was marked as off-topic.

This comment was marked as off-topic.

}

std::shared_ptr<Function>& VariableViewImpl::get_grad_fn() {
std::lock_guard<std::mutex> lock(mutex);

This comment was marked as off-topic.

This comment was marked as off-topic.

}

TORCH_ASSERTM(requires_grad, "Can't set grad_fn on view with requires_grad=False");
TORCH_ASSERT(output_nr == 0);

This comment was marked as off-topic.

This comment was marked as off-topic.

auto copySlices = std::make_shared<CopySlices>(base, TensorGeometry(data), std::move(grad_fn));
base.output_nr() = 0;
base.get()->_grad_fn = std::move(copySlices);
base.requires_grad() = true;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

};
inline void Variable::set_history(VarFlags flags, int output_nr, std::shared_ptr<Function> grad_fn) {
assert(!get()->_grad_fn && "set_history can only be called on new Variables");

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

torch::createTensor(output),
get_shared_base(output),
is_modified);
if (is_modified) {

This comment was marked as off-topic.

This comment was marked as off-topic.

Adds VariableViewImpl, a subclass of VariableImpl which has a pointer to
the base Variable on which it is a view. In-place operations on views
change the grad_fn of the base.

Note that in-place operations on views from Python-implemented autograd functions
are not currently supported, but in-place operations on non-views still
work.

Fixes pytorch#3313
 - Python functions can modify views if they return only the single
   modified variable
 - Move shared code to variable.cpp
 - Remove live_refs
Record base in SavedVariable. We were not saving the "base" in
SavedVariable. This was probably a bug, but no detectable because we
don't do in-place operations on saved variables.
The detach() function still shares version counters with the original
Variable, but doesn't return a view.
 - Rename and add descriptions to autograd tests
 - Move most of set_history to make_variable constructors
 - Additional assertions
@colesbury
Copy link
Member Author

I've removed the calls to set_history in-favor of passing those arguments to Variable factory functions. The generated VariableType.cpp still effectively calls "set_history", but it's inlined into the calling code. In a subsequent PR, I'll clean up VariableType.cpp. I don't want to do it here because that change is probably substantial and this isn't a correctness issue.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

: VariableImpl(std::move(data_))
VariableViewImpl::VariableViewImpl(Variable base_, at::Tensor data_, VarFlags flags,
int output_nr, std::shared_ptr<Function> grad_fn)
: VariableImpl(std::move(data_), flags, output_nr, std::move(grad_fn))

This comment was marked as off-topic.


inline Variable make_variable(at::Tensor data, bool requires_grad, bool is_volatile=false) {
return Variable(new VariableImpl(std::move(data), requires_grad, is_volatile), false);
return make_variable(std::move(data), VarFlags(requires_grad, is_volatile));

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Nov 6, 2017

Actually you haven't added a check for the case when we're rebasing a view that requires grad (we should check that base still requires grad as well)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants