/
utils.py
1652 lines (1316 loc) · 50.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import collections
import contextlib
import dataclasses
import enum
import functools
import inspect
import io
import itertools
import json
import logging
import math
import operator
import os
import platform
import shutil
import sys
import tempfile
import textwrap
import time
import unittest
from datetime import datetime
from io import StringIO
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
List,
NamedTuple,
Optional,
Protocol,
Set,
Tuple,
TypeVar,
Union,
ValuesView,
)
from unittest import mock
import sympy
from typing_extensions import Concatenate, ParamSpec
import torch
import torch._export
import torch.utils._pytree as pytree
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.utils import detect_fake_mode
from torch.autograd import DeviceType
from torch.autograd.profiler_util import EventList
from torch.fx.passes.shape_prop import ShapeProp
from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing
from torch.utils._sympy.symbol import make_symbol, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
from . import config
from .runtime.runtime_utils import cache_dir, ceildiv as runtime_ceildiv
log = logging.getLogger(__name__)
_T = TypeVar("_T")
VarRanges = Dict[sympy.Expr, sympy.Expr]
ALIGNMENT = 16
def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float:
"""
Returns benchmark results by examining torch profiler events.
This could be more accurate as it doesn't count CPU side overhead.
However, this also requires manually excluding irrelevant event, e.g.
vectorized_elementwise_kernel which is used to fill L2 cache,
various CUDA events, etc, so could also be fragile.
"""
fn()
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
# Estimate the runtime of the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
# Warm-up
for _ in range(n_warmup):
fn()
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CUDA,
]
) as p:
# Benchmark
for i in range(n_repeat):
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
fn()
# Record clocks
torch.cuda.synchronize()
log.debug("raw events")
log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
filtered_events = EventList(
[
event
for event in p.events()
if event.device_type == DeviceType.CUDA and event.name != "Context Sync"
]
)
if len(filtered_events) % n_repeat != 0:
raise RuntimeError(
"Failed to divide all profiling events into #repeat groups. "
"#CUDA events: %d, #repeats: %s",
len(filtered_events),
n_repeat,
)
num_event_per_group = len(filtered_events) / n_repeat
actual_events = EventList(
[
event
for i, event in enumerate(filtered_events)
if i % num_event_per_group != 0
]
)
actual_events._build_tree()
actual_events = actual_events.key_averages()
log.debug("profiling time breakdown")
log.debug(actual_events.table(row_limit=-1))
res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat
log.debug("profiling results: %s ms", res)
return res
@functools.lru_cache(None)
def has_torchvision_roi_align() -> bool:
try:
from torchvision.ops import roi_align # noqa: F401
return roi_align is not None and hasattr(
getattr(torch.ops, "torchvision", None), "roi_align"
)
except ImportError:
return False
def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
if device is None:
return torch.tensor(0.0).device # default device
if isinstance(device, str):
device = torch.device(device)
if device.type not in ("cpu", "meta") and device.index is None:
device_interface = get_interface_for_device(device.type)
return torch.device(device.type, index=device_interface.Worker.current_device())
return device
def sympy_product(it):
return functools.reduce(operator.mul, it, sympy.Integer(1))
def sympy_dot(seq1, seq2):
assert len(seq1) == len(seq2)
return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
def unique(it: Iterable[_T]) -> ValuesView[_T]:
return {id(x): x for x in it}.values()
def ceildiv(
numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
) -> Union[int, sympy.Expr]:
if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr):
return CeilDiv(numer, denom)
# TODO: There is a bug in a call to this function, to repro:
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
# --amp --only YituTechConvBert --dynamic-shapes
assert isinstance(numer, int) and isinstance(
denom, int
), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
return runtime_ceildiv(numer, denom)
def _type_of(key):
# Use the function here to get rid of dependencies on the Triton during the codegen.
# Refer to Triton implementation here:
# https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
# `None` is nullptr. Implicitly convert to *i8.
if key is None:
return "*i8"
dtype_str = str(key).split(".")[-1]
tys = {
"bool": "i1",
"float8e4nv": "fp8e4nv",
"float8e5": "fp8e5",
"float8e4b15": "fp8e4b15",
"float8e4b15x4": "fp8e4b15x4",
"float8_e4m3fn": "fp8e4nv",
"float8_e5m2": "fp8e5",
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"float64": "fp64",
"int8": "i8",
"int16": "i16",
"int32": "i32",
"int64": "i64",
"uint8": "u8",
"uint16": "u16",
"uint32": "u32",
"uint64": "u64",
}
# reinterpret can create triton type
for v in list(tys.values()):
tys[v] = v
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
def convert_shape_to_inductor(
lst: Iterable[Union[int, torch.SymInt]]
) -> List[sympy.Expr]:
"""
Gets the shape and stride of a tensor. For non-symbolic tensors, this is
trivial. But for symbolic tensors, we need to map from SymIntNode into
sympy.Expr.
"""
return [
i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst
]
def convert_shape_to_symint(
lst: Iterable[Union[int, sympy.Expr]]
) -> List[Union[int, torch.SymInt]]:
"""
Takes a list of shapes from Inductor and converts them into symints (or just
ints if all shapes are static).
"""
from .virtualized import V
return [
i
if isinstance(i, int)
else int(i)
if isinstance(i, sympy.Integer)
else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
for i in lst
]
def is_view(op: torch._ops.OpOverload):
"""
Does this op overload have aliasing
"""
assert isinstance(op, torch._ops.OpOverload)
return any(a.alias_info is not None for a in op._schema.arguments)
def is_pointwise_use(use):
if not use.op == "call_function":
return False
if not (
isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
):
return False
if use.target is operator.getitem or is_view(use.target):
return all(is_pointwise_use(u) for u in use.users)
return torch.Tag.pointwise in use.target.tags
def gen_gm_and_inputs(target, args, kwargs):
g = torch.fx.Graph()
g_args = []
a_args = []
for n, arg in enumerate(args):
if isinstance(arg, torch.Tensor):
g_args.append(g.placeholder(f"arg{n}"))
a_args.append(arg)
else:
g_args.append(arg)
assert all(not isinstance(x, torch.Tensor) for x in kwargs.values())
node = g.call_function(target, tuple(g_args), kwargs)
if (
len(target._schema.returns) == 1
and str(target._schema.returns[0].type) == "Tensor"
):
node = (node,)
g.output(node)
gm = torch.fx.GraphModule({}, g)
return gm, a_args
def synchronize(device: str = "cuda"):
if device == "cpu":
return
device_interface = get_interface_for_device(device)
if device_interface.is_available():
device_interface.synchronize()
def timed(
model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda"
) -> float:
synchronize(device)
torch.manual_seed(1337)
t0 = time.perf_counter()
for _ in range(times):
result = model(*example_inputs)
synchronize(device)
t1 = time.perf_counter()
# GC the result after timing
assert result is not None # type: ignore[possibly-undefined]
return t1 - t0
def print_performance(
fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda"
):
timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)])
took = torch.median(timings) / times
print(f"{took/baseline:.6f}")
return took
def precompute_method(obj: Any, method: str):
"""Replace obj.method() with a new method that returns a precomputed constant."""
result = getattr(obj, method)()
setattr(obj, method, lambda: result)
def precompute_methods(obj: Any, methods: List[str]):
"""Replace methods with new methods that returns a precomputed constants."""
for method in methods:
precompute_method(obj, method)
def cmp(a, b) -> int:
return int(a > b) - int(a < b)
def pad_listlike(x, size):
if len(x) == 1:
return type(x)([x[0]]) * size
else:
return x
# Used to ensure that iterating over a set is deterministic
def tuple_sorted(x):
if len(x) == 0:
return []
def sort_func(elem):
if isinstance(elem, str):
return elem
else:
# We expect `elem` to be `scheduler.BaseSchedulerNode` type here,
# but we are not able to do isinstance assert because of circular dependency
return elem.get_name()
return sorted(x, key=sort_func)
P = ParamSpec("P")
RV = TypeVar("RV", covariant=True)
class CachedMethod(Generic[P, RV], Protocol):
@staticmethod
def clear_cache(self) -> None:
...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
...
# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
key = f"__{fn.__name__}_cache"
@functools.wraps(fn)
def wrapper(self):
if not hasattr(self, key):
setattr(self, key, fn(self))
return getattr(self, key)
def clear_cache(self):
if hasattr(self, key):
delattr(self, key)
wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
return wrapper # type: ignore[return-value]
def aggregate_origins(node_schedule):
from . import ir
if isinstance(node_schedule, list):
return functools.reduce(
operator.or_,
[
node.node.origins
for node in node_schedule
if hasattr(node, "node") and node.node
],
set(),
)
elif isinstance(node_schedule, ir.ExternKernel):
return node_schedule.origins
else:
return set()
def get_fused_kernel_name(node_schedule, descriptive_names):
all_origins = aggregate_origins(node_schedule)
if descriptive_names == "original_aten":
# Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
sources = [
origin.meta["original_aten"]._overloadpacket.__name__
for origin in all_origins
if origin.op == "call_function"
and "original_aten" in origin.meta
and origin.meta["original_aten"] is not None
]
sources = sorted(set(sources))
elif descriptive_names == "torch":
# Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
sources = []
for origin in all_origins:
if origin.op == "call_function" and "source_fn_stack" in origin.meta:
source_fn = origin.meta["source_fn_stack"][-1]
if isinstance(source_fn[1], str):
sources.append(source_fn[1])
else:
sources.append(source_fn[1].__name__)
sources = sorted(set(sources))
elif descriptive_names == "inductor_node":
sources = [
origin.name for origin in all_origins if origin.op == "call_function"
]
else:
raise NotImplementedError
sources = sources
return "_".join(["fused"] + sources)
def get_kernel_metadata(node_schedule, wrapper):
all_origins = aggregate_origins(node_schedule)
inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
from_node_dict = collections.defaultdict(list)
original_aten_dict = collections.defaultdict(list)
for node in inductor_nodes:
if "original_aten" in node.meta and node.meta["original_aten"] is not None:
key = str(node.meta["original_aten"]._overloadpacket)
original_aten_dict[key].append(node.name)
if "from_node" in node.meta:
key = node.meta["from_node"][0][0]
from_node_dict[key].append(node.name)
metadata = (
f"{wrapper.comment} Source Nodes: [{', '.join(sorted(from_node_dict.keys()))}], "
f"Original ATen: [{', '.join(sorted(original_aten_dict.keys()))}]"
)
# trace back to original node here
detailed_metadata = []
for original_node, nodes in sorted(from_node_dict.items()):
detailed_metadata.append(
f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}"
)
return metadata, "\n".join(detailed_metadata)
def dominated_nodes(
initial_queue: Iterable[torch.fx.Node], skip_filter=None
) -> Set[torch.fx.Node]:
"""Returns the set of nodes whose values depend on those within initial_queue"""
initial_queue = list(initial_queue)
dominated_set = set(initial_queue)
while initial_queue:
node = initial_queue.pop()
for user in node.users:
if skip_filter and skip_filter(user):
continue
if user not in dominated_set:
dominated_set.add(user)
initial_queue.append(user)
return dominated_set
def gather_origins(args, kwargs):
import itertools
from . import ir
def is_unrealized_node(n):
if isinstance(n, ir.TensorBox):
return is_unrealized_node(n.data)
if isinstance(n, ir.StorageBox):
return is_unrealized_node(n.data)
return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise)
kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)]
arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)]
return set(itertools.chain(*arg_origins, *kwarg_origins))
def sympy_str(expr: sympy.Expr) -> str:
"""
Normal sympy str is very slow, this is a lot faster. The result are
somewhat worse, as it doesn't do as much simplification. So don't
use this for final codegen.
"""
if isinstance(expr, sympy.Symbol):
return expr.name
if isinstance(expr, sympy.Add):
return " + ".join(map(sympy_str, expr.args))
if isinstance(expr, sympy.Mul):
return " * ".join(map(sympy_str, expr.args))
if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)):
return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
return str(expr)
def get_bounds_index_expr(index):
from .virtualized import V
# If this expression does not come from an FX node, we compute its bounds
if (
config.compute_all_bounds
and (fx_node := getattr(V.interpreter, "current_node", None))
and fx_node.target != "index_expr"
):
return bound_sympy(index)
else:
return ValueRanges.unknown()
def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
"""
Used to generate an integer-nonnegative symbol.
"""
# This should never be used for creating shape/stride symbols, as those
# should all be allocated before Inductor.
assert prefix != SymT.SIZE
# NOTE: shape symbols are positive (> 0), but index variables are only
# non-negative (>= 0).
return make_symbol(prefix, idx, integer=True, nonnegative=True)
def sympy_index_symbol(name: str) -> sympy.Symbol:
"""
Used to generate an integer-nonnegative symbol.
"""
# This should never be used for creating shape/stride symbols, as those
# should all be allocated before Inductor.
assert name[0] != "s"
# NOTE: shape symbols are positive (> 0), but index variables are only
# non-negative (>= 0).
return sympy.Symbol(name, integer=True, nonnegative=True)
def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr:
"""
When the passed replacement symbol v is a string, it is converted to a symbol with name v that
have the same replaced expression integer and nonnegative properties.
"""
def to_symbol(replaced, replacement):
assert isinstance(replaced, sympy.Expr)
if isinstance(replacement, str):
return sympy.Symbol(
replacement,
integer=replaced.is_integer, # type: ignore[attr-defined]
nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
)
else:
return replacement
# xreplace is faster than subs, but is way more picky
return sympy.sympify(expr).xreplace(
{k: to_symbol(k, v) for k, v in replacements.items()}
)
def is_symbolic(a: Any) -> bool:
return isinstance(a, torch.SymInt) or (
isinstance(a, torch.Tensor)
and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride()))
)
def any_is_symbolic(*args: Any) -> bool:
return any(is_symbolic(a) for a in args)
def get_first_incompatible_cudagraph_node(gm):
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
forbidden_set = {
"aten._fused_moving_avg_obs_fq_helper.default",
"aten._fused_moving_avg_obs_fq_helper_functional.default",
"aten.multinomial.default",
"fbgemm.dense_to_jagged.default",
"fbgemm.jagged_to_padded_dense.default",
"run_and_save_rng_state",
"run_with_rng_state",
"aten._local_scalar_dense",
# Technically, it's not necessary to ban this, because an
# assert_scalar with constant arguments can be validly run
# with CUDA graphs, but the operator is also pointless with
# constant arguments, so might as well ban
"aten._assert_scalar",
}
if torch.are_deterministic_algorithms_enabled():
forbidden_set.update(
{
"aten._unsafe_index_put.default",
"aten.index_put.default",
"aten.index_put_.default",
"aten.scatter.src",
"aten.scatter.reduce",
"aten.scatter.value_reduce",
"aten.scatter_add_",
"aten.scatter_add.default",
"aten.scatter_reduce.two",
"aten.scatter_reduce_.two",
"aten.scatter_reduce.two_out",
}
)
for node in gm.graph.nodes:
if str(node.target) in forbidden_set:
return node
if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
return node
return None
def has_incompatible_cudagraph_ops(gm):
return get_first_incompatible_cudagraph_node(gm) is not None
def output_node(gm: torch.fx.GraphModule):
"""Get the output node from an FX graph"""
last_node = next(iter(reversed(gm.graph.nodes)))
assert last_node.op == "output"
return last_node
_registered_caches: List[Any] = []
def clear_on_fresh_inductor_cache(obj: Any):
"""
Use this decorator to register any caches that should be cache_clear'd
with fresh_inductor_cache().
"""
if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
raise AttributeError(f"{obj} does not have a cache_clear method")
_registered_caches.append(obj)
return obj
def clear_inductor_caches():
"""
Clear all registered caches.
"""
for obj in _registered_caches:
obj.cache_clear()
@contextlib.contextmanager
def fresh_inductor_cache(cache_entries=None):
"""
Contextmanager that provides a clean tmp cachedir for inductor.
Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
generated with this cache instance.
"""
clear_inductor_caches()
inductor_cache_dir = tempfile.mkdtemp()
try:
with mock.patch.dict(
os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
):
triton_cache_dir = os.path.join(inductor_cache_dir, "triton")
with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}):
yield
if isinstance(cache_entries, dict):
assert len(cache_entries) == 0, "expected empty cache_entries dict"
if os.path.exists(triton_cache_dir):
files = os.listdir(triton_cache_dir)
cache_entries.update(
{
f: os.path.getsize(os.path.join(triton_cache_dir, f))
for f in files
if ".lock" not in f
}
)
shutil.rmtree(inductor_cache_dir)
except Exception:
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
raise
def argsort(seq) -> List[int]:
# preserve original order for equal strides
getter = seq.__getitem__
a_r = range(len(seq))
return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413
@functools.lru_cache(8)
def get_dtype_size(dtype):
return torch.empty((), dtype=dtype).element_size()
class LineContext(NamedTuple):
context: Any
class IndentedBuffer:
tabwidth = 4
def __init__(self, initial_indent=0):
self._lines = []
self._indent = initial_indent
def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]:
buf = StringIO()
p = 1
linemap = []
for line in self._lines:
if isinstance(line, DeferredLineBase):
line = line()
if line is None:
continue
elif isinstance(line, LineContext):
linemap.append((p, line.context))
continue
assert isinstance(line, str)
buf.write(line)
buf.write("\n")
p += 1 + line.count("\n")
return buf.getvalue(), linemap
def getvalue(self) -> str:
v, _ = self.getvaluewithlinemap()
return v
def getrawvalue(self) -> str:
buf = StringIO()
for line in self._lines:
if isinstance(line, DeferredLineBase):
line = line()
if line is None:
continue
elif isinstance(line, LineContext):
continue
assert isinstance(line, str)
# backslash implies line continuation
if line.endswith("\\"):
buf.write(line[:-1])
else:
buf.write(line)
buf.write("\n")
return buf.getvalue()
def clear(self):
self._lines.clear()
def __bool__(self):
return bool(self._lines)
def prefix(self):
return " " * (self._indent * self.tabwidth)
def newline(self):
self.writeline("\n")
def writeline(self, line):
if isinstance(line, LineContext):
self._lines.append(line)
elif isinstance(line, DeferredLineBase):
self._lines.append(line.with_prefix(self.prefix()))
elif line.strip():
self._lines.append(f"{self.prefix()}{line}")
else:
self._lines.append("")
def writelines(self, lines):
for line in lines:
self.writeline(line)
def indent(self, offset=1):
@contextlib.contextmanager
def ctx():
self._indent += offset
try:
yield
finally:
self._indent -= offset
return ctx()
def do_indent(self, offset=1):
self._indent += offset
def do_unindent(self, offset=1):
self._indent -= offset
def splice(self, other_code, strip=False):
if isinstance(other_code, IndentedBuffer):
dedent = float("inf")
for line in other_code._lines:
if not isinstance(line, LineContext) and line:
dedent = min(dedent, len(line) - len(line.lstrip()))
if math.isinf(dedent):
dedent = 0
for line in other_code._lines:
if isinstance(line, LineContext):
self._lines.append(line)
else:
IndentedBuffer.writeline(self, line[int(dedent) :])
else:
other_code = textwrap.dedent(other_code)
if strip:
other_code = other_code.lstrip()
if not other_code:
return
other_code = other_code.rstrip()
for line in other_code.split("\n"):
self.writeline(line)
def map(self, func: Callable[[Any], Any]) -> IndentedBuffer:
res = IndentedBuffer(initial_indent=self._indent)
res._lines = [func(line) for line in self._lines]
return res
def __repr__(self):
return f"{type(self)}({self.getvalue()})"
def __add__(self, other):
assert self._indent == other._indent
res = IndentedBuffer(initial_indent=self._indent)
res.writelines(self._lines)
res.writelines(other._lines)
return res
@contextlib.contextmanager
def restore_stdout_stderr(initial_stdout, initial_stderr):
try:
yield
finally:
sys.stdout = initial_stdout
sys.stderr = initial_stderr
class DeferredLineBase:
"""A line that can be 'unwritten' at a later time"""
def __init__(self, line):
if not line.strip():
line = ""
self.line = line
def __call__(self) -> Optional[str]:
"""Returns either self.line or None to indicate the line has been 'unwritten'"""
raise NotImplementedError
def _new_line(self, line: str) -> DeferredLineBase:
"""Returns a new deferred line with the same condition"""
raise NotImplementedError
def with_prefix(self, prefix):
return self._new_line(f"{prefix}{self.line}")
def lstrip(self):
return self._new_line(self.line.lstrip())
def __getitem__(self, index):
return self._new_line(self.line[index])
def __bool__(self):
return bool(self.line)
def __len__(self):
return len(self.line)
@functools.lru_cache(None)
def is_big_gpu(index) -> bool:
min_sms = 68 # 3080
avail_sms = torch.cuda.get_device_properties(index).multi_processor_count
if avail_sms < min_sms:
log.warning(
"Not enough SMs to use max_autotune_gemm mode",
extra={"min_sms": min_sms, "avail_sms": avail_sms},
)
return False
return True
def use_max_autotune() -> bool:
return (
config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache
)
def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
return (
use_max_autotune()
and layout.device.type == "cuda"
and layout.dtype in allowed_layout_dtypes
and is_big_gpu(layout.device.index or 0)
)
def _use_autotune_backend(backend: str) -> bool:
return backend.upper() in [
x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
]
def use_triton_template(layout, *, enable_int32=False):
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
if enable_int32:
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
"TRITON"
)
def use_cutlass_template(layout, m, n, k):
from .virtualized import V
gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1)
if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size:
return False
from .codegen.cuda.cutlass_utils import try_import_cutlass
# Do not use cutlass template on ROCm
if torch.version.hip:
return False
layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend(
"CUTLASS"
)
if res:
if not try_import_cutlass():
log.warning(
"Failed to import CUTLASS lib. Please check whether "
"_inductor.config.cuda.cutlass_dir is set correctly. "
"Skipping CUTLASS backend for now."
)
return False
return res
def use_aten_gemm_kernels():
return not use_max_autotune() or _use_autotune_backend("ATEN")
class DebugDirManager:
counter = itertools.count(0)
prev_debug_name: str
def __init__(self):
self.id = next(DebugDirManager.counter)
def __enter__(self):
self.prev_debug_name = torch._dynamo.config.debug_dir_root