Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 49 additions & 61 deletions test/test_c10d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2386,58 +2386,6 @@ def forward(self, x):
loss2 = criterion(output2, target)
loss2.backward()

@skip_if_not_nccl
@skip_if_not_multigpu
def test_no_used_parameters(self):
"""
Note: this test can be sped up by only running it on a CPU module
once DistributedDataParallel supports them.
"""
store = c10d.FileStore(self.file.name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)

class NoUsedParameters(nn.Module):
def __init__(self):
super(NoUsedParameters, self).__init__()

# Make sure this module has some parameters, only to then decide
# to never use them from the `forward` function.
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 4, bias=False)
self.fc3 = nn.Linear(4, 4, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
return x * 0.0

device_id = gpus_for_rank(self.world_size)[self.rank][0]
model = DistributedDataParallel(
NoUsedParameters().float().to(device_id),
device_ids=[device_id],
process_group=process_group,
find_unused_parameters=True,
)

batch_size = 4
input = torch.rand([batch_size, 2], dtype=torch.float)

# After initialization, no parameter has their gradient set.
for p in model.parameters():
self.assertTrue(p.requires_grad)
self.assertIsNone(p.grad)

# Run `forward` function.
model(input)

# Because none of the parameters were used, we expect reduction for
# all parameters will be executed right when initializing the reducer.
# Once `forward` returns, all the parameter's gradients must be set.
for p in model.parameters():
self.assertTrue(p.requires_grad)
self.assertIsNotNone(p.grad)
self.assertTrue(torch.is_tensor(p.grad))
self.assertEqual(p.size(), p.grad.size())

@skip_if_not_nccl
@skip_if_not_multigpu
def test_no_grad(self):
Expand Down Expand Up @@ -2592,15 +2540,13 @@ def step_model(model, input, target):
torch.manual_seed(1337 + iteration)
input = input[torch.randperm(global_batch_size)]

@skip_if_not_nccl
@skip_if_not_multigpu
def test_ignored_output(self):
"""
Note: this test can be sped up by only running it on a CPU module
once DistributedDataParallel supports them.
Test that the output of a model can be ignored and that there is no
implicit requirement that `backward` gets called.
"""
store = c10d.FileStore(self.file.name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)

class IgnoredOutput(nn.Module):
def __init__(self):
Expand All @@ -2614,17 +2560,59 @@ def forward(self, x):
x = self.relu(self.fc2(x))
return F.softmax(x, dim=1)

device_id = gpus_for_rank(self.world_size)[self.rank][0]
model = DistributedDataParallel(
IgnoredOutput().float().to(device_id),
device_ids=[device_id],
IgnoredOutput().float(),
process_group=process_group,
)

batch_size = 4
criterion = nn.CrossEntropyLoss()
input = torch.rand([batch_size, 2], dtype=torch.float)
target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)]).to(device_id)
target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])

# Run a few iterations where we ignore the output.
for _ in range(4):
output = model(input)
del output

# Run a few iterations where we use the output.
for _ in range(4):
output = model(input)
loss = criterion(output, target)
loss.backward()

def test_ignored_output_with_unused_parameters(self):
"""
Test that the output of a model can be ignored and that there is no
implicit requirement that `backward` gets called, if not all model
parameters participated in computing the model output.
"""
store = c10d.FileStore(self.file.name, self.world_size)
process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)

class IgnoredOutputWithUnusedParameters(nn.Module):
def __init__(self):
super(IgnoredOutputWithUnusedParameters, self).__init__()
self.fc1 = nn.Linear(2, 10, bias=False)
self.fc2 = nn.Linear(10, 4, bias=False)
self.fc3 = nn.Linear(4, 4, bias=False)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return F.softmax(x, dim=1)

model = DistributedDataParallel(
IgnoredOutputWithUnusedParameters().float(),
process_group=process_group,
find_unused_parameters=True,
)

batch_size = 4
criterion = nn.CrossEntropyLoss()
input = torch.rand([batch_size, 2], dtype=torch.float)
target = torch.LongTensor([random.randrange(4) for _ in range(batch_size)])

# Run a few iterations where we ignore the output.
for _ in range(4):
Expand Down
105 changes: 55 additions & 50 deletions torch/csrc/distributed/c10d/reducer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ Reducer::Reducer(
expect_sparse_gradients_(std::move(expect_sparse_gradients)),
expect_autograd_hooks_(false),
require_finalize_(false),
has_marked_unused_parameters_(false),
next_bucket_(0),
has_marked_unused_parameters_(false),
backward_stats_base_(0) {
AT_ASSERTM(replicas_.size() >= 1, "Expected at least one model replica.");
AT_ASSERTM(replicas_[0].size() >= 1, "Expected at least one parameter.");
Expand Down Expand Up @@ -118,6 +118,10 @@ Reducer::Reducer(
for (size_t variable_index = 0; variable_index < variable_count;
variable_index++) {
auto& variable = replicas_[replica_index][variable_index];
const auto index = VariableIndex{
.replica_index = replica_index,
.variable_index = variable_index,
};

// The gradient accumulator function is lazily initialized once.
// Therefore we can use its presence in the autograd graph as
Expand All @@ -126,21 +130,14 @@ Reducer::Reducer(

// Hook to execute after the gradient accumulator has executed.
hooks_.emplace_back(
grad_accumulator->add_post_hook(
torch::make_unique<LambdaPostHook>([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
this->mark_variable_ready(
replica_index,
variable_index,
/* called_from_autograd= */ true);
})),
grad_accumulator->add_post_hook(torch::make_unique<LambdaPostHook>(
[=] { this->autograd_hook(index); })),
grad_accumulator);

// Map raw function pointer to replica index and parameter index.
// This is used later on when the autograd graph is traversed
// to check for parameters for which no gradient is computed.
func_[grad_accumulator.get()] =
std::make_tuple(replica_index, variable_index);
func_[grad_accumulator.get()] = index;

// The gradient accumulator is stored as weak_ptr in the autograd
// metadata of the variable, so we have to keep it alive here for
Expand Down Expand Up @@ -177,9 +174,9 @@ Reducer::~Reducer() noexcept(false) {
}
}

void Reducer::mark_variable_ready_dense(
size_t replica_index,
size_t variable_index) {
void Reducer::mark_variable_ready_dense(VariableIndex index) {
const auto replica_index = index.replica_index;
const auto variable_index = index.variable_index;
const auto& bucket_index = variable_locators_[variable_index];
auto& bucket = buckets_[bucket_index.bucket_index];
auto& replica = bucket.replicas[replica_index];
Expand Down Expand Up @@ -214,9 +211,9 @@ void Reducer::mark_variable_ready_dense(
}
}

void Reducer::mark_variable_ready_sparse(
size_t replica_index,
size_t variable_index) {
void Reducer::mark_variable_ready_sparse(VariableIndex index) {
const auto replica_index = index.replica_index;
const auto variable_index = index.variable_index;
const auto& bucket_index = variable_locators_[variable_index];
auto& bucket = buckets_[bucket_index.bucket_index];
auto& replica = bucket.replicas[replica_index];
Expand All @@ -235,22 +232,37 @@ void Reducer::mark_variable_ready_sparse(
replica.contents = grad;
}

// Called when the gradient for the specified variable is ready.
// It can be called from two places:
// - By an autograd thread after executing a gradient accumulator function.
// - By the `Reducer::prepare_for_backward` function if the variable doesn't
// show up in the autograd graph (and it wouldn't be called by autograd).
void Reducer::mark_variable_ready(
size_t replica_index,
size_t variable_index,
bool called_from_autograd) {
// The function `autograd_hook` is called after the gradient for a
// model parameter has been accumulated into its gradient tensor.
// This function is only to be called from the autograd thread.
void Reducer::autograd_hook(VariableIndex index) {
std::lock_guard<std::mutex> lock(this->mutex_);

// Ignore if we don't expect to be called.
// This may be the case if the user wants to accumulate gradients
// for number of iterations before reducing them.
if (!expect_autograd_hooks_) {
return;
}

// If there are model parameters that went unused when computing the model
// output, they won't be part of the autograd graph, and won't receive
// gradients. These parameters are discovered in the `prepare_for_backward`
// function and their indexes stored in the `unused_parameters_` vector.
if (!has_marked_unused_parameters_ && !unused_parameters_.empty()) {
has_marked_unused_parameters_ = true;
for (const auto& unused_index : unused_parameters_) {
mark_variable_ready(unused_index);
}
}

// Finally mark variable for which this function was originally called.
mark_variable_ready(index);
}

void Reducer::mark_variable_ready(VariableIndex index) {
const auto replica_index = index.replica_index;
const auto variable_index = index.variable_index;
AT_ASSERTM(replica_index < replicas_.size(), "Out of range replica index.");
AT_ASSERTM(
variable_index < variable_locators_.size(),
Expand Down Expand Up @@ -293,9 +305,9 @@ void Reducer::mark_variable_ready(
}

if (bucket.expect_sparse_gradient) {
mark_variable_ready_sparse(replica_index, variable_index);
mark_variable_ready_sparse(index);
} else {
mark_variable_ready_dense(replica_index, variable_index);
mark_variable_ready_dense(index);
}

// TODO(@pietern): Make this work for both CPU/CUDA tensors.
Expand All @@ -316,14 +328,10 @@ void Reducer::mark_variable_ready(

// Run finalizer function once the final bucket was marked ready.
if (next_bucket_ == buckets_.size()) {
if (called_from_autograd) {
torch::autograd::Engine::get_default_engine().queue_callback([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
this->finalize_backward();
});
} else {
finalize_backward();
}
torch::autograd::Engine::get_default_engine().queue_callback([=] {
std::lock_guard<std::mutex> lock(this->mutex_);
this->finalize_backward();
});
}
}

Expand Down Expand Up @@ -476,21 +484,19 @@ void Reducer::initialize_buckets(

// Traverse the autograd graph starting at the specified output.
// All parameters for which we have a pointer to their gradient accumulation
// functions and don't show up in this graph can be marked as ready
// for reduction immediately. Not doing this means we would deadlock waiting
// on a gradient for those parameters that will never be computed.
//
// Rough copy of torch::autograd::Engine::compute_dependencies.
//
// functions, but don't show up in the autograd graph will be marked ready for
// for reduction as soon as the first autograd hook is called. This is not
// done immediately because the model output may be ignored, and we only
// want to start performing reductions on `torch.autograd.backward()`.
void Reducer::prepare_for_backward(
const std::vector<torch::autograd::Variable>& outputs) {
std::lock_guard<std::mutex> lock(mutex_);
std::unordered_set<torch::autograd::Function*> seen;
std::vector<torch::autograd::Function*> queue;

// Check that any prior reduction has finished.
// The variable `expect_autograd_hooks` is true until gradients for all
// parameters have been received and all buckets are ready.
// The variable `require_finalize_` is true until all gradients
// have been computed and reduction of all buckets has been kicked off.
if (require_finalize_) {
AT_ERROR(
"Expected to have finished reduction in the prior iteration before ",
Expand All @@ -513,7 +519,6 @@ void Reducer::prepare_for_backward(
}

// Reset accounting.
has_marked_unused_parameters_ = true;
expect_autograd_hooks_ = true;
next_bucket_ = 0;
backward_stats_base_ = current_time_in_nanos();
Expand All @@ -524,11 +529,14 @@ void Reducer::prepare_for_backward(
bucket.pending = bucket.replicas.size();
}

// Reset unused parameter accounting.
has_marked_unused_parameters_ = false;
unused_parameters_.clear();

// If no outputs are specified, we assume that autograd hooks for ALL
// variables will be called, and we don't have to search the autograd graph
// for presence of these hooks.
if (outputs.empty()) {
has_marked_unused_parameters_ = false;
return;
}

Expand Down Expand Up @@ -562,10 +570,7 @@ void Reducer::prepare_for_backward(
continue;
}

size_t replica_index;
size_t variable_index;
std::tie(replica_index, variable_index) = it.second;
mark_variable_ready(replica_index, variable_index);
unused_parameters_.push_back(it.second);
}
}

Expand Down
Loading