-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
Dispatcher.h
468 lines (401 loc) · 19 KB
/
Dispatcher.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
#pragma once
#include <ATen/SequenceNumber.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/boxing/impl/boxing.h>
#include <ATen/core/dispatch/OperatorEntry.h>
#include <ATen/core/dispatch/CppSignature.h>
#include <ATen/core/dispatch/RegistrationHandleRAII.h>
#include <ATen/record_function.h>
#include <c10/util/Exception.h>
#include <c10/util/LeftRight.h>
#include <mutex>
#include <list>
#include <ATen/core/grad_mode.h>
namespace c10 {
class CAFFE2_API OperatorHandle;
template<class FuncType> class TypedOperatorHandle;
/**
* Implement this interface and register your instance with the dispatcher
* to get notified when operators are registered or deregistered with
* the dispatcher.
*
* NB: registration events only occur when a 'def' occurs; we don't trigger
* on 'impl' or 'fallback' calls.
*/
class CAFFE2_API OpRegistrationListener {
public:
virtual ~OpRegistrationListener();
virtual void onOperatorRegistered(const OperatorHandle& op) = 0;
virtual void onOperatorDeregistered(const OperatorHandle& op) = 0;
};
namespace detail {
class RegistrationListenerList;
}
class SchemaRegistrationHandleRAII;
/**
* Top-level dispatch interface for dispatching via the dynamic dispatcher.
* Most end users shouldn't use this directly; if you're trying to register
* ops look in op_registration
*/
class CAFFE2_API Dispatcher final {
private:
// For direct access to backend fallback information
friend class impl::OperatorEntry;
struct OperatorDef final {
explicit OperatorDef(OperatorName&& op_name)
: op(std::move(op_name)) {}
impl::OperatorEntry op;
// These refer to the number of outstanding RegistrationHandleRAII
// for this operator. def_count reflects only def() registrations
// (in the new world, this should only ever be 1, but old style
// registrations may register the schema multiple times, which
// will increase this count). def_and_impl_count reflects the number
// of combined def() and impl() registrations. When the last def() gets
// unregistered, we must immediately call the Deregistered listeners, but we
// must not actually delete the handle as there are other outstanding RAII
// destructors which will try to destruct and they had better still have a
// working operator handle in this case
size_t def_count = 0;
size_t def_and_impl_count = 0;
};
friend class OperatorHandle;
template<class> friend class TypedOperatorHandle;
public:
~Dispatcher();
// Implementation note: this class abstracts over the fact that we have per-operator
// dispatch tables. This could be easily adjusted to have a single global hash
// table.
static Dispatcher& singleton();
// ------------------------------------------------------------------------
//
// Accessing operators by schema
//
// ------------------------------------------------------------------------
/**
* Looks for an operator schema with the given name and overload name
* and returns it if it is registered WITH A SCHEMA.
* Returns nullopt otherwise.
*/
c10::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
/**
* Variant of findSchema that results in less code generated at the call site.
* It (1) takes const char* pointer rather than OperatorName (so we skip
* generating std::string constructor calls at the call site), and (2)
* it raises an exception if the operator is not found (so we skip
* generating exception raising code at the call site)
*
* Irritatingly, we still have to generate the handful of instructions
* for dealing with an exception being thrown during static initialization
* (e.g. __cxa_guard_abort). If we could annotate this method noexcept we
* could avoid this code too, but as the name of the function suggests,
* it does throw exceptions.
*/
OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
// Like findSchema, but also returns OperatorHandle even if there is no schema
c10::optional<OperatorHandle> findOp(const OperatorName& operator_name);
// ------------------------------------------------------------------------
//
// Invoking operators
//
// ------------------------------------------------------------------------
template<class Return, class... Args>
Return call(const TypedOperatorHandle<Return (Args...)>& op, Args... args) const;
// Like call, but override the default DispatchKey calculation code,
// instead dispatching straight to the provided DispatchKey
template<class Return, class... Args>
Return callWithDispatchKey(const TypedOperatorHandle<Return (Args...)>& op, DispatchKey dispatchKey, Args... args) const;
// Like call, but intended for use in a redispatch: you are currently
// in some currentDispatchKey, you have finished processing the key and
// you now want to redispatch to the next dispatch key in the chain.
// This will mask out the current key *and all previous keys* from the
// eligible set, and reinvoke the dispatcher.
template<class Return, class... Args>
Return redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKey currentDispatchKey, Args... args) const;
// Invoke an operator via the boxed calling convention using an IValue stack
void callBoxed(const OperatorHandle& op, Stack* stack) const;
// ------------------------------------------------------------------------
//
// Performing registrations (NON user public; use op_registration)
//
// ------------------------------------------------------------------------
/**
* Register a new operator schema.
*
* If a schema with the same operator name and overload name already exists,
* this function will check that both schemas are exactly identical.
*/
RegistrationHandleRAII registerDef(FunctionSchema schema, std::string debug);
/**
* Register a kernel to the dispatch table for an operator.
* If dispatch_key is nullopt, then this registers a fallback kernel.
*
* @return A RAII object that manages the lifetime of the registration.
* Once that object is destructed, the kernel will be deregistered.
*/
// NB: steals the inferred function schema, as we may need to hold on to
// it for a bit until the real schema turns up
RegistrationHandleRAII registerImpl(OperatorName op_name, c10::optional<DispatchKey> dispatch_key, KernelFunction kernel, c10::optional<impl::CppSignature> cpp_signature, std::unique_ptr<FunctionSchema> inferred_function_schema, std::string debug);
/**
* Register a new operator by name.
*/
RegistrationHandleRAII registerName(OperatorName op_name);
/**
* Register a fallback kernel for a backend.
* If an operator is called but there is no concrete kernel for the dispatch
* key of the given operator arguments, it will check if there is such a
* fallback kernel for the given dispatch key and, if yes, call that one.
*/
RegistrationHandleRAII registerFallback(DispatchKey dispatch_key, KernelFunction kernel, std::string debug);
/**
* Use to register whenever we had a TORCH_LIBRARY declaration in the frontend
* API. These invocations are only permitted once per program, so we raise
* an error if this is called again for the same namespace.
*/
RegistrationHandleRAII registerLibrary(std::string ns, std::string debug);
// This function is a temporary hack that allows generated_unboxing_wrappers.cpp to register its codegen'ed
// unboxing wrapper for aten operators. We still need those for some operators because not all work
// with the templated unboxing logic yet.
// TODO Delete setBoxedKernelFor_ once all operators work with the templated boxing logic
void setManuallyBoxedKernelFor_(const OperatorHandle& op, KernelFunction::InternalBoxedKernelFunction* func);
// ------------------------------------------------------------------------
//
// Listeners on registrations
//
// ------------------------------------------------------------------------
/**
* Add a listener that gets called whenever a new op is registered or an existing
* op is deregistered. Immediately after registering, this listener gets called
* for all previously registered ops, so it can be used to keep track of ops
* registered with this dispatcher.
*/
RegistrationHandleRAII addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener);
void checkInvariants() const;
/* Check if operator calls with a given dispatch key
* need to be observed with RecordFunction.
*/
inline bool shouldRecord(DispatchKey dispatch_key) const {
return dispatch_key != DispatchKey::BackendSelect;
}
//
// ------------------------------------------------------------------------
//
// Assertions
//
// ------------------------------------------------------------------------
/**
* For testing purposes.
* Returns a list of all operators that were created through calls to registerImpl(),
* without any corresponding calls to registerDef(). After static initialization
* is done this is almost certainly a bug, as the created OperatorHandle won't have
* any schema associated with it and users calling the op through the dispatcher
* won't be able to access it
*
* Note that we cannot enforce this invariant "as we go" during static initialization,
* due to undefined static initialization order- we have no guarantees over the order
* in which .def() and .impl() calls are registered in the dispatcher at static
* initialization time. So this function should only be called after static initialization.
*/
std::vector<OperatorHandle> findDanglingImpls() const;
private:
Dispatcher();
OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema);
OperatorHandle findOrRegisterName_(const OperatorName& op_name);
void deregisterDef_(const OperatorHandle& op, const OperatorName& op_name);
void deregisterImpl_(
const OperatorHandle& op,
const OperatorName& op_name,
c10::optional<DispatchKey> dispatch_key,
std::list<impl::AnnotatedKernel>::iterator kernel_handle);
void deregisterName_(const OperatorHandle& op, const OperatorName& op_name);
void deregisterFallback_(DispatchKey dispatchKey);
void deregisterLibrary_(const std::string& ns);
void cleanup(const OperatorHandle& op, const OperatorName& op_name);
void checkSchemaCompatibility(const OperatorHandle& op, const FunctionSchema& schema, const std::string& debug);
std::list<OperatorDef> operators_;
LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
// Map from namespace to debug string (saying, e.g., where the library was defined)
ska::flat_hash_map<std::string, std::string> libraries_;
std::array<impl::AnnotatedKernel, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> backendFallbackKernels_;
std::unique_ptr<detail::RegistrationListenerList> listeners_;
std::mutex mutex_;
};
/**
* This is a handle to an operator schema registered with the dispatcher.
* This handle can be used to register kernels with the dispatcher or
* to lookup a kernel for a certain set of arguments.
*/
class CAFFE2_API OperatorHandle {
public:
OperatorHandle(OperatorHandle&&) noexcept = default;
OperatorHandle& operator=(OperatorHandle&&) noexcept = default;
OperatorHandle(const OperatorHandle&) = default;
OperatorHandle& operator=(const OperatorHandle&) = default;
const OperatorName& operator_name() const {
return operatorIterator_->op.operator_name();
}
bool hasSchema() const {
return operatorIterator_->op.hasSchema();
}
const FunctionSchema& schema() const {
return operatorIterator_->op.schema();
}
const std::string& debug() const {
return operatorIterator_->op.debug();
}
std::string dumpState() const {
return operatorIterator_->op.dumpState();
}
std::string dumpComputedTable() const {
return operatorIterator_->op.dumpComputedTable();
}
void checkInvariants() const {
return operatorIterator_->op.checkInvariants();
}
template<class FuncType>
TypedOperatorHandle<FuncType> typed() const {
// NB: This assert is not 100% sound: you can retrieve a typed() operator
// handle prior to ANY C++ signature being registered on the operator
// and the check will say everything is OK (at which point you can then
// smuggle in a kernel that is typed incorrectly). For everything
// in core library this won't happen, because all the static registrations
// will be done by the time a typed() handle is acquired.
operatorIterator_->op.assertSignatureIsCorrect<FuncType>();
return TypedOperatorHandle<FuncType>(operatorIterator_);
}
void callBoxed(Stack* stack) const {
c10::Dispatcher::singleton().callBoxed(*this, stack);
}
private:
explicit OperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
: operatorIterator_(std::move(operatorIterator)) {}
friend class Dispatcher;
template<class> friend class TypedOperatorHandle;
std::list<Dispatcher::OperatorDef>::iterator operatorIterator_;
};
/**
* This is a handle to an operator schema registered with the dispatcher.
* It holds the same information as an OperatorHandle, but it is templated
* on the operator arguments and allows calling the operator in an
* unboxed way.
*/
template<class FuncType>
class TypedOperatorHandle final {
static_assert(guts::false_t<FuncType>(), "FuncType in OperatorHandle::typed<FuncType> was not a valid function type");
};
template<class Return, class... Args>
class TypedOperatorHandle<Return (Args...)> final : public OperatorHandle {
public:
TypedOperatorHandle(TypedOperatorHandle&&) noexcept = default;
TypedOperatorHandle& operator=(TypedOperatorHandle&&) noexcept = default;
TypedOperatorHandle(const TypedOperatorHandle&) = default;
TypedOperatorHandle& operator=(const TypedOperatorHandle&) = default;
Return call(Args... args) const {
return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
}
Return callWithDispatchKey(DispatchKey dispatchKey, Args... args) const {
return c10::Dispatcher::singleton().callWithDispatchKey<Return, Args...>(*this, dispatchKey, std::forward<Args>(args)...);
}
private:
explicit TypedOperatorHandle(std::list<Dispatcher::OperatorDef>::iterator operatorIterator)
: OperatorHandle(std::move(operatorIterator)) {}
friend class OperatorHandle;
};
namespace detail {
template<class... Args> inline void unused_arg_(const Args&...) {}
}
template<class Return, class... Args>
inline Return Dispatcher::callWithDispatchKey(const TypedOperatorHandle<Return(Args...)>& op, DispatchKey dispatchKey, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
// No alias dispatch key is allowed at runtime.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKey));
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKey);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
// shouldRunRecordFunction checks whether RecordFunction should be executed,
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
// this boolean is passed into RecordFunction to adjust the sampling rates of
// the callbacks
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
// Check if we need to run callbacks registered with RecordFunction
// If true and callbacks need inputs, we box the arguments and pass
// them into the callbacks and also into the kernel call
// Note: for perf reasons we wouldn't want to pass arguments into
// the function call or prematurely box them
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && op.operatorIterator_->op.isObserved()) {
int64_t seq_num = -1;
// Setting sequence number in the Autograd case to associate
// the forward range with the coresponding Autograd's node
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
torch::jit::Stack stack = impl::boxArgs(args...);
guard.before(op, stack, seq_num);
} else {
guard.before(op, seq_num);
}
}
}
// keeping the guard alive while executing the kernel
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
}
template<class Return, class... Args>
inline Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
auto dispatchKey = op.operatorIterator_->op.dispatchKeyExtractor()
.template getDispatchKeyUnboxed<Args...>(
DispatchKeySet::FULL,
args...
);
return callWithDispatchKey<Return, Args...>(op, dispatchKey, args...);
}
template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKey currentDispatchKey, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
auto dispatchKey = op.operatorIterator_->op.dispatchKeyExtractor()
.template getDispatchKeyUnboxed<Args...>(
DispatchKeySet(DispatchKeySet::FULL_AFTER, currentDispatchKey),
args...
);
// do not use RecordFunction on redispatch
const KernelFunction& kernel = op.operatorIterator_->op.lookup(dispatchKey);
return kernel.template call<Return, Args...>(op, std::forward<Args>(args)...);
}
inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const {
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
const auto& entry = op.operatorIterator_->op;
auto dispatchKey = entry.dispatchKeyExtractor().getDispatchKeyBoxed(stack);
const auto& kernel = entry.lookup(dispatchKey);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
// using already existing stack to record function execution in observers
at::RecordFunction guard(at::RecordScope::FUNCTION, pre_sampled);
if (C10_UNLIKELY(guard.isActive())) {
if (shouldRecord(dispatchKey) && entry.isObserved()) {
int64_t seq_num = -1;
if (isIncludedInAlias(dispatchKey, DispatchKey::Autograd) && at::GradMode::is_enabled()) {
seq_num = at::sequence_number::peek();
}
if (guard.needsInputs()) {
guard.before(op, *stack, seq_num);
} else {
guard.before(op, seq_num);
}
}
}
// keeping the guard alive while executing the kernel
kernel.callBoxed(op, stack);
return;
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
kernel.callBoxed(op, stack);
}
} // namespace c10