-
Notifications
You must be signed in to change notification settings - Fork 21.6k
/
common.py
1993 lines (1676 loc) · 67.7 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# mypy: allow-untyped-defs
import contextlib
import dataclasses
import functools
import itertools
import logging
import math
import operator
import re
from itertools import chain
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
NamedTuple,
Optional,
Set,
Tuple,
Union,
)
import sympy
from sympy.printing.printer import Printer
import torch
import torch.fx
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from .. import config, metrics
from ..utils import (
DeferredLineBase,
generate_assert,
IndentedBuffer,
sympy_dot,
sympy_subs,
unique,
)
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
def data_type_logger(msg):
if schedule_log.isEnabledFor(logging.DEBUG):
schedule_log.debug("Data type propagation: %s", msg)
@dataclasses.dataclass
class WorkspaceArg:
"""A temporary buffer used for a single kernel, then discarded.
Not registered as a traditional buffer since there are no users,
so it would be dead code eliminated.
"""
nbytes: sympy.Expr
zero_fill: bool
@dataclasses.dataclass
class TensorArg:
name: str
buffer: str
dtype: torch.dtype
offset: sympy.Expr = sympy.Integer(0)
@dataclasses.dataclass
class SizeArg:
name: str
expr: sympy.Expr
@dataclasses.dataclass
class DeviceCodegen:
scheduling: Any
wrapper_codegen: type
cpp_wrapper_codegen: type = type(None)
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
device_codegens: Dict[str, DeviceCodegen] = {}
class DeviceOpOverrides:
def import_get_raw_stream_as(self, name):
raise NotImplementedError
def set_device(self, device_idx):
raise NotImplementedError
def synchronize(self):
raise NotImplementedError
def device_guard(self, device_idx):
raise NotImplementedError
device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
# For any new backend looking to integrate with Inductor, customization of these two main
# parts are necessary to generate its specific code.
#
# Kernel code generation is determined by different Scheduling. Consequently, a new
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
#
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
# and override specific member functions to create backend-specific Python wrapper code.
#
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
# register_backend_for_device, to equip a new backend at runtime.
#
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
# This backend can be used as a reference:
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
def register_backend_for_device(
device: str,
device_scheduling: Any,
device_wrapper_codegen: type,
device_cpp_wrapper_codegen: type = type(None),
):
device_codegens[device] = DeviceCodegen(
device_scheduling, device_wrapper_codegen, device_cpp_wrapper_codegen
)
def get_scheduling_for_device(device: str):
return device_codegens[device].scheduling if device in device_codegens else None
def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False):
if device in device_codegens:
wrapper_codegen_obj: DeviceCodegen = device_codegens[device]
return (
wrapper_codegen_obj.cpp_wrapper_codegen
if cpp_wrapper
else wrapper_codegen_obj.wrapper_codegen
)
else:
return None
def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
from ..ir import FlexibleLayout
# added contiguous index prevents reordering
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
device_op_overrides_dict[device] = device_op_overrides
def get_device_op_overrides(device: str):
assert isinstance(device, str)
if not device_op_overrides_dict.keys():
from .cuda import device_op_overrides # noqa: F401
from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401
if device in device_op_overrides_dict.keys():
return device_op_overrides_dict[device]
@functools.lru_cache(None)
def boolean_ops():
return (
"is_inf",
"is_nan",
"bitwise_xor",
"logical_not",
"signbit",
"le",
"lt",
"ge",
"gt",
"eq",
"ne",
)
DTYPE_TO_COMPUTATION_DTYPE = {
torch.bfloat16: torch.float,
torch.float16: torch.float,
**{
dtype: dtype
for dtype in [
torch.bool,
torch.float32,
torch.float64,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
]
},
}
class DataTypePropagation:
def __init__(self, body) -> None:
self.body = body
self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
"root": body.root_block.graph
}
for k, v in body.subblocks.items():
self.graphs[k] = v.graph
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
inputs = node.all_input_nodes
input_nodes = [
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
]
if len(input_nodes) == 0:
return None
all_input_nodes_propagated = all(
OptimizationContext.key in n.meta
and n.meta[OptimizationContext.key].dtype is not None
for n in input_nodes
)
if not all_input_nodes_propagated:
return None
return functools.reduce(
torch.promote_types,
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
)
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
sub_graph = self.graphs[node.target]
dtype = self.propagate_graph(sub_graph)
assert dtype
return dtype
def deduce_node_dtype(self, node: torch.fx.Node):
if node.target in boolean_ops():
return torch.bool
if node.op == "placeholder":
return None
if node.target == "output":
# we can infer output node if it only have 1 arg
if len(node.args) != 1:
return None
if node.target in (
"to_dtype",
"index_expr",
):
return node.args[-1]
if node.target in (
"rand",
"randn",
):
return torch.float
if node.target in (
"get_index",
"index_expr",
"randint64",
):
return torch.int64
if node.target in (
"load",
"store",
"store_reduction",
):
buf_name = node.args[1]
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
if node.target == operator.getitem:
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
assert isinstance(node.target, str)
if node.target == "reduction":
return node.args[1]
if node.target == "constant":
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
if node.target.startswith("masked_subblock"):
return self.deduce_node_dtype_by_subgraph(node)
return self.deduce_node_dtype_by_inputs(node)
def propagate_graph(self, graph: torch.fx.Graph):
assert graph.nodes
graph_dtype = None
# For masked_subblock, we use output's dtype to represent
# the dtype of this subgraph. For other cases, graph_dtype
# might be None
for node in graph.nodes:
if OptimizationContext.key in node.meta:
opt_ctx = node.meta[OptimizationContext.key]
else:
opt_ctx = OptimizationContext()
opt_ctx.dtype = self.deduce_node_dtype(node)
node.meta[OptimizationContext.key] = opt_ctx
if node.target == "output":
graph_dtype = opt_ctx.dtype
return graph_dtype
def propagate(self):
self.propagate_graph(self.graphs["root"])
@classmethod
def propagate_loopbody(cls, body):
return cls(body).propagate()
@classmethod
def propagate_scheduler_node(cls, node):
from ..ir import LoopBody
from ..scheduler import SchedulerNode
assert isinstance(node, SchedulerNode)
assert isinstance(node._body, LoopBody)
DataTypePropagation.propagate_loopbody(node._body)
# This printer contains rules that are supposed to be generic for both C/C++ and
# Python
class ExprPrinter(Printer):
@staticmethod
def paren(string):
def all_in_parens(string):
if string[0] != "(" or len(string) < 2:
return False
count = 1
for i, char in enumerate(string[1:]):
if char == "(":
count += 1
elif char == ")":
count -= 1
if count == 0 and i != len(string) - 2:
return False
assert count == 0
return True
if (
isinstance(string, CSEVariable)
or re.match(r"^[a-z0-9_.]+$", string, re.I)
or re.match(r"^\([^)]*\)$", string, re.I)
or string == ""
):
return string
# don't put extra parens for strings that are already wrapped in parens
if all_in_parens(string):
return string
return f"({string})"
def _print_Relational(self, expr):
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
def _print_Mul(self, expr):
return "*".join(map(self.paren, map(self._print, expr.args)))
def _print_Add(self, expr):
return " + ".join(map(self.paren, map(self._print, expr.args)))
# NB: this is OK to put here, because Mod is only defined for positive
# numbers, and so across C/Python its behavior is consistent
def _print_Mod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_FloatTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _print_CleanDiv(self, expr):
return self._print_FloorDiv(expr)
def _print_GreaterThan(self, expr):
# GreaterThan: >=
# StrictlyGreaterThan: >
# Go figure...
return " >= ".join(map(self.paren, map(self._print, expr.args)))
# NB: The C implementation is injected into codegen at
# torch/_inductor/codegen/wrapper.py
def _print_align(self, expr):
assert len(expr.args) == 1
return f"align({self._print(expr.args[0])})"
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
# any explicit intervention. We print it just like x * x, notably, we
# never generate sympy.Pow with floats.
#
# NB: this pow by natural, you should never have used builtin sympy.pow
# for FloatPow, and a symbolic exponent should be PowByNatural. These
# means exp is guaranteed to be integer.
def _print_Pow(self, expr):
base, exp = expr.args
base = self._print(base)
assert exp == int(exp), exp
exp = int(exp)
assert exp >= 0
if exp > 0:
return "*".join([self.paren(base)] * exp)
else: # exp == 0
return "1"
# Explicit NotImplemented functions are to prevent default sympy printing
# behavior, which will just barf out ToFloat(...) to your IR. The error
# message is better here because it tells you which printer class it needs
# to go in.
def _print_ToFloat(self, expr):
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
def _print_Infinity(self, expr):
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
def _print_NegativeInfinity(self, expr):
raise NotImplementedError(
f"_print_NegativeInfinity not implemented for {type(self)}"
)
def _print_FloorDiv(self, expr):
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
def _print_PythonMod(self, expr):
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
def _print_IntTrueDiv(self, expr):
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
def _print_PowByNatural(self, expr):
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatPow(self, expr):
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
def _print_TruncToInt(self, expr):
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
def _print_RoundToInt(self, expr):
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
def _print_RoundDecimal(self, expr):
raise NotImplementedError(
f"_print_RoundDecimal not implemented for {type(self)}"
)
# NB: Some float operations are INTENTIONALLY not implemented for
# printers. You can implement them as a quick unblock, but it is better
# to ask yourself why we haven't done this computation in the Tensor
# universe instead
def _print_TruncToFloat(self, expr):
raise NotImplementedError(
f"_print_TruncToFloat not implemented for {type(self)}"
)
def doprint(self, expr, *, simplify: bool = True):
# TODO: why are people passing strings to the printer here :think:
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
expr = V.graph.sizevars.simplify(expr)
return super().doprint(expr)
class PythonPrinter(ExprPrinter):
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"float({self._print(expr.args[0])})"
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
mod = self.paren(self.doprint(mod))
if div != "1":
x = f"({x} // {div})"
return f"{x} % {mod}"
def _print_Infinity(self, expr):
return "math.inf"
def _print_NegativeInfinity(self, expr):
return "-math.inf"
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"({x} // {div})"
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
# does a special algorithm
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _helper_sqrt(self, expr):
return f"math.sqrt({self._print(expr)})"
def _print_OpaqueUnaryFn_sqrt(self, expr):
return self._helper_sqrt(expr.args[0])
def _print_FloatPow(self, expr):
base, exp = expr.args
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
# TODO: Not sure this works with Triton, even when base/exp are integral
def _print_PowByNatural(self, expr):
base, exp = expr.args
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
def _print_floor(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_FloorToInt(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_TruncToInt(self, expr):
assert len(expr.args) == 1
# This also could have been int(), they'll do the same thing for float
return f"math.trunc({self._print(expr.args[0])})"
def _print_ceiling(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_CeilToInt(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"abs({self._print(expr.args[0])})"
# NB: It's expected that we've made explicit any promotion in the sympy
# expression, so it doesn't matter that Python max/min doesn't perform
# promotion
def _print_Max(self, expr):
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
def _print_Min(self, expr):
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"
def _print_OpaqueUnaryFn_cos(self, expr):
assert len(expr.args) == 1
return f"math.cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr):
assert len(expr.args) == 1
return f"math.cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr):
assert len(expr.args) == 1
return f"math.acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr):
assert len(expr.args) == 1
return f"math.sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr):
assert len(expr.args) == 1
return f"math.sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr):
assert len(expr.args) == 1
return f"math.asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr):
assert len(expr.args) == 1
return f"math.tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr):
assert len(expr.args) == 1
return f"math.tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr):
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr):
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2
number, ndigits = expr.args
assert isinstance(ndigits, sympy.Integer)
return f"round({self._print(number)}, {ndigits})"
class OpOverrides:
def __init__(self, parent):
super().__init__()
self._parent = parent
def __getattr__(self, item):
return getattr(self._parent, item)
@staticmethod
def identity(value):
# used to trigger cse
return value
@staticmethod
def constant(value, dtype):
return repr(value)
@staticmethod
def reciprocal(x):
return ops.truediv(ops.constant(1, torch.int32), x)
@staticmethod
def square(x):
return ops.mul(x, x)
@staticmethod
def erfc(x):
return ops.sub(ops.constant(1, torch.float32), ops.erf(x))
@staticmethod
def erfcx(x):
return ops.mul(ops.exp(ops.square(x)), ops.erfc(x))
@staticmethod
def expm1(x):
return ops.sub(ops.exp(x), ops.constant(1, torch.float32))
@staticmethod
def log10(x):
return ops.mul(ops.log(x), ops.constant(1 / math.log(10), torch.float32))
@staticmethod
def log2(x):
return ops.mul(ops.log(x), ops.constant(1 / math.log(2), torch.float32))
@staticmethod
def exp2(x):
return ops.exp(ops.mul(x, ops.constant(math.log(2), torch.float32)))
@staticmethod
def log1p(x):
return ops.log(ops.add(x, ops.constant(1, torch.int32)))
@staticmethod
def sigmoid(x):
one = ops.constant(1, torch.int32)
return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
@staticmethod
def libdevice_sigmoid(x):
one = ops.constant(1, torch.int32)
return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
@staticmethod
def relu(x):
return ops.maximum(x, ops.constant(0, torch.int32))
@staticmethod
def libdevice_abs(x):
return ops.abs(x)
@staticmethod
def libdevice_sqrt(x):
return ops.sqrt(x)
@staticmethod
def libdevice_cos(x):
return ops.cos(x)
@staticmethod
def libdevice_sin(x):
return ops.sin(x)
@staticmethod
def libdevice_log(x):
return ops.log(x)
@staticmethod
def libdevice_exp(x):
return ops.exp(x)
@staticmethod
def bitwise_not(x):
return f"~{ExprPrinter.paren(x)}"
@staticmethod
def logical_not(a):
return f"{ExprPrinter.paren(a)} == 0"
@staticmethod
def bitwise_and(x, y):
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_or(x, y):
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_xor(x, y):
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_left_shift(x, y):
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
@staticmethod
def bitwise_right_shift(x, y):
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
@staticmethod
def remainder(a, b):
r = ops.mod(a, b)
cond = ops.and_(
ops.ne(r, ops.constant(0, torch.int32)),
ops.ne(ops.signbit(r), ops.signbit(b)),
)
return ops.where(cond, ops.add(r, b), r)
@staticmethod
def trunc_to_int(a, dtype):
return ops.to_dtype(ops.trunc(a), dtype)
@staticmethod
def floor_to_int(a, dtype):
return ops.to_dtype(ops.floor(a), dtype)
@staticmethod
def ceil_to_int(a, dtype):
return ops.to_dtype(ops.ceil(a), dtype)
@staticmethod
def round_to_int(a, dtype):
return ops.to_dtype(ops.round(a), dtype)
@staticmethod
def int_truediv(a, b):
# TODO: this is wrong
# TODO: an easy bandaid is to generate runtime asserts that it's
# <= 2**53, which is when this equation is correct
return ops.truediv(a, b)
@staticmethod
def load_seed(name, offset):
return ops.load(name, sympy.Integer(offset))
@classmethod
def _initialize_pointwise_overrides(cls, target):
assert target in {"triton", "cpp", "cppvec"}, target
for funcname, data in pointwise_overrides_data.items():
impl = getattr(data, target)
if impl is None:
continue
setattr(cls, funcname, staticmethod(impl))
@dataclasses.dataclass
class OverridesData:
name: str
cpp: Callable[..., str]
# None when not impl in libdevice/triton
triton: Optional[Callable[..., str]] = None
# None when not impl in aten/.../vec
cppvec: Optional[Callable[..., str]] = None
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
# NB: if you add a new special function, don't forget to update
# torch._inductor.ops_handler too
pointwise_overrides_data: Dict[str, OverridesData] = dict(
airy_ai=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"airy_ai_forward({x})",
name="special_airy_ai",
),
bessel_j0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"bessel_j0_forward({x})",
triton=lambda x: f"libdevice.j0({x})",
name="special_bessel_j0",
),
bessel_j1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"bessel_j1_forward({x})",
triton=lambda x: f"libdevice.j1({x})",
name="special_bessel_j1",
),
bessel_y0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"bessel_y0_forward({x})",
triton=lambda x: f"libdevice.y0({x})",
name="special_bessel_y0",
),
bessel_y1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"bessel_y1_forward({x})",
triton=lambda x: f"libdevice.y1({x})",
name="special_bessel_y1",
),
digamma=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_digamma({x})",
cppvec=lambda x: f"{x}.digamma()",
name="digamma",
),
# no cpp nor triton implementation for entr, it is defined as decomposition
# erf, erfc
erfcx=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_erfcx({x})",
triton=lambda x: f"libdevice.erfcx({x})",
name="special_erfcx",
),
fma=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y, z: f"std::fma({x}, {y}, {z})",
cppvec=lambda x, y, z: f"fmadd({x}, {y}, {z})",
triton=lambda x, y, z: f"libdevice.fma({x}, {y}, {z})",
name="fma",
),
# erfinv, exp2, expit, gammaln
igamma=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"calc_igamma({x}, {y})",
name="igamma",
),
igammac=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"calc_igammac({x}, {y})",
name="igammac",
),
gammainc=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"calc_igamma({x}, {y})",
name="special_gammainc",
),
gammaincc=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"calc_igammac({x}, {y})",
name="special_gammaincc",
),
i0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_i0({x})",
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
cppvec=lambda x: f"{x}.i0()",
name="i0",
),
i0e=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_i0e({x})",
cppvec=lambda x: f"{x}.i0e()",
name="special_i0e",
),
i1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_i1({x})",
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
name="special_i1",
),
i1e=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_i1e({x})",
name="special_i1e",
),
log_ndtr=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_log_ndtr({x})",
name="special_log_ndtr",
),
# logit
modified_bessel_i0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"modified_bessel_i0_forward({x})",
triton=lambda x: f"libdevice.cyl_bessel_i0({x})",
name="special_modified_bessel_i0",
),
modified_bessel_i1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"modified_bessel_i1_forward({x})",
triton=lambda x: f"libdevice.cyl_bessel_i1({x})",
name="special_modified_bessel_i1",
),
modified_bessel_k0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"modified_bessel_k0_forward({x})",
name="special_modified_bessel_k0",
),
modified_bessel_k1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"modified_bessel_k1_forward({x})",
name="special_modified_bessel_k1",
),
# multigamma
ndtr=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_ndtr({x})",
name="special_ndtr",
),
ndtri=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"calc_ndtri({x})",
name="special_ndtri",
),
polygamma=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"calc_polygamma({y}, {x})",
name="polygamma",
),
# psi - alias to digamma
# round
scaled_modified_bessel_k0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"scaled_modified_bessel_k0_forward({x})",
name="special_scaled_modified_bessel_k0",
),
scaled_modified_bessel_k1=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"scaled_modified_bessel_k1_forward({x})",
name="special_scaled_modified_bessel_k1",
),
# sinc
spherical_bessel_j0=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x: f"spherical_bessel_j0_forward({x})",
name="special_spherical_bessel_j0",
),
zeta=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"zeta({x}, {y})",
name="special_zeta",
),
chebyshev_polynomial_t=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"chebyshev_polynomial_t_forward({x}, {y})",
name="special_chebyshev_polynomial_t",
),
chebyshev_polynomial_u=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"chebyshev_polynomial_u_forward({x}, {y})",
name="special_chebyshev_polynomial_u",
),
chebyshev_polynomial_v=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"chebyshev_polynomial_v_forward({x}, {y})",
name="special_chebyshev_polynomial_v",
),
chebyshev_polynomial_w=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"chebyshev_polynomial_w_forward({x}, {y})",
name="special_chebyshev_polynomial_w",
),
legendre_polynomial_p=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"legendre_polynomial_p_forward({x}, {y})",
name="special_legendre_polynomial_p",
),
shifted_chebyshev_polynomial_t=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"shifted_chebyshev_polynomial_t_forward({x}, {y})",
name="special_shifted_chebyshev_polynomial_t",
),
shifted_chebyshev_polynomial_u=OverridesData(
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
cpp=lambda x, y: f"shifted_chebyshev_polynomial_u_forward({x}, {y})",
name="special_shifted_chebyshev_polynomial_u",
),