Skip to content

Commit

Permalink
Revert D25111515: Extra sampling of record function events
Browse files Browse the repository at this point in the history
Test Plan: revert-hammer

Differential Revision:
D25111515 (09b974c)

Original commit changeset: 0d572a3636fe

fbshipit-source-id: d558d8052924d937d86db7dd40dc6388e6d28823
  • Loading branch information
Mike Ruberry authored and facebook-github-bot committed Dec 9, 2020
1 parent 73f7178 commit 9f7fb54
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 265 deletions.
1 change: 0 additions & 1 deletion aten/src/ATen/ThreadLocalState.cpp
Expand Up @@ -19,7 +19,6 @@ ThreadLocalState::ThreadLocalState(bool keep_grad_mode)
grad_mode_enabled_ = GradMode::is_enabled();
}
#endif
bumped_record_all_functions_ = at::checkRecordAllFunctions();
}

/* static */
Expand Down
24 changes: 1 addition & 23 deletions aten/src/ATen/ThreadLocalState.h
Expand Up @@ -38,47 +38,25 @@ class TORCH_API ThreadLocalState {
bool grad_mode_enabled_;
#endif

// Whether pre-sampling RecordFunction optimization was enabled
bool bumped_record_all_functions_ = false;

friend class ThreadLocalStateGuard;
};

// Guard to set and reset the thread local state
class TORCH_API ThreadLocalStateGuard {
public:
explicit ThreadLocalStateGuard(const ThreadLocalState& state)
: prev_state_(ThreadLocalState()),
bumped_record_all_functions_(state.bumped_record_all_functions_) {
// Special handling of RecordFunction pre-sampling optimization:
// pre-samping is enabled (bumped) when there're non-sampled
// (or high-frequency) global or TLS callbacks.
//
// ThreadLocalStateGuard simply resets RecordFunction's TLS and
// hence its thread local callbacks.
//
// Checking if the pre-sampling was enabled and preserving it in the
// async task by calling bumpRecordAllFunctions() and the corresponding
// releaseRecordAllFunctions()
if (bumped_record_all_functions_) {
at::bumpRecordAllFunctions();
}
: prev_state_(ThreadLocalState()) {
// set the given state across the thread boundary
ThreadLocalState::setThreadLocalState(state);
}

~ThreadLocalStateGuard() {
// restore previously set variables
ThreadLocalState::setThreadLocalState(prev_state_);
if (bumped_record_all_functions_) {
at::releaseRecordAllFunctions();
}
}

private:
const ThreadLocalState prev_state_;
// Whether pre-sampling RecordFunction optimization was enabled
bool bumped_record_all_functions_ = false;
};

template <typename T>
Expand Down
81 changes: 32 additions & 49 deletions aten/src/ATen/core/dispatch/Dispatcher.h
Expand Up @@ -371,39 +371,28 @@ inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle<Return(A
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKey);

#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
// shouldRunRecordFunction checks whether RecordFunction should be executed,
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
// this boolean is passed into RecordFunction to adjust the sampling rates of
// the callbacks
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
// Check if we need to run callbacks registered with RecordFunction
// If true and callbacks need inputs, we box the arguments and pass
// them into the callbacks and also into the kernel call

// Note: for perf reasons we wouldn't want to pass arguments into
// the function call or prematurely box them
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
int64_t seq_num = -1;
// Setting sequence number in the Autograd case to associate
// the forward range with the coresponding Autograd's node
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
torch::jit::Stack stack = impl::boxArgs(args...);
guard.before(op, stack, seq_num);
} else {
guard.before(op, seq_num);
}
// Check if we need to run callbacks registered with RecordFunction
// If true and callbacks need inputs, we box the arguments and pass
// them into the callbacks and also into the kernel call

// Note: for perf reasons we wouldn't want to pass arguments into
// the function call or prematurely box them
at::RecordFunction guard(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
int64_t seq_num = -1;
// Setting sequence number in the Autograd case to associate
// the forward range with the coresponding Autograd's node
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
torch::jit::Stack stack = impl::boxArgs(args...);
guard.before(op, stack, seq_num);
} else {
guard.before(op, seq_num);
}
}
// keeping the guard alive while executing the kernel
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
Expand Down Expand Up @@ -440,26 +429,20 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
const auto& kernel = entry.lookup(dispatchKey);

#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
// using already existing stack to record function execution in observers
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && entry.isObserved()) {
int64_t seq_num = -1;
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
guard.before(op, *stack, seq_num);
} else {
guard.before(op, seq_num);
}
// using already existing stack to record function execution in observers
at::RecordFunction guard(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && entry.isObserved()) {
int64_t seq_num = -1;
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
guard.before(op, *stack, seq_num);
} else {
guard.before(op, seq_num);
}
}
// keeping the guard alive while executing the kernel
kernel.callBoxed(op, stack);
return;
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
kernel.callBoxed(op, stack);
Expand Down
114 changes: 24 additions & 90 deletions aten/src/ATen/record_function.cpp
Expand Up @@ -30,6 +30,8 @@ std::atomic<int64_t> defaultNodeId(-1);
std::atomic<uint64_t> next_thread_id_ {0};
thread_local uint64_t current_thread_id_ = 0;

thread_local bool tls_record_function_enabled_ = true;

// Low probability constant
static const double kLowProb = 0.001;
struct CoinflipTLS {
Expand Down Expand Up @@ -66,10 +68,6 @@ void set_record_function_tls_(const RecordFunctionTLS& tls) {
class CallbackManager {
public:
CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb) {
if (cb.samplingProb() > kLowProb) {
// pre-sampling of RecordFunction with prob. kLowProb cannot be used
at::bumpRecordAllFunctions();
}
// note: monotonically increasing callbacks_unique_id keeps
// sorted_tls_callbacks_ sorted
auto handle = next_unique_callback_handle();
Expand All @@ -78,10 +76,6 @@ class CallbackManager {
}

CallbackHandle addGlobalCallback(RecordFunctionCallback cb) {
if (cb.samplingProb() > kLowProb) {
// pre-sampling of RecordFunction with prob. kLowProb cannot be used
at::bumpRecordAllFunctions();
}
auto handle = next_unique_callback_handle();
sorted_global_callbacks_.emplace_back(std::move(cb), handle);
return handle;
Expand All @@ -98,10 +92,6 @@ class CallbackManager {
return el.second == handle;
});
if (it != cbs.end()) {
if (it->first.samplingProb() > kLowProb) {
// try to restore pre-sampling of RecordFunction
at::releaseRecordAllFunctions();
}
// keeps it sorted
cbs.erase(it);
return true;
Expand Down Expand Up @@ -137,13 +127,7 @@ class CallbackManager {
// callbackShouldRun is even hotter because it's called multiple
// times per init(). Profiling shows that the function prologue is
// taking up a significant fraction of the time.
static bool C10_ALWAYS_INLINE callbackShouldRun(
const RecordFunctionCallback& cb, RecordScope scope, bool pre_sampled) {
TORCH_INTERNAL_ASSERT(
!pre_sampled || (cb.sampling_prob_ <= kLowProb),
"Incorrect usage of a pre-sampled RecordFunction with a high-frequency "
" or non-sampled callback");

static bool C10_ALWAYS_INLINE callbackShouldRun(const RecordFunctionCallback& cb, RecordScope scope) {
// first check whether this callback is interested in
// the given scope type
if (!cb.checkScope(scope)) {
Expand All @@ -154,45 +138,36 @@ class CallbackManager {
return cb.should_run_(cb);
}

// otherwise potentially do the sampling
double sampling_prob = cb.sampling_prob_;
if (pre_sampled) {
// adjust the sampling rate to account for kLowProb pre-sampling of
// the RecordFunction
sampling_prob /= kLowProb;
if (cb.sampling_prob_ == 1.0) {
return true;
}

if (sampling_prob < 1.0) {
// model the low probability events as events happening
// with probability kLowProb followed by another sampling with
// probability (sampling_prob / kLowProb), then replace the coin
// flip for kLowProb with a thread local number of tries tries_left_
// sampled from the geometric distribution.
if (sampling_prob < kLowProb) {
if (coinflip_tls_.tries_left_ == 0) {
coinflip_tls_.tries_left_ = sample_geometric();
return (sample_zero_one() < sampling_prob / kLowProb);
} else {
--coinflip_tls_.tries_left_;
return false;
}
// model the low probability events as events happening
// with probability kLowProb followed by another sampling with
// probability (sampling_prob__ / kLowProb), then replace the coin
// flip for kLowProb with a thread local number of tries tries_left_
// sampled from the geometric distribution.
if (cb.sampling_prob_ < kLowProb) {
if (coinflip_tls_.tries_left_ == 0) {
coinflip_tls_.tries_left_ = sample_geometric();
return (sample_zero_one() < cb.sampling_prob_ / kLowProb);
} else {
return (sample_zero_one() < sampling_prob);
--coinflip_tls_.tries_left_;
return false;
}
} else {
return (sample_zero_one() < cb.sampling_prob_);
}

return true;
}

// init is called by RecordFunction in constructor to
// determine which thread local and global callbacks are going
// to be executed and whether any of them need inputs
inline void init(RecordFunction& rec_fn, RecordScope scope, bool pre_sampled) {
inline void init(RecordFunction& rec_fn, RecordScope scope) {
bool found_needs_inputs = false;
bool found_needs_ids = false;

for (const auto& cb: rf_tls_.sorted_tls_callbacks_) {
if (callbackShouldRun(cb.first, scope, pre_sampled)) {
if (callbackShouldRun(cb.first, scope)) {
if (cb.first.needsInputs()) {
found_needs_inputs = true;
}
Expand All @@ -207,7 +182,7 @@ class CallbackManager {
}

for (const auto& cb: sorted_global_callbacks_) {
if (callbackShouldRun(cb.first, scope, pre_sampled)) {
if (callbackShouldRun(cb.first, scope)) {
if (cb.first.needsInputs()) {
found_needs_inputs = true;
}
Expand Down Expand Up @@ -333,6 +308,7 @@ namespace {
}
} // namespace


RecordFunctionCallbacks _getTLSCallbacks() {
return rf_tls_.sorted_tls_callbacks_;
}
Expand Down Expand Up @@ -398,12 +374,12 @@ void enableRecordFunction(bool enable) {
rf_tls_.tls_record_function_enabled_ = enable;
}

RecordFunction::RecordFunction(RecordScope scope, bool pre_sampled) {
RecordFunction::RecordFunction(RecordScope scope) {
auto* rf_tls_ptr = &rf_tls_;
if (rf_tls_ptr->tls_record_function_enabled_) {
auto& m = manager();
if (!m.sorted_global_callbacks_.empty() || !rf_tls_ptr->sorted_tls_callbacks_.empty()) {
m.init(*this, scope, pre_sampled);
m.init(*this, scope);
}
}
}
Expand Down Expand Up @@ -475,46 +451,4 @@ void RecordFunction::end() {
}
}

// RecordFunction pre-sampling
namespace {
// Whether to try to create RecordFunction on each call (>0) or
// use pre-sampling (=0)
std::atomic<int> global_record_all_functions_ {0};
}

void bumpRecordAllFunctions() {
global_record_all_functions_.fetch_add(1, std::memory_order_relaxed);
}

void releaseRecordAllFunctions() {
TORCH_CHECK(global_record_all_functions_.fetch_sub(1, std::memory_order_relaxed) >= 0);
}

bool checkRecordAllFunctions() {
return (global_record_all_functions_.load(std::memory_order_relaxed) > 0);
}

bool shouldRunRecordFunction(bool* pre_sampled) {
auto* rf_tls_ptr = &rf_tls_;
if (!rf_tls_ptr->tls_record_function_enabled_) {
*pre_sampled = false;
return false;
}

if (global_record_all_functions_.load(std::memory_order_relaxed) > 0) {
*pre_sampled = false;
return true;
}

*pre_sampled = true;
auto* coinflip_tls_ptr = &coinflip_tls_;
if (coinflip_tls_ptr->tries_left_ == 0) {
coinflip_tls_ptr->tries_left_ = sample_geometric();
return true;
} else {
--coinflip_tls_ptr->tries_left_;
return false;
}
}

} // namespace at

0 comments on commit 9f7fb54

Please sign in to comment.