diff --git a/onnxscript/_internal/analysis.py b/onnxscript/_internal/analysis.py index 0403f60c9..8b8d3f0b5 100644 --- a/onnxscript/_internal/analysis.py +++ b/onnxscript/_internal/analysis.py @@ -3,7 +3,8 @@ from __future__ import annotations import ast -from typing import Any, Optional, Sequence, Set +from collections.abc import Sequence +from typing import Any, Optional from onnxscript import sourceinfo from onnxscript._internal import ast_utils @@ -15,7 +16,7 @@ def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str: return for_stmt.target.id -def _used_vars(expr: Optional[ast.expr]) -> Set[str]: +def _used_vars(expr: Optional[ast.expr]) -> set[str]: """Return set of all variables used, including function names, in an expression.""" if expr is None: return set() @@ -35,7 +36,7 @@ def _used_vars(expr: Optional[ast.expr]) -> Set[str]: return result -def _lhs_vars(lhs: ast.expr) -> Set[str]: +def _lhs_vars(lhs: ast.expr) -> set[str]: """Return set of assigned variables in the lhs of an assignment statement.""" def get_id(e): @@ -49,12 +50,12 @@ def get_id(e): def assigned_vars( stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter -) -> Set[str]: +) -> set[str]: """Return the set of all variables that may be assigned to in an execution of input stmt or sequence of statements. """ - def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: + def assigned_in_block(block: Sequence[ast.stmt]) -> set[str]: result: set[Any] = set() for s in block: result = result | assigned_vars(s, formatter) @@ -90,14 +91,14 @@ def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter): and `s.live_out`. """ - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: stmt.live_out = live_out # type: ignore[attr-defined] live = do_visit(stmt, live_out) stmt.live_in = live # type: ignore[attr-defined] return live - def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + def do_visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: + def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]: for s in reversed(block): live_out = visit(s, live_out) return live_out @@ -165,12 +166,12 @@ def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter): (in the first statement). Hence x is included in the exposed_uses. """ - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]: for stmt in reversed(block): live_out = visit(stmt, live_out) return live_out - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: if isinstance(stmt, ast.Assign): return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) if isinstance(stmt, ast.AnnAssign): diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 836bafff9..7a57a7df4 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -3,7 +3,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Callable, Optional import numpy as np import onnx diff --git a/onnxscript/_internal/param_manipulation.py b/onnxscript/_internal/param_manipulation.py index b3591a0a8..0cda7ff33 100644 --- a/onnxscript/_internal/param_manipulation.py +++ b/onnxscript/_internal/param_manipulation.py @@ -5,7 +5,9 @@ from __future__ import annotations import collections -from typing import Any, OrderedDict, Sequence +from collections import OrderedDict +from collections.abc import Sequence +from typing import Any from onnxscript import values diff --git a/onnxscript/_internal/runtime_typing.py b/onnxscript/_internal/runtime_typing.py index 3cf8a8db5..d636bf638 100644 --- a/onnxscript/_internal/runtime_typing.py +++ b/onnxscript/_internal/runtime_typing.py @@ -22,10 +22,6 @@ checked = typing.cast(typing.Callable[[T], T], _beartype_decorator) - # Beartype warns when we import from typing because the types are deprecated - # in Python 3.9. But there will be a long time until we can move to using - # the native container types for type annotations (when 3.9 is the lowest - # supported version). So we silence the warning. warnings.filterwarnings( "ignore", category=_roar.BeartypeDecorHintPep585DeprecationWarning, diff --git a/onnxscript/_internal/utils.py b/onnxscript/_internal/utils.py index ce2b657cf..03ab1c2c8 100644 --- a/onnxscript/_internal/utils.py +++ b/onnxscript/_internal/utils.py @@ -3,7 +3,8 @@ from __future__ import annotations import numbers -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional import numpy as np import onnx diff --git a/onnxscript/_internal/version_utils.py b/onnxscript/_internal/version_utils.py index 2b43c54f4..8553e275e 100644 --- a/onnxscript/_internal/version_utils.py +++ b/onnxscript/_internal/version_utils.py @@ -5,7 +5,8 @@ from __future__ import annotations import warnings -from typing import Callable, Sequence +from collections.abc import Sequence +from typing import Callable import packaging.version diff --git a/onnxscript/_legacy_ir/__init__.py b/onnxscript/_legacy_ir/__init__.py index 29bba5458..ae5b67124 100644 --- a/onnxscript/_legacy_ir/__init__.py +++ b/onnxscript/_legacy_ir/__init__.py @@ -4,7 +4,7 @@ import dataclasses from collections import deque -from typing import List, Tuple, Union +from typing import Union import numpy as np import onnx @@ -47,9 +47,9 @@ def __init__(self) -> None: # TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of # tensors, etc. However, we currently only handle lists of tensors. -SymbolicValue = Union[str, List[str]] +SymbolicValue = Union[str, list[str]] -FunctionId = Tuple[str, str, str] +FunctionId = tuple[str, str, str] def get_function_id(function: onnx.FunctionProto) -> FunctionId: diff --git a/onnxscript/_legacy_ir/visitor.py b/onnxscript/_legacy_ir/visitor.py index 6adfeab6d..d4bc6948d 100644 --- a/onnxscript/_legacy_ir/visitor.py +++ b/onnxscript/_legacy_ir/visitor.py @@ -5,7 +5,8 @@ import dataclasses import logging -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import numpy as np import onnx diff --git a/onnxscript/_thirdparty/asciichartpy.py b/onnxscript/_thirdparty/asciichartpy.py index 88c46202c..740d8c145 100644 --- a/onnxscript/_thirdparty/asciichartpy.py +++ b/onnxscript/_thirdparty/asciichartpy.py @@ -32,8 +32,8 @@ from __future__ import annotations +from collections.abc import Mapping from math import ceil, floor, isnan -from typing import Mapping black = "\033[30m" red = "\033[31m" diff --git a/onnxscript/backend/onnx_backend.py b/onnxscript/backend/onnx_backend.py index ef93bb50b..461b74849 100644 --- a/onnxscript/backend/onnx_backend.py +++ b/onnxscript/backend/onnx_backend.py @@ -4,7 +4,7 @@ import os import textwrap -from typing import Iterator +from collections.abc import Iterator import numpy as np import onnx diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index 04c4639ea..6d7a06851 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Optional import numpy import onnx diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index 1d05428a2..8a3ec33c2 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -9,7 +9,7 @@ import re import sys import unittest -from typing import Pattern +from re import Pattern import onnx import onnxruntime as ort diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 1ee6e0ecd..0a7ae3b60 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -4,15 +4,12 @@ import ast import logging +from collections.abc import Sequence from typing import ( TYPE_CHECKING, Any, - Dict, - List, NoReturn, Optional, - Sequence, - Tuple, Union, ) @@ -178,11 +175,11 @@ def __init__( self.default_opset_ = default_opset # States initialized by `_init_function_translation` - self._outer: List[irbuilder.IRFunction] = [] + self._outer: list[irbuilder.IRFunction] = [] self._current_fn: irbuilder.IRFunction = None self._nextvar: int = 0 self._used_vars: set[str] = set() - self._locals: List[Dict[str, LocalSymValue]] = [{}] + self._locals: list[dict[str, LocalSymValue]] = [{}] @property def default_opset(self) -> values.Opset: @@ -230,7 +227,7 @@ def _init_function_translation(self) -> None: self._current_fn: Optional[irbuilder.IRFunction] = None self._nextvar = 0 self._used_vars = set() - self._locals: List[Dict[str, LocalSymValue]] = [{}] + self._locals: list[dict[str, LocalSymValue]] = [{}] def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: return sourceinfo.SourceInfo(node, self.source, self._current_fn.name) @@ -269,7 +266,7 @@ def _exit_scope(self) -> irbuilder.IRFunction: self._locals.pop(0) return graph - def _current_scope(self) -> Dict[str, LocalSymValue]: + def _current_scope(self) -> dict[str, LocalSymValue]: return self._locals[0] def _bind(self, name: str, val: LocalSymValue) -> None: @@ -528,12 +525,7 @@ def _translate_attr( return attr def _translate_docstring(self, node: ast.Expr) -> None: - if hasattr(node.value, "value"): - # python 3.8+ - return self.ir_builder.add_docstring(self._current_fn, node.value.value) - raise TypeError( - f"Unexpected type {type(node)!r} for node. Unsupoorted version of python." - ) + return self.ir_builder.add_docstring(self._current_fn, node.value.value) def _translate_expr( self, node: ast.AST, target: Optional[PreferredName] = None @@ -697,9 +689,9 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: # As the first step, we partition the index elements into four kinds: Slice (eg., 1:5:2), # known-to-be-scalar (eg., 2), other-tensor (eg., I), skip/no-op (that is, just ":") - sliced_indices: List[Tuple[int, ast.expr]] = [] - scalar_indices: List[Tuple[int, ast.expr]] = [] - non_scalar_indices: List[Tuple[int, ast.expr]] = [] + sliced_indices: list[tuple[int, ast.expr]] = [] + scalar_indices: list[tuple[int, ast.expr]] = [] + non_scalar_indices: list[tuple[int, ast.expr]] = [] for axis, elt in enumerate(indices): if isinstance(elt, ast.Slice): # Add to sliced_indices, unless it is "::", which is a no-op. @@ -870,14 +862,7 @@ def _translate_unary_op_expr(self, node): # should intercept this call and replace node # by node.operand. # This mechanism does not handle somthing like `(-(-5))`. - if hasattr(node.operand, "value"): - # python 3.8+ - val = node.operand.value - else: - raise TypeError( - f"Unable to guess constant value from type {type(node.operand)!r} " - f"and attributes {dir(node.operand)!r}." - ) + val = node.operand.value if op == ast.USub: cst = ast.Constant(-val, lineno=node.lineno, col_offset=node.col_offset) return self._translate_expr(cst) diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index 38784ca7f..d0374bec8 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -5,13 +5,12 @@ import abc import contextlib import pprint +from collections.abc import Mapping, Sequence from typing import ( Any, Callable, - Mapping, Optional, Protocol, - Sequence, TypeVar, Union, runtime_checkable, diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py index c5b87898c..9a1b2a798 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py @@ -5,7 +5,8 @@ import copy import dataclasses import logging -from typing import Dict, Mapping, Optional, Sequence, Set +from collections.abc import Mapping, Sequence +from typing import Optional import onnx import onnx.defs @@ -47,10 +48,10 @@ class TypeConstraint: """Type constraint shared by multiple values.""" name: str - type_strs: Set[str] - values: Set[Value] + type_strs: set[str] + values: set[Value] - def __init__(self, name: str, type_strs: Set[str]): + def __init__(self, name: str, type_strs: set[str]): self.name = name self.type_strs = type_strs self.values = set() @@ -125,9 +126,9 @@ def __repr__(self) -> str: @dataclasses.dataclass class OnnxFunctionTypeConstraints: - input_type_constraints: Dict[str, Optional[TypeConstraint]] - output_type_constraints: Dict[str, Optional[TypeConstraint]] - intermediate_type_constraints: Dict[str, Optional[TypeConstraint]] + input_type_constraints: dict[str, Optional[TypeConstraint]] + output_type_constraints: dict[str, Optional[TypeConstraint]] + intermediate_type_constraints: dict[str, Optional[TypeConstraint]] def __repr__(self): repr_strs = [ @@ -190,7 +191,7 @@ def __repr__(self): class TypeConstraintDeducer: def __init__(self, onnx_function: onnxscript.OnnxFunction): self.onnx_function = onnx_function - self.values: Dict[str, Value] = {} + self.values: dict[str, Value] = {} def type_constraints(self, signature_only: bool = True) -> OnnxFunctionTypeConstraints: """Retrieve deduced type constraints for the ONNX function.""" @@ -210,7 +211,7 @@ def type_constraints(self, signature_only: bool = True) -> OnnxFunctionTypeConst ) # Rename type constraints to T0, T1, T2, ... - seen_type_constraints: Set[TypeConstraint] = set() + seen_type_constraints: set[TypeConstraint] = set() for type_constraint in ( *input_type_constraints.values(), *output_type_constraints.values(), @@ -250,7 +251,7 @@ def _bind_signature( node: onnx.NodeProto, param_names: Sequence[str], param_schemas: Sequence[onnx.defs.OpSchema.FormalParameter], - op_type_constraints: Dict[str, TypeConstraint], + op_type_constraints: dict[str, TypeConstraint], is_output: bool = False, ): param_schemas = list(param_schemas) diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index a8d15c242..68031ccd1 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -7,7 +7,7 @@ import inspect import logging import unittest -from typing import Generator +from collections.abc import Generator import parameterized diff --git a/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py index eb2d8015a..056ae4ed9 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py @@ -13,7 +13,8 @@ import os import textwrap import typing -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import torchgen.gen import torchgen.model diff --git a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py index ebbdd43bd..f5ca893a9 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py @@ -13,7 +13,8 @@ import os import re import textwrap -from typing import Any, Dict, List, Sequence +from collections.abc import Sequence +from typing import Any import torch import torchgen.gen @@ -241,8 +242,8 @@ def copyright_header() -> str: ) -def _get_func_schema_in_namespace(namespaces: List[_OpNamespace]) -> Dict[str, FunctionSchema]: - table: Dict[str, FunctionSchema] = {} +def _get_func_schema_in_namespace(namespaces: list[_OpNamespace]) -> dict[str, FunctionSchema]: + table: dict[str, FunctionSchema] = {} for op_namespace in namespaces: for attr_name in dir(op_namespace): op_overload_packet = getattr(op_namespace, attr_name) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index b5c1456c1..41ebcb78f 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -8,7 +8,8 @@ import os import tempfile import typing -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union import numpy as np import onnx @@ -110,7 +111,7 @@ def __init__( super().__init__(None) self._torch_value: torch.Value = value self._concrete_value: Optional[np.ndarray] = None - self._shape: Optional[Tuple[int | str | None, ...]] = None + self._shape: Optional[tuple[int | str | None, ...]] = None self._torch_dtype: Optional[torch.dtype] = None self._name: Optional[str] = None self._is_complex: bool = False @@ -152,7 +153,7 @@ def rank(self) -> int | None: return value_type.dim() @property # type: ignore[override] - def shape(self) -> Tuple[int | str | None, ...] | None: + def shape(self) -> tuple[int | str | None, ...] | None: if self._shape is not None: return self._shape @@ -169,7 +170,7 @@ def shape(self) -> Tuple[int | str | None, ...] | None: return tuple(shape) @shape.setter - def shape(self, shape: Union[torch.Size, Tuple[int | str | None, ...]]): + def shape(self, shape: Union[torch.Size, tuple[int | str | None, ...]]): # Normalize torch symbolic dimension size to str. torch_sym_types = (torch.SymInt, torch.SymFloat, torch.SymBool) self._shape = tuple( @@ -250,9 +251,9 @@ def _unwrap_tensor_to_torch_value( ], ) -> Union[ ValidTorchValueType, - Dict[str, ValidTorchValueType], - List[ValidTorchValueType], - Tuple[ValidTorchValueType, ...], + dict[str, ValidTorchValueType], + list[ValidTorchValueType], + tuple[ValidTorchValueType, ...], ]: """Unwrap the TorchScriptTensor to torch.Value.""" if isinstance(value, TorchScriptTensor): @@ -274,14 +275,14 @@ def _wrap_torch_value_to_tensor( torch.Value, Mapping[str, ValidTorchValueType], Sequence[ValidTorchValueType] ], *, - shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None, + shape: Optional[Union[torch.Size, tuple[Union[int, str, None], ...]]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Union[ ValidArgumentType, - Dict[str, ValidArgumentType], - List[ValidArgumentType], - Tuple[ValidArgumentType, ...], + dict[str, ValidArgumentType], + list[ValidArgumentType], + tuple[ValidArgumentType, ...], ]: """Wrap torch.Value to TorchScriptTensor.""" if isinstance(value, torch.Value): @@ -488,7 +489,7 @@ def _create_op_call_in_torch_graph( inputs: Sequence[torch.Value], attributes: Mapping[str, Any], n_outputs: int = 1, -) -> Tuple[torch.Value, ...]: +) -> tuple[torch.Value, ...]: """Creates a node representing an onnx op in `graph`. Args: @@ -548,17 +549,17 @@ def __init__( self._torch_graph = torch.Graph() # All the functions used, deduplicated by name # key: (name, domain) - self._function_store: Dict[Tuple[str, str], onnxscript.OnnxFunction] = {} + self._function_store: dict[tuple[str, str], onnxscript.OnnxFunction] = {} # Mapping from intializer name to data(torch.Tensor). - self._initializers: Dict[str, torch.Tensor] = {} + self._initializers: dict[str, torch.Tensor] = {} # Mapping from intializer name to input(TorchScriptTensor). - self._initializers_inputs: Dict[str, TorchScriptTensor] = {} + self._initializers_inputs: dict[str, TorchScriptTensor] = {} # Mapping from intializer name to input(TorchScriptTensor) from parent graph. - self._initializers_inputs_from_parent: Dict[str, TorchScriptTensor] = {} + self._initializers_inputs_from_parent: dict[str, TorchScriptTensor] = {} # Mapping from model local function type name to function graph. # Local function type name is expected to be unique. Converter creates # a unique name and a unique function graph for every module call. - self._sub_torch_script_graphs: Dict[str, TorchScriptGraph] = {} + self._sub_torch_script_graphs: dict[str, TorchScriptGraph] = {} # Parent graph. None if this is the top level graph. self._parent_torch_script_graph = parent_torch_script_graph # Domain name of the graph. None if this is the top level graph. @@ -572,7 +573,7 @@ def __init__( # This info is later serialized as `ValueInfoProto` inside ONNX, to # provide shape and dtype information for nodes within nested function calls. # https://github.com/onnx/onnx/issues/5487 - self._value_to_tensor: Dict[torch.Value, TorchScriptTensor] = {} + self._value_to_tensor: dict[torch.Value, TorchScriptTensor] = {} if self._domain_name is None and self._parent_torch_script_graph is not None: raise RuntimeError( @@ -592,7 +593,7 @@ def initializers(self) -> Mapping[str, torch.Tensor]: # we need to filter out the initializers that has fake tensor. This # is because we don't want to introduce fake tensor in onnxscript. @initializers.setter - def initializers(self, initializers: Dict[str, torch.Tensor]): + def initializers(self, initializers: dict[str, torch.Tensor]): self._initializers = initializers @property @@ -615,7 +616,7 @@ def domain_name(self) -> Optional[str]: def add_input( self, input_name: Optional[str], - shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None, + shape: Optional[Union[torch.Size, tuple[Union[int, str, None], ...]]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> TorchScriptTensor: @@ -642,7 +643,7 @@ def add_input( ) if isinstance(tensor_value, TorchScriptTensor): # NOTE: Only track value that maps to tensor. - # Value that maps to Sequence/Dict of tensors is not tracked. + # Value that maps to Sequence/dict of tensors is not tracked. self._value_to_tensor[torch_value] = tensor_value return tensor_value # type: ignore[return-value] @@ -682,7 +683,7 @@ def add_initializer(self, name: str, value: torch.Tensor) -> TorchScriptTensor: @runtime_typing.checked def register_outputs( - self, outputs: Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]] + self, outputs: Union[TorchScriptTensor, tuple[TorchScriptTensor, ...]] ): unwrapped_outputs = _unwrap_tensors_to_torch_values(outputs) if isinstance(unwrapped_outputs, torch.Value): @@ -737,7 +738,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value: value.setDebugName(_rename_intermediate_value(value.debugName())) return value - def preprocess_inputs(self, onnx_inputs: Sequence[ValidInputType]) -> List[torch.Value]: + def preprocess_inputs(self, onnx_inputs: Sequence[ValidInputType]) -> list[torch.Value]: unwrapped_inputs = _unwrap_tensors_to_torch_values(onnx_inputs) graph_inputs = [] assert isinstance(unwrapped_inputs, Sequence) @@ -770,7 +771,7 @@ def _add_torchscript_op_call( onnx_inputs: Sequence[ValidInputType], onnx_attributes: Mapping[str, ValidArgumentType], n_outputs: int, - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + ) -> Union[TorchScriptTensor, tuple[TorchScriptTensor, ...]]: graph_inputs = self.preprocess_inputs(onnx_inputs) for key, value in onnx_attributes.items(): assert not isinstance(value, TorchScriptTensor), ( @@ -801,8 +802,8 @@ def _add_torchscript_op_call( @runtime_typing.checked def fetch_function_proto_dict( self, opset_version: int - ) -> Mapping[Tuple[str, str], onnx.FunctionProto]: - function_proto_dict: Dict[Tuple[str, str], onnx.FunctionProto] = {} + ) -> Mapping[tuple[str, str], onnx.FunctionProto]: + function_proto_dict: dict[tuple[str, str], onnx.FunctionProto] = {} # Fetch local function protos. E.g., local functions representing module calls. for ( sub_graph_name, @@ -893,7 +894,7 @@ def add_op_call( onnx_op_schema: onnx.defs.OpSchema, onnx_inputs: Sequence[ValidInputType], onnx_attributes: Mapping[str, ValidArgumentType], - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + ) -> Union[TorchScriptTensor, tuple[TorchScriptTensor, ...]]: # Compute outputs from the onnx_op op schema n_outputs = evaluator.compute_num_outputs(onnx_op_schema, onnx_inputs, onnx_attributes) result = self._add_torchscript_op_call( @@ -911,7 +912,7 @@ def add_function_call( onnx_function: onnxscript.OnnxFunction, onnx_inputs: Sequence[ValidInputType], onnx_attributes: Mapping[str, ValidArgumentType], - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + ) -> Union[TorchScriptTensor, tuple[TorchScriptTensor, ...]]: identifier = (onnx_function.name, onnx_function.function_ir.domain) self._function_store[identifier] = onnx_function @@ -931,7 +932,7 @@ def add_module_call( name: str, sub_torch_script_graph: TorchScriptGraph, onnx_inputs: Sequence[ValidInputType], - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + ) -> Union[TorchScriptTensor, tuple[TorchScriptTensor, ...]]: self._sub_torch_script_graphs[name] = sub_torch_script_graph domain_name = sub_torch_script_graph.domain_name assert domain_name is not None @@ -948,7 +949,7 @@ def add_module_call( def generate_function_value_info_proto( self, function_op_type: str ) -> Mapping[str, onnx.ValueInfoProto]: - named_value_info: Dict[str, onnx.ValueInfoProto] = {} + named_value_info: dict[str, onnx.ValueInfoProto] = {} function_id = _function_id(self.domain_name, function_op_type) for torch_value, tensor in self._value_to_tensor.items(): if (value_info := tensor.value_info()) is None: @@ -960,7 +961,7 @@ def generate_function_value_info_proto( return named_value_info @runtime_typing.checked - def generate_subgraphs_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: + def generate_subgraphs_value_info_proto(self) -> dict[str, onnx.ValueInfoProto]: """Unique naming strategies for values inside subgraphs, i.e. local functions. {function_domain::function_op_type}/{value_name} @@ -970,15 +971,15 @@ def generate_subgraphs_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: the `value_info` carried in `TorchScriptTensor` represents the general compatible shape and type. """ - named_value_info: Dict[str, onnx.ValueInfoProto] = {} + named_value_info: dict[str, onnx.ValueInfoProto] = {} for name, sub_graph in self._sub_torch_script_graphs.items(): named_value_info.update(sub_graph.generate_function_value_info_proto(name)) return named_value_info @runtime_typing.checked - def generate_maingraph_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: + def generate_maingraph_value_info_proto(self) -> dict[str, onnx.ValueInfoProto]: """Returns value info proto for values in the main graph.""" - named_value_info: Dict[str, onnx.ValueInfoProto] = {} + named_value_info: dict[str, onnx.ValueInfoProto] = {} for torch_value, tensor in self._value_to_tensor.items(): if (value_info := tensor.value_info()) is None: continue @@ -1034,10 +1035,10 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func def to_model_proto( self, opset_version: int, include_initializers: bool = True ) -> onnx.ModelProto: - function_proto_dict: Mapping[Tuple[str, str], onnx.FunctionProto] = ( + function_proto_dict: Mapping[tuple[str, str], onnx.FunctionProto] = ( self.fetch_function_proto_dict(opset_version) ) - unique_custom_domains: Dict[str, int] = {} + unique_custom_domains: dict[str, int] = {} for function_proto in function_proto_dict.values(): # TODO(BowenBao): All local function domain versions are hardcoded as 1. diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ea43c2c4d..e0776d271 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -12,7 +12,8 @@ from __future__ import annotations import math -from typing import Any, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Optional, Union import numpy as np import torch @@ -594,7 +595,7 @@ def _adjust_args_for_arange_int_dtype( start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, step: TRealUnlessFloat16OrInt8, -) -> Tuple[FLOAT, FLOAT, FLOAT]: +) -> tuple[FLOAT, FLOAT, FLOAT]: zero = op.Cast(0.0, to=FLOAT.dtype) start = op.Cast(start, to=FLOAT.dtype) end = op.Cast(end, to=FLOAT.dtype) @@ -2957,7 +2958,7 @@ def aten_embedding_bag( sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)""" # assert(rank(indices) in [1,2]) @@ -2985,7 +2986,7 @@ def _aten_embedding_bag_onnx( mode: int, per_sample_weights: TFloat, include_last_offset: bool, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat]: neg_1 = op.Constant(value_ints=[-1]) # Assume indices is shape(5,2), indices_1d is shape(10,) indices_1d = op.Reshape(indices, neg_1) @@ -3092,7 +3093,7 @@ def aten_embedding_bag_padding_idx( per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, padding_idx: int = -1, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) We add default values for the attributes to accommodate _embedding_bag as well: @@ -3126,7 +3127,7 @@ def _aten_embedding_bag_1d_padding_idx_onnx( per_sample_weights: TFloat, include_last_offset: bool, padding_idx: int, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat]: neg_1 = op.Constant(value_ints=[-1]) # Get weight out according to indices, # e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]] @@ -5289,7 +5290,7 @@ def aten_max(self: TReal) -> TReal: @torch_op("aten::max.dim", trace_only=True) -def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, INT64]: +def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> tuple[TReal, INT64]: """max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)""" if len(self.shape) == 0: @@ -5356,7 +5357,7 @@ def aten_min(self: TReal) -> TReal: @torch_op("aten::min.dim", trace_only=True) -def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, TInt]: +def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> tuple[TReal, TInt]: """min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)""" if len(self.shape) == 0: result = self @@ -5891,7 +5892,7 @@ def aten__native_batch_norm_no_training( running_var: Optional[TFloat] = None, momentum: float = 0.9, eps: float = 1e-05, -) -> Tuple[TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat]: """_native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)""" return aten_native_batch_norm( @@ -5907,7 +5908,7 @@ def aten__native_batch_norm_no_stats( training: bool = False, momentum: float = 0.9, eps: float = 1e-05, -) -> Tuple[TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat]: """_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)""" return aten_native_batch_norm(input, weight, bias, None, None, training, momentum, eps) @@ -5923,7 +5924,7 @@ def aten_native_batch_norm( training: bool = False, momentum: float = 0.9, eps: float = 1e-05, -) -> Tuple[TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat]: """native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)""" if weight is None: # Set to 1.0 as default @@ -5991,7 +5992,7 @@ def _aten_native_batch_norm_training_onnx( axes: INT64, momentum: float, eps: float, -) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: """Batch normalization training mode. NOTE: momentum in PyTorch is 1.0-momentum in ONNX. @@ -6042,7 +6043,7 @@ def _aten_native_batch_norm_inference_onnx( running_var: TFloat, momentum: float, eps: float, -) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: """Batch normalization inference mode. NOTE: momentum in PyTorch is 1.0-momentum in ONNX. @@ -6082,7 +6083,7 @@ def aten__native_batch_norm_legit_functional( training: bool = False, momentum: float = 0.9, eps: float = 1e-05, -) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: if weight is None: # Set to 1.0 as default weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)) @@ -6168,7 +6169,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: @torch_op("aten::native_dropout", trace_only=True) -def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]: +def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> tuple[TFloat, BOOL]: """native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)""" result, mask = op.Dropout(input, p, train) @@ -6193,7 +6194,7 @@ def aten_native_group_norm( HxW: Optional[INT64] = None, group: int = 1, eps: float = 1e-05, -) -> Tuple[TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat]: """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" # Actually we don't need N,C,HxW value because the input tensor has that information @@ -6215,7 +6216,7 @@ def _aten_native_group_norm_onnx( bias: TFloat, group: INT64, eps: float, -) -> Tuple[TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat]: # Because onnx.GroupNorm() need size=group for weight and bias # But the torch's aten function's input need size=channel, the size mismatched # So we have to use onnx.InstanceNorm() to simulate @@ -6285,7 +6286,7 @@ def aten_native_layer_norm( weight: Optional[TReal] = None, bias: Optional[TReal] = None, eps: float = 1e-05, -) -> Tuple[TReal, TReal, TReal]: +) -> tuple[TReal, TReal, TReal]: """native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)""" # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html#torch.nn.LayerNorm @@ -8113,7 +8114,7 @@ def aten_std_correction( # std_mean is decomposed by PyTroch -def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: +def aten_std_mean(self: TReal, unbiased: bool = True) -> tuple[TReal, TReal]: """std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" # Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction" @@ -8125,7 +8126,7 @@ def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: # std_mean is decomposed by PyTroch def aten_std_mean_dim( self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: """std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)""" # Although dim is Optional in signature, but we assume it must have value for this overload @@ -8143,7 +8144,7 @@ def aten_std_mean_correction( dim: Optional[int] = None, correction: Optional[float] = None, keepdim: bool = False, -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: """std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)""" if correction is None: @@ -8443,7 +8444,7 @@ def aten_to_sparse_csr(self: TensorType) -> TensorType: @torch_op("aten::topk", trace_only=True) def aten_topk( self: TReal, k: int, dim: int = -1, largest: bool = True, sorted: bool = True -) -> Tuple[TReal, INT64]: +) -> tuple[TReal, INT64]: """topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)""" # We do not handle scalar inputs for topk @@ -8908,7 +8909,7 @@ def _aten_var_dim_onnx( # var_mean is decomposed by PyTroch -def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: +def aten_var_mean(self: TReal, unbiased: bool = True) -> tuple[TReal, TReal]: """var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" # Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction" @@ -8919,7 +8920,7 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: # var_mean is decomposed by PyTroch def aten_var_mean_dim( self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: """var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)""" # Although dim is Optional in signature, but we assume it must have value for this overload @@ -8934,7 +8935,7 @@ def aten_var_mean_correction( dim: Optional[int] = None, correction: Optional[float] = None, keepdim: bool = False, -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: """var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)""" if correction is None: @@ -8952,7 +8953,7 @@ def aten_var_mean_correction( # var_mean is decomposed by PyTroch def _aten_var_mean_onnx( self: TReal, correction: float = 1.0, keepdim: bool = False -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: # Compute mean and var mean = op.ReduceMean(self, keepdims=keepdim) sub_mean = op.Sub(self, mean) @@ -8972,7 +8973,7 @@ def _aten_var_mean_onnx( # var_mean is decomposed by PyTroch def _aten_var_mean_dim_onnx( self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: dims = op.Reshape(dims, op.Constant(value_ints=[-1])) # Computer mean and var mean = op.ReduceMean(self, dims, keepdims=keepdim) diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index ea92dc347..e23a351bc 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -12,7 +12,8 @@ from __future__ import annotations -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from onnxscript import INT64 from onnxscript.function_libs.torch_lib.registration import torch_op diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py index 05bac181c..af4e7b437 100644 --- a/onnxscript/function_libs/torch_lib/ops/linalg.py +++ b/onnxscript/function_libs/torch_lib/ops/linalg.py @@ -13,7 +13,8 @@ from __future__ import annotations import math -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from onnxscript import BOOL from onnxscript.function_libs.torch_lib.registration import torch_op diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 4a607e75b..e9073ff15 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -15,7 +15,8 @@ from __future__ import annotations import math -from typing import Optional, Sequence, Tuple, TypeVar, Union +from collections.abc import Sequence +from typing import Optional, TypeVar, Union from onnxscript import BFLOAT16, BOOL, DOUBLE, FLOAT, FLOAT16, INT64, ir from onnxscript.function_libs.torch_lib.ops import common as common_ops @@ -89,7 +90,7 @@ def _adjust_attributes_of_avg_pool( kernel_size: Sequence[int], stride: Sequence[int], padding: Sequence[int], -) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]: +) -> tuple[Sequence[int], Sequence[int], Sequence[int]]: """Adjust attributes of avg_pool to match ONNX specification.""" if isinstance(kernel_size, int): @@ -894,7 +895,7 @@ def aten_max_pool1d_with_indices( padding: Sequence[int] = (0,), dilation: Sequence[int] = (1,), ceil_mode: bool = False, -) -> Tuple[TFloatOrUInt8, INT64]: +) -> tuple[TFloatOrUInt8, INT64]: """max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)""" # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly. @@ -925,7 +926,7 @@ def _adjust_attributes_of_max_pool( stride: Sequence[int], padding: Sequence[int], dilation: Sequence[int], -) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: +) -> tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: if isinstance(dilation, int): dilations = [dilation] * expand_size else: @@ -1047,7 +1048,7 @@ def aten_max_pool2d_with_indices( padding: Sequence[int] = (0, 0), dilation: Sequence[int] = (1, 1), ceil_mode: bool = False, -) -> Tuple[TFloatOrUInt8, INT64]: +) -> tuple[TFloatOrUInt8, INT64]: """max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)""" # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly. @@ -1095,7 +1096,7 @@ def aten_max_pool3d_with_indices( padding: Sequence[int] = (0, 0, 0), dilation: Sequence[int] = (1, 1, 1), ceil_mode: bool = False, -) -> Tuple[TFloatOrUInt8, INT64]: +) -> tuple[TFloatOrUInt8, INT64]: """max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)""" # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly. @@ -1131,7 +1132,7 @@ def _aten_max_pool_with_indices_onnx( n_dims_one: Sequence[int], n_dims_zero: Sequence[int], n_dims_axes: Sequence[int], -) -> Tuple[TFloatOrUInt8, INT64]: +) -> tuple[TFloatOrUInt8, INT64]: self_rank_is_unbatched_rank = Rank(self) == unbatched_rank if self_rank_is_unbatched_rank: self = op.Unsqueeze(self, axes=[0]) @@ -1791,7 +1792,7 @@ def aten_scaled_dot_product_attention( def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs( query: TFloat, -) -> Tuple[FLOAT, INT64, INT64, FLOAT]: +) -> tuple[FLOAT, INT64, INT64, FLOAT]: query_first_three_dims = op.Slice( op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3]) ) @@ -1816,7 +1817,7 @@ def aten__scaled_dot_product_flash_attention( is_causal: bool = False, return_debug_mask: bool = False, scale: Optional[float] = None, -) -> Tuple[TFloat, FLOAT, INT64, INT64, INT64, INT64, INT64, INT64, FLOAT]: +) -> tuple[TFloat, FLOAT, INT64, INT64, INT64, INT64, INT64, INT64, FLOAT]: """_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) One of the implementations of scaled_dot_product_attention. @@ -1855,7 +1856,7 @@ def aten__scaled_dot_product_flash_attention( def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( query: TFloat, compute_log_sumexp: bool, -) -> Tuple[FLOAT, INT64]: +) -> tuple[FLOAT, INT64]: """_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)""" query = op.Transpose(query, perm=[0, 2, 1, 3]) @@ -1891,7 +1892,7 @@ def aten__scaled_dot_product_flash_attention_for_cpu( is_causal: bool = False, attn_mask: Optional[TFloat] = None, scale: Optional[float] = None, -) -> Tuple[TFloat, FLOAT]: +) -> tuple[TFloat, FLOAT]: """_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)""" result = aten_scaled_dot_product_attention( query, @@ -1923,7 +1924,7 @@ def aten__scaled_dot_product_efficient_attention( dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, -) -> Tuple[TFloat, FLOAT, INT64, INT64]: +) -> tuple[TFloat, FLOAT, INT64, INT64]: """_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)""" result = aten_scaled_dot_product_attention( diff --git a/onnxscript/function_libs/torch_lib/ops/prims.py b/onnxscript/function_libs/torch_lib/ops/prims.py index ed870b0d7..a26344b82 100644 --- a/onnxscript/function_libs/torch_lib/ops/prims.py +++ b/onnxscript/function_libs/torch_lib/ops/prims.py @@ -12,7 +12,8 @@ from __future__ import annotations -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from onnxscript import INT64 from onnxscript.function_libs.torch_lib.ops import common as common_ops diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index 1b123394d..eabe969ed 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -13,7 +13,8 @@ from __future__ import annotations import math -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import TFloat diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 162d69d74..f265609e8 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -5,7 +5,8 @@ from __future__ import annotations import re -from typing import Any, Callable, Generator, Optional +from collections.abc import Generator +from typing import Any, Callable, Optional import onnxscript from onnxscript.function_libs.torch_lib import _constants diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 3c96f0eee..1a03529d3 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -2,6 +2,8 @@ # Licensed under the MIT License. """In-memory intermediate representation for ONNX graphs.""" +from __future__ import annotations + __all__ = [ # Modules "serde", diff --git a/onnxscript/ir/_convenience/__init__.py b/onnxscript/ir/_convenience/__init__.py index f43685a6f..2afd1b22f 100644 --- a/onnxscript/ir/_convenience/__init__.py +++ b/onnxscript/ir/_convenience/__init__.py @@ -16,7 +16,8 @@ "replace_nodes_and_values", ] -from typing import Mapping, Sequence, Union +from collections.abc import Mapping, Sequence +from typing import Union import onnx diff --git a/onnxscript/ir/_convenience/_constructors.py b/onnxscript/ir/_convenience/_constructors.py index 86477bcf7..871d5a6bd 100644 --- a/onnxscript/ir/_convenience/_constructors.py +++ b/onnxscript/ir/_convenience/_constructors.py @@ -10,7 +10,7 @@ ] import typing -from typing import Mapping, Sequence +from collections.abc import Mapping, Sequence import numpy as np import onnx diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 32073c5b9..c3c8604c5 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -22,18 +22,14 @@ import sys import textwrap import typing -from collections.abc import Hashable +from collections import OrderedDict +from collections.abc import Collection, Hashable, Iterable, Iterator, Sequence +from collections.abc import Set as AbstractSet from typing import ( - AbstractSet, Any, Callable, - Collection, Generic, - Iterable, - Iterator, NamedTuple, - OrderedDict, - Sequence, SupportsInt, Union, ) diff --git a/onnxscript/ir/_linked_list.py b/onnxscript/ir/_linked_list.py index 0db770e20..65c48f296 100644 --- a/onnxscript/ir/_linked_list.py +++ b/onnxscript/ir/_linked_list.py @@ -4,7 +4,8 @@ from __future__ import annotations -from typing import Generic, Iterable, Iterator, Sequence, TypeVar +from collections.abc import Iterable, Iterator, Sequence +from typing import Generic, TypeVar T = TypeVar("T") diff --git a/onnxscript/ir/_metadata.py b/onnxscript/ir/_metadata.py index 77db7cc41..35fef2c94 100644 --- a/onnxscript/ir/_metadata.py +++ b/onnxscript/ir/_metadata.py @@ -5,7 +5,8 @@ from __future__ import annotations import collections -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any class MetadataStore(collections.UserDict): diff --git a/onnxscript/ir/_polyfill.py b/onnxscript/ir/_polyfill.py index fb6008db3..4980e9615 100644 --- a/onnxscript/ir/_polyfill.py +++ b/onnxscript/ir/_polyfill.py @@ -3,7 +3,8 @@ """Polyfill for Python builtin functions.""" import sys -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any if sys.version_info >= (3, 10): zip = zip # pylint: disable=self-assigning-variable diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index fbc2c7c05..a2425b86d 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -31,19 +31,17 @@ from __future__ import annotations import typing -from typing import ( - Any, +from collections import OrderedDict +from collections.abc import ( Collection, Iterable, Iterator, Mapping, MutableMapping, MutableSequence, - OrderedDict, - Protocol, Sequence, - Tuple, ) +from typing import Any, Protocol, Tuple # noqa: UP035 from onnxscript.ir import _enums @@ -52,7 +50,11 @@ from typing_extensions import TypeAlias # An identifier that will uniquely identify an operator. E.g (domain, op_type, overload) -OperatorIdentifier: TypeAlias = Tuple[str, str, str] +OperatorIdentifier: TypeAlias = ( + Tuple[ # Requires Tuple because tuple[] does not have __module__ + str, str, str + ] +) @typing.runtime_checkable diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index d4d88ab5b..53d49bf7a 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -8,7 +8,8 @@ import logging import types import typing -from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union +from collections.abc import Iterator, Mapping, Sequence +from typing import Any, Optional, TypeVar, Union import onnx diff --git a/onnxscript/ir/_schemas_test.py b/onnxscript/ir/_schemas_test.py index c134bd7a6..2c3e2b742 100644 --- a/onnxscript/ir/_schemas_test.py +++ b/onnxscript/ir/_schemas_test.py @@ -3,7 +3,8 @@ from __future__ import annotations import unittest -from typing import Any, Optional, Sequence, TypeVar, Union +from collections.abc import Sequence +from typing import Any, Optional, TypeVar, Union import parameterized diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index fbcfcb428..6cb6c8d65 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -4,19 +4,14 @@ from __future__ import annotations -from typing import ( - Any, - Mapping, - Optional, - Sequence, - Tuple, -) +from collections.abc import Mapping, Sequence +from typing import Any, Optional from onnxscript import ir from onnxscript.ir import _convenience # A type representing the domains/versions used in creating nodes in IR. -UsedOpsets = set[Tuple[str, Optional[int]]] +UsedOpsets = set[tuple[str, Optional[int]]] class Tape: diff --git a/onnxscript/ir/_type_casting.py b/onnxscript/ir/_type_casting.py index 20bab6903..f02656878 100644 --- a/onnxscript/ir/_type_casting.py +++ b/onnxscript/ir/_type_casting.py @@ -6,7 +6,7 @@ from __future__ import annotations import typing -from typing import Sequence +from collections.abc import Sequence import ml_dtypes import numpy as np diff --git a/onnxscript/ir/external_data.py b/onnxscript/ir/external_data.py index 4ca9ca503..824974792 100644 --- a/onnxscript/ir/external_data.py +++ b/onnxscript/ir/external_data.py @@ -15,7 +15,7 @@ import dataclasses import logging import os -from typing import Iterator, Sequence +from collections.abc import Iterator, Sequence from onnxscript.ir import _core, _enums, _protocols from onnxscript.ir import traversal as _traversal diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 18e5c8715..e5f603ad1 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -16,7 +16,8 @@ import dataclasses import logging -from typing import Literal, Sequence, final +from collections.abc import Sequence +from typing import Literal, final __all__ = [ "PassBase", diff --git a/onnxscript/ir/passes/common/inliner.py b/onnxscript/ir/passes/common/inliner.py index 5cefc9426..61229d02f 100644 --- a/onnxscript/ir/passes/common/inliner.py +++ b/onnxscript/ir/passes/common/inliner.py @@ -9,7 +9,7 @@ __all__ = ["InlinePass", "InlinePassResult"] from collections import defaultdict -from typing import Iterable, List, Sequence, Tuple +from collections.abc import Iterable, Sequence import onnxscript.ir.convenience as _ir_convenience from onnxscript import ir @@ -17,13 +17,13 @@ # A replacement for a node specifies a list of nodes that replaces the original node, # and a list of values that replaces the original node's outputs. -NodeReplacement = Tuple[Sequence[ir.Node], Sequence[ir.Value]] +NodeReplacement = tuple[Sequence[ir.Node], Sequence[ir.Value]] # A call stack is a list of identifiers of call sites, where the first element is the # outermost call site, and the last element is the innermost call site. This is used # primarily for generating unique names for values in the inlined functions. CallSiteId = str -CallStack = List[CallSiteId] +CallStack = list[CallSiteId] def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument diff --git a/onnxscript/ir/passes/common/inliner_test.py b/onnxscript/ir/passes/common/inliner_test.py index 7a64a8d4b..95d3b27dc 100644 --- a/onnxscript/ir/passes/common/inliner_test.py +++ b/onnxscript/ir/passes/common/inliner_test.py @@ -5,7 +5,8 @@ from __future__ import annotations import unittest -from typing import Callable, Sequence +from collections.abc import Sequence +from typing import Callable import onnx from onnx import parser diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 64703b2ba..c6d97d464 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -61,7 +61,8 @@ import collections import logging import os -from typing import Any, Callable, List, Mapping, Sequence +from collections.abc import Mapping, Sequence +from typing import Any, Callable import numpy as np import onnx @@ -740,7 +741,7 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function: name=proto.name, overload=getattr(proto, "overload", ""), graph=graph, - attributes=typing.cast(List[_core.Attr], attributes), + attributes=typing.cast("list[_core.Attr]", attributes), ) diff --git a/onnxscript/ir/traversal.py b/onnxscript/ir/traversal.py index 5fa9a9acf..8c9107d9a 100644 --- a/onnxscript/ir/traversal.py +++ b/onnxscript/ir/traversal.py @@ -8,7 +8,8 @@ "RecursiveGraphIterator", ] -from typing import Callable, Iterator, Reversible, Union +from collections.abc import Iterator, Reversible +from typing import Callable, Union from typing_extensions import Self diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index a845dcbc5..77403b094 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -7,7 +7,8 @@ import io import logging import warnings -from typing import Any, Optional, Protocol, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Protocol, Union import onnx from onnx import ValueInfoProto, helper diff --git a/onnxscript/main.py b/onnxscript/main.py index 3ea3e50f9..8ce58dc67 100644 --- a/onnxscript/main.py +++ b/onnxscript/main.py @@ -6,7 +6,8 @@ import ast import inspect import sys -from typing import Any, Callable, Optional, Sequence, TypeVar +from collections.abc import Sequence +from typing import Any, Callable, Optional, TypeVar from typing_extensions import ParamSpec diff --git a/onnxscript/onnx_opset/__init__.py b/onnxscript/onnx_opset/__init__.py index c720c35bb..5fecc4eb8 100644 --- a/onnxscript/onnx_opset/__init__.py +++ b/onnxscript/onnx_opset/__init__.py @@ -13,7 +13,8 @@ from __future__ import annotations -from typing import Mapping, Tuple +from collections.abc import Mapping +from typing import Tuple from onnx.defs import onnx_opset_version diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index af1d5b491..d64b6c629 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -4,7 +4,7 @@ from __future__ import annotations import abc -from typing import ClassVar, Optional, Tuple, Union +from typing import ClassVar, Optional, Union import onnx @@ -12,7 +12,7 @@ _DType = onnxscript.ir.DataType _DimType = Union[int, str, type(None)] -_ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)] +_ShapeType = Union[tuple[_DimType, ...], _DimType, type(Ellipsis)] _tensor_type_shape_cache: dict[_DType, TensorType] = {} tensor_type_registry: dict[_DType, TensorType] = {} diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index cce74cb13..514302022 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -9,7 +9,8 @@ import logging import math import typing -from typing import Any, Callable, Iterable, Sequence, Union +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Union import numpy as np import onnx diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 5efaf784b..d0ed5000e 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, TypeVar, Union +from collections.abc import Sequence +from typing import TypeVar, Union __all__ = [ "pattern", diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index 166b81d7e..232cd5475 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Callable, Sequence, Union +from collections.abc import Sequence +from typing import Callable, Union import onnxscript.ir as ir from onnxscript.rewriter import pattern diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index d6c4177ae..e4773c579 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -3,7 +3,8 @@ from __future__ import annotations import math -from typing import Callable, Sequence +from collections.abc import Sequence +from typing import Callable import numpy as np diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index 42bc1ce76..26fe61071 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -7,7 +7,8 @@ import os import textwrap import warnings -from typing import Any, Callable, Iterator, Sequence +from collections.abc import Iterator, Sequence +from typing import Any, Callable import onnxscript.rewriter.pattern as orp from onnxscript import ir diff --git a/onnxscript/rewriter/ort_fusions/attention.py b/onnxscript/rewriter/ort_fusions/attention.py index 2738432cd..35fdf45e1 100644 --- a/onnxscript/rewriter/ort_fusions/attention.py +++ b/onnxscript/rewriter/ort_fusions/attention.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union import onnxscript.ir as ir from onnxscript.rewriter import _fusion_utils, pattern diff --git a/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py index 75c4f66f9..83fed97aa 100644 --- a/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py +++ b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union import onnxscript.ir as ir from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 266987dd4..e6eb0d6b7 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union import numpy as np diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 5fed44691..dab51774d 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union import onnxscript.ir as ir from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index cfca31125..9363dc31a 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -10,15 +10,11 @@ import itertools import math from collections import defaultdict +from collections.abc import Iterable, Iterator, MutableSequence, Sequence from typing import ( Any, Callable, - Iterable, - Iterator, - MutableSequence, Protocol, - Sequence, - Tuple, TypeVar, Union, ) @@ -535,7 +531,7 @@ def __str__(self) -> str: inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs return f"{outputs} = {qualified_op} ({inputs_and_attributes})" - def op_identifier(self) -> Tuple[str, str, str] | None: + def op_identifier(self) -> tuple[str, str, str] | None: return self._op_identifier @property diff --git a/onnxscript/testing/__init__.py b/onnxscript/testing/__init__.py index 048b45e7e..fc150056c 100644 --- a/onnxscript/testing/__init__.py +++ b/onnxscript/testing/__init__.py @@ -11,7 +11,8 @@ import difflib import math -from typing import Any, Collection, Sequence +from collections.abc import Collection, Sequence +from typing import Any import google.protobuf.message import numpy as np diff --git a/onnxscript/tools/transformers_models/__init__.py b/onnxscript/tools/transformers_models/__init__.py index ed4648916..b23908244 100644 --- a/onnxscript/tools/transformers_models/__init__.py +++ b/onnxscript/tools/transformers_models/__init__.py @@ -6,7 +6,8 @@ from __future__ import annotations import random -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import onnx import onnx.inliner diff --git a/onnxscript/tools/transformers_models/llama.py b/onnxscript/tools/transformers_models/llama.py index 9b1337167..ea62e613d 100644 --- a/onnxscript/tools/transformers_models/llama.py +++ b/onnxscript/tools/transformers_models/llama.py @@ -5,7 +5,8 @@ # pylint: disable=import-outside-toplevel from __future__ import annotations -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import torch diff --git a/onnxscript/tools/transformers_models/mistral.py b/onnxscript/tools/transformers_models/mistral.py index d053b9057..e013c7126 100644 --- a/onnxscript/tools/transformers_models/mistral.py +++ b/onnxscript/tools/transformers_models/mistral.py @@ -5,7 +5,8 @@ # pylint: disable=import-outside-toplevel from __future__ import annotations -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import torch diff --git a/onnxscript/tools/transformers_models/phi.py b/onnxscript/tools/transformers_models/phi.py index f1cb88edd..0c2cc5daf 100644 --- a/onnxscript/tools/transformers_models/phi.py +++ b/onnxscript/tools/transformers_models/phi.py @@ -5,7 +5,8 @@ # pylint: disable=import-outside-toplevel from __future__ import annotations -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import torch diff --git a/onnxscript/tools/transformers_models/phi3.py b/onnxscript/tools/transformers_models/phi3.py index f5bf7beb5..9a6522e9b 100644 --- a/onnxscript/tools/transformers_models/phi3.py +++ b/onnxscript/tools/transformers_models/phi3.py @@ -4,7 +4,8 @@ from __future__ import annotations -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import torch diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 8a71b5c2d..667b6b46c 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -2,10 +2,10 @@ # Licensed under the MIT License. from __future__ import annotations -import collections import inspect import typing -from typing import Optional, Sequence, Union +from collections.abc import Sequence +from typing import Optional, Union import onnx @@ -40,7 +40,7 @@ bool: onnx.AttributeProto.INTS, # experimental } -_LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) +_LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, Sequence]) # Map from ONNX AttributeProto type to its representation (in ONNX Script). _ATTRTYPE_TO_REPR = { @@ -76,10 +76,8 @@ def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeTy def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue: """Remove Annotated wrapper if present, otherwise return typeinfo as is.""" - if hasattr(typing, "Annotated"): - # Present in Python 3.9+ - if typing.get_origin(typeinfo) is typing.Annotated: - return typing.get_args(typeinfo)[0] + if typing.get_origin(typeinfo) is typing.Annotated: + return typing.get_args(typeinfo)[0] return typeinfo @@ -130,6 +128,10 @@ def base_type_is_bool(pytype: TypeAnnotationValue) -> bool: def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: if isinstance(typeinfo, onnx_types.TensorType): return True + if typeinfo is onnx_types.TensorType: + # Special case the handle when typeinfo is TensorType. + # It seems abc.ABC in py39 has issues with issubclass + return True if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType): return True return False @@ -169,6 +171,10 @@ def is_value_type(typeinfo: TypeAnnotationValue) -> bool: if hasattr(typeinfo, "__bound__"): bound = typeinfo.__bound__ return is_value_type(bound) + if hasattr(typeinfo, "__constraints__"): + constraints = typeinfo.__constraints__ + if constraints: + return any(is_value_type(x) for x in constraints) raise ValueError(f"Unsupported type annotation {typeinfo}") diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 4104eb51d..b2e5b2791 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. import unittest -from typing import Any, List, Optional, Sequence, TypeVar, Union +from collections.abc import Sequence +from typing import Any, Optional, TypeVar, Union import parameterized @@ -212,7 +213,7 @@ class TypeConversionFunctionsTest(unittest.TestCase): ), ] ) - def test_pytype_to_type_strings(self, _, pytype: Any, expected: List[str]): + def test_pytype_to_type_strings(self, _, pytype: Any, expected: list[str]): self.assertEqual(type_annotation.pytype_to_type_strings(pytype), expected) @parameterized.parameterized.expand( diff --git a/onnxscript/values.py b/onnxscript/values.py index 266f7da57..9e7bb767c 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -8,6 +8,7 @@ import logging import types import typing +from collections.abc import Sequence from enum import IntFlag from typing import ( # type: ignore[attr-defined] Any, @@ -16,7 +17,6 @@ Generic, Optional, Protocol, - Sequence, TypeVar, _GenericAlias, ) @@ -749,7 +749,10 @@ def __init__(self, info: sourceinfo.SourceInfo) -> None: class AttrRef(SymbolValue): def __init__( - self, attr_name: str, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo + self, + attr_name: str, + typeinfo: _GenericAlias | types.GenericAlias, + info: sourceinfo.SourceInfo, ) -> None: """Initializes AttrRef. @@ -762,9 +765,10 @@ def __init__( super().__init__(info) self.value = attr_name self.typeinfo = typeinfo - if not isinstance(typeinfo, (type, _GenericAlias)): + if not isinstance(typeinfo, (type, _GenericAlias, types.GenericAlias)): # typing._GenericAlias for List[int] and List[str], etc. - raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.") + # types.GenericAlias for list[int] and tuple[int], etc. + raise TypeError(f"Expecting a type not {type(typeinfo)} for typeinfo.") self.typeinfo = typeinfo diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 46b4596fb..0fd39e87d 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -7,7 +7,8 @@ import dataclasses import functools import logging -from typing import Callable, Sequence, Union +from collections.abc import Sequence +from typing import Callable, Union import onnxscript.ir.convenience as ir_convenience import onnxscript.rewriter.pattern as orp diff --git a/opgen/onnx_opset_builder.py b/opgen/onnx_opset_builder.py index 5fd1f60b6..d94cba244 100644 --- a/opgen/onnx_opset_builder.py +++ b/opgen/onnx_opset_builder.py @@ -5,9 +5,10 @@ from __future__ import annotations +from collections.abc import Iterable from pathlib import Path from textwrap import dedent -from typing import Annotated, Any, Iterable, Optional, Set, TextIO +from typing import Annotated, Any, Optional, TextIO import pygen as cg from onnx.defs import ( @@ -157,8 +158,8 @@ def __init__( *, module_base_name: str, min_default_opset_version: int, - include_opsets: Optional[Set[OpsetId]] = None, - exclude_opsets: Optional[Set[OpsetId]] = None, + include_opsets: Optional[set[OpsetId]] = None, + exclude_opsets: Optional[set[OpsetId]] = None, ): self.module_base_name = module_base_name self.min_default_opset_version = min_default_opset_version diff --git a/opgen/pygen.py b/opgen/pygen.py index bea743118..a0228b093 100644 --- a/opgen/pygen.py +++ b/opgen/pygen.py @@ -8,18 +8,15 @@ import io from abc import ABC, abstractmethod +from collections.abc import Iterable from enum import Enum from textwrap import TextWrapper, dedent from typing import ( Any, Callable, Generic, - Iterable, Optional, - Set, TextIO, - Tuple, - Type, TypeVar, Union, ) @@ -30,7 +27,7 @@ NoneType = type(None) -def _assert_instance(instance, expected_type: Union[Type, Tuple[Type, ...]]): +def _assert_instance(instance, expected_type: Union[type, tuple[type, ...]]): if not isinstance(instance, expected_type): raise TypeError(f"expected: {expected_type!r}; actual: {instance!r}") @@ -71,7 +68,7 @@ class NodePredicate: def __init__( self, role: Optional[Role] = None, - type_: Optional[Type[TNode]] = None, + type_: Optional[type[TNode]] = None, func: Optional[Callable[[Node], bool]] = None, ): _assert_instance(role, (Role, NoneType)) @@ -164,7 +161,7 @@ def get_children_in_role(self, role: Role): _assert_instance(role, Role) return self.get_children(NodePredicate(role=role)) - def get_children_of_type(self, type_: Type[TNode]) -> Iterable[TNode]: + def get_children_of_type(self, type_: type[TNode]) -> Iterable[TNode]: _assert_instance(type_, type) return self.get_children(NodePredicate(type_=type_)) @@ -183,7 +180,7 @@ def get_ancestors_in_role(self, role: Role, and_self=False): _assert_instance(role, Role) return self.get_ancestors(NodePredicate(role=role), and_self=and_self) - def get_ancestors_of_type(self, type_: Type[TNode], and_self=False) -> Iterable[TNode]: + def get_ancestors_of_type(self, type_: type[TNode], and_self=False) -> Iterable[TNode]: _assert_instance(type_, type) return self.get_ancestors(NodePredicate(type_=type_), and_self=and_self) @@ -1131,7 +1128,7 @@ def __init__(self, predicate: NodePredicate): super().__init__() _assert_instance(predicate, NodePredicate) self._predicate = predicate - self.names: Set[str] = set() + self.names: set[str] = set() def leave(self, node: Node) -> Optional[bool]: if self._predicate.matches(node) and hasattr(node, "name"): @@ -1141,7 +1138,7 @@ def leave(self, node: Node) -> Optional[bool]: class ImportAdjuster(FixupVisitor): def __init__(self): super().__init__() - self.naming_conflicts: Set[str] = set() + self.naming_conflicts: set[str] = set() def enter(self, node: Node): if len(self.node_stack) == 0: diff --git a/pyproject.toml b/pyproject.toml index 361ba40aa..04b56cd0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version", "urls"] description = "Naturally author ONNX functions and models using a subset of Python" authors = [{ name = "Microsoft Corporation", email = "onnx@microsoft.com" }] readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { file = "LICENSE" } classifiers = [ "Development Status :: 4 - Beta", @@ -17,7 +17,6 @@ classifiers = [ "Operating System :: POSIX", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -139,7 +138,7 @@ convention = "google" [tool.ruff] line-length = 95 -target-version = "py38" +target-version = "py39" [tool.ruff.lint] select = [ @@ -217,6 +216,7 @@ ignore-init-module-imports = true "setup.py" = ["TID251"] # pathlib is allowed in supporting code "**/{examples,tests,docs,tools,utils,opgen,_framework_apis}/*" = ["TID251"] # pathlib is allowed in supporting code "**/*_test.py" = ["TID251"] # pathlib is allowed in tests +"onnxscript/onnx_opset/*" = ["UP035"] # Need to update opgen to use the new types [tool.ruff.lint.flake8-tidy-imports] # Disallow all relative imports. diff --git a/tests/common/onnx_script_test_case.py b/tests/common/onnx_script_test_case.py index 3a46a870a..ab6911432 100644 --- a/tests/common/onnx_script_test_case.py +++ b/tests/common/onnx_script_test_case.py @@ -7,7 +7,8 @@ import numbers import unittest import warnings -from typing import Any, Collection, Iterable, Optional, Sequence +from collections.abc import Collection, Iterable, Sequence +from typing import Any, Optional import numpy as np import onnx diff --git a/tests/function_libs/torch_lib/error_reproduction.py b/tests/function_libs/torch_lib/error_reproduction.py index 1eac88c48..93afd15c8 100644 --- a/tests/function_libs/torch_lib/error_reproduction.py +++ b/tests/function_libs/torch_lib/error_reproduction.py @@ -8,7 +8,8 @@ import sys import time import traceback -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import numpy as np import onnx diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 26b75bf93..850567fc1 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -7,7 +7,7 @@ import functools import itertools -from typing import Any, List +from typing import Any import torch import torchvision @@ -2150,7 +2150,7 @@ def __init__(self): # in ops_test_data.py and opinfo_core.OpInfo("unique_name", ...) # To avoid name duplication, it is possible to rename the OpInfo and specify # the `op` field explicitly. -OP_DB: List[opinfo_core.OpInfo] = [ +OP_DB: list[opinfo_core.OpInfo] = [ opinfo_core.OpInfo( "ops.aten.bernoulli.p", aten_name="bernoulli.p", diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 59e6c98c9..e01deecf5 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -27,7 +27,8 @@ import os import unittest -from typing import Callable, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Callable, Optional import numpy as np import onnx @@ -71,7 +72,7 @@ def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]: def _should_skip_xfail_test_sample( op_name: str, sample, dtype: torch.dtype, device_type: str -) -> Tuple[Optional[str], Optional[str]]: +) -> tuple[Optional[str], Optional[str]]: """Returns a reason if a test sample should be skipped.""" if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS: return None, None diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index a9f922ce2..f22b86032 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -13,14 +13,11 @@ import sys import unittest import warnings +from collections.abc import Collection, Iterable, Mapping, Sequence from typing import ( Any, Callable, - Collection, - Iterable, - Mapping, Optional, - Sequence, TypeVar, ) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 3628ed8c4..8f9188a3e 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -39,7 +39,8 @@ import copy import dataclasses import functools -from typing import Any, Callable, Collection, Optional +from collections.abc import Collection +from typing import Any, Callable, Optional import numpy as np import torch diff --git a/tests/ir/public_api_test.py b/tests/ir/public_api_test.py index ac2655cf4..38e054e8a 100644 --- a/tests/ir/public_api_test.py +++ b/tests/ir/public_api_test.py @@ -12,7 +12,7 @@ import pathlib import pkgutil import unittest -from typing import Iterable +from collections.abc import Iterable import onnxscript.ir diff --git a/tools/diagnostics/gen_diagnostics.py b/tools/diagnostics/gen_diagnostics.py index cf0f0f35b..322107255 100644 --- a/tools/diagnostics/gen_diagnostics.py +++ b/tools/diagnostics/gen_diagnostics.py @@ -19,7 +19,8 @@ import string import subprocess import textwrap -from typing import Any, Mapping, Sequence +from collections.abc import Mapping, Sequence +from typing import Any import yaml from torchgen import utils as torchgen_utils