New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograd refactor #1016

Merged
merged 8 commits into from May 1, 2017

Fix memory issues with Conv and BatchNorm

  • Loading branch information...
apaszke committed Apr 9, 2017
commit 9c9ea05de87c101b5ac7f01aa54d5e27f4509c53
@@ -34,12 +34,15 @@ struct BatchNormBackward : public Function, public BatchNormParams {
SavedVariable weight,
SavedVariable bias)
: Function(std::move(flags))
, BatchNormParams(std::move(params))
, save_mean(std::move(save_mean))
, save_std(std::move(save_std))
, input(std::move(input))
, weight(std::move(weight))
, bias(std::move(bias)) {}
, BatchNormParams(std::move(params)) {
if (is_executable) {
this->save_mean = std::move(save_mean);

This comment has been minimized.

@colesbury

colesbury Apr 19, 2017

Member

Why are these guarded by an if?

This comment has been minimized.

@apaszke

apaszke Apr 19, 2017

Member

Because you don't want to save anything if the function is not going to be executed. Otherwise it's going to be wasting a lot of memory. It was the cause of that recent memory regression.

this->save_std = std::move(save_std);
this->input = std::move(input);
this->weight = std::move(weight);
this->bias = std::move(bias);
}
}
virtual variable_list apply(const variable_list& gradOutputs) override;
@@ -54,12 +54,15 @@ struct ConvBackward : public Function, public ConvParams {
std::unique_ptr<torch::cudnn::Convolution> convolution)
: Function(std::move(flags))
, ConvParams(std::move(params))
, input_(std::move(input))
, weight_(std::move(weight))
, bias_(std::move(bias))
, columns(std::move(columns))
, ones(std::move(ones))
, convolution(std::move(convolution)) {}
, convolution(std::move(convolution)) {
if (is_executable) {
this->input_ = std::move(input);
this->weight_ = std::move(weight);
this->bias_ = std::move(bias);
this->columns = std::move(columns);
this->ones = std::move(ones);
}
}
virtual variable_list apply(const variable_list& gradOutputs) override;
ProTip! Use n and p to navigate between commits in a pull request.