Skip to content

Commit

Permalink
Update on "Have FutureNCCL record streams w/ allocator in addCallback"
Browse files Browse the repository at this point in the history
This commit is part of a stack that reworks FutureNCCL in order to extract a generic CUDA-aware Future subclass. The stack deliberately breaks up this transition into elementary changes, to make it easier to verify that the behavior is preserved (or to highlight how it gets changed).

---

There are two ways to add a callback to a Future: `then` and `addCallback` (with the former deferring to the latter). FutureNCCL only "patched" `then`, which caused `addCallback` to be unsupported. By patching `addCallback`, on the other hand, we cover both.

The high-level goal of this change though is to remove all CUDA-specific stuff from `then`, and move it to either `markCompleted` or to a wrapper around the callback. This will take a few more steps to achieve.

Differential Revision: [D25177558](https://our.internmc.facebook.com/intern/diff/D25177558/)

[ghstack-poisoned]
  • Loading branch information
lw committed Dec 9, 2020
2 parents 5e72780 + 203376d commit b4bd461
Show file tree
Hide file tree
Showing 88 changed files with 1,923 additions and 860 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
- name: Run flake8
run: |
set -eux
pip install flake8==3.8.2 flake8-bugbear==20.1.4 flake8-comprehensions==3.3.0 flake8-executable==2.0.4 flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
pip install -r requirements-flake8.txt
flake8 --version
flake8 | tee ${GITHUB_WORKSPACE}/flake8-output.txt
- name: Add annotations
Expand Down
31 changes: 0 additions & 31 deletions .travis.aten.yml

This file was deleted.

2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ which is in PyTorch's `requirements.txt`.
## Pre-commit tidy/linting hook

We use clang-tidy and flake8 (installed with flake8-bugbear,
flake8-comprehensions, flake8-mypy, and flake8-pyi) to perform additional
flake8-comprehensions, flake8-pyi, and others) to perform additional
formatting and semantic checking of code. We provide a pre-commit git hook for
performing these checks, before a commit is created:

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_ex
On Linux
```bash
# Add LAPACK support for the GPU if needed
conda install -c pytorch magma-cuda102 # or [ magma-cuda101 | magma-cuda100 | magma-cuda92 ] depending on your cuda version
conda install -c pytorch magma-cuda110 # or the magma-cuda* that matches your CUDA version from https://anaconda.org/pytorch/repo
```

On MacOS
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,8 +941,8 @@ Tensor new_empty_strided_batching_rule(
size.size(), ") must match dimensionality of strides (",
stride.size(), ")");
auto storage_size = native::storage_size_for(size, stride);
for (int64_t idx = 0; idx < physical_strides.size(); ++idx) {
physical_strides[idx] *= storage_size;
for (auto& physical_stride : physical_strides) {
physical_stride *= storage_size;
}

// physical_strides = [B1 * B2 * S, B2 * S, S] + strides
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/Config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#define AT_MKLDNN_ENABLED() @AT_MKLDNN_ENABLED@
#define AT_MKL_ENABLED() @AT_MKL_ENABLED@
#define AT_FFTW_ENABLED() @AT_FFTW_ENABLED@
#define AT_NNPACK_ENABLED() @AT_NNPACK_ENABLED@
#define CAFFE2_STATIC_LINK_CUDA() @CAFFE2_STATIC_LINK_CUDA_INT@
#define AT_BUILD_WITH_BLAS() @USE_BLAS@
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/NamedTensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,11 @@ static std::vector<Dimname> compute_dot_product_outnames(
}
std::vector<Dimname> outnames(num_outnames, Dimname::wildcard());
int64_t index = 0;
for (int64_t j = 0; j < tensor_names.size(); ++j) {
for (size_t j = 0; j < tensor_names.size(); ++j) {
if (j == tensor_dotted_dim) continue;
outnames[index++] = tensor_names[j];
}
for (int64_t j = 0; j < other_names.size(); ++j) {
for (size_t j = 0; j < other_names.size(); ++j) {
if (j == other_dotted_dim) continue;
outnames[index++] = other_names[j];
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/SparseTensorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::Typ
AT_ASSERT(values_.sizes() == IntArrayRef({0}));
AT_ASSERT(values_.device() == indices_.device());
AT_ASSERT(values_.device() == device());

is_non_overlapping_and_dense_ = false;
}

IntArrayRef SparseTensorImpl::strides() const {
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,8 +939,8 @@ TensorIterator TensorIterator::reduce_op(Tensor& out1, Tensor& out2, const Tenso
}

void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) {
for (int i = 0; i < config.tensors_.size(); i++) {
operands_.emplace_back(std::move(config.tensors_[i]));
for (auto& tensor: config.tensors_) {
operands_.emplace_back(std::move(tensor));
}
num_outputs_ = config.num_outputs_;
}
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/TensorNames.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ TensorNames::TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end) {
}

TensorNames& TensorNames::unifyFromRightInplace(const TensorNames& other, const char* op_name) {
int64_t size_diff = std::labs(names_.size() - other.names_.size());
size_t size_diff = std::labs(names_.size() - other.names_.size());

if (names_.size() > other.names_.size()) {
for (int64_t idx = size_diff; idx < names_.size(); ++idx) {
for (size_t idx = size_diff; idx < names_.size(); ++idx) {
names_[idx] = names_[idx].unify(other.names_[idx - size_diff], op_name);
}
} else {
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/ThreadLocalState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
grad_mode_enabled_ = GradMode::is_enabled();
}
#endif
bumped_record_all_functions_ = at::checkRecordAllFunctions();
}

/* static */
Expand Down
24 changes: 23 additions & 1 deletion aten/src/ATen/ThreadLocalState.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,47 @@ class TORCH_API ThreadLocalState {
bool grad_mode_enabled_;
#endif

// Whether pre-sampling RecordFunction optimization was enabled
bool bumped_record_all_functions_ = false;

friend class ThreadLocalStateGuard;
};

// Guard to set and reset the thread local state
class TORCH_API ThreadLocalStateGuard {
public:
explicit ThreadLocalStateGuard(const ThreadLocalState& state)
: prev_state_(ThreadLocalState()) {
: prev_state_(ThreadLocalState()),
bumped_record_all_functions_(state.bumped_record_all_functions_) {
// Special handling of RecordFunction pre-sampling optimization:
// pre-samping is enabled (bumped) when there're non-sampled
// (or high-frequency) global or TLS callbacks.
//
// ThreadLocalStateGuard simply resets RecordFunction's TLS and
// hence its thread local callbacks.
//
// Checking if the pre-sampling was enabled and preserving it in the
// async task by calling bumpRecordAllFunctions() and the corresponding
// releaseRecordAllFunctions()
if (bumped_record_all_functions_) {
at::bumpRecordAllFunctions();
}
// set the given state across the thread boundary
ThreadLocalState::setThreadLocalState(state);
}

~ThreadLocalStateGuard() {
// restore previously set variables
ThreadLocalState::setThreadLocalState(prev_state_);
if (bumped_record_all_functions_) {
at::releaseRecordAllFunctions();
}
}

private:
const ThreadLocalState prev_state_;
// Whether pre-sampling RecordFunction optimization was enabled
bool bumped_record_all_functions_ = false;
};

template <typename T>
Expand Down
81 changes: 49 additions & 32 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,28 +371,39 @@ inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle<Return(A
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKey);

#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// Check if we need to run callbacks registered with RecordFunction
// If true and callbacks need inputs, we box the arguments and pass
// them into the callbacks and also into the kernel call

// Note: for perf reasons we wouldn't want to pass arguments into
// the function call or prematurely box them
at::RecordFunction guard(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
int64_t seq_num = -1;
// Setting sequence number in the Autograd case to associate
// the forward range with the coresponding Autograd's node
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
torch::jit::Stack stack = impl::boxArgs(args...);
guard.before(op, stack, seq_num);
} else {
guard.before(op, seq_num);
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
// shouldRunRecordFunction checks whether RecordFunction should be executed,
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
// this boolean is passed into RecordFunction to adjust the sampling rates of
// the callbacks
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
// Check if we need to run callbacks registered with RecordFunction
// If true and callbacks need inputs, we box the arguments and pass
// them into the callbacks and also into the kernel call

// Note: for perf reasons we wouldn't want to pass arguments into
// the function call or prematurely box them
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
int64_t seq_num = -1;
// Setting sequence number in the Autograd case to associate
// the forward range with the coresponding Autograd's node
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
torch::jit::Stack stack = impl::boxArgs(args...);
guard.before(op, stack, seq_num);
} else {
guard.before(op, seq_num);
}
}
}
// keeping the guard alive while executing the kernel
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
Expand Down Expand Up @@ -429,20 +440,26 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
const auto& kernel = entry.lookup(dispatchKey);

#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// using already existing stack to record function execution in observers
at::RecordFunction guard(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && entry.isObserved()) {
int64_t seq_num = -1;
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
guard.before(op, *stack, seq_num);
} else {
guard.before(op, seq_num);
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
// using already existing stack to record function execution in observers
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && entry.isObserved()) {
int64_t seq_num = -1;
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
guard.before(op, *stack, seq_num);
} else {
guard.before(op, seq_num);
}
}
}
// keeping the guard alive while executing the kernel
kernel.callBoxed(op, stack);
return;
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
kernel.callBoxed(op, stack);
Expand Down
8 changes: 2 additions & 6 deletions aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,10 @@ auto ConvParams::needs_64bit_indexing_no_split(const at::Tensor& input, const at
int64_t outsize = 1;
if (transposed) {
std::vector<int64_t> o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups);
for (int64_t i = 1; i < o.size(); i++) {
outsize *= o[i];
}
outsize = prod_intlist(o.begin() + 1, o.end());
} else {
std::vector<int64_t> o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation);
for (int64_t i = 1; i < o.size(); i++) {
outsize *= o[i];
}
outsize = prod_intlist(o.begin() + 1, o.end());
}
return outsize > int_max;
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/ForeachOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ std::vector<Tensor> foreach_tensor_##NAME##_slow(TensorList tensors1, TensorList
\
std::vector<Tensor> result; \
result.reserve(tensors1.size()); \
for (int i = 0; i < tensors1.size(); i++) { \
for (size_t i = 0; i < tensors1.size(); i++) { \
result.emplace_back(at::NAME(tensors1[i], tensors2[i])); \
} \
\
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(

bool use_cudnn = false;
use_cudnn = (input.is_cuda()
&& input.scalar_type() != at::kBFloat16 && weight.scalar_type() != at::kBFloat16
&& (input.scalar_type() != at::kHalf
|| weight.scalar_type() == at::kFloat)
&& weight.defined() && bias.defined()
Expand Down

0 comments on commit b4bd461

Please sign in to comment.