Skip to content

Commit

Permalink
Makes AccumulateGrad high priority in backwards passes (pytorch#7604)
Browse files Browse the repository at this point in the history
* Makes accumulate_grad functions high priority in backwards passes

* Delegating constructor and comments

* Sequence_nr ain't pretty no more

* Sequence_nr ain't pretty no more
  • Loading branch information
mruberry authored and weiyangfb committed Jun 11, 2018
1 parent 0cab270 commit df3f5b0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
1 change: 1 addition & 0 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct FunctionTask {
, inputs(std::move(inputs)) {}
};

// Returns true when t2 should be (weakly) BEFORE t1 in the queue.
struct CompareFunctionTaskTime {
bool operator()(FunctionTask const & t1, FunctionTask const & t2) {
return t1.fn->sequence_nr() < t2.fn->sequence_nr();
Expand Down
18 changes: 13 additions & 5 deletions torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,22 @@ void deleteFunction(Function* function);
struct Function : std::enable_shared_from_this<Function> {
public:
/// Construct a new `Function` with `num_inputs` inputs and the given
/// `next_edges`.
/// `next_edges`. sequence_nr is a (currently THE) hint to prioritization
/// in the backward() pass, with higher sequence numbers prioritized
/// before lower sequence numbers.
explicit Function(
uint32_t num_inputs = 0,
uint32_t num_inputs,
uint64_t sequence_nr,
edge_list&& next_edges = edge_list())
: sequence_nr_(next_sequence_nr_++),
num_inputs_(num_inputs),
next_edges_(std::move(next_edges)) {}
: sequence_nr_(sequence_nr),
num_inputs_(num_inputs),
next_edges_(std::move(next_edges)) {}

explicit Function(
uint32_t num_inputs = 0,
edge_list&& next_edges = edge_list())
: Function(num_inputs, next_sequence_nr_++, std::move(next_edges)) {}

/// Functions are neither copyable nor moveable.
Function(const Function& other) = delete;
Function(Function&& other) = delete;
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/autograd/functions/accumulate_grad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
#include "torch/csrc/autograd/functions/utils.h"
#include "torch/csrc/utils/auto_gpu.h"

#include <cstdint>

using at::Tensor;

namespace torch { namespace autograd {

// AccumulateGrad sets sequence_nr to the max value so it's always called
// ASAP during backwards.
AccumulateGrad::AccumulateGrad(Variable variable_)
: Function(/*num_inputs=*/1), variable(std::move(variable_)) {}
: Function(/*num_inputs=*/1
, /*sequence_nr=*/UINT64_MAX)
, variable(std::move(variable_)) {}

auto AccumulateGrad::apply(const variable_list& grads) -> variable_list {
// XXX: this method is not thread-safe!
Expand Down

0 comments on commit df3f5b0

Please sign in to comment.