Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes AccumulateGrad high priority in backwards passes #7604

Merged
merged 4 commits into from
May 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct FunctionTask {
, inputs(std::move(inputs)) {}
};

// Returns true when t2 should be (weakly) BEFORE t1 in the queue.
struct CompareFunctionTaskTime {
bool operator()(FunctionTask const & t1, FunctionTask const & t2) {
return t1.fn->sequence_nr() < t2.fn->sequence_nr();
Expand Down
18 changes: 13 additions & 5 deletions torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,22 @@ void deleteFunction(Function* function);
struct Function : std::enable_shared_from_this<Function> {
public:
/// Construct a new `Function` with `num_inputs` inputs and the given
/// `next_edges`.
/// `next_edges`. sequence_nr is a (currently THE) hint to prioritization
/// in the backward() pass, with higher sequence numbers prioritized
/// before lower sequence numbers.
explicit Function(
uint32_t num_inputs = 0,
uint32_t num_inputs,
uint64_t sequence_nr,
edge_list&& next_edges = edge_list())
: sequence_nr_(next_sequence_nr_++),
num_inputs_(num_inputs),
next_edges_(std::move(next_edges)) {}
: sequence_nr_(sequence_nr),
num_inputs_(num_inputs),
next_edges_(std::move(next_edges)) {}

explicit Function(
uint32_t num_inputs = 0,
edge_list&& next_edges = edge_list())
: Function(num_inputs, next_sequence_nr_++, std::move(next_edges)) {}

/// Functions are neither copyable nor moveable.
Function(const Function& other) = delete;
Function(Function&& other) = delete;
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/autograd/functions/accumulate_grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
#include "torch/csrc/autograd/functions/utils.h"
#include "torch/csrc/utils/auto_gpu.h"

#include <cstdint>

using at::Tensor;

namespace torch { namespace autograd {

// AccumulateGrad sets sequence_nr to the max value so it's always called
// ASAP during backwards.
AccumulateGrad::AccumulateGrad(Variable variable_)
: Function(/*num_inputs=*/1), variable(std::move(variable_)) {}
: Function(/*num_inputs=*/1
, /*sequence_nr=*/UINT64_MAX)
, variable(std::move(variable_)) {}

auto AccumulateGrad::apply(const variable_list& grads) -> variable_list {
// XXX: this method is not thread-safe!
Expand Down