-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
_script.py
1110 lines (911 loc) · 43.2 KB
/
_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""TorchScript
This module contains functionality to support the JIT's scripting frontend, notably:
- torch.jit.script
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import functools
import collections
import inspect
import copy
import pickle
import warnings
from typing import Any, Dict
import torch
import torch._jit_internal as _jit_internal
from torch.utils import set_module
from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module, infer_methods_to_compile
from torch.nn import Module
from torch.jit._state import _enabled
from torch.jit._builtins import _register_builtin
from torch._six import with_metaclass, get_function_from_type
from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
from torch._jit_internal import _qualified_name
from torch.jit._fuser import _graph_for
from torch.jit._state import (
_try_get_jit_cached_function,
_try_get_jit_cached_overloads,
_set_jit_function_cache,
_set_jit_overload_cache,
)
from torch.overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic)
torch._C.ScriptMethod.graph_for = _graph_for # type: ignore
torch._C.ScriptFunction.graph_for = _graph_for # type: ignore
ScriptFunction = torch._C.ScriptFunction
ScriptFunction.__doc__ = """
Functionally equivalent to a :class:`ScriptModule`, but represents a single
function and does not have any attributes or Parameters.
"""
set_module(ScriptFunction, "torch.jit")
if _enabled:
Attribute = collections.namedtuple("Attribute", ["value", "type"])
else:
def Attribute(value, type): # type: ignore
return value
# ScriptClasses must be new-style classes because we construct them using their
# __new__ method.
def _is_new_style_class(cls):
if hasattr(cls, "__class__"):
return "__dict__" in dir(cls) or hasattr(cls, "__slots__")
def _compile_and_register_class(obj, rcb, qualified_name):
ast = get_jit_class_def(obj, obj.__name__)
defaults = torch.jit.frontend.get_default_args_for_class(obj)
torch._C._jit_script_class_compile(qualified_name, ast, defaults, rcb)
torch.jit._state._add_script_class(obj, qualified_name)
# These OrderedDictWrapper classes replace the actual OrderedDicts in
# module with versions that get/set properties inside of Module.
# This allows us to reuse most of nn.Module while still storing the
# data in C++.
# Each OrderedDict needs to support:
# x not in view
# x in view
# view[name] = ...
# view.values()
# del view[name]
# view.items()
# view.keys()
# len(view)
class OrderedDictWrapper(object):
def __init__(self, _c):
self._c = _c
def keys(self):
return [k for k, v in self.items()]
def values(self):
return [v for k, v in self.items()]
def __len__(self):
return len(self.values())
def __delitem__(self, k):
raise RuntimeError("cannot delete methods or parameters of a script module")
def items(self):
return self._c.items()
def __setitem__(self, k, v):
if k not in self:
raise RuntimeError(
"Can't add a new parameter after ScriptModule construction."
" Tried to add '{}".format(k)
)
self._c.setattr(k, v)
def __contains__(self, k):
return self._c.contains(k)
def __getitem__(self, k):
if k not in self:
raise KeyError(k)
return self._c.getattr(k)
class OrderedModuleDict(OrderedDictWrapper):
def __init__(self, module, python_dict):
super(OrderedModuleDict, self).__init__(torch._C.ModuleDict(module))
# contains _both_ script modules and non-script python-only modules
# because script modules are subclassed in python and the
# C++ Module class will not hold references to them,
# to ensure that you always get the same python value here
# we store it in the python dict as well
self._python_modules = python_dict
def items(self):
r = self._python_modules.items()
return r
def __contains__(self, k):
return k in self._python_modules
def __setitem__(self, k, v):
# Cases where sub-module can be re-assigned after ScriptModule construction
# 1. If the attr is an module interface type, it's guaranteed that the module is
# not inlined in the graph, so it's safe to swap a new ScriptModule in.
# 2. if the new value if a ScriptModule with the same JIT type, IR won't change
# and it's legit to swap a new module in.
# In these two cases we allow swapping a new scripted module and update the
# corresponding python module dict to keep sync.
# Note: the value to be swapped in has to be ScriptModule instead of nn.Module,
# otherwise it's illegal and we throw error.
if isinstance(v, ScriptModule):
self._c.setattr(k, v)
self._python_modules[k] = v
else:
raise RuntimeError(
"Cannot re-assign modules in a ScriptModule with non-scripted "
"module, tried to replace existing module '{}': {}".format(k, v)
)
def __getitem__(self, k):
return self._python_modules[k]
# For each user-defined class that subclasses ScriptModule, this meta-class:
# (1) finds all the methods annotated with @script_method in a ScriptModule and
# removes them from the class attributes
# (2) puts a wrapper around the class's __init__ method to recusively compile
# all of the script_methods with the module after the original __init__ has
# run. This has to occur after the user-defined __init__ so that submodules and
# parameters are initialized _before_ the script compiler resolve references to
# `self.param` or `self.module`.
class ScriptMeta(type):
def __init__(cls, name, bases, attrs): # noqa: B902
# Aggregate all the ScriptMethods and constants from superclasses
cls._methods: Dict[str, Any] = {}
cls._constants_set = set(getattr(cls, "__constants__", ()))
for base in reversed(bases):
for k, v in getattr(base, "_methods", {}).items():
cls._methods[k] = v
base_constants = getattr(base, "_constants_set", set())
cls._constants_set = cls._constants_set.union(base_constants)
# find all the script methods of the current class
for k, v in sorted(attrs.items()):
if isinstance(v, ScriptMethodStub):
delattr(cls, k)
cls._methods[v.original_method.__name__] = v
if getattr(cls, "_disable_script_meta", False):
# We leave built-in ScriptModule types alone, since this metaclass
# is only for compiling user classes that inherit from
# ScriptModule.
return super(ScriptMeta, cls).__init__(name, bases, attrs)
original_init = getattr(cls, "__init__", lambda self: None)
@functools.wraps(original_init)
def init_then_script(self, *args, **kwargs):
num_methods = len(cls._methods)
original_init(self, *args, **kwargs)
added_methods_in_init = len(cls._methods) > num_methods
if type(self) == cls:
def make_stubs(module):
cls = type(module)
if hasattr(cls, "_methods"):
return [v for k, v in sorted(cls._methods.items())]
else:
return infer_methods_to_compile(module)
self.__dict__[
"_actual_script_module"
] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)
# Delete the Python attributes that now shadow the ScriptModule
# ones, so that __getattr__ and __setattr__ will properly find
# the scripted versions.
concrete_type = self._actual_script_module._concrete_type
for name in concrete_type.get_attributes():
delattr(self, name)
for name, _ in concrete_type.get_modules():
delattr(self, name)
for name in ("_parameters", "_buffers", "_modules"):
delattr(self, name)
cls.__init__ = init_then_script # type: ignore
return super(ScriptMeta, cls).__init__(name, bases, attrs)
class _CachedForward(object):
def __get__(self, obj, cls):
return self.__getattr__("forward") # type: ignore
class ScriptWarning(Warning):
pass
def script_method(fn):
if not _enabled:
return fn
# NOTE: we need to traverse two frames here because the meta-class frame
# for ScriptModule will be present, as opposed to invoking @script on a
# a function or invoking define() on a CompilationUnit.
# The stack will look like:
#
# 0. createResolutionCallback()
# 1. script_method()
# 2. ScriptModule metaclass frame
# 3. Surrounding scope
#
# createResolutionCallback internally adds 1 to get us to the scope of this
# function (the calling function). Adding 2 gets us to the proper surrounding scope.
_rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2)
ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule")
return ScriptMethodStub(_rcb, ast, fn)
class ConstMap:
def __init__(self, const_mapping):
self.const_mapping = const_mapping
def __getattr__(self, attr):
return self.const_mapping[attr]
if _enabled:
# this is a Python 'non-data descriptor' that causes the first access
# to ScriptModule's forward to lookup the forward method and stash
# it in the objects dict. Due to the standard rules for attribute lookup
# subsequent lookups will just directly return the previously looked up method.
# This is necessary because nn.Module defines forward as a method. If we
# did nothing __getattr__ would not be called. Instead we'd get nn.Module.forward
# which always throws an exception.
class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore
r"""
A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s
contain methods, attributes, parameters, and
constants. These can be accessed the same as on a normal ``nn.Module``.
"""
__jit_unused_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name']
def __init__(self):
super(ScriptModule, self).__init__()
forward = _CachedForward()
def __getattr__(self, attr):
if "_actual_script_module" not in self.__dict__:
return super(ScriptModule, self).__getattr__(attr)
return getattr(self._actual_script_module, attr)
def __setattr__(self, attr, value):
if "_actual_script_module" not in self.__dict__:
# Unwrap torch.jit.Attribute into a regular setattr + recording
# the provided type in __annotations__.
#
# This ensures that if we use the attr again in `__init__`, it
# will look like the actual value, not an instance of Attribute.
if isinstance(value, Attribute):
# NB: Ensure that we set __annotations__ on the specific
# class in question, and not on a superclass (which would
# be wrong wrong wrong!).
# See also https://github.com/pytorch/pytorch/issues/39463
if "__annotations__" not in self.__class__.__dict__:
self.__class__.__annotations__ = {}
self.__annotations__[attr] = value.type
value = value.value
return super(ScriptModule, self).__setattr__(attr, value)
setattr(self._actual_script_module, attr, value)
def define(self, src):
if "_actual_script_module" in self.__dict__:
# If we have completed initialization, just defer to the
# backing RecursiveScriptModule to eagerly compile the provided
# source.
return self._actual_script_module.define(src)
# Otherwise, we are still in the object's __init__.
# In that case, add `src` as a stub to be compiled.
#
# We use frames_up=1 to get to the proper surrounding scope. The stack
# will look like:
# 0. createResolutionCallback
# 1. define()
# 2. surrounding scope.
#
# createResolutionCallback internally adds 1 to get us to our frame, then
# we add 1 to get to the proper surrounding scope.
rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1)
ast = torch._C._parse_source_def(src)
self._methods[ast.name().name] = ScriptMethodStub(rcb, ast, None)
def _replicate_for_data_parallel(self):
return self._actual_script_module._replicate_for_data_parallel()
class RecursiveScriptModule(ScriptModule):
# XXX: RecursiveScriptModule inherits from ScriptModule for the sole
# reason that it retains the existing isinstance(ScriptModule)
# behavior.
r"""
The core data structure in TorchScript is the ``ScriptModule``. It is an
analogue of torch's ``nn.Module`` and represents an entire model as a tree of
submodules. Like normal modules, each individual module in a ``ScriptModule`` can
have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented
as Python functions, but in ``ScriptModule``\s methods are implemented as
TorchScript functions, a statically-typed subset of Python that contains all
of PyTorch's built-in Tensor operations. This difference allows your
``ScriptModule``\s code to run without the need for a Python interpreter.
``ScriptModule``\s should not be created manually, instead use
either :func:`tracing <torch.jit.trace>` or :func:`scripting <torch.jit.script>`.
Tracing and scripting can be applied incrementally and :ref:`composed as necessary <Types>`.
* Tracing records the tensor operations as executed with a set of example inputs and uses these
operations to construct a computation graph. You can use the full dynamic behavior of Python with tracing,
but values other than Tensors and control flow aren't captured in the graph.
* Scripting inspects the Python code of the model
and compiles it to TorchScript. Scripting allows the use of many `types`_ of values and supports dynamic control flow.
Many, but not all features of Python are supported by the compiler, so changes to the source code may be necessary.
"""
_disable_script_meta = True
def __init__(self, cpp_module):
self.__dict__["_initializing"] = True
self._c = cpp_module
super(RecursiveScriptModule, self).__init__()
# Delete the 'training' attribute set up by `Module.__init__`. It
# will get set on the underlying cpp module, so we delete it here
# to avoid this version shadowing the cpp module version.
delattr(self, "training")
@staticmethod
def _construct(cpp_module, init_fn):
"""
Construct a RecursiveScriptModule that's ready for use. PyTorch
code should use this to construct a RecursiveScriptModule instead
of instead of calling `__init__` directly, as it makes sure the
object is properly finalized (and in the future we may take
control of how the RecursiveScriptModule instance is created).
Args:
cpp_module: The C++ Module that will hold the actual state of
this RecursiveScriptModule instance.
init_fn: Lambda that initializes the RecursiveScriptModule passed to it.
"""
script_module = RecursiveScriptModule(cpp_module)
init_fn(script_module)
# Finalize the ScriptModule: replace the nn.Module state with our
# custom implementations and flip the _initializing bit.
RecursiveScriptModule._finalize_scriptmodule(script_module)
return script_module
@staticmethod
def _finalize_scriptmodule(script_module):
script_module._parameters = OrderedDictWrapper(
torch._C.ParameterDict(script_module._c)
)
script_module._buffers = OrderedDictWrapper(
torch._C.BufferDict(script_module._c)
)
script_module._modules = OrderedModuleDict(
script_module._c, script_module._modules
)
script_module._initializing = False
def _reconstruct(self, cpp_module):
"""
Re-construct an instance of RecursiveScriptModule using an instance of a C++ module.
Args:
cpp_module: The C++ module that this RecursiveScriptModule will be rebuilt around.
"""
self.__init__(cpp_module) # type: ignore
# Copy the concrete type from the C++ module to this ScriptModule.
self._concrete_type = torch._C.ConcreteModuleType.from_jit_type(
self._c._type()
)
# Copy submodules from the C++ module to this ScriptModule.
modules = {}
for name, cpp_module in torch._C.ModuleDict(self._c).items():
modules[name] = wrap_cpp_module(cpp_module)
self._modules = OrderedModuleDict(self._c, modules)
# Copy parameters and buffers.
self._parameters = OrderedDictWrapper(torch._C.ParameterDict(self._c))
self._buffers = OrderedDictWrapper(torch._C.BufferDict(self._c))
# Get rid of the functions from the old C++ module.
self.__dict__ = {
k: v
for k, v in self.__dict__.items()
if not isinstance(v, torch._C.ScriptMethod)
}
self.__dict__["_initializing"] = False
@property
def graph(self):
r"""
Returns a string representation of the internal graph for the
``forward`` method. See :ref:`interpreting-graphs` for details.
"""
return self._c._get_method("forward").graph
@property
def inlined_graph(self):
r"""
Returns a string representation of the internal graph for the
``forward`` method. This graph will be preprocessed to inline all function and method calls.
See :ref:`interpreting-graphs` for details.
"""
return self.forward.inlined_graph
@property
def code(self):
r"""
Returns a pretty-printed representation (as valid Python syntax) of
the internal graph for the ``forward`` method. See
:ref:`inspecting-code` for details.
"""
return self.forward.code
@property
def code_with_constants(self):
r"""
Returns a tuple of:
[0] a pretty-printed representation (as valid Python syntax) of
the internal graph for the ``forward`` method. See `code`.
[1] a ConstMap following the CONSTANT.cN format of the output in [0].
The indices in the [0] output are keys to the underlying constant's values.
See :ref:`inspecting-code` for details.
"""
r = self.forward.code_with_constants
return (r[0], ConstMap(r[1]))
def save(self, f, **kwargs):
r"""
save(f, _extra_files={})
See :func:`torch.jit.save <torch.jit.save>` for details.
"""
return self._c.save(str(f), **kwargs)
def _save_for_lite_interpreter(self, *args, **kwargs):
r"""
_save_for_lite_interpreter(f)
Add (or update) the bytecode session to the script model. The updated model is used
in lite interpreter for mobile applications.
Args:
f: a string containing a file name.
_extra_files: Map from filename to contents which will be stored as part of 'f'.
"""
return self._c._save_for_mobile(*args, **kwargs)
def _save_to_buffer_for_lite_interpreter(self, *args, **kwargs):
return self._c._save_to_buffer_for_mobile(*args, **kwargs)
def save_to_buffer(self, *args, **kwargs):
return self._c.save_to_buffer(*args, **kwargs)
def get_debug_state(self, *args, **kwargs):
return self._c.get_debug_state()
def extra_repr(self):
return "original_name={}".format(self.original_name)
def graph_for(self, *args, **kwargs):
return self.forward.graph_for(*args, **kwargs)
@property
def original_name(self):
if type(self) == str(self._c._type().name()):
return ""
return str(self._c._type().name())
def define(self, src):
# We use frames_up=1 to get to the proper surrounding scope. The stack
# will look like:
# 0. createResolutionCallback
# 1. define()
# 2. surrounding scope.
#
# createResolutionCallback internally adds 1 to get us to our frame, then
# we add 1 to get to the proper surrounding scope.
rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=1)
self._c._define(self._concrete_type, src, rcb)
def __getattr__(self, attr):
if "_initializing" not in self.__dict__:
raise RuntimeError(
"ScriptModule has not been initialized, did you forget to call super's init?"
)
if self._initializing:
return super(RecursiveScriptModule, self).__getattr__(attr)
# _modules check is before hasattr since modules are included as attributes in _c,
# but we want to get the python wrapper from _modules instead of the raw _c object.
if attr in self._modules:
return self._modules[attr]
elif self._c.hasattr(attr):
return self._c.getattr(attr)
elif self._c._has_method(attr):
script_method = self._c._get_method(attr)
# cache method so future calls do not go through __getattr__
# to improve invocation performance
self.__dict__[attr] = script_method
return script_method
return super(RecursiveScriptModule, self).__getattr__(attr)
def __setattr__(self, attr, value):
if self._initializing:
return super(RecursiveScriptModule, self).__setattr__(attr, value)
if attr in self._modules:
self._modules[attr] = value
elif self._c.hasattr(attr):
self._c.setattr(attr, value)
elif (
hasattr(self, "_concrete_type")
and attr in self._concrete_type.get_constants().keys()
):
# TODO: we don't have _concrete_type set after load(), and in general we lose constant information.
# We should encode constants as class type attributes (or something) so it persists across save/load.
raise AttributeError(
"Cannot mutate TorchScript constant value: '{}'. Value: '{}'".format(
attr, value
)
)
else:
# We allow setting Python attributes on the ScriptModule, for
# when people want to stash some convenience info on it.
# TODO: it's possible that the following is confusing:
# s = torch.jit.script(...)
# s.python_attr = ...
# s.save() <--- this doesn't have `python_attr`
# It's fairly trivial to save enough info to warn in this case.
return super(RecursiveScriptModule, self).__setattr__(attr, value)
def __getstate__(self):
raise pickle.PickleError(
"ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save. "
+ "Mixed serialization of script and non-script modules is not supported. "
+ "For purely script modules use my_script_module.save(<filename>) instead."
)
def __copy__(self):
return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c))
def __deepcopy__(self, memo):
return torch.jit._recursive.wrap_cpp_module(copy.deepcopy(self._c, memo))
# Python magic methods do method lookups on an object's class type, instead of looking up
# the method defines on the class instance. In order to continue to expose the magic methods
# of builtin-containers (ModuleList, Sequential, ModuleDict) to python we
# define magic methods here as a shim to the correct attribute.
def forward_magic_method(self, method_name, *args, **kwargs):
self_method = getattr(self, method_name)
if getattr(self_method, "__func__", None) == getattr(
RecursiveScriptModule, method_name
):
raise NotImplementedError()
return self_method(*args, **kwargs)
def __iter__(self):
return self.forward_magic_method("__iter__")
def __getitem__(self, idx):
return self.forward_magic_method("__getitem__", idx)
def __len__(self):
return self.forward_magic_method("__len__")
def __contains__(self, key):
return self.forward_magic_method("__contains__", key)
# dir is defined by the base nn.Module, so instead of throwing if
# it is not overriden, we call into the nn.Module __dir__ method
def __dir__(self):
self_method = self.__dir__
if self_method.__func__ == get_function_from_type( # type: ignore
RecursiveScriptModule, "__dir__"
):
return super(RecursiveScriptModule, self).__dir__()
return self_method()
# to resolve bool(value), python looks if __bool__ is defined then __iter__
# is defined then returns true for classes. because __iter__() on this
# class throws if it isn't overriden, we define __bool__ to preserve default behavior
def __bool__(self):
self_method = self.__bool__
if self_method.__func__ == get_function_from_type( # type: ignore
RecursiveScriptModule, "__bool__"
):
return True
return self_method()
def _replicate_for_data_parallel(self):
# we have to initialize ScriptModule properly so that
# it works with pybind11
def init_fn(script_module):
# Don't do anything here, we'll initialize the ScriptModule below
return
return RecursiveScriptModule._construct(
self._c._replicate_for_data_parallel(), init_fn
)
# Need to copy all RecursiveScriptModule methods to ScriptModule.
#
# This is because `super(MyScriptModule, self).foo()` does not use
# `__getattr__` to look up `foo`. So we need to make each method available on
# the ScriptModule manually.
for name, item in RecursiveScriptModule.__dict__.items():
if not callable(item) and not isinstance(item, property):
continue
if name.startswith("__") or hasattr(ScriptModule, name):
continue
# We can copy over the implementation wholesale because besides the
# `super()` thing above, ScriptModule behaves exactly like
# RecursiveScriptModule
setattr(ScriptModule, name, item)
def _get_methods(cls):
import inspect
# In Python 3 unbound methods are functions, but in Python 2 they are methods
return inspect.getmembers(
cls, predicate=lambda x: inspect.isfunction(x) or inspect.ismethod(x)
)
_compiled_methods_allowlist = {
"forward",
"register_buffer",
"register_parameter",
"add_module",
"_apply",
"apply",
"cuda",
"cpu",
"to",
"type",
"float",
"double",
"half",
"state_dict",
"_save_to_state_dict",
"load_state_dict",
"_load_from_state_dict",
"_named_members",
"parameters",
"named_parameters",
"buffers",
"named_buffers",
"children",
"named_children",
"modules",
"named_modules",
"zero_grad",
"share_memory",
"_get_name",
"extra_repr",
"_slow_forward",
"_tracing_name",
"eval",
"train",
}
def _make_fail(name):
def fail(self, *args, **kwargs):
raise RuntimeError(name + " is not supported on ScriptModules")
return fail
for name, method in _get_methods(torch.nn.Module):
if name.startswith("__"):
continue
if (
name not in RecursiveScriptModule.__dict__
and name not in _compiled_methods_allowlist
):
setattr(RecursiveScriptModule, method.__name__, _make_fail(name))
else:
# TODO MAKE SURE THAT DISABLING WORKS
class ScriptModule(torch.nn.Module): # type: ignore
def __init__(self, arg=None):
super().__init__()
class RecursiveScriptModule(ScriptModule): # type: ignore
def __init__(self, arg=None):
super().__init__()
def call_prepare_scriptable_func_impl(obj, memo):
if not isinstance(obj, torch.nn.Module):
return obj
obj_id = id(obj)
# If obj_id is in memo, obj has already been prepared or is being
# prepared in another call up the stack.
if obj_id in memo:
return memo[id(obj)]
obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore
# Record obj in memo to avoid infinite recursion in the case of cycles in the module
# hierarchy when recursing below.
memo[obj_id] = obj
new_obj_dict = {}
for name in obj.__dict__:
sub_module = obj.__dict__.get(name)
if name == '_modules':
for k, v in sub_module.items():
sub_module[k] = call_prepare_scriptable_func_impl(v, memo)
new_obj_dict[name] = sub_module
elif isinstance(sub_module, torch.nn.Module) and not isinstance(sub_module, ScriptModule):
new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo)
else:
new_obj_dict[name] = sub_module
for k, v in new_obj_dict.items():
obj.__dict__[name] = v
return obj
def call_prepare_scriptable_func(obj):
memo: Dict[int, torch.nn.Module] = {}
return call_prepare_scriptable_func_impl(obj, memo)
def script(obj, optimize=None, _frames_up=0, _rcb=None):
r"""
Scripting a function or ``nn.Module`` will inspect the source code, compile
it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or
:class:`ScriptFunction`. TorchScript itself is a subset of the Python language, so not all
features in Python work, but we provide enough functionality to compute on
tensors and do control-dependent operations. For a complete guide, see the
:ref:`language-reference`.
``torch.jit.script`` can be used as a function for modules and functions, and as a decorator
``@torch.jit.script`` for :ref:`torchscript-classes` and functions.
Args:
obj (callable, class, or ``nn.Module``): The ``nn.Module``, function, or class type to
compile.
Returns:
If ``obj`` is ``nn.Module``, ``script`` returns
a :class:`ScriptModule` object. The returned :class:`ScriptModule` will
have the same set of sub-modules and parameters as the
original ``nn.Module``. If ``obj`` is a standalone function,
a :class:`ScriptFunction` will be returned.
**Scripting a function**
The ``@torch.jit.script`` decorator will construct a :class:`ScriptFunction`
by compiling the body of the function.
Example (scripting a function):
.. testcode::
import torch
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
print(type(foo)) # torch.jit.ScriptFuncion
# See the compiled graph as Python code
print(foo.code)
# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
.. testoutput::
:hide:
...
**Scripting an nn.Module**
Scripting an ``nn.Module`` by default will compile the ``forward`` method and recursively
compile any methods, submodules, and functions called by ``forward``. If a ``nn.Module`` only uses
features supported in TorchScript, no changes to the original module code should be necessary. ``script``
will construct :class:`ScriptModule` that has copies of the attributes, parameters, and methods of
the original module.
Example (scripting a simple module with a Parameter):
.. testcode::
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
# This parameter will be copied to the new ScriptModule
self.weight = torch.nn.Parameter(torch.rand(N, M))
# When this submodule is used, it will be compiled
self.linear = torch.nn.Linear(N, M)
def forward(self, input):
output = self.weight.mv(input)
# This calls the `forward` method of the `nn.Linear` module, which will
# cause the `self.linear` submodule to be compiled to a `ScriptModule` here
output = self.linear(output)
return output
scripted_module = torch.jit.script(MyModule(2, 3))
Example (scripting a module with traced submodules):
.. testcode::
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
scripted_module = torch.jit.script(MyModule())
To compile a method other than ``forward`` (and recursively compile anything it calls), add
the :func:`@torch.jit.export <torch.jit.export>` decorator to the method. To opt out of compilation
use :func:`@torch.jit.ignore <torch.jit.ignore>` or :func:`@torch.jit.unused <torch.jit.unused>`.
Example (an exported and ignored method in a module)::
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
@torch.jit.export
def some_entry_point(self, input):
return input + 10
@torch.jit.ignore
def python_only_fn(self, input):
# This function won't be compiled, so any
# Python APIs can be used
import pdb
pdb.set_trace()
def forward(self, input):
if self.training:
self.python_only_fn(input)
return input * 99
scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))
"""
if not _enabled:
return obj
if optimize is not None:
warnings.warn(
"`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
)
if isinstance(obj, ScriptModule):
return obj
if isinstance(obj, torch.nn.Module):
obj = call_prepare_scriptable_func(obj)
return torch.jit._recursive.create_script_module(
obj, torch.jit._recursive.infer_methods_to_compile
)
qualified_name = _qualified_name(obj)
if inspect.isclass(obj):
# If this type is a `nn.Module` subclass, they probably meant to pass
# an instance instead of a Module
if issubclass(obj, torch.nn.Module):
raise RuntimeError(
"Type '{}' cannot be compiled since it inherits"
" from nn.Module,"
" pass an instance instead".format(obj)
)
if not _is_new_style_class(obj):
raise RuntimeError(
"TorchScript classes must be new-style classes. "
"Please inherit from 'object'."
)
if len(obj.mro()) > 2:
raise RuntimeError(
"TorchScript classes does not support inheritance yet. "
"Please directly inherit from 'object'."
)
if _rcb is None:
_rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
_compile_and_register_class(obj, _rcb, qualified_name)
return obj
else:
# this is a decorated fn, and we need to the underlying fn and its rcb
if hasattr(obj, "__script_if_tracing_wrapper"):
obj = obj.__original_fn
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
_check_directly_compile_overloaded(obj)
maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
if maybe_already_compiled_fn:
return maybe_already_compiled_fn
ast = get_jit_def(obj, obj.__name__)
if _rcb is None:
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
fn = torch._C._jit_script_compile(
qualified_name, ast, _rcb, get_default_args(obj)
)
# Forward docstrings
fn.__doc__ = obj.__doc__
_set_jit_function_cache(obj, fn)
return fn
# overloads are registered in _jit_internal and compiled here so that _overload
# can be used in nn/functional.py without an import cycle
def _check_overload_defaults(impl_defaults, overload_defaults, loc):
for name, overload_value in overload_defaults.items():
if name not in impl_defaults or impl_defaults[name] != overload_value:
raise torch.jit.frontend.FrontendError(
loc,
"Default parameters on overloads do not affect the runtime so they "
"must equal to the default parameter on the implementation function. Found on "
"parameter {name}".format(name=name),
)