diff --git a/aten/src/ATen/VmapModeRegistrations.cpp b/aten/src/ATen/VmapModeRegistrations.cpp index 6bf0f027cf7c..ab4556c8c415 100644 --- a/aten/src/ATen/VmapModeRegistrations.cpp +++ b/aten/src/ATen/VmapModeRegistrations.cpp @@ -79,15 +79,15 @@ TORCH_LIBRARY_IMPL(aten, VmapMode, m) { m.impl("rand", unsupportedRandomOp); m.impl("rand.generator", unsupportedRandomOp, TENSOROPTIONS>); - m.impl_UNBOXED("rand.names", unsupportedRandomOp, const TensorOptions&>); - m.impl_UNBOXED("rand.generator_with_names", unsupportedRandomOp, optional, const TensorOptions&>); + m.impl("rand.names", unsupportedRandomOp, TENSOROPTIONS>); + m.impl("rand.generator_with_names", unsupportedRandomOp, optional, TENSOROPTIONS>); m.impl("rand.out", unsupportedRandomOp_); m.impl("rand.generator_out", unsupportedRandomOp_, Tensor&>); m.impl("randn", unsupportedRandomOp); m.impl("randn.generator", unsupportedRandomOp, TENSOROPTIONS>); - m.impl_UNBOXED("randn.names", unsupportedRandomOp, const TensorOptions&>); - m.impl_UNBOXED("randn.generator_with_names", unsupportedRandomOp, optional, const TensorOptions&>); + m.impl("randn.names", unsupportedRandomOp, TENSOROPTIONS>); + m.impl("randn.generator_with_names", unsupportedRandomOp, optional, TENSOROPTIONS>); m.impl("randn.out", unsupportedRandomOp_); m.impl("randn.generator_out", unsupportedRandomOp_, Tensor&>); diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index ab603a09c86b..7bdb0d996a13 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -265,20 +265,13 @@ namespace impl { return ivalue_to_arg, AllowDeprecatedTypes>::call(std::move(v)); } }; - template - struct ivalue_to_arg>, AllowDeprecatedTypes> final { - // If an argument is optional>, convert the IValue to a optional> and pass that - // to the operator. - static OptionalArray call(IValue&& v) { - return std::move(v).toOptionalIntArray(); - } - }; - template - struct ivalue_to_arg>, AllowDeprecatedTypes> final { - // If an argument is optional>, convert the IValue to a optional> and pass that - // to the operator. - static OptionalArray call(IValue&& v) { - return std::move(v).toOptionalDoubleArray(); + template + struct ivalue_to_arg>, AllowDeprecatedTypes> final { + // If an argument is optional>, convert the IValue to an optional> and pass that + // to the operator. OptionalArray is basically a optional> but impliticly convertible + // to optional>. + static OptionalArray call(IValue&& v) { + return ivalue_to_arg, AllowDeprecatedTypes>::call(std::move(v)); } }; diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index d2e72933b532..5ab5a9c0a501 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -705,14 +705,6 @@ struct CAFFE2_API IValue final { template optional toOptional(); - /// @private [doxygen private] - /// Only for use in generated code. - OptionalArray toOptionalIntArray(); - - /// @private [doxygen private] - /// Only for use in generated code. - OptionalArray toOptionalDoubleArray(); - /// @private [doxygen private] /// this is a shallow comparison of two IValues to test the object identity bool isSameIdentity(const IValue& rhs) const; diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index b3b53aed994c..46bde6103043 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -861,6 +861,36 @@ c10::List generic_to(IValue ivalue, _fake_type>) { return impl::toTypedList(std::move(ivalue).toList()); } +template +static std::vector createVectorFromList(const c10::detail::ListImpl* impl) { + std::vector result; + result.reserve(impl->list.size()); + for (size_t i = 0, N = impl->list.size(); i < N; ++i) { + result.push_back(impl->list[i].to()); + } + return result; +} + +template +static std::vector createVectorFromList(const c10::List& impl) { + std::vector result; + result.reserve(impl.size()); + for (size_t i = 0, N = impl.size(); i < N; ++i) { + result.push_back(impl[i]); + } + return result; +} + +template +OptionalArray generic_to(IValue ivalue, _fake_type>) { + if (ivalue.isNone()) { + return {}; + } + return createVectorFromList( + std::move(ivalue).to>() + ); +} + namespace detail { template std::array generic_to_array( @@ -952,16 +982,6 @@ inline T IValue::to() const& { return generic_to(*this, _fake_type{}); } -template -static std::vector createVectorFromList(const c10::detail::ListImpl* impl) { - std::vector result; - result.reserve(impl->list.size()); - for (size_t i = 0, N = impl->list.size(); i < N; ++i) { - result.push_back(impl->list[i].to()); - } - return result; -} - inline c10::List IValue::toIntList() && { AT_ASSERT(isIntList(), "Expected IntList but got ", tagKind()); return c10::List(moveToIntrusivePtr()); @@ -1211,20 +1231,6 @@ inline optional IValue::toOptional() { return this->to(); } -inline OptionalArray IValue::toOptionalIntArray() { - if (this->isNone()) { - return {}; - } - return this->toIntVector(); -} - -inline OptionalArray IValue::toOptionalDoubleArray() { - if (this->isNone()) { - return {}; - } - return this->toDoubleVector(); -} - inline bool IValue::isCustomClass() const { return torch::isCustomClass(*this); } diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index f3147bdf78aa..1d9f9d9d2a12 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -290,6 +290,27 @@ Tensor index(const Tensor & self, TensorList indices) { return iter.output(); } +Tensor quantized_index(const Tensor & self, TensorList indices) { + TORCH_INTERNAL_ASSERT( + self.qscheme() == c10::kPerTensorAffine || + self.qscheme() == c10::kPerTensorSymmetric, + "Indexing is only supported for per-Tensor quantized Tensors."); + + // For now, this is a naive implementation which does dq -> index -> q. + // TODO(future PR): improve performance by removing the copies. + const auto& self_dq = self.dequantize(); + + TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); + + auto info = make_info(self_dq, indices); + auto iter = make_index_iterator(info); + index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides); + at::Tensor res = iter.output(); + + return at::quantize_per_tensor( + res, self.q_scale(), self.q_zero_point(), self.scalar_type()); +} + Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); at::assert_no_internal_overlap(result); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 09b7c5f7e762..2bbde22c9389 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -106,9 +106,11 @@ variants: method - func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!) + use_c10_dispatcher: full variants: method - func: rename(Tensor(a) self, Dimname[]? names) -> Tensor(a) + use_c10_dispatcher: full variants: method - func: align_to(Tensor(a) self, Dimname[] names) -> Tensor(a) @@ -1738,6 +1740,7 @@ use_c10_dispatcher: full - func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor @@ -1942,6 +1945,7 @@ variants: function, method - func: unflatten.int(Tensor(a) self, int dim, int[] sizes, Dimname[]? names=None) -> Tensor(a) + use_c10_dispatcher: full variants: method - func: unflatten.Dimname(Tensor(a) self, Dimname dim, int[] sizes, Dimname[] names) -> Tensor(a) @@ -2023,6 +2027,7 @@ CPU, CUDA: frac_out - func: full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: full(int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -2200,6 +2205,7 @@ variants: function, method dispatch: CPU, CUDA: index + QuantizedCPU: quantized_index # NB: This function is special-cased in tools/autograd/gen_variable_type.py # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor Tensor::index(ArrayRef indices) @@ -3164,6 +3170,7 @@ variants: function - func: ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: ones(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -3310,9 +3317,11 @@ use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: rand.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: rand.generator_with_names(int[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: rand(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -3367,9 +3376,11 @@ use_c10_dispatcher: hacky_wrapper_for_legacy_signatures - func: randn.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: randn.generator_with_names(int[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: randn.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) @@ -4442,6 +4453,7 @@ variants: function - func: zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures device_guard: False - func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -9565,7 +9577,8 @@ CPU: slow_conv_transpose2d_cpu CUDA: slow_conv_transpose2d_cuda -- func: slow_conv_transpose2d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, Tensor columns, Tensor ones, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +- func: slow_conv_transpose2d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, Tensor columns, Tensor ones, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv_transpose2d_backward_out_cpu @@ -9592,7 +9605,8 @@ CPU: slow_conv_transpose3d_cpu CUDA: slow_conv_transpose3d_cuda -- func: slow_conv_transpose3d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +- func: slow_conv_transpose3d_backward.grad_output(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv_transpose3d_backward_out_cpu @@ -9627,7 +9641,8 @@ CPU: slow_conv2d_forward_cpu CUDA: legacy::cuda::_thnn_conv2d_forward -- func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +- func: thnn_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv2d_backward_out_cpu @@ -9660,7 +9675,8 @@ dispatch: CUDA: legacy::cuda::_thnn_conv_depthwise2d_forward -- func: thnn_conv_depthwise2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight) -> (Tensor(a!), Tensor(b!)) +- func: thnn_conv_depthwise2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, *, Tensor(a!) grad_input, Tensor(b!) grad_weight) -> (Tensor(a!), Tensor(b!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CUDA: thnn_conv_depthwise2d_backward_out @@ -9691,7 +9707,8 @@ dispatch: CPU: slow_conv3d_forward_cpu -- func: slow_conv3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!)? grad_input, Tensor(b!)? grad_weight, Tensor(c!)? grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +- func: slow_conv3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, Tensor finput, Tensor fgrad_input, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn dispatch: CPU: slow_conv3d_backward_out_cpu diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index d1b0acb87c28..a75b1a1295db 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -277,10 +277,12 @@ class CallbackManager { bool is_start) { try { if (is_start) { - ctx = rfcb.start()(rf); + ctx = rfcb.start() ? rfcb.start()(rf) : nullptr; } else { - rfcb.end()(rf, ctx.get()); + if (rfcb.end()) { + rfcb.end()(rf, ctx.get()); + } } return true; } catch (const std::exception &e) { diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 6b2e08576068..e9939667feb7 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -305,14 +305,16 @@ struct TORCH_API RecordFunction { */ class TORCH_API RecordFunctionCallback { public: + using StartCallback = std::unique_ptr(*)(const RecordFunction&); + using EndCallback = void (*)(const RecordFunction&, ObserverContext*); + // This interface supports observers that require passing an ObserverContext // between start and end callbacks. explicit RecordFunctionCallback( - std::function(const RecordFunction&)> start, - std::function end = - [](const RecordFunction&, ObserverContext*) {}): - start_(std::move(start)), - end_(std::move(end)) { + StartCallback start, + EndCallback end = nullptr) : + start_(start), + end_(end) { scopes_.fill(true); } @@ -368,18 +370,18 @@ class TORCH_API RecordFunctionCallback { return scopes_[(size_t)sc]; } - inline const std::function(const RecordFunction&)>& start() const { + inline StartCallback start() const { return start_; } - inline const std::function& end() const { + inline EndCallback end() const { return end_; } private: friend class CallbackManager; - std::function(const RecordFunction&)> start_; - std::function end_; + StartCallback start_; + EndCallback end_; bool(*should_run_)(const RecordFunctionCallback&) = nullptr; double sampling_prob_ = 1.0; std::array(RecordScope::NUM_SCOPES)> scopes_ = {}; diff --git a/binaries/record_function_benchmark.cc b/binaries/record_function_benchmark.cc index d47cedada40f..c80f46d75652 100644 --- a/binaries/record_function_benchmark.cc +++ b/binaries/record_function_benchmark.cc @@ -19,10 +19,10 @@ const float kLowSamplingProb = 0.0001; void addTestCallback( double sampling_prob = 1.0, - std::function(const at::RecordFunction&)> fn = - [](const at::RecordFunction&) { return nullptr; }) { + at::RecordFunctionCallback::StartCallback fn = + [](const at::RecordFunction&) -> std::unique_ptr { return nullptr; }) { auto cb = at::RecordFunctionCallback( - std::move(fn), + fn, [](const at::RecordFunction&, at::ObserverContext*) {}) .needsInputs(false); if (sampling_prob < 1.0) { @@ -106,10 +106,10 @@ int main(int argc, char** argv) { at::clearCallbacks(); std::cout << "Checking number of sampled observer invocations" << std::endl; - int cb_count = 0; + static int cb_count = 0; addTestCallback( kLowSamplingProb, - [&](const at::RecordFunction& fn) { + [](const at::RecordFunction&) -> std::unique_ptr { ++cb_count; return nullptr; } diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 347df066cc90..34e17c37f774 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -691,6 +691,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) { return DeviceType::Vulkan; } else if (tid == DispatchKey::Metal) { return DeviceType::Metal; + } else if (tid == DispatchKey::QuantizedCPU) { + return DeviceType::CPU; } else { AT_ASSERTM(false, "Unknown DispatchKey: ", tid); } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 61bca7c6ffc0..4fcf86be55e2 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -515,59 +515,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) endif() if(USE_CUDA OR USE_ROCM) - list(APPEND Caffe2_GPU_HIP_JIT_FUSERS_SRCS - ${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp - ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp - ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp - ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/arith.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/compute_at.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/codegen.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/dispatch.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/expr_evaluator.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor_kernel_arg.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor_launch_params.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/executor_utils.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/fusion.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/graph_fuser.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/index_compute.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/instrumentation.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_base_nodes.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_cloner.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_graphviz.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_nodes.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_iostream.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/iter_visitor.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_cache.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_ir.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_ir_builder.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_ir_printer.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_index.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_alias_memory.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_insert_syncs.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_thread_predicate.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_unroll.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_utils.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_validation.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower2device.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/partition.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/predicate_compute.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/register_interface.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/scheduler.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/shape_inference.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/tensor_view.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_iter.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_replay.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/transform_rfactor.cpp - ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/type.cpp - ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp - ) + append_filelist("libtorch_cuda_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS) endif() if(USE_CUDA) diff --git a/caffe2/operators/async_net_barrier_op.cc b/caffe2/operators/async_net_barrier_op.cc new file mode 100644 index 000000000000..25d10e673eac --- /dev/null +++ b/caffe2/operators/async_net_barrier_op.cc @@ -0,0 +1,50 @@ +#include "caffe2/operators/async_net_barrier_op.h" + +namespace caffe2 { + +namespace { +std::pair, std::vector> +asyncBarrierOpDevInfer(const OperatorDef& def) { + auto op_device = + def.has_device_option() ? def.device_option() : DeviceOption(); + ArgumentHelper helper(def); + auto cross_device = helper.GetSingleArgument("cross_device", 0); + std::vector opt; + for (int i = 0; i < def.input().size(); ++i) { + if (cross_device == 1) { + DeviceOption dev; + dev.set_device_type(op_device.device_type()); + dev.set_device_id(i); + opt.push_back(dev); + } else { + opt.push_back(op_device); + } + } + return std::make_pair(opt, opt); +} +} + +OPERATOR_SCHEMA(AsyncNetBarrier) + .NumInputs(1, INT_MAX) + .NumOutputs(1, INT_MAX) + .IdenticalTypeAndShape() + .InputsCanCrossDevices() + .AllowOneToOneInplace() + .DeviceInferenceFunction(asyncBarrierOpDevInfer) + .SetDoc(R"DOC( +This is a pretty much no-op operator, since it's only purposes is make sure that +async_scheduling will schedule certian operations earlier than others. + +Exaple where this operator can work well - mixture of data-parallel and model- +parallel training, where one wants to force that all copies are started before +data-parallel part starts. +)DOC") + .Arg( + "cross_device", + "Specifies either inputs should be across different devices in dev inference options"); + +SHOULD_NOT_DO_GRADIENT(AsyncNetBarrier); +REGISTER_CPU_OPERATOR(AsyncNetBarrier, AsyncNetBarrierOp); + + +} // namespace caffe2 diff --git a/caffe2/operators/async_net_barrier_op.cu b/caffe2/operators/async_net_barrier_op.cu new file mode 100644 index 000000000000..b516c4c14177 --- /dev/null +++ b/caffe2/operators/async_net_barrier_op.cu @@ -0,0 +1,8 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/async_net_barrier_op.h" + +namespace caffe2 { + +REGISTER_CUDA_OPERATOR(AsyncNetBarrier, AsyncNetBarrierOp); + +} // namespace caffe2 diff --git a/caffe2/operators/async_net_barrier_op.h b/caffe2/operators/async_net_barrier_op.h new file mode 100644 index 000000000000..9b44db317a7a --- /dev/null +++ b/caffe2/operators/async_net_barrier_op.h @@ -0,0 +1,30 @@ +#ifndef CAFFE2_OPERATORS_ASYNC_BARRIER_OP_H_ +#define CAFFE2_OPERATORS_ASYNC_BARRIER_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/export_caffe2_op_to_c10.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +template +class AsyncNetBarrierOp : public Operator { + public: + USE_OPERATOR_CONTEXT_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(AsyncNetBarrierOp) + + bool RunOnDevice() override { + // This is a pretty much no-op operator, since it's only purposes is make + // sure that async_scheduling will schedule certian operations earlier than + // others. + // + // Exaple where this operator can work well - mixture of data-parallel and + // model parallel training, where one wants to force that all copies are + // started before data-parallel part starts. + return true; + } +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_ASYNC_BARRIER_OP_H_ diff --git a/caffe2/python/operator_test/async_net_barrier_test.py b/caffe2/python/operator_test/async_net_barrier_test.py new file mode 100644 index 000000000000..e2c0ea0ccc1a --- /dev/null +++ b/caffe2/python/operator_test/async_net_barrier_test.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +import caffe2.python.hypothesis_test_util as hu +import hypothesis.strategies as st +import numpy as np +from caffe2.python import core +from hypothesis import given + + +class TestAsyncNetBarrierOp(hu.HypothesisTestCase): + @given( + n=st.integers(1, 5), + shape=st.lists(st.integers(0, 5), min_size=1, max_size=3), + **hu.gcs + ) + def test_async_net_barrier_op(self, n, shape, dc, gc): + test_inputs = [(100 * np.random.random(shape)).astype(np.float32) for _ in range(n)] + test_input_blobs = ["x_{}".format(i) for i in range(n)] + + barrier_op = core.CreateOperator( + "AsyncNetBarrier", + test_input_blobs, + test_input_blobs, + device_option=gc, + ) + + def reference_func(*args): + self.assertEquals(len(args), n) + return args + + self.assertReferenceChecks(gc, barrier_op, test_inputs, reference_func) diff --git a/docs/cpp/requirements.txt b/docs/cpp/requirements.txt index 452aa3eadad0..731a0475be79 100644 --- a/docs/cpp/requirements.txt +++ b/docs/cpp/requirements.txt @@ -1,5 +1,5 @@ sphinx==3.1.2 -breathe==4.19.2 +breathe==4.25.0 exhale==0.2.3 -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme bs4 diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index deb7a161e1d3..fa2e54844935 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -190,6 +190,11 @@ ("aten::quantile", datetime.date(2021, 1, 31)), ("aten::nanquantile", datetime.date(2021, 1, 31)), ("aten::_fft_with_size", datetime.date(2021, 1, 31)), + ("aten::thnn_conv_depthwise2d_backward", datetime.date(2021, 1, 31)), + ("aten::slow_conv3d_backward", datetime.date(2021, 1, 31)), + ("aten::thnn_conv2d_backward", datetime.date(2021, 1, 31)), + ("aten::slow_conv_transpose3d_backward", datetime.date(2021, 1, 31)), + ("aten::slow_conv_transpose2d_backward", datetime.date(2021, 1, 31)), ] def allow_listed(schema, allow_list): diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 10f36cc8e394..521c9f3ccf78 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -721,12 +721,40 @@ void checkTracedInputs(const TracedTestInputs& inputs) { TORCH_CHECK(found_mul); } +static bool bad_scope = false; +template +std::unique_ptr checkScopeCallback( + const at::RecordFunction& fn) { + if (fn.scope() == scope) { + ++(*cnt); + } else { + bad_scope = true; + } + return nullptr; +} + +template +void pushScopedCallback() { + at::addGlobalCallback( + at::RecordFunctionCallback(checkScopeCallback) + .scopes({scope})); +} + +// These cannot be function-local because that would prohibit them +// from being used as template arguments prior to C++17. +static size_t fun_cnt; +static size_t ts_fun_cnt; +static size_t user_scope_cnt; + void checkScopeCallbacks() { - bool found_function_scope = false; - bool found_method_scope = false; - bool found_user_scope = false; + static bool found_function_scope; + static bool found_method_scope; + static bool found_user_scope; + found_function_scope = false; + found_method_scope = false; + found_user_scope = false; at::addGlobalCallback(at::RecordFunctionCallback( - [&](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn) -> std::unique_ptr { if (fn.scope() == at::RecordScope::FUNCTION && std::string(fn.name().str()) == "test_function") { found_function_scope = true; @@ -742,27 +770,13 @@ void checkScopeCallbacks() { return nullptr; })); - bool bad_scope = false; - auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) { - at::addGlobalCallback( - at::RecordFunctionCallback( - [&bad_scope, &cnt, scope](const at::RecordFunction& fn) { - if (fn.scope() == scope) { - ++cnt; - } else { - bad_scope = true; - } - return nullptr; - }) - .scopes({scope})); - }; - - size_t fun_cnt = 0; - pushScopedCallback(at::RecordScope::FUNCTION, fun_cnt); - size_t ts_fun_cnt = 0; - pushScopedCallback(at::RecordScope::TORCHSCRIPT_FUNCTION, ts_fun_cnt); - size_t user_scope_cnt = 0; - pushScopedCallback(at::RecordScope::USER_SCOPE, user_scope_cnt); + bad_scope = false; + fun_cnt = 0; + pushScopedCallback(); + ts_fun_cnt = 0; + pushScopedCallback(); + user_scope_cnt = 0; + pushScopedCallback(); TORCH_CHECK(at::hasCallbacks()); @@ -788,33 +802,35 @@ static bool shouldRunCallback(const RecordFunctionCallback&) { return should_run; } -TEST(RecordFunctionTest, Basic) { +static TracedTestInputs traced_inputs; +static std::unordered_set ts_names; + +std::unique_ptr tracedInputsCallback( + const RecordFunction& fn) { + if (fn.scope() == RecordScope::FUNCTION) { + auto inputs = fn.inputs(); + std::vector> sizes; + for (const auto& input : inputs) { + if (input.isTensor()) { + sizes.push_back(input.toTensor().sizes().vec()); + } else if (input.isScalar()) { + sizes.push_back(std::vector()); + } + } + traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes)); + } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) { + ts_names.insert(fn.name().str()); + } + return nullptr; +} + +TEST(RecordFunctionTest, TracedTestInputs) { // disabling the inlining of method calls GraphOptimizerEnabledGuard opt_guard(false); // [(fn, [[sizes], [sizes], ...]), ...] - TracedTestInputs traced_inputs; - std::unordered_set ts_names; addGlobalCallback( - RecordFunctionCallback( - [&](const RecordFunction& fn) { - if (fn.scope() == RecordScope::FUNCTION) { - auto inputs = fn.inputs(); - std::vector> sizes; - for (const auto& input : inputs) { - if (input.isTensor()) { - sizes.push_back(input.toTensor().sizes().vec()); - } else if (input.isScalar()) { - sizes.push_back(std::vector()); - } - } - traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes)); - } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) { - ts_names.insert(fn.name().str()); - } - return nullptr; - }) - .needsInputs(true)); + RecordFunctionCallback(tracedInputsCallback).needsInputs(true)); TracedTestInputs eager_inputs, jit_inputs; { @@ -841,28 +857,36 @@ TEST(RecordFunctionTest, Basic) { checkTracedInputs(eager_inputs); checkTracedInputs(jit_inputs); at::clearCallbacks(); +} + +static int sampled_cb_ctr = 0; +std::unique_ptr sampledCallback(const RecordFunction& fn) { + if (std::string(fn.name().str()) == "test") { + ++sampled_cb_ctr; + } + return nullptr; +} + +static int non_sampled_cb_ctr = 0; +std::unique_ptr nonSampledCallback(const RecordFunction& fn) { + if (std::string(fn.name().str()) == "test") { + ++non_sampled_cb_ctr; + } + return nullptr; +} + +TEST(RecordFunctionTest, SampledCallbacks) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); // test sampled callbacks - int sampled_cb_ctr = 0; - auto setup_sampled_callback = [&sampled_cb_ctr](double sampling_prob) { - return addGlobalCallback(RecordFunctionCallback( - [&sampled_cb_ctr](const RecordFunction& fn) { - if (std::string(fn.name().str()) == "test") { - ++sampled_cb_ctr; - } - return nullptr; - }) - .samplingProb(sampling_prob)); + sampled_cb_ctr = 0; + auto setup_sampled_callback = [](double sampling_prob) { + return addGlobalCallback( + RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob)); }; - int non_sampled_cb_ctr = 0; - addGlobalCallback(RecordFunctionCallback( - [&non_sampled_cb_ctr](const RecordFunction& fn) { - if (std::string(fn.name().str()) == "test") { - ++non_sampled_cb_ctr; - } - return nullptr; - })); + addGlobalCallback(RecordFunctionCallback(nonSampledCallback)); auto handle = setup_sampled_callback(0.5); @@ -897,13 +921,19 @@ TEST(RecordFunctionTest, Basic) { // test the scope of the callbacks checkScopeCallbacks(); clearCallbacks(); +} + +TEST(RecordFunctionTest, RecordFunctionGuard) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); + + static std::vector fn_names; + static std::mutex guard_mtx; // check record function guard - std::vector fn_names; - std::mutex mtx; addGlobalCallback(RecordFunctionCallback( - [&fn_names, &mtx](const RecordFunction& fn) { - std::lock_guard lock(mtx); + [](const RecordFunction& fn) -> std::unique_ptr { + std::lock_guard lock(guard_mtx); fn_names.push_back(fn.name().str()); return nullptr; })); @@ -925,20 +955,26 @@ TEST(RecordFunctionTest, Basic) { TORCH_CHECK(fn_names.size() == 1); TORCH_CHECK(fn_names[0] == "B"); clearCallbacks(); +} - // test add/remove - std::vector ids; - auto add_remove_test_add_cb = [&ids](size_t id) { - return addGlobalCallback(RecordFunctionCallback( - [&ids, id](const RecordFunction& fn) { - ids.push_back(id); - return nullptr ; - })); - }; +static std::vector ids; - auto h1 = add_remove_test_add_cb(1); - auto h2 = add_remove_test_add_cb(2); - auto h3 = add_remove_test_add_cb(3); +template +auto add_remove_test_add_cb() { + return addGlobalCallback(RecordFunctionCallback( + [](const RecordFunction& fn) -> std::unique_ptr { + ids.push_back(id); + return nullptr; + })); +} + +TEST(RecordFunctionTest, Callbacks) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); + + auto h1 = add_remove_test_add_cb<1>(); + auto h2 = add_remove_test_add_cb<2>(); + auto h3 = add_remove_test_add_cb<3>(); { RECORD_USER_SCOPE("test"); } @@ -969,8 +1005,7 @@ TEST(RecordFunctionTest, Basic) { // thread local / global callbacks ids.clear(); - addGlobalCallback(RecordFunctionCallback( - [&ids](const RecordFunction& fn) { ids.push_back(1); return nullptr; })); + add_remove_test_add_cb<1>(); { RECORD_USER_SCOPE("test"); } @@ -978,9 +1013,12 @@ TEST(RecordFunctionTest, Basic) { TORCH_CHECK(ids[0] == 1); ids.clear(); - auto th = std::thread([&ids]() { + auto th = std::thread([]() { addThreadLocalCallback(RecordFunctionCallback( - [&ids](const RecordFunction& fn) { ids.push_back(2); return nullptr; })); + [](const RecordFunction& fn) -> std::unique_ptr { + ids.push_back(2); + return nullptr; + })); { RECORD_USER_SCOPE("test_thread"); } }); @@ -1005,22 +1043,20 @@ TEST(RecordFunctionTest, Basic) { }; ids.clear(); { // START: global test - const int test_val = 123; - const std::string test_str = "test str"; addGlobalCallback(RecordFunctionCallback( - [test_val, test_str, &ids](const RecordFunction& /* unused */) { + [](const RecordFunction & + /* unused */) -> std::unique_ptr { auto ctx = std::make_unique(); - ctx->a = test_val; - ctx->b = test_str; + ctx->a = 123; + ctx->b = "test_str"; ids.push_back(1); return ctx; }, - [test_val, test_str]( - const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { + [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { auto ctx = dynamic_cast(ctx_ptr); TORCH_CHECK(ctx_ptr != nullptr); - TORCH_CHECK(ctx->a == test_val); - TORCH_CHECK(ctx->b == test_str); + TORCH_CHECK(ctx->a == 123); + TORCH_CHECK(ctx->b == "test_str"); })); { RECORD_USER_SCOPE("test"); } @@ -1030,23 +1066,23 @@ TEST(RecordFunctionTest, Basic) { ids.clear(); } // END: global test { // START: thread local test - auto ctx_th = std::thread([&ids]() { + auto ctx_th = std::thread([]() { const int test_val = 234; const std::string test_str = "test thread str"; addThreadLocalCallback(RecordFunctionCallback( - [test_val, test_str, &ids](const RecordFunction& /* unused */) { + [](const RecordFunction & + /* unused */) -> std::unique_ptr { auto ctx = std::make_unique(); - ctx->a = test_val; - ctx->b = test_str; + ctx->a = 234; + ctx->b = "test_thread_str"; ids.push_back(2); return ctx; }, - [test_val, test_str]( - const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { + [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) { auto ctx = dynamic_cast(ctx_ptr); TORCH_CHECK(ctx_ptr != nullptr); - TORCH_CHECK(ctx->a == test_val); - TORCH_CHECK(ctx->b == test_str); + TORCH_CHECK(ctx->a == 234); + TORCH_CHECK(ctx->b == "test_thread_str"); })); // Will call both global and thread local callbacks. @@ -1060,14 +1096,21 @@ TEST(RecordFunctionTest, Basic) { } // END: thread local test clearCallbacks(); +} - // test should_run +TEST(RecordFunctionTest, ShouldRun) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); - bool ran = false; should_run = false; - addGlobalCallback(RecordFunctionCallback( - [&ran](const RecordFunction& fn) { ran = true; return nullptr; }) - .setShouldRun(shouldRunCallback)); + static bool ran = false; + addGlobalCallback( + RecordFunctionCallback( + [](const RecordFunction& fn) -> std::unique_ptr { + ran = true; + return nullptr; + }) + .setShouldRun(shouldRunCallback)); { RECORD_USER_SCOPE("test"); } @@ -1080,13 +1123,20 @@ TEST(RecordFunctionTest, Basic) { TORCH_CHECK(ran); clearCallbacks(); +} + +TEST(RecordFunctionTest, Basic) { + // disabling the inlining of method calls + GraphOptimizerEnabledGuard opt_guard(false); + + static std::string recorded_op; + static bool has_ids = false; // test propagation of TLS callbacks std::thread t([]() { RecordFunctionGuard enable_rec_fn; - std::string recorded_op; auto handle = addThreadLocalCallback(RecordFunctionCallback( - [&recorded_op](const RecordFunction& fn) { + [](const RecordFunction& fn) -> std::unique_ptr { recorded_op = fn.name().str(); return nullptr; })); @@ -1096,17 +1146,16 @@ TEST(RecordFunctionTest, Basic) { RECORD_USER_SCOPE("test_in_thread"); }); t_child.join(); - TORCH_CHECK(recorded_op == "test_in_thread"); + EXPECT_EQ(recorded_op, "test_in_thread"); removeCallback(handle); }); t.join(); clearCallbacks(); // test set ids - bool has_ids = false; addGlobalCallback( RecordFunctionCallback( - [&has_ids](const RecordFunction& fn) { + [](const RecordFunction& fn) -> std::unique_ptr { has_ids = fn.handle() > 0; return nullptr; }) @@ -1116,7 +1165,7 @@ TEST(RecordFunctionTest, Basic) { clearCallbacks(); has_ids = false; addGlobalCallback(RecordFunctionCallback( - [&has_ids](const RecordFunction& fn) { + [](const RecordFunction& fn) -> std::unique_ptr { has_ids = fn.handle() > 0; return nullptr; })); @@ -1126,10 +1175,10 @@ TEST(RecordFunctionTest, Basic) { } TEST(RecordFunctionTest, OperatorNameOverload) { - std::set operator_names; - + static std::set operator_names; at::addGlobalCallback(at::RecordFunctionCallback( - [&operator_names](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn) + -> std::unique_ptr { c10::optional op_name = fn.operator_name(); if (op_name.has_value()) { @@ -1178,6 +1227,8 @@ void checkDebugInfo(c10::DebugInfoKind kind, int model_id) { } TEST(ThreadLocalDebugInfoTest, Basic) { + static std::atomic done{false}; + TORCH_CHECK( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); auto debug_info = std::make_shared(); @@ -1190,10 +1241,9 @@ TEST(ThreadLocalDebugInfoTest, Basic) { // check that thread local debug info is propagated through fork calls TORCH_CHECK( c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); - std::atomic done{false}; { c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); - at::launch([&done]() { + at::launch([]() { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); done = true; }); @@ -1206,7 +1256,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) { c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); done = false; auto handle = addGlobalCallback(RecordFunctionCallback( - [&done](const RecordFunction&) { + [](const RecordFunction&) -> std::unique_ptr { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); done = true; return nullptr; @@ -1236,7 +1286,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314); done = false; - at::launch([&done]() { + at::launch([]() { checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314); done = true; diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index 1c66c8fb986f..f1e52fc38d32 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -2274,6 +2274,45 @@ def test_empty_batch(self): result = torch.ops.quantized.linear_dynamic(X, w_packed) self.assertEqual(result.shape, (0, 2)) + def test_advanced_indexing(self): + """ + Verifies that the x[:, [0], :, :] syntax works for quantized tensors. + """ + for dtype in (torch.qint8, torch.quint8, torch.qint32): + scale = 0.1 + zp = 0 + x_q = torch.quantize_per_tensor( + torch.randn(1, 4, 4, 4), scale, zp, dtype) + # reference + x_fp32 = x_q.dequantize() + + # single dim, single index + x_q_s1 = x_q[:, [0], :, :] + x_fp32_s1 = x_fp32[:, [0], :, :] + x_fp32_s1_ref = \ + torch.quantize_per_tensor(x_fp32_s1, scale, zp, dtype) + self.assertEqual(x_q_s1, x_fp32_s1_ref) + + # multiple dim, single index + x_q_s2 = x_q[:, [0], [2], :] + x_fp32_s2 = x_fp32[:, [0], [2], :] + x_fp32_s2_ref = \ + torch.quantize_per_tensor(x_fp32_s2, scale, zp, dtype) + self.assertEqual(x_q_s2, x_fp32_s2_ref) + + # single dim, multiple indices + x_q_s3 = x_q[:, [2, 0, 1], :, :] + x_fp32_s3 = x_fp32[:, [2, 0, 1], :, :] + x_fp32_s3_ref = \ + torch.quantize_per_tensor(x_fp32_s3, scale, zp, dtype) + self.assertEqual(x_q_s3, x_fp32_s3_ref) + + # multiple dim, multiple indices + x_q_s4 = x_q[:, [2, 0, 1], :, [1]] + x_fp32_s4 = x_fp32[:, [2, 0, 1], :, [1]] + x_fp32_s4_ref = \ + torch.quantize_per_tensor(x_fp32_s4, scale, zp, dtype) + self.assertEqual(x_q_s4, x_fp32_s4_ref) class TestDynamicQuantizedLinear(TestCase): diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index eca10839ae88..4dd370b303f2 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -353,9 +353,8 @@ libtorch_extra_sources = libtorch_core_jit_sources + [ def libtorch_sources(gencode_pattern = ":generate-code[{}]"): return libtorch_generated_sources(gencode_pattern) + libtorch_core_sources + libtorch_distributed_sources + libtorch_extra_sources -libtorch_cuda_sources = [ +libtorch_cuda_core_sources = [ "torch/csrc/cuda/comm.cpp", - "torch/csrc/cuda/nccl.cpp", "torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp", "torch/csrc/autograd/profiler_cuda.cpp", "torch/csrc/autograd/functions/comm.cpp", @@ -408,6 +407,10 @@ libtorch_cuda_sources = [ "torch/csrc/jit/tensorexpr/cuda_codegen.cpp", ] +libtorch_cuda_sources = libtorch_cuda_core_sources + [ + "torch/csrc/cuda/nccl.cpp", +] + torch_cpp_srcs = [ "torch/csrc/api/src/cuda.cpp", # this just forwards stuff, no real CUDA "torch/csrc/api/src/data/datasets/mnist.cpp", diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 488b7be9bd8a..7bf11a4d6316 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -172,9 +172,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { at::enableRecordFunction(enable); }); m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) { - auto cb = at::RecordFunctionCallback( - [](const at::RecordFunction&) { return nullptr; }, - [](const at::RecordFunction&, at::ObserverContext*) {}) + auto cb = at::RecordFunctionCallback(nullptr) .needsInputs(true) .samplingProb(sampling_prob); if (is_global) { diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index ac6ef84104f3..1c3c351eeb09 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -136,7 +136,7 @@ void pushProfilingCallbacks() { auto state_ptr = getProfilerTLSState(); TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( - [](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn) -> std::unique_ptr { auto state_ptr = getProfilerTLSState(); if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) { return std::make_unique(); diff --git a/torch/csrc/autograd/profiler_legacy.cpp b/torch/csrc/autograd/profiler_legacy.cpp index eb52aec8920d..e40da5bfda1f 100644 --- a/torch/csrc/autograd/profiler_legacy.cpp +++ b/torch/csrc/autograd/profiler_legacy.cpp @@ -414,7 +414,7 @@ void pushProfilingCallbacksLegacy() { auto state_ptr = getProfilerTLSState(); TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( - [](const at::RecordFunction& fn) { + [](const at::RecordFunction& fn) -> std::unique_ptr { auto state_ptr = getProfilerTLSState(); if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { return nullptr; diff --git a/torch/csrc/jit/codegen/cuda/shape_inference.h b/torch/csrc/jit/codegen/cuda/shape_inference.h index da2a2ed3f3a9..ede73d97afc1 100644 --- a/torch/csrc/jit/codegen/cuda/shape_inference.h +++ b/torch/csrc/jit/codegen/cuda/shape_inference.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace torch { diff --git a/torch/csrc/jit/frontend/concrete_module_type.h b/torch/csrc/jit/frontend/concrete_module_type.h index 91dc3023bc83..ff829d101fc1 100644 --- a/torch/csrc/jit/frontend/concrete_module_type.h +++ b/torch/csrc/jit/frontend/concrete_module_type.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include diff --git a/torch/csrc/jit/passes/cuda_graph_fuser.h b/torch/csrc/jit/passes/cuda_graph_fuser.h index 0f821845613b..104a437104aa 100644 --- a/torch/csrc/jit/passes/cuda_graph_fuser.h +++ b/torch/csrc/jit/passes/cuda_graph_fuser.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 5d88264a2f2c..24ca9dbf9793 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -891,6 +891,10 @@ struct CodeImpl { } void emitWarn(Node* node) { + if (FLAGS_torch_jit_disable_warning_prints) { + return; + } + emitLoadInputs(node->inputs()); int32_t idx = -1; if (node->hasAttribute(attr::warn_id)) { diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 025ac67f6f2e..279d41c20cb9 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -8,6 +8,8 @@ #include #include +C10_DECLARE_bool(torch_jit_disable_warning_prints); + namespace at { class Tensor; CAFFE2_API void launch(std::function func); diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 31750636d762..6bd3c29d78fb 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -35,6 +35,11 @@ C10_DEFINE_bool( true, "If this flag is set to false TorchScript will be using the legacy/original executor"); +C10_DEFINE_bool( + torch_jit_disable_warning_prints, + false, + "Disables warning.warn prints in TorchScript graph"); + constexpr size_t kDefaultNumProfiledRuns = 1; constexpr size_t kDefaultBailoutDepth = 20; diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index 42464cecd89a..d7c9ede9294f 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 57db79699e07..11fb5dae2d6c 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -32,7 +32,6 @@ bool canRunNatively(Node* n) { const static std::unordered_set native_nodes{ "aten::flatten", "aten::narrow", - "aten::permute", "aten::reshape", "aten::slice", "aten::transpose", diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index a75d187b2a49..f264423fdec2 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -71,11 +71,30 @@ void ConcatBatchMatMulBatchGather(std::shared_ptr& graph) { fuse.runOnGraph(graph); } +void ClipRangesGatherRangesLengthsToOffsets( + std::shared_ptr& graph) { + // TODO:: check restrictions for inputs; outputs not used elsewhere + std::string pattern = R"IR( + graph(%a, %b, %c, %d): + %y0 : Tensor = fb::clip_ranges(%b, %c) + %y1 : Tensor, %y2 : Tensor = fb::gather_ranges(%a, %y0) + %y3 : Tensor = fb::lengths_to_offsets(%y2, %d) + return (%y3, %y1))IR"; + std::string fused_pattern = R"IR( + graph(%a, %b, %c, %d): + %y0 : Tensor, %y1 : Tensor = fb::clip_ranges_gather_lengths_to_offsets(%a, %b, %c, %d) + return (%y1, %y0))IR"; + SubgraphRewriter fuse; + fuse.RegisterRewritePattern(pattern, fused_pattern); + fuse.runOnGraph(graph); +} + void FuseInferenceOpsForSparseNN(std::shared_ptr& graph) { #ifdef FBCODE_CAFFE2 ConcatAddMulReplaceNaNClip(graph); CastedBatchOneHotLengths(graph); ConcatBatchMatMulBatchGather(graph); + ClipRangesGatherRangesLengthsToOffsets(graph); #endif } diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp index e61676b83eca..44bc56206eaf 100644 --- a/torch/csrc/jit/runtime/vararg_functions.cpp +++ b/torch/csrc/jit/runtime/vararg_functions.cpp @@ -208,13 +208,20 @@ void listConstruct( Stack& stack, const at::ListTypePtr& type, size_t num_inputs) { - c10::List vals(type->getElementType()); - vals.reserve(num_inputs); - for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) { - vals.emplace_back(std::move(stack[i])); - } - drop(stack, num_inputs); - push(stack, std::move(vals)); + // Structuring the implementation this way allows NRVO to avoid + // move-constructing vals on its way onto the stack. Moving a List + // isn't free. + auto makeList = + [](Stack& stack, const at::ListTypePtr& type, size_t num_inputs) { + c10::List vals(type->getElementType()); + vals.reserve(num_inputs); + for (size_t i = stack.size() - num_inputs; i < stack.size(); ++i) { + vals.emplace_back(std::move(stack[i])); + } + drop(stack, num_inputs); + return vals; + }; + stack.push_back(makeList(stack, type, num_inputs)); } void dictConstruct( diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index a41c0e454f24..42847896c136 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -201,7 +201,7 @@ def _run_function(python_udf): def _handle_exception(result): if isinstance(result, RemoteException): - raise result.exception_type(result.msg) + raise result.exception_type(result.msg.encode("utf-8").decode("unicode_escape")) def _build_rpc_profiling_key( diff --git a/torch/fx/node.py b/torch/fx/node.py index d304a4c0a472..247896b5a920 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -59,7 +59,10 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', self.target = target # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add - self._uses : Dict[Node, None] = {} + # All `Node`-valued inputs. Key is the Node, value is don't-care. + # The public API for this is `all_input_nodes`, this private attribute + # should not be accessed directly. + self._input_nodes : Dict[Node, None] = {} self._update_args_kwargs(map_arg(args, lambda x: x), map_arg(kwargs, lambda x: x)) # type: ignore # All of the nodes that use the value produced by this Node @@ -191,10 +194,7 @@ def all_input_nodes(self) -> List['Node']: List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this ``Node``, in that order. """ - all_nodes : List['Node'] = [] - map_arg(self.args, lambda n: all_nodes.append(n)) - map_arg(self.kwargs, lambda n: all_nodes.append(n)) - return all_nodes + return list(self._input_nodes.keys()) def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict[str, Argument]): """ @@ -203,14 +203,14 @@ def _update_args_kwargs(self, new_args : Tuple[Argument, ...], new_kwargs : Dict self._args = new_args self._kwargs = new_kwargs - for old_use in self._uses.keys(): + for old_use in self._input_nodes.keys(): old_use.users.pop(self) - self._uses = {} - map_arg(self._args, lambda n: self._uses.setdefault(n)) - map_arg(self._kwargs, lambda n: self._uses.setdefault(n)) + self._input_nodes = {} + map_arg(self._args, lambda n: self._input_nodes.setdefault(n)) + map_arg(self._kwargs, lambda n: self._input_nodes.setdefault(n)) - for new_use in self._uses.keys(): + for new_use in self._input_nodes.keys(): new_use.users.setdefault(self) def __repr__(self) -> str: diff --git a/torch/optim/optimizer.pyi b/torch/optim/optimizer.pyi index aa50a6fd1027..6202050f3493 100644 --- a/torch/optim/optimizer.pyi +++ b/torch/optim/optimizer.pyi @@ -10,7 +10,7 @@ class Optimizer: param_groups: List[dict] def __init__(self, params: _params_t, default: dict) -> None: ... - def __setstate__(self, statue: dict) -> None: ... + def __setstate__(self, state: dict) -> None: ... def state_dict(self) -> dict: ... def load_state_dict(self, state_dict: dict) -> None: ... def zero_grad(self, set_to_none: Optional[bool]=...) -> None: ... diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 46dbacc3c2eb..a149c541a090 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -317,6 +317,10 @@ def my_script_func(tensor): def raise_func(): raise ValueError(expected_err) +expected_err_escape = "\nFirst line of error \n next line of error \n last line of error" +def raise_func_escape(): + raise ValueError(expected_err_escape) + global_rref = None @@ -1982,6 +1986,20 @@ def test_py_raise_in_user_func(self): stderr_lines = err.getvalue() self.assertTrue(expected_err in stderr_lines) + @dist_init + def test_py_raise_in_user_func_escaped_str(self): + n = self.rank + 1 + dst_rank = n % self.world_size + fut = rpc.rpc_async(worker_name(dst_rank), raise_func_escape) + try: + fut.wait() + except ValueError as e: + msg = str(e) + # Ensure newlines are unescaped to provide a better repr of error. + self.assertEqual(msg, msg.encode("utf-8").decode("unicode_escape")) + else: + self.assertTrue(False, "expected raise_func_escape to raise ValueError.") + @dist_init def test_nested_rpc(self): n = self.rank + 1