Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2119,7 +2119,6 @@ exclude_patterns = [
'torch/fx/tensor_type.py',
'torch/fx/traceback.py',
'torch/hub.py',
'torch/jit/_script.py', # "Callable[[], Any]" has no attribute "__func__"
'torch/jit/frontend.py', # "expr" has no attribute "id"
'torch/library.py',
'torch/linalg/__init__.py',
Expand Down
151 changes: 98 additions & 53 deletions torch/jit/_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,51 @@
This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""
import functools
import collections
import copy
import enum
import functools
import inspect
import copy
import pickle
import warnings
from typing import Any, Dict, List, Set, Tuple, Union, Callable

from typing import Any, Callable, Dict, List, Set, Tuple, Union

import torch
import torch._jit_internal as _jit_internal
from torch.utils import set_module
from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module, infer_methods_to_compile, _compile_and_register_class
from torch.nn import Module
from torch.jit._state import _enabled
from torch.jit._builtins import _register_builtin
from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
from torch._classes import classes
from torch._jit_internal import _qualified_name
from torch.jit._builtins import _register_builtin
from torch.jit._fuser import _graph_for, _script_method_graph_for

from torch.jit._monkeytype_config import (
JitTypeTraceConfig,
JitTypeTraceStore,
monkeytype_trace,
)
from torch.jit._recursive import (
_compile_and_register_class,
infer_methods_to_compile,
ScriptMethodStub,
wrap_cpp_module,
)
from torch.jit._state import (
_try_get_jit_cached_function,
_try_get_jit_cached_overloads,
_enabled,
_set_jit_function_cache,
_set_jit_overload_cache,
_try_get_jit_cached_function,
_try_get_jit_cached_overloads,
)
from torch.jit.frontend import get_default_args, get_jit_class_def, get_jit_def
from torch.nn import Module
from torch.overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic)
has_torch_function,
has_torch_function_unary,
has_torch_function_variadic,
)
from torch.package import PackageExporter, PackageImporter
from torch.utils import set_module
from ._serialization import validate_map_location

from torch.jit._monkeytype_config import (
monkeytype_trace,
JitTypeTraceConfig ,
JitTypeTraceStore
)
from torch._classes import classes

type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType

torch._C.ScriptMethod.graph_for = _script_method_graph_for # type: ignore[attr-defined]
Expand All @@ -55,11 +62,13 @@
"""
set_module(ScriptFunction, "torch.jit")


# Throws an error if a jit function is pickled.
# Helps to avoid Python crashes for Python versions 3.9.5 + when protocol 0 or 1 is given as an argument.
def _reduce(cls):
raise pickle.PickleError("ScriptFunction cannot be pickled")


ScriptFunction.__reduce__ = _reduce # type: ignore[assignment]


Expand All @@ -70,6 +79,7 @@ def _reduce(cls):
def Attribute(value, type): # type: ignore[no-redef]
return value


Attribute.__doc__ = """
This method is a pass-through function that returns `value`, mostly
used to indicate to the TorchScript compiler that the left-hand side
Expand Down Expand Up @@ -145,10 +155,12 @@ def __init__(self):
Returns `value`
"""


def _get_type_trace_db():
# This is a private API. Use of this for external purposes is discouraged.
return type_trace_db


# Gets a function from the name of a method on a type
def _get_function_from_type(cls, name):
return getattr(cls, name, None)
Expand Down Expand Up @@ -302,7 +314,9 @@ def make_stubs(module):

self.__dict__[
"_actual_script_module"
] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)
] = torch.jit._recursive.create_script_module(
self, make_stubs, share_types=not added_methods_in_init
)

# Delete the Python attributes that now shadow the ScriptModule
# ones, so that __getattr__ and __setattr__ will properly find
Expand Down Expand Up @@ -356,7 +370,9 @@ def __getattr__(self, attr):
return self.const_mapping[attr]


def unpackage_script_module(importer: PackageImporter, script_module_id: str) -> torch.nn.Module:
def unpackage_script_module(
importer: PackageImporter, script_module_id: str
) -> torch.nn.Module:
"""
Called by ``torch.package.PackageImporter``'s Pickler's ``persistent_load`` function.
Performs work of loading and returning a ScriptModule from a ``torch.package`` archive.
Expand Down Expand Up @@ -421,13 +437,17 @@ class RecursiveScriptClass:
_props [Dict[str, property]]: A dictionary of properties fetched from self._c and
exposed on this wrppaer.
"""

def __init__(self, cpp_class):
super().__init__()
self.__dict__["_initializing"] = True
self._c = cpp_class

# Add wrapped object's properties to this class instance.
self._props = {prop.name: property(prop.getter, prop.setter) for prop in self._c._properties()}
self._props = {
prop.name: property(prop.getter, prop.setter)
for prop in self._c._properties()
}

self.__dict__["_initializing"] = False

Expand Down Expand Up @@ -467,8 +487,8 @@ def __iadd__(self, other):
else:
return self.forward_magic_method("__add__", other)


for method_name in _magic_methods:

def method_template(self, *args, **kwargs):
return self.forward_magic_method(method_name, *args, **kwargs)

Expand All @@ -488,7 +508,13 @@ class ScriptModule(Module, metaclass=ScriptMeta):
contain methods, attributes, parameters, and
constants. These can be accessed the same way as on a normal ``nn.Module``.
"""
__jit_unused_properties__ = ['code', 'code_with_constants', 'graph', 'inlined_graph', 'original_name']
__jit_unused_properties__ = [
"code",
"code_with_constants",
"graph",
"inlined_graph",
"original_name",
]

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -843,8 +869,9 @@ def __contains__(self, key):
# it is not overridden, we call into the nn.Module __dir__ method
def __dir__(self):
self_method = self.__dir__
if self_method.__func__ == _get_function_from_type( # type: ignore[attr-defined]
RecursiveScriptModule, "__dir__"
if (
self_method.__func__ # type: ignore[attr-defined]
== _get_function_from_type(RecursiveScriptModule, "__dir__")
):
return super().__dir__()
return self_method()
Expand All @@ -854,8 +881,9 @@ def __dir__(self):
# class throws if it isn't overridden, we define __bool__ to preserve default behavior
def __bool__(self):
self_method = self.__bool__
if self_method.__func__ == _get_function_from_type( # type: ignore[attr-defined]
RecursiveScriptModule, "__bool__"
if (
self_method.__func__ # type: ignore[attr-defined]
== _get_function_from_type(RecursiveScriptModule, "__bool__")
):
return True
return self_method()
Expand Down Expand Up @@ -931,7 +959,7 @@ def _get_methods(cls):
"eval",
"train",
"get_extra_state",
"set_extra_state"
"set_extra_state",
}

def _make_fail(name):
Expand Down Expand Up @@ -963,6 +991,7 @@ class RecursiveScriptModule(ScriptModule): # type: ignore[no-redef]
def __init__(self, arg=None):
super().__init__()


def call_prepare_scriptable_func_impl(obj, memo):
if not isinstance(obj, torch.nn.Module):
return obj
Expand All @@ -974,19 +1003,21 @@ def call_prepare_scriptable_func_impl(obj, memo):
if obj_id in memo:
return memo[id(obj)]

obj = obj.__prepare_scriptable__() if hasattr(obj, '__prepare_scriptable__') else obj # type: ignore[operator]
obj = obj.__prepare_scriptable__() if hasattr(obj, "__prepare_scriptable__") else obj # type: ignore[operator]
# Record obj in memo to avoid infinite recursion in the case of cycles in the module
# hierarchy when recursing below.
memo[obj_id] = obj

new_obj_dict = {}

for name, sub_module in obj.__dict__.items():
if name == '_modules':
if name == "_modules":
for k, v in sub_module.items():
sub_module[k] = call_prepare_scriptable_func_impl(v, memo)
new_obj_dict[name] = sub_module
elif isinstance(sub_module, torch.nn.Module) and not isinstance(sub_module, ScriptModule):
elif isinstance(sub_module, torch.nn.Module) and not isinstance(
sub_module, ScriptModule
):
new_obj_dict[name] = call_prepare_scriptable_func_impl(sub_module, memo)
else:
new_obj_dict[name] = sub_module
Expand All @@ -1001,6 +1032,7 @@ def call_prepare_scriptable_func(obj):
memo: Dict[int, torch.nn.Module] = {}
return call_prepare_scriptable_func_impl(obj, memo)


def create_script_dict(obj):
"""
Create a ``torch._C.ScriptDict`` instance with the data from ``obj``.
Expand Down Expand Up @@ -1031,8 +1063,13 @@ def create_script_list(obj, type_hint=None):
return torch._C.ScriptList(obj) # type: ignore[attr-defined]


def script(obj, optimize=None, _frames_up=0, _rcb=None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None):
def script(
obj,
optimize=None,
_frames_up=0,
_rcb=None,
example_inputs: Union[List[Tuple], Dict[Callable, List[Tuple]], None] = None,
):
r"""
Scripting a function or ``nn.Module`` will inspect the source code, compile
it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or
Expand Down Expand Up @@ -1269,12 +1306,16 @@ def forward(self, a) -> MyModule:
for examples in example_inputs:
obj(*examples)
else:
raise ValueError("Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
" or `Dict[Callable, List[Tuple]]` to be run with MonkeyType.")
raise ValueError(
"Error: Unable to infer types. Please format the inputs to type `List[Tuple]`"
" or `Dict[Callable, List[Tuple]]` to be run with MonkeyType."
)
else:
warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
"to enable Profile-Directed Typing in TorchScript. Refer to "
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
warnings.warn(
"Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
"to enable Profile-Directed Typing in TorchScript. Refer to "
"https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. "
)

if isinstance(obj, torch.nn.Module):
obj = call_prepare_scriptable_func(obj)
Expand Down Expand Up @@ -1393,8 +1434,9 @@ def _get_overloads(obj):
return existing_compiled_fns

if obj in uncompiled_overloads:
raise RuntimeError(_jit_internal.get_overload_no_implementation_error_message(
'function', obj))
raise RuntimeError(
_jit_internal.get_overload_no_implementation_error_message("function", obj)
)

compiled_fns = []
for overload_fn in uncompiled_overloads:
Expand Down Expand Up @@ -1460,14 +1502,15 @@ def _recursive_compile_class(obj, loc):
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
return _compile_and_register_class(obj, rcb, _qual_name)


CompilationUnit = torch._C.CompilationUnit
set_module(CompilationUnit, "torch.jit")


def pad(s: str, padding: int, offset: int = 0, char: str = ' '):
def pad(s: str, padding: int, offset: int = 0, char: str = " "):
if padding >= len(s):
padding -= len(s)
return ''.join([char for _ in range(padding + offset)]) + s
return "".join([char for _ in range(padding + offset)]) + s


class _ScriptProfileColumn:
Expand All @@ -1483,7 +1526,7 @@ def add_row(self, lineno: int, value: Any):
def materialize(self):
max_length = len(self.header)
rows: List[Tuple[int, str]] = []
for (key, value) in self.rows.items():
for key, value in self.rows.items():
cell = str(value)
rows.append((key, cell))
max_length = max(len(cell), max_length)
Expand All @@ -1506,24 +1549,24 @@ def __init__(self, cols: List[_ScriptProfileColumn], source_range: List[int]):
def dump_string(self):
outputs: List[str] = []
cells: List[Tuple[str, Dict[int, str]]] = []
header_buffer = ''
header_buffer = ""
for col in self.cols:
header, rows = col.materialize()
header_buffer += header
cells.append((header, dict(rows)))

outputs.append(header_buffer)
outputs.append(pad('', len(header_buffer), 0, '='))
outputs.append(pad("", len(header_buffer), 0, "="))
for line in self.source_range:
row_buffer = ''
row_buffer = ""
for header, rows in cells:
cell = rows.get(line)
if cell is None:
row_buffer += pad('', len(header))
row_buffer += pad("", len(header))
else:
row_buffer += cell
outputs.append(row_buffer)
return '\n'.join(outputs)
return "\n".join(outputs)


class _ScriptProfile:
Expand All @@ -1541,7 +1584,7 @@ def dump_string(self) -> str:
for source_stats in self.profile._dump_stats():
source_ref = source_stats.source()
source_lines = source_ref.text().splitlines()
dedent = min([len(line) - len(line.lstrip(' ')) for line in source_lines])
dedent = min([len(line) - len(line.lstrip(" ")) for line in source_lines])
source_lines = [line[dedent:] for line in source_lines]

start_line = source_ref.starting_lineno()
Expand All @@ -1560,9 +1603,11 @@ def dump_string(self) -> str:
hits.add_row(line, stat.count())
time_ns.add_row(line, stat.duration_ns())

table = _ScriptProfileTable([lineno, hits, time_ns, line_contents], list(source_range))
table = _ScriptProfileTable(
[lineno, hits, time_ns, line_contents], list(source_range)
)
outputs.append(table.dump_string())
return '\n\n'.join(outputs)
return "\n\n".join(outputs)

def dump(self):
print(self.dump_string())
Expand Down