/
ir.py
4089 lines (3433 loc) · 129 KB
/
ir.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import contextlib
import dataclasses
import functools
import itertools
import logging
import re
import textwrap
from contextlib import nullcontext
from enum import Enum
from functools import partial
from inspect import signature
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
from unittest.mock import patch
import sympy
from sympy import Expr, Integer
import torch._dynamo.config as dynamo_config
import torch._logging
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.utils import identity
from torch._prims_common import (
is_boolean_dtype,
is_float_dtype,
make_channels_last_strides_for,
make_contiguous_strides_for,
)
from torch.fx.experimental.symbolic_shapes import FloorDiv
from . import config, dependencies
from .codegen.common import index_prevent_reordering
from .cuda_properties import get_device_properties
from .dependencies import extract_read_writes, var_builder
from .utils import (
argsort,
cache_on_self,
convert_shape_to_inductor,
convert_shape_to_symint,
developer_warning,
sympy_dot,
sympy_product,
sympy_subs,
sympy_symbol,
)
from .virtualized import ops, V
log = logging.getLogger(__name__)
indent = functools.partial(textwrap.indent, prefix=" ")
aten = torch.ops.aten
""" [Note: Inductor IR]
Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each
lowering is registered to a particular aten operator, and expects inputs that
correspond to the aten schema. However, in place of torch Tensor inputs, lowerings
expect Inductor TensorBox inputs.
TensorBox IR represents torch tensors. Tensors are sometimes single objects owning
storage, and sometimes views of another Tensor's storage. Mutating tensor operations
(such as add_()) affect the underlying storage and any associated views. Other operations
(such as .t_()) update metadata about the current view but don't modify the underlying storage.
To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
output from an operation. But just as torch.Tensors take different forms, TensorBox IR can
reference View IR or directly reference StorageBox IRs.
Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
may take an existing TensorBox and point it to a new underlying View IR.
Tensors that directly own storage are represented as a chain of:
TensorBox -> StorageBox -> Buffer
where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
(leaving the old buffer unmodified and functionalizing the operation).
Tensors backed by views add one more indirection to the IR.
TensorBox -> View -> StorageBox -> Buffer
In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
"""
def validate_ir(node_or_nodes):
def _check_tensorbox(node):
# Could expand this to check deeper properties
# (e.g. TensorBox points to View or StorageBox)
assert isinstance(
node,
(
DynamicScalar,
TensorBox,
RandSeedBuffer,
sympy.Symbol,
sympy.core.relational.Relational,
Expr,
),
), f"Found {type(node)}, which is not a supported top level IR node. See [Note: Inductor IR]"
# Be picky about the accepted data structure (don't use pytree here)
if isinstance(node_or_nodes, (List, Tuple)):
for node in node_or_nodes:
_check_tensorbox(node)
else:
_check_tensorbox(node_or_nodes)
def inverse_reorder(order):
inv_order = dict(zip(order, range(len(order))))
def reindex(index):
assert len(index) == len(inv_order)
return [index[inv_order[i]] for i in range(len(index))]
return reindex
def same_reorder(order):
def reindex(index):
assert len(index) == len(order)
return [index[order[i]] for i in range(len(index))]
return reindex
def fuse_reindexing(reindex1, reindex2):
def reindex(index):
return reindex1(reindex2(index))
return reindex
def stride_order2fill_order(order):
"""
Convert stride order to fill order
For channel last format,
stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
"""
lookup = {pos: idx for idx, pos in enumerate(order)}
fill_order = [lookup[i] for i in range(len(order))]
return fill_order
def get_stride_order(seq):
"""
Convert strides to stride order
"""
sorted_idx = argsort(seq)
out = [None for _ in range(len(seq))]
for i, elem in enumerate(sorted_idx):
out[elem] = i
return out
def ir_node_to_tensor(x, guard_shape=True):
if x is None:
return None
if not guard_shape:
shape_fn = V.graph.sizevars.size_hint
else:
shape_fn = identity
size = [shape_fn(s) for s in x.get_size()]
if is_storage_and_layout(x):
stride = [shape_fn(s) for s in x.get_layout().stride]
else:
stride = make_contiguous_strides_for(size)
dtype = x.get_dtype()
device = x.get_device()
size = convert_shape_to_symint(size)
stride = convert_shape_to_symint(stride)
t = torch.empty_strided(
size=size, stride=stride, dtype=dtype, device=device
).zero_()
return t
class ModularIndexing(sympy.Function):
"""
ModularIndexing(a, b, c) => (a // b) % c
"""
nargs = (3,)
is_integer = True
@classmethod
def eval(cls, base, divisor, modulus):
if base == 0 or modulus == 1:
return sympy.Integer(0)
if (
isinstance(base, sympy.Integer)
and isinstance(divisor, sympy.Integer)
and isinstance(modulus, sympy.Integer)
):
return (base // divisor) % modulus
if divisor != 1:
gcd = sympy.gcd(base, divisor)
if gcd != 1:
return ModularIndexing(base / gcd, divisor / gcd, modulus)
if isinstance(base, sympy.Add):
new_terms = []
all_positive = True
for term in base.args:
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
if (isinstance(term, sympy.Integer) and term < 0) or (
isinstance(term, sympy.Mul)
and isinstance(term.args[0], sympy.Integer)
and term.args[0] < 0
):
# workaround for https://github.com/openai/triton/issues/619,
# if there are negative terms, // produces wrong result
# TODO if https://github.com/openai/triton/issues/619 is fixed
# this optimization would become valid
all_positive = False
break
else:
new_terms.append(term)
if len(new_terms) != len(base.args) and all_positive:
return ModularIndexing(sum(new_terms), divisor, modulus)
if isinstance(base, FloorDiv):
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
class CleanDiv(FloorDiv):
"""
Div where we can assume no rounding.
This is to enable future optimizations.
"""
pass
class CeilDiv(sympy.Function):
"""
Div used in indexing that rounds up.
"""
is_integer = True
def __new__(cls, base, divisor):
if sympy.gcd(base, divisor) == divisor:
return CleanDiv(base, divisor)
else:
return FloorDiv(base + (divisor - 1), divisor)
def get_device_type(x):
if getattr(x, "get_device", None):
return get_device_type(x.get_device())
if isinstance(x, torch.device):
return x.type
return None
def is_triton(x):
return get_device_type(x) == "cuda"
def is_cpu(x):
return get_device_type(x) == "cpu"
@dataclasses.dataclass
class IRNode:
_current_origins: ClassVar[Set[Any]] = set()
@staticmethod
@contextlib.contextmanager
def current_origins(origins: Set[torch.fx.Node]):
old = IRNode._current_origins
IRNode._current_origins = old | origins
try:
yield
finally:
IRNode._current_origins = old
def __post_init__(self):
self.origins = set(self._current_origins)
def common_repr(self):
origins = f"origins={getattr(self, 'origins', '')}"
if len(origins) > 64:
# this can get *very* long
origins = f"{origins[:61]}..."
return [origins]
def str_helper(self, lines):
lines = lines + self.common_repr()
lines = indent(",\n".join(map(str, lines)))
return f"{type(self).__name__}(\n{lines}\n)"
def is_user_of(self, name):
return any(name == dep.name for dep in self.get_reads())
def get_numel(self):
return sympy_product(self.get_size())
@dataclasses.dataclass
class Loops(IRNode):
device: torch.device
dtype: torch.dtype
inner_fn: Callable
ranges: List[Expr]
def __str__(self, names=("ranges",)):
return self.str_helper(
[
f"'{self.device.type}'",
str(self.dtype),
self.inner_fn_str(),
]
+ [f"{name}={getattr(self, name)}" for name in names]
)
__repr__ = __str__
def get_dtype(self):
return self.dtype
def get_device(self):
return self.device
def get_size(self):
return self.ranges
def is_extern(self):
return False
@classmethod
def create(cls, *args, **kwargs):
return TensorBox.create(cls(*args, **kwargs))
@staticmethod
def _index(ranges, prefix="i"):
return [
sympy.Integer(0) if s == 1 else sympy_symbol(f"{prefix}{n}")
for n, s in enumerate(ranges)
]
@cache_on_self
def inner_fn_str(self):
index = self._index(self.ranges)
return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index)
def is_zero_elements(self):
return any(r == 0 for r in self.ranges)
@cache_on_self
def get_reads(self):
with patch.object(FlexibleLayout, "allow_indexing", True):
if self.get_reduction_type():
return extract_read_writes(
self.make_loader(),
self.get_size(),
self.get_reduction_size(),
).reads
else:
return extract_read_writes(
self.make_loader(),
self.get_size(),
).reads
class Pointwise(Loops):
def make_loader(self):
return self.inner_fn
def get_reduction_size(self):
return []
def get_reduction_type(self):
return None
def store_output(self, output_name, indexer, vars):
return ops.store(output_name, indexer(vars), self.inner_fn(vars))
def constant_to_device(self, device):
"""Move this to a given device. Requires that all reads are to constants."""
loader = self.make_loader()
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
return Pointwise(device, self.dtype, loader, self.ranges)
@dataclasses.dataclass
class Scatter(Pointwise):
output_indexer: Callable[[List[Expr]], Expr]
scatter_mode: Optional[str] = None
def constant_to_device(self, device):
"""Move this to a given device. Requires that all reads are to constants."""
loader = self.make_loader()
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
return Scatter(
device,
self.dtype,
loader,
self.ranges,
self.output_indexer,
self.scatter_mode,
)
def store_output(self, output_name, indexer, vars):
return ops.store(
output_name,
indexer(self.output_indexer(vars)),
self.inner_fn(vars),
mode=self.scatter_mode,
)
class ReductionHint(Enum):
INNER = 0
OUTER = 1
OUTER_TINY = 2
DEFAULT = 3
class TileHint(Enum):
SQUARE = 0
DEFAULT = 1
@dataclasses.dataclass
class Reduction(Loops):
reduction_ranges: List[Expr]
reduction_type: str
# self.dtype represents the dst dtype
src_dtype: torch.dtype
reduction_hint: ReductionHint
def __str__(self):
return Loops.__str__(
self, names=("ranges", "reduction_ranges", "reduction_type")
)
__repr__ = __str__
def get_reduction_size(self):
return self.reduction_ranges
def get_reduction_type(self):
return self.reduction_type
def store_reduction(self, output_name, indexer, vars, reduction_vars):
return ops.reduction(
output_name,
self.dtype,
self.src_dtype,
self.reduction_type,
indexer(vars),
self.inner_fn(vars, reduction_vars),
)
def index_length(self):
return len(self.ranges) + len(self.reduction_ranges)
@cache_on_self
def inner_fn_str(self):
index = self._index(self.ranges)
rindex = self._index(self.reduction_ranges, "r")
return V.KernelFormatterHandler.ir_to_string(
self.inner_fn,
index,
rindex,
)
def constant_to_device(self, device):
"""Move this to a given device. Requires that all reads are to constants."""
loader = self.make_loader()
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
return Reduction(
device,
self.dtype,
loader,
self.ranges,
self.reduction_ranges,
self.reduction_type,
self.src_dtype,
ReductionHint.DEFAULT,
)
@staticmethod
def num_splits(
device,
dst_dtype,
src_dtype,
inner_fn,
ranges,
reduction_ranges,
reduction_type,
reduction_numel,
):
num_sm = get_device_properties(device).multi_processor_count
min_elements_per_thread = 32
max_elements_per_thread = 512
threads_per_sm = 2048
min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
def inner_reduction_splits(reduction_numel_hint, numel_hint):
# do heuristics that's close to eager mode for split inner reduction
# we leak reduction autotune configs here, and will need to refactor to avoid this later
num_warps = 8
num_threads = 32 * num_warps
if numel_hint >= 2 * num_sm: # don't split if there are enough outputs
return 1
if reduction_numel_hint <= 8192:
return 1
if reduction_numel_hint * numel_hint <= min_elements_per_device:
split_size = min_elements_per_thread
elif reduction_numel_hint * numel_hint < max_elements_per_device:
target_blocks = num_sm * threads_per_sm // (2 * num_threads)
blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
tmp_split_size = (
reduction_numel_hint + num_threads * blocks_per_output - 1
) // (num_threads * blocks_per_output)
divisors = sympy.divisors(reduction_numel_hint)
closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
if abs(closest - tmp_split_size) < 30:
# prefer even splits, but never smalle than min_elements_per_thread
split_size = max(closest, min_elements_per_thread)
else:
split_size = tmp_split_size
else:
divisors = sympy.divisors(reduction_numel_hint)
closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
if abs(closest - max_elements_per_thread) < 50:
# prefer even splits
split_size = closest
else:
split_size = max_elements_per_thread
return (reduction_numel_hint + split_size * num_threads - 1) // (
split_size * num_threads
)
def outer_reduction_splits(reduction_numel_hint, numel_hint):
# TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
# extend to even smaller number of outputs
num_warps = 8
num_threads = num_warps * 32
rvals_per_thread = 4 # comes from heuristics, refactor to not leak here
xvals_per_block = 128
xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
if reduction_numel_hint * numel_hint < min_elements_per_device:
split_size = min_elements_per_thread
elif reduction_numel_hint * numel_hint < max_elements_per_device:
target_blocks = num_sm * threads_per_sm // (num_threads)
target_blocks = (target_blocks + xblocks - 1) // xblocks
tmp_split_size = (
reduction_numel_hint + rvals_per_thread * target_blocks - 1
) // (rvals_per_thread * target_blocks)
divisors = sympy.divisors(reduction_numel_hint)
closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
if abs(tmp_split_size - closest) < 20:
split_size = max(closest, min_elements_per_thread)
else:
split_size = tmp_split_size
else:
divisors = sympy.divisors(reduction_numel_hint)
closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
if abs(closest - max_elements_per_thread) < 50:
# prefer even splits
split_size = closest
else:
split_size = max_elements_per_thread
return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
rvals_per_thread * split_size
)
reduction_numel_hint = V.graph.sizevars.size_hint(reduction_numel)
numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
# easy cases
if numel_hint == 1:
return ReductionHint.INNER, inner_reduction_splits(
reduction_numel_hint, numel_hint
)
if (
reduction_numel_hint <= min_elements_per_thread
or numel_hint >= num_sm * 2 * 32
):
return ReductionHint.DEFAULT, 1
r = Reduction(
device,
dst_dtype,
inner_fn,
ranges,
reduction_ranges,
reduction_type,
src_dtype,
ReductionHint.DEFAULT,
)
def get_read_indices(r):
cb = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=r.get_device(),
dtype=r.get_dtype(),
size=r.get_size(),
),
data=r,
)
read_writes = cb.get_read_writes()
# try finding the full size producer
# TODO this will fail for something like ((1, N) * (N, 1)).sum()
# this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
range_vars = [
r
for r in read_writes.range_vars
if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number)
]
indices = []
changed = False
for md in sorted(read_writes.reads, key=lambda x: x.name):
if all([r in md.index.free_symbols for r in range_vars]):
indices.append(md.index)
if md.name in V.graph.name_to_buffer:
buf = V.graph.name_to_buffer[md.name]
original_stride = buf.layout.stride
buf.decide_layout()
if buf.layout.stride != original_stride:
changed = True
return indices, changed
indices, changed = get_read_indices(r)
if changed:
indices, _ = get_read_indices(r)
if len(indices) == 0:
# TODO determine splits when all inputs are broadcast
return ReductionHint.DEFAULT, 1
_, (_, reduction_vars), _ = dependencies.index_vars_squeeze(
r.get_size(), r.get_reduction_size()
)
num_outer = 0
num_inner = 0
for i in indices:
strides = V.graph.sizevars.stride_hints(i, reduction_vars)
outer = all([s > 1 for s in strides])
if outer:
num_outer += 1
else:
num_inner += 1
if num_inner > num_outer:
return ReductionHint.INNER, inner_reduction_splits(
reduction_numel_hint, numel_hint
)
else:
return ReductionHint.OUTER, outer_reduction_splits(
reduction_numel_hint, numel_hint
)
@staticmethod
def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type):
"""Convert inner_fn from a reduction to an pointwise"""
reduction_ranges = [
V.graph.sizevars.guard_static_shape(x) for x in reduction_ranges
]
if reduction_type == "sum":
def combine_fn(a, b):
return ops.add(a, b)
elif reduction_type == "min":
def combine_fn(a, b):
return ops.minimum(a, b)
elif reduction_type == "max":
def combine_fn(a, b):
return ops.maximum(a, b)
elif reduction_type == "any":
def combine_fn(a, b):
return ops.logical_or(a, b)
elif reduction_type == "argmin":
def combine_fn(a, b):
return ops.minimum(a[0], b[0]), ops.where(
ops.lt(b[0], a[0]), b[1], a[1]
)
elif reduction_type == "argmax":
def combine_fn(a, b):
return ops.maximum(a[0], b[0]), ops.where(
ops.gt(b[0], a[0]), b[1], a[1]
)
else:
raise NotImplementedError(f"unknown reduction_type={reduction_type}")
def fn(index):
return functools.reduce(
combine_fn,
(
value_fn(index, rindex)
for rindex in itertools.product(
*[range(x) for x in reduction_ranges]
)
),
)
if reduction_type in ("argmin", "argmax"):
flatten_index = FixedLayout(
None,
None,
reduction_ranges,
FlexibleLayout.contiguous_strides(reduction_ranges),
).make_indexer()
def value_fn(index, rindex):
rindex = [sympy.expand(i) for i in rindex]
return (
inner_fn(index, rindex),
ops.index_expr(flatten_index(rindex), torch.int64),
)
return lambda index: fn(index)[1]
else:
value_fn = inner_fn
return fn
@classmethod
def create(
cls,
device: torch.device,
dst_dtype: torch.dtype,
src_dtype: torch.dtype,
inner_fn: Callable,
ranges: List[Expr],
reduction_ranges: List[Expr],
reduction_type: str,
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
):
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
if reduction_numel == 0:
# N.B. This is a hack to generate the literal of the given type
# Ideally, we should be fixing `def constant` in triton.py
# but it breaks due to hardcoded dtypes in other places
def py_cnst(val):
return (
bool(val)
if dst_dtype == torch.bool
else float(val)
if dst_dtype.is_floating_point
else int(val)
)
rtypes_to_inits = {
"sum": py_cnst(0),
"prod": py_cnst(1),
"any": py_cnst(0),
# "all" is desugared to `!any(!val)`
}
assert (
reduction_type in rtypes_to_inits.keys()
), f"{reduction_type} not supported for zero-dimension tensors!"
def const_fn(index):
return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
return Pointwise.create(
device=device,
dtype=src_dtype,
inner_fn=const_fn,
ranges=list(ranges),
)
if reduction_numel == 1:
# this reduction is actually a pointwise op
if reduction_type in ("argmin", "argmax"):
def fn(index):
return ops.constant(0, dst_dtype)
else:
def fn(index):
reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
return inner_fn(index, reduction_index)
return Pointwise.create(device, dst_dtype, fn, ranges)
if (
isinstance(reduction_numel, sympy.Integer)
and V.graph.sizevars.size_hint(reduction_numel)
< config.unroll_reductions_threshold
and sympy_product(ranges) != 1
):
return Pointwise.create(
device,
dst_dtype,
cls._unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type),
ranges,
)
split_reduction = is_triton(device) and reduction_type not in {
"argmax",
"argmin",
}
if split_reduction and not dynamo_config.dynamic_shapes:
# triton doesn't support reduce to single element well, so break it up
hint, split = cls.num_splits(
device,
dst_dtype,
src_dtype,
inner_fn,
ranges,
reduction_ranges,
reduction_type,
reduction_numel,
)
# intermediate reduction in split can contain complex indexing,
# and num_splits will fail to correctly set the hint
# reuse the passed hint if available
if reduction_hint == ReductionHint.DEFAULT:
reduction_hint = hint
if split > 1:
# triton doesn't support reduce to single element well, so break it up
return cls.create_multilayer(
device,
dst_dtype,
src_dtype,
inner_fn,
ranges,
reduction_ranges,
reduction_type,
split,
reduction_hint,
)
elif split_reduction and dynamo_config.dynamic_shapes:
torch._logging.warning_once(
log,
"Could not do split reduction due to dynamic shapes; performance may be worse",
)
return TensorBox.create(
Reduction(
device,
dst_dtype,
inner_fn,
ranges,
reduction_ranges,
reduction_type,
src_dtype,
reduction_hint,
)
)
@staticmethod
def default_value(reduction_type, dtype):
if reduction_type in {"max", "argmax"}:
if is_float_dtype(dtype):
return float("-inf")
elif is_boolean_dtype(dtype):
return 0
else:
return torch.iinfo(dtype).min
if reduction_type in {"min", "argmin"}:
if is_float_dtype(dtype):
return float("inf")
elif is_boolean_dtype(dtype):
return 1
else:
return torch.iinfo(dtype).max
return {
"sum": 0,
"any": 0,
}[reduction_type]
@classmethod
def create_multilayer(
cls,
device: torch.device,
dst_dtype: torch.dtype,
src_dtype: torch.dtype,
inner_fn: Callable,
ranges: List[Expr],
reduction_ranges: List[Expr],
reduction_type: str,
split: int,
reduction_hint: ReductionHint,
):
"""
Break a large reduction up into multiple smaller reductions
recursively
"""
reduction_numel = sympy_product(reduction_ranges)
# TODO(jansel): convert this to dynamic shapes
# TODO(jansel): realize the reduction so we can do dynamic indexing
reduction_ranges = [
sympy.Integer(V.graph.sizevars.guard_static_shape(s))
for s in reduction_ranges
]
reduction_numel = sympy.Integer(
V.graph.sizevars.guard_static_shape(reduction_numel)
)
if V.graph.sizevars.size_hint(reduction_numel) % split == 0:
need_mask = False
else:
need_mask = True
split = sympy.Integer(split)
block_size = FloorDiv(reduction_numel + (split - 1), split)
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
def wrapper_fn(index, reduction_index):
(reduction_index,) = reduction_index
*new_index, reduction_block = index
indices = block_size * reduction_block + reduction_index
def body():
return inner_fn(new_index, reindex([indices]))
if need_mask:
mask = ops.lt(
ops.index_expr(indices, torch.int32),
ops.index_expr(reduction_numel, torch.int32),
)
return ops.masked(
mask, body, cls.default_value(reduction_type, dst_dtype)
)
else:
return body()
# triton will automatically compute reductions in fp32 if reducing over fp16/bf16
# within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
# in fp32 and not reduce precision by breaking up the kernel into multiple layers
intermediate_dtype = (
dst_dtype
if dst_dtype not in (torch.float16, torch.bfloat16)
else torch.float
)
intermediate = Reduction.create(
device,
intermediate_dtype,
src_dtype,
wrapper_fn,
[*ranges, split],
[block_size],
reduction_type,
reduction_hint,
)
intermediate.realize()
intermediate_loader = intermediate.make_loader()
def intermediate_fn(index, reduction_index):
return intermediate_loader([*index, *reduction_index])
numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
reduction_hint = ReductionHint.OUTER_TINY
if (
split <= 1024
and numel_hint <= 256
and reduction_hint == ReductionHint.OUTER
):
reduction_hint = ReductionHint.OUTER_TINY
return TensorBox.create(
Reduction(
device,
dst_dtype,
intermediate_fn,
ranges,
[split],
reduction_type,
src_dtype,
reduction_hint,
)
)
def is_storage_and_layout(x):
try:
as_storage_and_layout(x, freeze=False)
return True
except NotImplementedError:
return False