/
utils.py
2652 lines (2145 loc) · 83.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import atexit
import collections
import contextlib
import copy
import dataclasses
import datetime
import dis
import enum
import functools
import gc
import inspect
import itertools
import linecache
import logging
import math
import operator
import os
import re
import sys
import textwrap
import threading
import time
import types
import typing
import weakref
from contextlib import contextmanager
from functools import lru_cache, wraps
from types import MethodWrapperType
from typing import (
Any,
Callable,
cast,
ClassVar,
Counter,
DefaultDict,
Deque,
Dict,
Iterator,
KeysView,
List,
Optional,
Set,
Tuple,
Type,
Union,
ValuesView,
)
from ..utils.hooks import RemovableHandle
try:
import numpy as np
except ModuleNotFoundError:
np = None # type: ignore[assignment]
try:
import torch._logging
import torch._numpy as tnp
from torch._guards import detect_fake_mode # noqa: F401n
from torch._logging import LazyString
from . import config
# NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync.
if np:
NP_SUPPORTED_MODULES: Tuple[types.ModuleType, ...] = (
np,
np.fft,
np.linalg,
np.random,
)
NP_TO_TNP_MODULE = {
np: tnp,
np.fft: tnp.fft,
np.linalg: tnp.linalg,
np.random: tnp.random,
}
else:
NP_SUPPORTED_MODULES = tuple()
NP_TO_TNP_MODULE = {}
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
except ImportError:
pass
import importlib
import torch
import torch._functorch.config
import torch.fx.experimental.symbolic_shapes
import torch.utils._pytree as pytree
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import log_compilation_event
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._triton import has_triton, has_triton_package
counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter)
optimus_scuba_log: Dict[str, Any] = {}
troubleshooting_url = (
"https://pytorch.org/docs/main/torch.compiler_troubleshooting.html"
)
nnmodule_doc_url = "https://pytorch.org/docs/main/torch.compiler_nn_module.html"
nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
log = logging.getLogger(__name__)
# profiling compilation time by function
compilation_time_metrics: Dict[str, List[float]] = {}
# profiling compilation time by frame phase
frame_phase_timing: Dict[str, Dict[str, float]] = {}
timer_counter = itertools.count()
def tabulate(rows, headers):
try:
import tabulate
return tabulate.tabulate(rows, headers=headers)
except ImportError:
return "\n".join(
", ".join(map(str, row)) for row in itertools.chain([headers], rows)
)
curr_frame = 0
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
def increment_frame():
global curr_frame
curr_frame = curr_frame + 1
# Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
def reset_frame_count():
global curr_frame
frame_phase_timing.clear()
compilation_time_metrics.clear()
curr_frame = 0
op_count = 0
def increment_op_count(cnt):
global op_count
op_count += cnt
# Print a report of time spent so far
# Ex:
# TIMING:
# entire_frame_compile:8.574629999999999
# backend_compile:5.26806
def print_time_report():
total = 0.0
total_by_key = {}
for timings in frame_phase_timing.values():
for key, timing in timings.items():
total += timing
if key not in total_by_key:
total_by_key[key] = timing
else:
total_by_key[key] += timing
out = "TIMING:"
for key, value in total_by_key.items():
out = f"{out} {key}:{round(value, 5)}"
print(out)
# dynamo_timed API works as a function decorator
# By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics
# where the key is the functions name.
# For example:
#
# @dynamo_timed
# def _foo(...):
#
# Would show up as an entry in our timing dict:
# OrderedDict([('bar.<locals>._foo', [0.083690, 0.23949, 3.1425e-05])])
# This is extremely useful for granular debugging.
#
# For a higher-level mode, pass a phase_name into dynamo_timed
# phase_names record an extra record into a separate compilation timing structure,
# one keyed on frame+name rather than function.
# The frame is incremented outside of this function, in def increment_frame() above.
def dynamo_timed(original_function=None, phase_name=None):
def dynamo_timed_inner(func):
if config.cprofile:
return func
@wraps(func)
def time_wrapper(*args, **kwargs):
key = func.__qualname__
if key not in compilation_time_metrics:
compilation_time_metrics[key] = []
with torch.profiler.record_function(f"{key} (dynamo_timed)"):
t0 = time.time()
r = func(*args, **kwargs)
time_spent = time.time() - t0
compilation_time_metrics[key].append(time_spent)
if phase_name:
frame_key = str(curr_frame)
if frame_key not in frame_phase_timing:
frame_phase_timing[frame_key] = {}
if phase_name not in frame_phase_timing[frame_key]:
frame_phase_timing[frame_key][phase_name] = time_spent
else:
frame_phase_timing[frame_key][phase_name] += time_spent
return r
return time_wrapper
if original_function:
return dynamo_timed_inner(original_function)
return dynamo_timed_inner
def compile_times(repr="str", aggregate=False):
"""
Get metrics about torchdynamo frontend/backend compilation times.
Accumulates information from functions tagged with `@dynamo_timed`.
repr='str' returns a printable string for user interaction, and 'csv'
returns headers, rows which can be logged for output
aggregate causes values from multiple compilations (e.g. split graphs)
to be accumulated into one value. If false, expect more than one value
per metric.
"""
def fmt_fn(values, item_fn=lambda x: x):
if aggregate:
return item_fn(sum(values))
return ", ".join(map(item_fn, values))
if repr == "str":
rows = [
(k, fmt_fn(compilation_time_metrics[k], item_fn=lambda x: f"{x:.4f}"))
for k in compilation_time_metrics
]
out = "TorchDynamo compilation metrics:\n"
out += tabulate(rows, headers=("Function", "Runtimes (s)"))
return out
elif repr == "csv":
values = [
fmt_fn(v, item_fn=lambda x: f"{x:.6f}")
for v in compilation_time_metrics.values()
]
headers = list(compilation_time_metrics.keys())
return headers, values
@atexit.register
def dump_compile_times():
log.info(compile_times(repr="str", aggregate=True))
tensortype_to_dtype = {
torch.FloatTensor: (torch.float32, torch.float),
torch.DoubleTensor: (torch.float64, torch.double),
torch.HalfTensor: (torch.float16, torch.half),
torch.BFloat16Tensor: (torch.bfloat16,),
torch.ByteTensor: (torch.uint8,),
torch.CharTensor: (torch.int8,),
torch.LongTensor: (torch.int64, torch.long),
torch.IntTensor: (torch.int32, torch.int),
torch.ShortTensor: (torch.int16, torch.short),
torch.BoolTensor: (torch.bool,),
}
class DuplicateWarningChecker:
def __init__(self, maxsize=4096):
self.maxsize = maxsize
self.reset()
def reset(self):
self.set = collections.OrderedDict()
def add(self, key):
if key in self.set:
self.set.move_to_end(key, last=True)
if not config.verbose:
return False
else:
self.set[key] = None
while len(self.set) > self.maxsize:
self.set.popitem(last=False)
return True
graph_break_dup_warning_checker = DuplicateWarningChecker()
def setup_compile_debug():
compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
if compile_debug:
return add_file_handler()
return contextlib.ExitStack()
def reset_graph_break_dup_checker():
graph_break_dup_warning_checker.reset()
def add_file_handler():
log_path = os.path.join(get_debug_dir(), "torchdynamo")
os.makedirs(log_path, exist_ok=True)
log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log"))
logger = logging.getLogger("torch._dynamo")
logger.addHandler(log_file_handler)
exitstack = contextlib.ExitStack()
exitstack.callback(lambda: logger.removeHandler(log_file_handler))
return exitstack
def setup_log_file():
exitstack = contextlib.ExitStack()
if config.log_file_name is not None:
log_file_handler = logging.FileHandler(config.log_file_name)
for logger in torch._logging._internal.get_loggers():
logger.addHandler(log_file_handler)
exitstack.callback(lambda: logger.removeHandler(log_file_handler))
return exitstack
return exitstack
def gen_record_file_name(exc, code):
return f"{get_debug_dir()}/error_recordings/\
{code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec"
def write_record_to_file(filename, exec_record):
try:
if os.path.exists(filename):
log.warning(
"Unable to write execution record %s; file already exists.", filename
)
else:
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "wb") as f:
exec_record.dump(f)
except Exception:
log.exception("Unable to write execution record %s", filename)
def count_calls(g: fx.Graph):
c = 0
for n in g.nodes:
if "call" in n.op:
c += 1
return c
def identity(x):
return x
def hashable(x):
try:
hash(x)
return True
except TypeError:
return False
# cannot hash writable memoryview object
except ValueError:
return False
def nothing(*args, **kwargs):
pass
class ExactWeakKeyDictionary:
"""Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality"""
def __init__(self):
self.values = dict()
self.refs = dict()
def __getitem__(self, key):
return self.values[id(key)]
def get(self, key, default=None):
return self.values.get(id(key), default)
def __contains__(self, key):
return id(key) in self.values
def __setitem__(self, key, value):
idx = id(key)
if idx not in self.refs:
self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx))
self.values[idx] = value
def _remove_id(self, idx):
if idx in self.values:
del self.values[idx]
if idx in self.refs:
del self.refs[idx]
def clear(self):
self.refs.clear()
self.values.clear()
def istype(obj, allowed_types):
"""isinstance() without subclasses"""
if isinstance(allowed_types, (tuple, list, set)):
return type(obj) in allowed_types
return type(obj) is allowed_types
if sys.version_info >= (3, 12):
# Some typing classes moved to C in 3.12,
# which no longer have the _Final mixin.
_builtin_final_typing_classes = (
typing.ParamSpecArgs,
typing.ParamSpecKwargs,
typing.ParamSpec,
typing.TypeVar,
typing.TypeVarTuple,
typing.TypeAliasType,
)
def is_typing(value):
# _Final catches most of typing classes:
# - Any
# - Callable
# - Union
# ...
#
# NB: we intentionally ignore classes that inherit from Generic, since they
# can be used as both TypingVariable as well as UserDefinedClassVariable.
if sys.version_info >= (3, 12) and isinstance(value, _builtin_final_typing_classes):
return True
return isinstance(value, typing._Final) or value is typing.Generic # type: ignore[attr-defined]
def is_numpy_int_type(value):
if not np:
return False
return istype(
value,
(
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
),
)
def is_numpy_float_type(value):
if not np:
return False
return istype(
value,
(
np.float16,
np.float32,
np.float64,
),
)
def is_function_or_wrapper(value):
return (
is_function(value)
or isinstance(value, functools._lru_cache_wrapper)
and is_function(inspect.getattr_static(value, "__wrapped__"))
or isinstance(value, (torch._ops.OpOverloadPacket, torch._ops.OpOverload))
)
def is_function(value):
return isinstance(
value,
(
types.FunctionType,
types.BuiltinFunctionType,
types.MethodDescriptorType,
types.WrapperDescriptorType,
torch.jit.ScriptFunction,
),
)
def unwrap_if_wrapper(fn):
return unwrap_with_attr_name_if_wrapper(fn)[0]
def unwrap_with_attr_name_if_wrapper(fn):
# unpack @functools.lru_cache wrapped function
if isinstance(fn, functools._lru_cache_wrapper):
fn = inspect.getattr_static(fn, "__wrapped__")
attr_name = "__wrapped__"
# unpack @torch._dynamo.optimize()(fn) wrapped function
elif is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False):
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
attr_name = "_torchdynamo_inline"
# unpack torch.jit.script_if_tracing
elif is_function(fn) and inspect.getattr_static(
fn, "__script_if_tracing_wrapper", False
):
fn = inspect.getattr_static(fn, "__original_fn", fn)
attr_name = "__original_fn"
else:
attr_name = None
return fn, attr_name
def is_numpy_ndarray(value):
if not np:
return False
return istype(value, np.ndarray)
def istensor(obj):
"""Check of obj is a tensor"""
tensor_list = (
torch.Tensor,
torch.nn.Parameter,
*config.traceable_tensor_subclasses,
)
tensor_list = tensor_list + (torch._subclasses.FakeTensor,)
return istype(obj, tensor_list)
def is_lazy_module(mod):
return isinstance(mod, LazyModuleMixin)
@functools.lru_cache(4096)
def print_once(*args):
print(*args)
def make_cell(val=None):
"""Some black magic to create a cell object that usually only exists in a closure"""
x = val
def f():
return x
assert f.__closure__ is not None and len(f.__closure__) == 1
return f.__closure__[0]
def proxy_args_kwargs(args, kwargs):
try:
proxy_args = tuple(arg.as_proxy() for arg in args)
proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
return proxy_args, proxy_kwargs
except NotImplementedError as e:
from .exc import unimplemented
from .variables.base import typestr
unimplemented(
f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}",
from_exc=e,
)
@dataclasses.dataclass
class CompilationMetrics:
frame_key: str
co_name: str
co_filename: str
co_firstlineno: int
cache_size: int
accumulated_cache_size: int
guard_count: Optional[int]
shape_env_guard_count: Optional[int]
graph_op_count: Optional[int]
graph_node_count: Optional[int]
graph_input_count: Optional[int]
start_time: float
entire_frame_compile_time_s: Optional[float]
backend_compile_time_s: Optional[float]
inductor_compile_time_s: Optional[float]
code_gen_time_s: Optional[float]
fail_type: Optional[str]
fail_reason: Optional[str]
fail_user_frame_filename: Optional[str]
fail_user_frame_lineno: Optional[int]
non_compliant_ops: Set[str]
compliant_custom_ops: Set[str]
restart_reasons: Set[str]
dynamo_time_before_restart_s: float
# Sometimes, we will finish analyzing a frame but conclude we don't want
# to install any guarded code. True means we actually decided to install
# a compiled frame
has_guarded_code: bool
DEFAULT_COMPILATION_METRICS_LIMIT = 64
_compilation_metrics: Deque[CompilationMetrics] = collections.deque(
maxlen=DEFAULT_COMPILATION_METRICS_LIMIT
)
def record_compilation_metrics(compilation_metrics: CompilationMetrics):
global _compilation_metrics
_compilation_metrics.append(compilation_metrics)
torch._logging.trace_structured(
"compilation_metrics",
lambda: {
k: list(v) if isinstance(v, set) else v
for k, v in dataclasses.asdict(compilation_metrics).items()
},
)
if config.log_compilation_metrics:
log_compilation_event(compilation_metrics)
def set_compilation_metrics_limit(new_size: int) -> None:
global _compilation_metrics
while len(_compilation_metrics) > new_size:
_compilation_metrics.popleft()
new_deque = collections.deque(_compilation_metrics, maxlen=new_size)
_compilation_metrics = new_deque
def clear_compilation_metrics() -> None:
global _compilation_metrics
_compilation_metrics.clear()
def get_compilation_metrics() -> List[CompilationMetrics]:
return list(_compilation_metrics)
@dataclasses.dataclass
class CleanupHook:
"""Remove a global variable when hook is called"""
scope: Dict[str, Any]
name: str
def __call__(self, *args):
# Make sure we're not shutting down
if CleanupManager is not None:
CleanupManager.count -= 1
del self.scope[self.name]
@staticmethod
def create(scope, name, val):
assert name not in scope
CleanupManager.count += 1
scope[name] = val
return CleanupHook(scope, name)
class CleanupManager(ExactWeakKeyDictionary):
count = 0
instance: ClassVar["CleanupManager"]
def _remove_id(self, idx):
for hook in self.values[idx]:
hook()
super()._remove_id(idx)
CleanupManager.instance = CleanupManager()
def clone_tensor(x):
"""Clone the tensor and its gradient"""
y = x.clone().requires_grad_(x.requires_grad)
if x.is_leaf and x.grad is not None:
y.grad = x.grad.clone()
return y
def clone_input(x, *, dtype=None):
"""copy while preserving strides"""
# TODO: this is questionable
if is_fake(x):
# this func fails on fake tensors in __torch_dispatch__
return x
def torch_clone(x):
y = torch.clone(x)
if x.is_leaf:
y.requires_grad_(x.requires_grad)
if x.is_leaf and x.grad is not None:
y.grad = clone_input(x.grad, dtype=dtype)
if hasattr(x, "_dynamo_dynamic_indices"):
y._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() # type: ignore[attr-defined]
return y
with torch.no_grad():
if x.device.type == "xla":
# Access data_ptr() for a xla tensor will cause crash
return torch_clone(x)
# Handle sparse storage (no stride).
if x.layout is torch.sparse_coo:
return torch.sparse_coo_tensor(
torch_clone(x._indices()),
torch_clone(x._values()),
x.shape,
is_coalesced=x.is_coalesced(),
)
elif is_sparse_compressed(x):
if x.layout in {torch.sparse_csr, torch.sparse_bsr}:
compressed_indices = x.crow_indices()
plain_indices = x.col_indices()
else:
compressed_indices = x.ccol_indices()
plain_indices = x.row_indices()
return torch.sparse_compressed_tensor(
torch_clone(compressed_indices),
torch_clone(plain_indices),
torch_clone(x.values()),
x.shape,
layout=x.layout,
)
needed_size = sum(
(shape - 1) * stride for shape, stride in zip(x.size(), x.stride())
)
if x.is_quantized:
result = torch.empty_quantized((needed_size + 32,), x)
else:
result = torch.empty(
needed_size + 32, dtype=dtype or x.dtype, device=x.device
)
cache_line_offset = (
(x.data_ptr() - result.data_ptr()) % 32
) // x.element_size()
result.as_strided_(x.size(), x.stride(), cache_line_offset)
try:
result.copy_(x.clone())
if x.is_leaf:
result.requires_grad_(x.requires_grad)
if x.is_leaf and x.grad is not None:
result.grad = clone_input(x.grad, dtype=dtype)
except RuntimeError:
# RuntimeError: unsupported operation: more than one element of the written-to
# tensor refers to a single memory location. Please clone() the tensor before
# performing the operation.
return torch_clone(x)
if hasattr(x, "_dynamo_dynamic_indices"):
result._dynamo_dynamic_indices = x._dynamo_dynamic_indices.copy() # type: ignore[attr-defined]
return result
def clone_inputs(example_inputs):
res: Union[Dict[Any, Any], List[Any]]
if type(example_inputs) is dict:
res = dict(example_inputs)
for key, value in res.items():
if isinstance(value, tuple):
res[key] = clone_inputs(value)
else:
assert isinstance(value, torch.Tensor), type(value)
res[key] = clone_input(value)
return res
res = list(example_inputs)
for i in range(len(res)):
if isinstance(res[i], torch.Tensor):
res[i] = clone_input(res[i])
return res
def skip_frame_if_in_functorch_mode(val: torch.Tensor):
try:
val.data_ptr() # will throw for functorch tensors
except RuntimeError as e:
from .exc import SkipFrame
# This will be GradTrackingTensor/BatchedTensor/etc
functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
raise SkipFrame(
f"torch.compile cannot be run in context: {functorch_subclass_name}"
) from e
@contextmanager
def preserve_rng_state():
disable_functorch = torch._C._DisableFuncTorch
disable_current_modes = torch.utils._python_dispatch._disable_current_modes
with disable_current_modes(), disable_functorch():
rng_state = torch.clone(torch.random.get_rng_state())
skip_frame_if_in_functorch_mode(rng_state)
if torch.cuda.is_available():
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
try:
yield
finally:
with torch.utils._python_dispatch._disable_current_modes():
torch.random.set_rng_state(rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
def is_jit_model(model0):
return isinstance(
model0,
(
torch.jit._trace.TopLevelTracedModule,
torch.jit._script.RecursiveScriptModule,
torch.jit.ScriptFunction,
torch.jit.ScriptModule,
),
)
def torchscript(model, example_inputs, verbose=False):
if is_jit_model(model):
# already done?
return model
try:
return torch.jit.trace(model, example_inputs)
except Exception:
try:
return torch.jit.script(model)
except Exception:
if verbose:
log.exception("jit error")
else:
log.error("Both torch.jit.trace and torch.jit.script failed")
return None
def getfile(obj):
try:
return inspect.getfile(obj)
except (TypeError, OSError):
return None
def is_namedtuple(obj):
"""Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple"""
return is_namedtuple_cls(type(obj))
def is_namedtuple_cls(cls):
"""Test if an object is a namedtuple or a (torch.return_types|torch.autograd.forward_ad).* quasi-namedtuple"""
try:
if issubclass(cls, tuple):
bases = getattr(cls, "__bases__", []) or [None]
module = getattr(cls, "__module__", None)
return module in ("torch.return_types", "torch.autograd.forward_ad") or (
bases[0] is tuple and hasattr(cls, "_make") and hasattr(cls, "_fields")
)
except TypeError:
pass
return False
@functools.lru_cache(1)
def namedtuple_fields(cls):
"""Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple"""
if cls is slice:
return ["start", "stop", "step"]
assert issubclass(cls, tuple)
if hasattr(cls, "_fields"):
# normal namedtuples
return cls._fields
@dataclasses.dataclass
class Marker:
index: int
# frustrating ones e.g. torch.return_types.max
assert cls.__module__ == "torch.return_types"
obj = cls(map(Marker, range(cls.n_fields)))
fields: List[Optional[str]] = [None] * cls.n_fields
for name in dir(obj):
if name[0] != "_" and isinstance(getattr(obj, name), Marker):
fields[getattr(obj, name).index] = name
return fields
def checkpoint_params(gm):
with torch.no_grad():
rng_state = torch.clone(torch.random.get_rng_state())
if torch.cuda.is_available():
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
saved_state = []
for param in itertools.chain(gm.parameters(), gm.buffers()):
saved_state.append((param, param._version, torch.clone(param)))
def restore():
with torch.no_grad():
torch.random.set_rng_state(rng_state)
if torch.cuda.is_available():
torch.cuda.set_rng_state(cuda_rng_state)
for param, version, original_value in saved_state:
if param._version != version:
param.copy_(original_value)
return restore
def timed(model, example_inputs, times=1):
if torch.cuda.is_available():
synchronize = torch.cuda.synchronize
else:
synchronize = nothing
synchronize()
gc.collect()
torch.manual_seed(1337)
t0 = time.perf_counter()
for _ in range(times):
result = model(*example_inputs)
synchronize()
t1 = time.perf_counter()
return result, t1 - t0 # type: ignore[possibly-undefined]
def check_is_cuda(gm, example_inputs):
return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True)))
@lru_cache(32)
def rot_n_helper(n):
assert n > 1
vars = [f"v{i}" for i in range(n)]
rotated = reversed(vars[-1:] + vars[:-1])
fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})")
fn.__name__ = f"rot_{n}_helper"
return fn
common_constant_types = {
int,
float,
complex,
bool,
str,
bytes,
type(None),
Ellipsis.__class__,
types.CodeType,
torch.device,
torch.dtype,
torch.memory_format,
torch.layout,
}
if has_triton_package():
import triton
common_constant_types.add(triton.language.dtype)
def is_safe_constant(v):
if istype(v, (tuple, frozenset)):
return all(map(is_safe_constant, v))
return isinstance(v, (enum.Enum, type)) or istype(
v,
common_constant_types | {slice},
)
def specialize_symnode(arg):
from .variables import ConstantVariable, SymNodeVariable
# Guard and specialize
if isinstance(arg, SymNodeVariable):
return ConstantVariable.create(arg.evaluate_expr())
return arg