Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement assign function #244

Merged
merged 19 commits into from Apr 23, 2019
Merged
Changes from 1 commit
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.

Always

Just for now

Fix variable names at F.assign

  • Loading branch information...
takuseno committed Apr 18, 2019
commit ed24bf0d686f4456ea65a3a877b09c358e4ce4f7
@@ -53,16 +53,14 @@ void Assign<T>::backward_impl(const Variables &inputs,
if (!propagate_down[0])
return;

inputs[0]->grad()->zero();

auto gy_ = make_shared<Variable>(outputs[0]->grad());
auto gx_ = make_shared<Variable>(inputs[0]->grad());
auto f_add_ = create_Add2(this->ctx_, true);
f_add_->setup(Variables{gx_.get(), gy_.get()}, Variables{gx_.get()});
auto gy = make_shared<Variable>(outputs[0]->grad());
This conversation was marked as resolved by TE-TakuyaNarihira

This comment has been minimized.

Copy link
@TE-TakuyaNarihira

TE-TakuyaNarihira Apr 18, 2019

Contributor

You don't have to use a shared pointer here because it's not shared with any other object. You can simply write;

Variable gy(outputs[0]->grad());

and pass a raw pointer to a Variables by &gy.

That is also the case for gx.

auto gx = make_shared<Variable>(inputs[0]->grad());
auto f_add = create_Add2(this->ctx_, true);
f_add->setup(Variables{gx.get(), gy.get()}, Variables{gx.get()});

if (!accum[0])
gx_->data()->zero();
gx->data()->zero();
This conversation was marked as resolved by TE-TakuyaNarihira

This comment has been minimized.

Copy link
@TE-TakuyaNarihira

TE-TakuyaNarihira Apr 18, 2019

Contributor

This should not be required because the same thing is actually performed in the parent class function Function::backward.

If we consider more performance improvement, we should copy gy values to gx values if accum is false. The current implementation with accum is False, the zeroing kernel is called, then the add kernel is called. Calling small kernels multiple times leads to some overhead. You can use either Identity class in a similar way to Add2 or Array::copy_from function as in the forward function.


f_add_->forward(Variables{gx_.get(), gy_.get()}, Variables{gx_.get()});
f_add->forward(Variables{gx.get(), gy.get()}, Variables{gx.get()});
}
}
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.