diff --git a/aten/src/ATen/record_function.cpp b/aten/src/ATen/record_function.cpp index 1b240280f5148..6185ab0f96bed 100644 --- a/aten/src/ATen/record_function.cpp +++ b/aten/src/ATen/record_function.cpp @@ -92,11 +92,14 @@ class CallbackManager { bool found_needs_ids = false; auto init_handles = [ scope, &found_active_cb, &found_needs_inputs, &found_needs_ids]( - CallbackHandles& handles, RecordFunctionCallbacks& cbs) { + CallbackHandles& handles, RecordFunctionCallbacks& cbs, ObserverContextList& ctx_list) { handles.clear(); + + size_t num_callbacks = 0; for (const auto& cb : cbs) { if (cb.first.shouldRun(scope)) { handles.push_back(cb.second); + ++num_callbacks; found_active_cb = true; if (cb.first.needsInputs()) { found_needs_inputs = true; @@ -106,10 +109,12 @@ class CallbackManager { } } } + // Pre-allocate observer context list with nullptr. + ctx_list.resize(num_callbacks); }; - init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_); - init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_); + init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_, rec_fn.tls_ctx_); + init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_, rec_fn.global_ctx_); rec_fn.active = found_active_cb; rec_fn.needs_inputs = found_needs_inputs; if (found_needs_ids && found_active_cb) { @@ -121,11 +126,13 @@ class CallbackManager { mergeRunCallbacks( sorted_global_callbacks_, rf.sorted_active_global_handles_, + rf.global_ctx_, /* is_start */ true, rf); mergeRunCallbacks( sorted_tls_callbacks_, rf.sorted_active_tls_handles_, + rf.tls_ctx_, /* is_start */ true, rf); rf.called_start_callbacks_ = true; @@ -135,21 +142,30 @@ class CallbackManager { mergeRunCallbacks( sorted_global_callbacks_, rf.sorted_active_global_handles_, + rf.global_ctx_, /* is_start */ false, rf); mergeRunCallbacks( sorted_tls_callbacks_, rf.sorted_active_tls_handles_, + rf.tls_ctx_, /* is_start */ false, rf); } private: bool tryRunCallback( - const std::function& fn, - RecordFunction& rf) { + const RecordFunctionCallback& rfcb, + RecordFunction& rf, + std::unique_ptr& ctx, + bool is_start) { try { - fn(rf); + if (is_start) { + ctx = rfcb.start()(rf); + } + else { + rfcb.end()(rf, ctx.get()); + } return true; } catch (const std::exception &e) { LOG(WARNING) << "Exception in RecordFunction callback: " @@ -165,11 +181,12 @@ class CallbackManager { void mergeRunCallbacks( const RecordFunctionCallbacks& sorted_callbacks, const CallbackHandles& sorted_handles, + ObserverContextList& ctx_list, bool is_start, RecordFunction& rf) { size_t num_executed = 0; size_t idx_c = 0; - for (size_t idx_h = 0; idx_h < sorted_handles.size(); ++idx_h) { + for (size_t idx_h = 0; idx_h < sorted_handles.size() && idx_h < ctx_list.size(); ++idx_h) { while (idx_c < sorted_callbacks.size() && sorted_callbacks[idx_c].second < sorted_handles[idx_h]) { ++idx_c; @@ -178,11 +195,7 @@ class CallbackManager { break; } if (sorted_callbacks[idx_c].second == sorted_handles[idx_h]) { - if (is_start) { - tryRunCallback(sorted_callbacks[idx_c].first.start(), rf); - } else { - tryRunCallback(sorted_callbacks[idx_c].first.end(), rf); - } + tryRunCallback(sorted_callbacks[idx_c].first, rf, ctx_list[idx_h], is_start); ++num_executed; } } diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 13b14302a18f6..3ff64144b03de 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -67,7 +67,16 @@ struct TORCH_API StringView { // Soft limit on the number of callbacks to use; constexpr std::size_t kSoftLimitCallbacks = 4; +// An abstract base class for various observer contexts that can be attached to +// the RecordFunction. +struct ObserverContext { + virtual ~ObserverContext() {} + protected: + ObserverContext() {} +}; + typedef c10::SmallVector CallbackHandles; +typedef std::vector> ObserverContextList; typedef uint64_t RecordFunctionHandle; struct TORCH_API RecordFunction { @@ -164,6 +173,15 @@ struct TORCH_API RecordFunction { // public because of anonymous "friend" class CallbackHandles sorted_active_tls_handles_; CallbackHandles sorted_active_global_handles_; + + // Stores various ObserverContext objects with event metadata for thread local + // callbacks. + ObserverContextList tls_ctx_; + + // Stores various ObserverContext objects with event metadata for global + // callbacks. + ObserverContextList global_ctx_; + // Whether this RecordFunction runs any callbacks bool active = false; /// Whether any of the picked callbacks require inputs @@ -198,6 +216,8 @@ struct TORCH_API RecordFunction { * RecordFunctionCallback represents a pair of callbacks to be used with * RecordFunction, members: * start, end - the callbacks to run when entering and exiting the scope; + * optionally, the start callback may return an ObserverContext which will + * be passed to the end callback, use appropriate constructor accordingly. * needs_inputs - whether the callbacks need the inputs passed from the observed * function/range; NOTE: passing the inputs incurs an additional overhead; * sampling_probability - if not 1.0, then the callback is probabilistically sampled @@ -211,12 +231,25 @@ struct TORCH_API RecordFunction { */ class TORCH_API RecordFunctionCallback { public: + // This interface supports observers that require passing an ObserverContext + // between start and end callbacks. + explicit RecordFunctionCallback( + std::function(const RecordFunction&)> start, + std::function end = + [](const RecordFunction&, ObserverContext*) {}): + start_(std::move(start)), + end_(std::move(end)) { + scopes_.fill(true); + } + + // This interface is for observers that do not pass an ObserverContext object + // between start and end callbacks. explicit RecordFunctionCallback( std::function start, std::function end = [](const RecordFunction&) {}): - start_(std::move(start)), - end_(std::move(end)) { + start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }}, + end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} { scopes_.fill(true); } @@ -272,11 +305,11 @@ class TORCH_API RecordFunctionCallback { return scopes_[(size_t)sc]; } - inline const std::function& start() const { + inline const std::function(const RecordFunction&)>& start() const { return start_; } - inline const std::function& end() const { + inline const std::function& end() const { return end_; } @@ -284,8 +317,8 @@ class TORCH_API RecordFunctionCallback { bool shouldRun(RecordScope scope) const; private: - std::function start_; - std::function end_; + std::function(const RecordFunction&)> start_; + std::function end_; std::function should_run_; bool needs_inputs_ = false; bool needs_ids_ = false; diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 2e639bd94bcae..597d2b235849b 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -294,6 +294,9 @@ __host__ __device__ #endif // ANDROID / IOS // Portably determine if a type T is trivially copyable or not. +// Warning: __has_trivial_copy for GCC may not always detect the non-POD +// correctly. For example, T = std::unique_ptr may evaluate to true and be +// treated as POD. This can cause unexpected behavior. #if defined(__GNUG__) && __GNUC__ < 5 #define C10_IS_TRIVIALLY_COPYABLE(T) __has_trivial_copy(T) #else diff --git a/c10/util/SmallVector.h b/c10/util/SmallVector.h index 9566e7ac4eb8f..076a1d4010651 100644 --- a/c10/util/SmallVector.h +++ b/c10/util/SmallVector.h @@ -378,6 +378,9 @@ class SmallVectorTemplateBase : public SmallVectorTemplateCommon { /// This class consists of common code factored out of the SmallVector class to /// reduce code duplication based on the SmallVector 'N' template parameter. +/// Warning: C10_IS_TRIVIALLY_COPYABLE may not always detect non-POD +/// type correctly. For example, std::unique_ptr may be treated as POD and cause +/// memory leaks. template class SmallVectorImpl : public SmallVectorTemplateBase {