From 68f97f330291412d10b9d3a1c6db51180bbb6bef Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 27 Apr 2025 18:45:46 -0700 Subject: [PATCH 01/24] Formally drop python 3.8 support --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ff873319fb..88c1c71954 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version", "urls"] description = "Naturally author ONNX functions and models using a subset of Python" authors = [{ name = "Microsoft Corporation", email = "onnx@microsoft.com" }] readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { file = "LICENSE" } classifiers = [ "Development Status :: 4 - Beta", @@ -17,7 +17,6 @@ classifiers = [ "Operating System :: POSIX", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", From a258172029c1857fdccd09f4ecacc80ca3e2bd5c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 27 Apr 2025 18:51:25 -0700 Subject: [PATCH 02/24] Update converter.py --- onnxscript/converter.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 1ee6e0ecd0..1421376d09 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -528,12 +528,7 @@ def _translate_attr( return attr def _translate_docstring(self, node: ast.Expr) -> None: - if hasattr(node.value, "value"): - # python 3.8+ - return self.ir_builder.add_docstring(self._current_fn, node.value.value) - raise TypeError( - f"Unexpected type {type(node)!r} for node. Unsupoorted version of python." - ) + return self.ir_builder.add_docstring(self._current_fn, node.value.value) def _translate_expr( self, node: ast.AST, target: Optional[PreferredName] = None @@ -870,14 +865,7 @@ def _translate_unary_op_expr(self, node): # should intercept this call and replace node # by node.operand. # This mechanism does not handle somthing like `(-(-5))`. - if hasattr(node.operand, "value"): - # python 3.8+ - val = node.operand.value - else: - raise TypeError( - f"Unable to guess constant value from type {type(node.operand)!r} " - f"and attributes {dir(node.operand)!r}." - ) + val = node.operand.value if op == ast.USub: cst = ast.Constant(-val, lineno=node.lineno, col_offset=node.col_offset) return self._translate_expr(cst) From d65d0c745942e4abecb87b1e22f6ab5839fdb89b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 27 Apr 2025 18:53:44 -0700 Subject: [PATCH 03/24] py --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 88c1c71954..647e9b8a23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,7 +138,7 @@ convention = "google" [tool.ruff] line-length = 95 -target-version = "py38" +target-version = "py39" [tool.ruff.lint] select = [ From 7116f3697b50bbe16f8d9feb34705de8fbf88686 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 27 Apr 2025 19:01:02 -0700 Subject: [PATCH 04/24] Format --- onnxscript/_internal/analysis.py | 3 ++- onnxscript/_internal/autocast.py | 3 ++- onnxscript/_internal/param_manipulation.py | 4 +++- onnxscript/_internal/utils.py | 3 ++- onnxscript/_internal/version_utils.py | 3 ++- onnxscript/_legacy_ir/visitor.py | 3 ++- onnxscript/_thirdparty/asciichartpy.py | 2 +- onnxscript/backend/onnx_backend.py | 2 +- onnxscript/backend/onnx_export.py | 3 ++- onnxscript/backend/onnx_export_test.py | 2 +- onnxscript/converter.py | 2 +- onnxscript/evaluator.py | 3 +-- .../tools/torch_lib/deduce_type_constraints.py | 3 ++- .../tools/torch_lib/deduce_type_constraints_test.py | 2 +- .../tools/torch_lib/generate_aten_signatures.py | 3 ++- .../tools/torch_lib/generate_prims_signatures.py | 3 ++- .../torch_lib/graph_building/_graph_building_torch.py | 3 ++- onnxscript/function_libs/torch_lib/ops/core.py | 3 ++- onnxscript/function_libs/torch_lib/ops/fft.py | 3 ++- onnxscript/function_libs/torch_lib/ops/linalg.py | 3 ++- onnxscript/function_libs/torch_lib/ops/nn.py | 3 ++- onnxscript/function_libs/torch_lib/ops/prims.py | 3 ++- onnxscript/function_libs/torch_lib/ops/special.py | 3 ++- onnxscript/function_libs/torch_lib/registration.py | 3 ++- onnxscript/ir/_convenience/__init__.py | 3 ++- onnxscript/ir/_convenience/_constructors.py | 2 +- onnxscript/ir/_core.py | 8 ++------ onnxscript/ir/_linked_list.py | 3 ++- onnxscript/ir/_metadata.py | 3 ++- onnxscript/ir/_polyfill.py | 3 ++- onnxscript/ir/_protocols.py | 10 ++++++---- onnxscript/ir/_schemas.py | 3 ++- onnxscript/ir/_schemas_test.py | 3 ++- onnxscript/ir/_tape.py | 3 +-- onnxscript/ir/_type_casting.py | 2 +- onnxscript/ir/external_data.py | 2 +- onnxscript/ir/passes/_pass_infra.py | 2 +- onnxscript/ir/passes/common/inliner.py | 3 ++- onnxscript/ir/passes/common/inliner_test.py | 3 ++- onnxscript/ir/serde.py | 3 ++- onnxscript/ir/traversal.py | 3 ++- onnxscript/irbuilder.py | 3 ++- onnxscript/main.py | 3 ++- onnxscript/onnx_opset/__init__.py | 3 ++- onnxscript/optimizer/_constant_folding.py | 3 ++- onnxscript/rewriter/__init__.py | 3 ++- onnxscript/rewriter/_fusion_utils.py | 3 ++- onnxscript/rewriter/_ir_utils.py | 3 ++- onnxscript/rewriter/generic_pattern.py | 3 ++- onnxscript/rewriter/ort_fusions/attention.py | 3 ++- onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py | 3 ++- onnxscript/rewriter/ort_fusions/gqa.py | 3 ++- onnxscript/rewriter/ort_fusions/mha.py | 3 ++- onnxscript/rewriter/pattern.py | 5 +---- onnxscript/testing/__init__.py | 3 ++- onnxscript/tools/transformers_models/__init__.py | 3 ++- onnxscript/tools/transformers_models/llama.py | 3 ++- onnxscript/tools/transformers_models/mistral.py | 3 ++- onnxscript/tools/transformers_models/phi.py | 3 ++- onnxscript/tools/transformers_models/phi3.py | 3 ++- onnxscript/type_annotation.py | 3 ++- onnxscript/type_annotation_test.py | 3 ++- onnxscript/values.py | 2 +- onnxscript/version_converter/_version_converter.py | 3 ++- opgen/onnx_opset_builder.py | 3 ++- opgen/pygen.py | 2 +- pyproject.toml | 1 + tests/common/onnx_script_test_case.py | 3 ++- tests/function_libs/torch_lib/error_reproduction.py | 3 ++- tests/function_libs/torch_lib/ops_test.py | 3 ++- tests/function_libs/torch_lib/ops_test_common.py | 5 +---- tests/function_libs/torch_lib/ops_test_data.py | 3 ++- tests/ir/public_api_test.py | 2 +- tools/diagnostics/gen_diagnostics.py | 3 ++- 74 files changed, 136 insertions(+), 89 deletions(-) diff --git a/onnxscript/_internal/analysis.py b/onnxscript/_internal/analysis.py index 0403f60c91..9fae662d2a 100644 --- a/onnxscript/_internal/analysis.py +++ b/onnxscript/_internal/analysis.py @@ -3,7 +3,8 @@ from __future__ import annotations import ast -from typing import Any, Optional, Sequence, Set +from collections.abc import Sequence +from typing import Any, Optional, Set from onnxscript import sourceinfo from onnxscript._internal import ast_utils diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 00fab2432d..477d58dd91 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -3,7 +3,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Callable, Optional import numpy as np import onnx diff --git a/onnxscript/_internal/param_manipulation.py b/onnxscript/_internal/param_manipulation.py index b3591a0a8d..0cda7ff335 100644 --- a/onnxscript/_internal/param_manipulation.py +++ b/onnxscript/_internal/param_manipulation.py @@ -5,7 +5,9 @@ from __future__ import annotations import collections -from typing import Any, OrderedDict, Sequence +from collections import OrderedDict +from collections.abc import Sequence +from typing import Any from onnxscript import values diff --git a/onnxscript/_internal/utils.py b/onnxscript/_internal/utils.py index e081bb34a2..5f53df7489 100644 --- a/onnxscript/_internal/utils.py +++ b/onnxscript/_internal/utils.py @@ -3,7 +3,8 @@ from __future__ import annotations import numbers -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional import numpy as np import onnx diff --git a/onnxscript/_internal/version_utils.py b/onnxscript/_internal/version_utils.py index 2b43c54f49..8553e275ea 100644 --- a/onnxscript/_internal/version_utils.py +++ b/onnxscript/_internal/version_utils.py @@ -5,7 +5,8 @@ from __future__ import annotations import warnings -from typing import Callable, Sequence +from collections.abc import Sequence +from typing import Callable import packaging.version diff --git a/onnxscript/_legacy_ir/visitor.py b/onnxscript/_legacy_ir/visitor.py index 8dcc3893ab..be3e5a89d8 100644 --- a/onnxscript/_legacy_ir/visitor.py +++ b/onnxscript/_legacy_ir/visitor.py @@ -4,7 +4,8 @@ import dataclasses import logging -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import numpy as np import onnx diff --git a/onnxscript/_thirdparty/asciichartpy.py b/onnxscript/_thirdparty/asciichartpy.py index 88c46202ca..740d8c1451 100644 --- a/onnxscript/_thirdparty/asciichartpy.py +++ b/onnxscript/_thirdparty/asciichartpy.py @@ -32,8 +32,8 @@ from __future__ import annotations +from collections.abc import Mapping from math import ceil, floor, isnan -from typing import Mapping black = "\033[30m" red = "\033[31m" diff --git a/onnxscript/backend/onnx_backend.py b/onnxscript/backend/onnx_backend.py index 78089ebe6a..c6afae06b5 100644 --- a/onnxscript/backend/onnx_backend.py +++ b/onnxscript/backend/onnx_backend.py @@ -4,7 +4,7 @@ import os import textwrap -from typing import Iterator +from collections.abc import Iterator import numpy as np import onnx diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index b3f695d700..22855a3f13 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Any, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Optional import numpy import onnx diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index 1d05428a2c..8a3ec33c27 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -9,7 +9,7 @@ import re import sys import unittest -from typing import Pattern +from re import Pattern import onnx import onnxruntime as ort diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 1421376d09..46e57e3400 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -4,6 +4,7 @@ import ast import logging +from collections.abc import Sequence from typing import ( TYPE_CHECKING, Any, @@ -11,7 +12,6 @@ List, NoReturn, Optional, - Sequence, Tuple, Union, ) diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index 97551567bb..25125c1961 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -5,13 +5,12 @@ import abc import contextlib import pprint +from collections.abc import Mapping, Sequence from typing import ( Any, Callable, - Mapping, Optional, Protocol, - Sequence, TypeVar, Union, runtime_checkable, diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py index c5b87898c9..6bb2da11af 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py @@ -5,7 +5,8 @@ import copy import dataclasses import logging -from typing import Dict, Mapping, Optional, Sequence, Set +from collections.abc import Mapping, Sequence +from typing import Dict, Optional, Set import onnx import onnx.defs diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index a8d15c242a..68031ccd1b 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -7,7 +7,7 @@ import inspect import logging import unittest -from typing import Generator +from collections.abc import Generator import parameterized diff --git a/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py index eb2d8015a4..056ae4ed95 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_aten_signatures.py @@ -13,7 +13,8 @@ import os import textwrap import typing -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import torchgen.gen import torchgen.model diff --git a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py index ebbdd43bd8..f93cfdd070 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py @@ -13,7 +13,8 @@ import os import re import textwrap -from typing import Any, Dict, List, Sequence +from collections.abc import Sequence +from typing import Any, Dict, List import torch import torchgen.gen diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 8d0aab509e..6ed2b935d2 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -7,7 +7,8 @@ import os import tempfile import typing -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from collections.abc import Mapping, Sequence +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import onnx diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index ea43c2c4db..e0bfa74297 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -12,7 +12,8 @@ from __future__ import annotations import math -from typing import Any, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Optional, Tuple, Union import numpy as np import torch diff --git a/onnxscript/function_libs/torch_lib/ops/fft.py b/onnxscript/function_libs/torch_lib/ops/fft.py index ea92dc347d..e23a351bc8 100644 --- a/onnxscript/function_libs/torch_lib/ops/fft.py +++ b/onnxscript/function_libs/torch_lib/ops/fft.py @@ -12,7 +12,8 @@ from __future__ import annotations -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from onnxscript import INT64 from onnxscript.function_libs.torch_lib.registration import torch_op diff --git a/onnxscript/function_libs/torch_lib/ops/linalg.py b/onnxscript/function_libs/torch_lib/ops/linalg.py index 05bac181ca..af4e7b437f 100644 --- a/onnxscript/function_libs/torch_lib/ops/linalg.py +++ b/onnxscript/function_libs/torch_lib/ops/linalg.py @@ -13,7 +13,8 @@ from __future__ import annotations import math -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from onnxscript import BOOL from onnxscript.function_libs.torch_lib.registration import torch_op diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 34f143b4ee..93bed5a51c 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -15,7 +15,8 @@ from __future__ import annotations import math -from typing import Optional, Sequence, Tuple, TypeVar, Union +from collections.abc import Sequence +from typing import Optional, Tuple, TypeVar, Union import onnx diff --git a/onnxscript/function_libs/torch_lib/ops/prims.py b/onnxscript/function_libs/torch_lib/ops/prims.py index ed870b0d7d..a26344b826 100644 --- a/onnxscript/function_libs/torch_lib/ops/prims.py +++ b/onnxscript/function_libs/torch_lib/ops/prims.py @@ -12,7 +12,8 @@ from __future__ import annotations -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from onnxscript import INT64 from onnxscript.function_libs.torch_lib.ops import common as common_ops diff --git a/onnxscript/function_libs/torch_lib/ops/special.py b/onnxscript/function_libs/torch_lib/ops/special.py index 1b123394d3..eabe969ed6 100644 --- a/onnxscript/function_libs/torch_lib/ops/special.py +++ b/onnxscript/function_libs/torch_lib/ops/special.py @@ -13,7 +13,8 @@ from __future__ import annotations import math -from typing import Optional, Sequence +from collections.abc import Sequence +from typing import Optional from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import TFloat diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 162d69d747..f265609e88 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -5,7 +5,8 @@ from __future__ import annotations import re -from typing import Any, Callable, Generator, Optional +from collections.abc import Generator +from typing import Any, Callable, Optional import onnxscript from onnxscript.function_libs.torch_lib import _constants diff --git a/onnxscript/ir/_convenience/__init__.py b/onnxscript/ir/_convenience/__init__.py index 0addc9da2f..c237a9abdc 100644 --- a/onnxscript/ir/_convenience/__init__.py +++ b/onnxscript/ir/_convenience/__init__.py @@ -16,7 +16,8 @@ "replace_nodes_and_values", ] -from typing import Mapping, Sequence, Union +from collections.abc import Mapping, Sequence +from typing import Union import onnx diff --git a/onnxscript/ir/_convenience/_constructors.py b/onnxscript/ir/_convenience/_constructors.py index 3c6137f8cc..6245893298 100644 --- a/onnxscript/ir/_convenience/_constructors.py +++ b/onnxscript/ir/_convenience/_constructors.py @@ -10,7 +10,7 @@ ] import typing -from typing import Mapping, Sequence +from collections.abc import Mapping, Sequence import numpy as np import onnx diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 58dad2e6bb..0931ac533a 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -22,17 +22,13 @@ import sys import textwrap import typing -from collections.abc import Hashable +from collections import OrderedDict +from collections.abc import Collection, Hashable, Iterable, Iterator, Sequence from typing import ( AbstractSet, Any, - Collection, Generic, - Iterable, - Iterator, NamedTuple, - OrderedDict, - Sequence, SupportsInt, Union, ) diff --git a/onnxscript/ir/_linked_list.py b/onnxscript/ir/_linked_list.py index 0db770e20e..65c48f2960 100644 --- a/onnxscript/ir/_linked_list.py +++ b/onnxscript/ir/_linked_list.py @@ -4,7 +4,8 @@ from __future__ import annotations -from typing import Generic, Iterable, Iterator, Sequence, TypeVar +from collections.abc import Iterable, Iterator, Sequence +from typing import Generic, TypeVar T = TypeVar("T") diff --git a/onnxscript/ir/_metadata.py b/onnxscript/ir/_metadata.py index 77db7cc410..35fef2c945 100644 --- a/onnxscript/ir/_metadata.py +++ b/onnxscript/ir/_metadata.py @@ -5,7 +5,8 @@ from __future__ import annotations import collections -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any class MetadataStore(collections.UserDict): diff --git a/onnxscript/ir/_polyfill.py b/onnxscript/ir/_polyfill.py index fb6008db37..4980e96154 100644 --- a/onnxscript/ir/_polyfill.py +++ b/onnxscript/ir/_polyfill.py @@ -3,7 +3,8 @@ """Polyfill for Python builtin functions.""" import sys -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any if sys.version_info >= (3, 10): zip = zip # pylint: disable=self-assigning-variable diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index fbc2c7c054..a68761f48d 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -31,17 +31,19 @@ from __future__ import annotations import typing -from typing import ( - Any, +from collections import OrderedDict +from collections.abc import ( Collection, Iterable, Iterator, Mapping, MutableMapping, MutableSequence, - OrderedDict, - Protocol, Sequence, +) +from typing import ( + Any, + Protocol, Tuple, ) diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py index d4d88ab5bb..53d49bf7aa 100644 --- a/onnxscript/ir/_schemas.py +++ b/onnxscript/ir/_schemas.py @@ -8,7 +8,8 @@ import logging import types import typing -from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union +from collections.abc import Iterator, Mapping, Sequence +from typing import Any, Optional, TypeVar, Union import onnx diff --git a/onnxscript/ir/_schemas_test.py b/onnxscript/ir/_schemas_test.py index c134bd7a63..2c3e2b7429 100644 --- a/onnxscript/ir/_schemas_test.py +++ b/onnxscript/ir/_schemas_test.py @@ -3,7 +3,8 @@ from __future__ import annotations import unittest -from typing import Any, Optional, Sequence, TypeVar, Union +from collections.abc import Sequence +from typing import Any, Optional, TypeVar, Union import parameterized diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index 340142df3d..b3472949a5 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -4,11 +4,10 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from typing import ( Any, - Mapping, Optional, - Sequence, Tuple, ) diff --git a/onnxscript/ir/_type_casting.py b/onnxscript/ir/_type_casting.py index 20bab69037..f026568789 100644 --- a/onnxscript/ir/_type_casting.py +++ b/onnxscript/ir/_type_casting.py @@ -6,7 +6,7 @@ from __future__ import annotations import typing -from typing import Sequence +from collections.abc import Sequence import ml_dtypes import numpy as np diff --git a/onnxscript/ir/external_data.py b/onnxscript/ir/external_data.py index 4ca9ca5036..8249747927 100644 --- a/onnxscript/ir/external_data.py +++ b/onnxscript/ir/external_data.py @@ -15,7 +15,7 @@ import dataclasses import logging import os -from typing import Iterator, Sequence +from collections.abc import Iterator, Sequence from onnxscript.ir import _core, _enums, _protocols from onnxscript.ir import traversal as _traversal diff --git a/onnxscript/ir/passes/_pass_infra.py b/onnxscript/ir/passes/_pass_infra.py index 56566e7556..7416a6bda4 100644 --- a/onnxscript/ir/passes/_pass_infra.py +++ b/onnxscript/ir/passes/_pass_infra.py @@ -16,7 +16,7 @@ import dataclasses import logging -from typing import Sequence +from collections.abc import Sequence __all__ = [ "PassBase", diff --git a/onnxscript/ir/passes/common/inliner.py b/onnxscript/ir/passes/common/inliner.py index 5cefc94268..23034a1531 100644 --- a/onnxscript/ir/passes/common/inliner.py +++ b/onnxscript/ir/passes/common/inliner.py @@ -9,7 +9,8 @@ __all__ = ["InlinePass", "InlinePassResult"] from collections import defaultdict -from typing import Iterable, List, Sequence, Tuple +from collections.abc import Iterable, Sequence +from typing import List, Tuple import onnxscript.ir.convenience as _ir_convenience from onnxscript import ir diff --git a/onnxscript/ir/passes/common/inliner_test.py b/onnxscript/ir/passes/common/inliner_test.py index 7a64a8d4b4..95d3b27dc3 100644 --- a/onnxscript/ir/passes/common/inliner_test.py +++ b/onnxscript/ir/passes/common/inliner_test.py @@ -5,7 +5,8 @@ from __future__ import annotations import unittest -from typing import Callable, Sequence +from collections.abc import Sequence +from typing import Callable import onnx from onnx import parser diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 64703b2baa..8be63125e3 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -61,7 +61,8 @@ import collections import logging import os -from typing import Any, Callable, List, Mapping, Sequence +from collections.abc import Mapping, Sequence +from typing import Any, Callable, List import numpy as np import onnx diff --git a/onnxscript/ir/traversal.py b/onnxscript/ir/traversal.py index 5fa9a9acf7..8c9107d9aa 100644 --- a/onnxscript/ir/traversal.py +++ b/onnxscript/ir/traversal.py @@ -8,7 +8,8 @@ "RecursiveGraphIterator", ] -from typing import Callable, Iterator, Reversible, Union +from collections.abc import Iterator, Reversible +from typing import Callable, Union from typing_extensions import Self diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 407a1ccdb1..e30861657a 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -6,7 +6,8 @@ import io import logging import warnings -from typing import Any, Optional, Protocol, Sequence, Union +from collections.abc import Sequence +from typing import Any, Optional, Protocol, Union import onnx from onnx import ValueInfoProto, helper diff --git a/onnxscript/main.py b/onnxscript/main.py index 7407baedd1..3d0016fa70 100644 --- a/onnxscript/main.py +++ b/onnxscript/main.py @@ -6,7 +6,8 @@ import ast import inspect import sys -from typing import Any, Callable, Optional, Sequence, TypeVar +from collections.abc import Sequence +from typing import Any, Callable, Optional, TypeVar import onnx.helper from typing_extensions import ParamSpec diff --git a/onnxscript/onnx_opset/__init__.py b/onnxscript/onnx_opset/__init__.py index c720c35bbe..5fecc4eb8e 100644 --- a/onnxscript/onnx_opset/__init__.py +++ b/onnxscript/onnx_opset/__init__.py @@ -13,7 +13,8 @@ from __future__ import annotations -from typing import Mapping, Tuple +from collections.abc import Mapping +from typing import Tuple from onnx.defs import onnx_opset_version diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index cce74cb132..5143020225 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -9,7 +9,8 @@ import logging import math import typing -from typing import Any, Callable, Iterable, Sequence, Union +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Union import numpy as np import onnx diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 5efaf784b0..d0ed5000e8 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, TypeVar, Union +from collections.abc import Sequence +from typing import TypeVar, Union __all__ = [ "pattern", diff --git a/onnxscript/rewriter/_fusion_utils.py b/onnxscript/rewriter/_fusion_utils.py index 166b81d7e2..232cd54759 100644 --- a/onnxscript/rewriter/_fusion_utils.py +++ b/onnxscript/rewriter/_fusion_utils.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Callable, Sequence, Union +from collections.abc import Sequence +from typing import Callable, Union import onnxscript.ir as ir from onnxscript.rewriter import pattern diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index d6c4177ae8..e4773c5794 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -3,7 +3,8 @@ from __future__ import annotations import math -from typing import Callable, Sequence +from collections.abc import Sequence +from typing import Callable import numpy as np diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index 42bc1ce766..26fe61071c 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -7,7 +7,8 @@ import os import textwrap import warnings -from typing import Any, Callable, Iterator, Sequence +from collections.abc import Iterator, Sequence +from typing import Any, Callable import onnxscript.rewriter.pattern as orp from onnxscript import ir diff --git a/onnxscript/rewriter/ort_fusions/attention.py b/onnxscript/rewriter/ort_fusions/attention.py index 2738432cd2..35fdf45e1c 100644 --- a/onnxscript/rewriter/ort_fusions/attention.py +++ b/onnxscript/rewriter/ort_fusions/attention.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union import onnxscript.ir as ir from onnxscript.rewriter import _fusion_utils, pattern diff --git a/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py index 75c4f66f9d..83fed97aa0 100644 --- a/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py +++ b/onnxscript/rewriter/ort_fusions/fuse_packed_qkv_gqa.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union import onnxscript.ir as ir from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 266987dd4d..e6eb0d6b7d 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union import numpy as np diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index 5fed446911..dab51774d3 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union import onnxscript.ir as ir from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index cfca31125f..690fb0f069 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -10,14 +10,11 @@ import itertools import math from collections import defaultdict +from collections.abc import Iterable, Iterator, MutableSequence, Sequence from typing import ( Any, Callable, - Iterable, - Iterator, - MutableSequence, Protocol, - Sequence, Tuple, TypeVar, Union, diff --git a/onnxscript/testing/__init__.py b/onnxscript/testing/__init__.py index f7bb74980d..816875bddb 100644 --- a/onnxscript/testing/__init__.py +++ b/onnxscript/testing/__init__.py @@ -11,7 +11,8 @@ import difflib import math -from typing import Any, Collection, Sequence +from collections.abc import Collection, Sequence +from typing import Any import google.protobuf.message import onnx diff --git a/onnxscript/tools/transformers_models/__init__.py b/onnxscript/tools/transformers_models/__init__.py index ed4648916b..b239082447 100644 --- a/onnxscript/tools/transformers_models/__init__.py +++ b/onnxscript/tools/transformers_models/__init__.py @@ -6,7 +6,8 @@ from __future__ import annotations import random -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import onnx import onnx.inliner diff --git a/onnxscript/tools/transformers_models/llama.py b/onnxscript/tools/transformers_models/llama.py index 9b1337167f..ea62e613de 100644 --- a/onnxscript/tools/transformers_models/llama.py +++ b/onnxscript/tools/transformers_models/llama.py @@ -5,7 +5,8 @@ # pylint: disable=import-outside-toplevel from __future__ import annotations -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import torch diff --git a/onnxscript/tools/transformers_models/mistral.py b/onnxscript/tools/transformers_models/mistral.py index d053b90571..e013c7126e 100644 --- a/onnxscript/tools/transformers_models/mistral.py +++ b/onnxscript/tools/transformers_models/mistral.py @@ -5,7 +5,8 @@ # pylint: disable=import-outside-toplevel from __future__ import annotations -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import torch diff --git a/onnxscript/tools/transformers_models/phi.py b/onnxscript/tools/transformers_models/phi.py index f1cb88edd0..0c2cc5dafd 100644 --- a/onnxscript/tools/transformers_models/phi.py +++ b/onnxscript/tools/transformers_models/phi.py @@ -5,7 +5,8 @@ # pylint: disable=import-outside-toplevel from __future__ import annotations -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import torch diff --git a/onnxscript/tools/transformers_models/phi3.py b/onnxscript/tools/transformers_models/phi3.py index f5bf7beb54..9a6522e9b7 100644 --- a/onnxscript/tools/transformers_models/phi3.py +++ b/onnxscript/tools/transformers_models/phi3.py @@ -4,7 +4,8 @@ from __future__ import annotations -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any import torch diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 8a71b5c2d4..e66288aa25 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -5,7 +5,8 @@ import collections import inspect import typing -from typing import Optional, Sequence, Union +from collections.abc import Sequence +from typing import Optional, Union import onnx diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 4104eb51dd..715dfc2cae 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. import unittest -from typing import Any, List, Optional, Sequence, TypeVar, Union +from collections.abc import Sequence +from typing import Any, List, Optional, TypeVar, Union import parameterized diff --git a/onnxscript/values.py b/onnxscript/values.py index d748dc6e64..ecba1ca5f6 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -8,6 +8,7 @@ import logging import types import typing +from collections.abc import Sequence from enum import IntFlag from typing import ( # type: ignore[attr-defined] Any, @@ -16,7 +17,6 @@ Generic, Optional, Protocol, - Sequence, TypeVar, _GenericAlias, ) diff --git a/onnxscript/version_converter/_version_converter.py b/onnxscript/version_converter/_version_converter.py index 46b4596fb5..0fd39e87dd 100644 --- a/onnxscript/version_converter/_version_converter.py +++ b/onnxscript/version_converter/_version_converter.py @@ -7,7 +7,8 @@ import dataclasses import functools import logging -from typing import Callable, Sequence, Union +from collections.abc import Sequence +from typing import Callable, Union import onnxscript.ir.convenience as ir_convenience import onnxscript.rewriter.pattern as orp diff --git a/opgen/onnx_opset_builder.py b/opgen/onnx_opset_builder.py index 5fd1f60b68..c0ae8d5158 100644 --- a/opgen/onnx_opset_builder.py +++ b/opgen/onnx_opset_builder.py @@ -5,9 +5,10 @@ from __future__ import annotations +from collections.abc import Iterable from pathlib import Path from textwrap import dedent -from typing import Annotated, Any, Iterable, Optional, Set, TextIO +from typing import Annotated, Any, Optional, Set, TextIO import pygen as cg from onnx.defs import ( diff --git a/opgen/pygen.py b/opgen/pygen.py index bea7431186..730b624573 100644 --- a/opgen/pygen.py +++ b/opgen/pygen.py @@ -8,13 +8,13 @@ import io from abc import ABC, abstractmethod +from collections.abc import Iterable from enum import Enum from textwrap import TextWrapper, dedent from typing import ( Any, Callable, Generic, - Iterable, Optional, Set, TextIO, diff --git a/pyproject.toml b/pyproject.toml index 647e9b8a23..27952cf4ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,6 +214,7 @@ ignore-init-module-imports = true "setup.py" = ["TID251"] # pathlib is allowed in supporting code "**/{examples,tests,docs,tools,utils,opgen,_framework_apis}/*" = ["TID251"] # pathlib is allowed in supporting code "**/*_test.py" = ["TID251"] # pathlib is allowed in tests +"onnxscript/onnx_opset/_impl/*" = ["UP035"] # Need to update opgen to use the new types [tool.ruff.lint.flake8-tidy-imports] # Disallow all relative imports. diff --git a/tests/common/onnx_script_test_case.py b/tests/common/onnx_script_test_case.py index 3a46a870a0..ab6911432a 100644 --- a/tests/common/onnx_script_test_case.py +++ b/tests/common/onnx_script_test_case.py @@ -7,7 +7,8 @@ import numbers import unittest import warnings -from typing import Any, Collection, Iterable, Optional, Sequence +from collections.abc import Collection, Iterable, Sequence +from typing import Any, Optional import numpy as np import onnx diff --git a/tests/function_libs/torch_lib/error_reproduction.py b/tests/function_libs/torch_lib/error_reproduction.py index 1eac88c48a..93afd15c8a 100644 --- a/tests/function_libs/torch_lib/error_reproduction.py +++ b/tests/function_libs/torch_lib/error_reproduction.py @@ -8,7 +8,8 @@ import sys import time import traceback -from typing import Any, Mapping +from collections.abc import Mapping +from typing import Any import numpy as np import onnx diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 59e6c98c9f..c652916311 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -27,7 +27,8 @@ import os import unittest -from typing import Callable, Optional, Sequence, Tuple +from collections.abc import Sequence +from typing import Callable, Optional, Tuple import numpy as np import onnx diff --git a/tests/function_libs/torch_lib/ops_test_common.py b/tests/function_libs/torch_lib/ops_test_common.py index a9f922ce25..f22b86032a 100644 --- a/tests/function_libs/torch_lib/ops_test_common.py +++ b/tests/function_libs/torch_lib/ops_test_common.py @@ -13,14 +13,11 @@ import sys import unittest import warnings +from collections.abc import Collection, Iterable, Mapping, Sequence from typing import ( Any, Callable, - Collection, - Iterable, - Mapping, Optional, - Sequence, TypeVar, ) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 3628ed8c45..8f9188a3e0 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -39,7 +39,8 @@ import copy import dataclasses import functools -from typing import Any, Callable, Collection, Optional +from collections.abc import Collection +from typing import Any, Callable, Optional import numpy as np import torch diff --git a/tests/ir/public_api_test.py b/tests/ir/public_api_test.py index ac2655cf43..38e054e8a5 100644 --- a/tests/ir/public_api_test.py +++ b/tests/ir/public_api_test.py @@ -12,7 +12,7 @@ import pathlib import pkgutil import unittest -from typing import Iterable +from collections.abc import Iterable import onnxscript.ir diff --git a/tools/diagnostics/gen_diagnostics.py b/tools/diagnostics/gen_diagnostics.py index cf0f0f35b7..3221072550 100644 --- a/tools/diagnostics/gen_diagnostics.py +++ b/tools/diagnostics/gen_diagnostics.py @@ -19,7 +19,8 @@ import string import subprocess import textwrap -from typing import Any, Mapping, Sequence +from collections.abc import Mapping, Sequence +from typing import Any import yaml from torchgen import utils as torchgen_utils From e29256f0a2f649b729a6156ac256aa80cf185064 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 27 Apr 2025 19:01:12 -0700 Subject: [PATCH 05/24] format --- onnxscript/_internal/analysis.py | 20 +++--- onnxscript/_legacy_ir/__init__.py | 6 +- onnxscript/converter.py | 27 ++++--- .../torch_lib/deduce_type_constraints.py | 20 +++--- .../torch_lib/generate_prims_signatures.py | 12 ++-- .../graph_building/_graph_building_torch.py | 72 +++++++++---------- .../function_libs/torch_lib/ops/core.py | 54 +++++++------- onnxscript/ir/_core.py | 2 +- onnxscript/onnx_types.py | 4 +- 9 files changed, 107 insertions(+), 110 deletions(-) diff --git a/onnxscript/_internal/analysis.py b/onnxscript/_internal/analysis.py index 9fae662d2a..8b8d3f0b58 100644 --- a/onnxscript/_internal/analysis.py +++ b/onnxscript/_internal/analysis.py @@ -4,7 +4,7 @@ import ast from collections.abc import Sequence -from typing import Any, Optional, Set +from typing import Any, Optional from onnxscript import sourceinfo from onnxscript._internal import ast_utils @@ -16,7 +16,7 @@ def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str: return for_stmt.target.id -def _used_vars(expr: Optional[ast.expr]) -> Set[str]: +def _used_vars(expr: Optional[ast.expr]) -> set[str]: """Return set of all variables used, including function names, in an expression.""" if expr is None: return set() @@ -36,7 +36,7 @@ def _used_vars(expr: Optional[ast.expr]) -> Set[str]: return result -def _lhs_vars(lhs: ast.expr) -> Set[str]: +def _lhs_vars(lhs: ast.expr) -> set[str]: """Return set of assigned variables in the lhs of an assignment statement.""" def get_id(e): @@ -50,12 +50,12 @@ def get_id(e): def assigned_vars( stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter -) -> Set[str]: +) -> set[str]: """Return the set of all variables that may be assigned to in an execution of input stmt or sequence of statements. """ - def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]: + def assigned_in_block(block: Sequence[ast.stmt]) -> set[str]: result: set[Any] = set() for s in block: result = result | assigned_vars(s, formatter) @@ -91,14 +91,14 @@ def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter): and `s.live_out`. """ - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: stmt.live_out = live_out # type: ignore[attr-defined] live = do_visit(stmt, live_out) stmt.live_in = live # type: ignore[attr-defined] return live - def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + def do_visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: + def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]: for s in reversed(block): live_out = visit(s, live_out) return live_out @@ -166,12 +166,12 @@ def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter): (in the first statement). Hence x is included in the exposed_uses. """ - def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]: + def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]: for stmt in reversed(block): live_out = visit(stmt, live_out) return live_out - def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]: + def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]: if isinstance(stmt, ast.Assign): return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value) if isinstance(stmt, ast.AnnAssign): diff --git a/onnxscript/_legacy_ir/__init__.py b/onnxscript/_legacy_ir/__init__.py index 6c4e0c07ec..aec19c6437 100644 --- a/onnxscript/_legacy_ir/__init__.py +++ b/onnxscript/_legacy_ir/__init__.py @@ -4,7 +4,7 @@ import dataclasses from collections import deque -from typing import List, Tuple, Union +from typing import Union import numpy as np import onnx @@ -47,9 +47,9 @@ def __init__(self) -> None: # TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of # tensors, etc. However, we currently only handle lists of tensors. -SymbolicValue = Union[str, List[str]] +SymbolicValue = Union[str, list[str]] -FunctionId = Tuple[str, str, str] +FunctionId = tuple[str, str, str] def get_function_id(function: onnx.FunctionProto) -> FunctionId: diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 46e57e3400..44220a423a 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -8,11 +8,8 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, NoReturn, Optional, - Tuple, Union, ) @@ -178,11 +175,11 @@ def __init__( self.default_opset_ = default_opset # States initialized by `_init_function_translation` - self._outer: List[irbuilder.IRFunction] = [] + self._outer: list[irbuilder.IRFunction] = [] self._current_fn: irbuilder.IRFunction = None self._nextvar: int = 0 self._used_vars: set[str] = set() - self._locals: List[Dict[str, LocalSymValue]] = [{}] + self._locals: list[dict[str, LocalSymValue]] = [{}] @property def default_opset(self) -> values.Opset: @@ -230,7 +227,7 @@ def _init_function_translation(self) -> None: self._current_fn: Optional[irbuilder.IRFunction] = None self._nextvar = 0 self._used_vars = set() - self._locals: List[Dict[str, LocalSymValue]] = [{}] + self._locals: list[dict[str, LocalSymValue]] = [{}] def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo: return sourceinfo.SourceInfo(node, self.source, self._current_fn.name) @@ -269,7 +266,7 @@ def _exit_scope(self) -> irbuilder.IRFunction: self._locals.pop(0) return graph - def _current_scope(self) -> Dict[str, LocalSymValue]: + def _current_scope(self) -> dict[str, LocalSymValue]: return self._locals[0] def _bind(self, name: str, val: LocalSymValue) -> None: @@ -433,7 +430,7 @@ def _is_constant_expr(self, node: ast.AST) -> None: ast.UnaryOp, ast.Compare, ast.Attribute, - ast.List, + ast.list, ast.Load, ast.Constant, ), @@ -692,9 +689,9 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: # As the first step, we partition the index elements into four kinds: Slice (eg., 1:5:2), # known-to-be-scalar (eg., 2), other-tensor (eg., I), skip/no-op (that is, just ":") - sliced_indices: List[Tuple[int, ast.expr]] = [] - scalar_indices: List[Tuple[int, ast.expr]] = [] - non_scalar_indices: List[Tuple[int, ast.expr]] = [] + sliced_indices: list[tuple[int, ast.expr]] = [] + scalar_indices: list[tuple[int, ast.expr]] = [] + non_scalar_indices: list[tuple[int, ast.expr]] = [] for axis, elt in enumerate(indices): if isinstance(elt, ast.Slice): # Add to sliced_indices, unless it is "::", which is a no-op. @@ -984,7 +981,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: typeinfo = None var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo) self._bind(lhs, var) - elif isinstance(lhs, ast.Tuple): + elif isinstance(lhs, ast.tuple): # Assignments of the form "x, y, z = op.SomeOp(...)" if not isinstance(rhs, ast.Call): self.fail( @@ -1019,9 +1016,9 @@ def generate_onnx_name(x: ast.AST): self.fail(stmt, "Multi-assignment not supported.") lhs = targets[0] rhs = stmt.value - if isinstance(rhs, ast.Tuple): + if isinstance(rhs, ast.tuple): # Assignments of the form "... = Expression1, Expression2" - if not isinstance(lhs, ast.Tuple): + if not isinstance(lhs, ast.tuple): # Assignments of the form "single_var = Expression1, Expression2". # We do not support tuple-typed variables. self.fail(lhs, f"Left term must be a tuple not '{type(lhs)!r}'.") @@ -1070,7 +1067,7 @@ def ret(exp, i, suffix): val = stmt.value assert val is not None, "Return statement without return-value not supported." - if isinstance(val, ast.Tuple): + if isinstance(val, ast.tuple): check_num_outputs(len(val.elts)) return [ret(exp, i, str(i)) for i, exp in enumerate(val.elts)] check_num_outputs(1) diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py index 6bb2da11af..9a1b2a7986 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py @@ -6,7 +6,7 @@ import dataclasses import logging from collections.abc import Mapping, Sequence -from typing import Dict, Optional, Set +from typing import Optional import onnx import onnx.defs @@ -48,10 +48,10 @@ class TypeConstraint: """Type constraint shared by multiple values.""" name: str - type_strs: Set[str] - values: Set[Value] + type_strs: set[str] + values: set[Value] - def __init__(self, name: str, type_strs: Set[str]): + def __init__(self, name: str, type_strs: set[str]): self.name = name self.type_strs = type_strs self.values = set() @@ -126,9 +126,9 @@ def __repr__(self) -> str: @dataclasses.dataclass class OnnxFunctionTypeConstraints: - input_type_constraints: Dict[str, Optional[TypeConstraint]] - output_type_constraints: Dict[str, Optional[TypeConstraint]] - intermediate_type_constraints: Dict[str, Optional[TypeConstraint]] + input_type_constraints: dict[str, Optional[TypeConstraint]] + output_type_constraints: dict[str, Optional[TypeConstraint]] + intermediate_type_constraints: dict[str, Optional[TypeConstraint]] def __repr__(self): repr_strs = [ @@ -191,7 +191,7 @@ def __repr__(self): class TypeConstraintDeducer: def __init__(self, onnx_function: onnxscript.OnnxFunction): self.onnx_function = onnx_function - self.values: Dict[str, Value] = {} + self.values: dict[str, Value] = {} def type_constraints(self, signature_only: bool = True) -> OnnxFunctionTypeConstraints: """Retrieve deduced type constraints for the ONNX function.""" @@ -211,7 +211,7 @@ def type_constraints(self, signature_only: bool = True) -> OnnxFunctionTypeConst ) # Rename type constraints to T0, T1, T2, ... - seen_type_constraints: Set[TypeConstraint] = set() + seen_type_constraints: set[TypeConstraint] = set() for type_constraint in ( *input_type_constraints.values(), *output_type_constraints.values(), @@ -251,7 +251,7 @@ def _bind_signature( node: onnx.NodeProto, param_names: Sequence[str], param_schemas: Sequence[onnx.defs.OpSchema.FormalParameter], - op_type_constraints: Dict[str, TypeConstraint], + op_type_constraints: dict[str, TypeConstraint], is_output: bool = False, ): param_schemas = list(param_schemas) diff --git a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py index f93cfdd070..76c975f4fb 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py @@ -14,7 +14,7 @@ import re import textwrap from collections.abc import Sequence -from typing import Any, Dict, List +from typing import Any import torch import torchgen.gen @@ -27,7 +27,7 @@ def create_list_type(arg: torchgen.model.Argument) -> cg.TypeRef: inner_arg_type = arg.type if not arg.type.is_nullable() else arg.type.elem - assert isinstance(inner_arg_type, torchgen.model.ListType), f"arg: {arg}" + assert isinstance(inner_arg_type, torchgen.model.listType), f"arg: {arg}" arg_type = arg_type_to_str(arg.type) if type_is_builtin(arg_type): @@ -75,7 +75,7 @@ def get_argument_type(arg: torchgen.model.Argument) -> cg.TypeRef: """Returns the Python type for the given argument.""" inner_arg_type = arg.type if not arg.type.is_nullable() else arg.type.elem - if isinstance(inner_arg_type, torchgen.model.ListType): + if isinstance(inner_arg_type, torchgen.model.listType): inner_node = create_list_type(arg) else: arg_type_str = arg_type_to_str(inner_arg_type) @@ -130,7 +130,7 @@ def parse_default_value(arg: torchgen.model.Argument) -> Any: else: if isinstance(value, int): # Expand the value to a tuple if the type is a list. - if isinstance(arg.type, torchgen.model.ListType): + if isinstance(arg.type, torchgen.model.listType): if arg.type.size is not None: return (value,) * arg.type.size return (value,) @@ -242,8 +242,8 @@ def copyright_header() -> str: ) -def _get_func_schema_in_namespace(namespaces: List[_OpNamespace]) -> Dict[str, FunctionSchema]: - table: Dict[str, FunctionSchema] = {} +def _get_func_schema_in_namespace(namespaces: list[_OpNamespace]) -> dict[str, FunctionSchema]: + table: dict[str, FunctionSchema] = {} for op_namespace in namespaces: for attr_name in dir(op_namespace): op_overload_packet = getattr(op_namespace, attr_name) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index 6ed2b935d2..fad5c6f818 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -8,7 +8,7 @@ import tempfile import typing from collections.abc import Mapping, Sequence -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional, Union import numpy as np import onnx @@ -110,7 +110,7 @@ def __init__( super().__init__(None) self._torch_value: torch.Value = value self._concrete_value: Optional[np.ndarray] = None - self._shape: Optional[Tuple[int | str | None, ...]] = None + self._shape: Optional[tuple[int | str | None, ...]] = None self._torch_dtype: Optional[torch.dtype] = None self._name: Optional[str] = None self._is_complex: bool = False @@ -152,7 +152,7 @@ def rank(self) -> int | None: return value_type.dim() @property # type: ignore[override] - def shape(self) -> Tuple[int | str | None, ...] | None: + def shape(self) -> tuple[int | str | None, ...] | None: if self._shape is not None: return self._shape @@ -169,7 +169,7 @@ def shape(self) -> Tuple[int | str | None, ...] | None: return tuple(shape) @shape.setter - def shape(self, shape: Union[torch.Size, Tuple[int | str | None, ...]]): + def shape(self, shape: Union[torch.Size, tuple[int | str | None, ...]]): # Normalize torch symbolic dimension size to str. torch_sym_types = (torch.SymInt, torch.SymFloat, torch.SymBool) self._shape = tuple( @@ -250,9 +250,9 @@ def _unwrap_tensor_to_torch_value( ], ) -> Union[ ValidTorchValueType, - Dict[str, ValidTorchValueType], - List[ValidTorchValueType], - Tuple[ValidTorchValueType, ...], + dict[str, ValidTorchValueType], + list[ValidTorchValueType], + tuple[ValidTorchValueType, ...], ]: """Unwrap the TorchScriptTensor to torch.Value.""" if isinstance(value, TorchScriptTensor): @@ -274,14 +274,14 @@ def _wrap_torch_value_to_tensor( torch.Value, Mapping[str, ValidTorchValueType], Sequence[ValidTorchValueType] ], *, - shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None, + shape: Optional[Union[torch.Size, tuple[Union[int, str, None], ...]]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Union[ ValidArgumentType, - Dict[str, ValidArgumentType], - List[ValidArgumentType], - Tuple[ValidArgumentType, ...], + dict[str, ValidArgumentType], + list[ValidArgumentType], + tuple[ValidArgumentType, ...], ]: """Wrap torch.Value to TorchScriptTensor.""" if isinstance(value, torch.Value): @@ -488,7 +488,7 @@ def _create_op_call_in_torch_graph( inputs: Sequence[torch.Value], attributes: Mapping[str, Any], n_outputs: int = 1, -) -> Tuple[torch.Value, ...]: +) -> tuple[torch.Value, ...]: """Creates a node representing an onnx op in `graph`. Args: @@ -548,17 +548,17 @@ def __init__( self._torch_graph = torch.Graph() # All the functions used, deduplicated by name # key: (name, domain) - self._function_store: Dict[Tuple[str, str], onnxscript.OnnxFunction] = {} + self._function_store: dict[tuple[str, str], onnxscript.OnnxFunction] = {} # Mapping from intializer name to data(torch.Tensor). - self._initializers: Dict[str, torch.Tensor] = {} + self._initializers: dict[str, torch.Tensor] = {} # Mapping from intializer name to input(TorchScriptTensor). - self._initializers_inputs: Dict[str, TorchScriptTensor] = {} + self._initializers_inputs: dict[str, TorchScriptTensor] = {} # Mapping from intializer name to input(TorchScriptTensor) from parent graph. - self._initializers_inputs_from_parent: Dict[str, TorchScriptTensor] = {} + self._initializers_inputs_from_parent: dict[str, TorchScriptTensor] = {} # Mapping from model local function type name to function graph. # Local function type name is expected to be unique. Converter creates # a unique name and a unique function graph for every module call. - self._sub_torch_script_graphs: Dict[str, TorchScriptGraph] = {} + self._sub_torch_script_graphs: dict[str, TorchScriptGraph] = {} # Parent graph. None if this is the top level graph. self._parent_torch_script_graph = parent_torch_script_graph # Domain name of the graph. None if this is the top level graph. @@ -572,7 +572,7 @@ def __init__( # This info is later serialized as `ValueInfoProto` inside ONNX, to # provide shape and dtype information for nodes within nested function calls. # https://github.com/onnx/onnx/issues/5487 - self._value_to_tensor: Dict[torch.Value, TorchScriptTensor] = {} + self._value_to_tensor: dict[torch.Value, TorchScriptTensor] = {} if self._domain_name is None and self._parent_torch_script_graph is not None: raise RuntimeError( @@ -592,7 +592,7 @@ def initializers(self) -> Mapping[str, torch.Tensor]: # we need to filter out the initializers that has fake tensor. This # is because we don't want to introduce fake tensor in onnxscript. @initializers.setter - def initializers(self, initializers: Dict[str, torch.Tensor]): + def initializers(self, initializers: dict[str, torch.Tensor]): self._initializers = initializers @property @@ -615,7 +615,7 @@ def domain_name(self) -> Optional[str]: def add_input( self, input_name: Optional[str], - shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None, + shape: Optional[Union[torch.Size, tuple[Union[int, str, None], ...]]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> TorchScriptTensor: @@ -642,7 +642,7 @@ def add_input( ) if isinstance(tensor_value, TorchScriptTensor): # NOTE: Only track value that maps to tensor. - # Value that maps to Sequence/Dict of tensors is not tracked. + # Value that maps to Sequence/dict of tensors is not tracked. self._value_to_tensor[torch_value] = tensor_value return tensor_value # type: ignore[return-value] @@ -682,7 +682,7 @@ def add_initializer(self, name: str, value: torch.Tensor) -> TorchScriptTensor: @runtime_typing.checked def register_outputs( - self, outputs: Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]] + self, outputs: Union[TorchScriptTensor, tuple[TorchScriptTensor, ...]] ): unwrapped_outputs = _unwrap_tensors_to_torch_values(outputs) if isinstance(unwrapped_outputs, torch.Value): @@ -737,7 +737,7 @@ def _add_constant_to_graph(self, constant) -> torch.Value: value.setDebugName(_rename_intermediate_value(value.debugName())) return value - def preprocess_inputs(self, onnx_inputs: Sequence[ValidInputType]) -> List[torch.Value]: + def preprocess_inputs(self, onnx_inputs: Sequence[ValidInputType]) -> list[torch.Value]: unwrapped_inputs = _unwrap_tensors_to_torch_values(onnx_inputs) graph_inputs = [] assert isinstance(unwrapped_inputs, Sequence) @@ -770,7 +770,7 @@ def _add_torchscript_op_call( onnx_inputs: Sequence[ValidInputType], onnx_attributes: Mapping[str, ValidArgumentType], n_outputs: int, - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + ) -> Union[TorchScriptTensor, tuple[TorchScriptTensor, ...]]: graph_inputs = self.preprocess_inputs(onnx_inputs) for key, value in onnx_attributes.items(): assert not isinstance(value, TorchScriptTensor), ( @@ -801,8 +801,8 @@ def _add_torchscript_op_call( @runtime_typing.checked def fetch_function_proto_dict( self, opset_version: int - ) -> Mapping[Tuple[str, str], onnx.FunctionProto]: - function_proto_dict: Dict[Tuple[str, str], onnx.FunctionProto] = {} + ) -> Mapping[tuple[str, str], onnx.FunctionProto]: + function_proto_dict: dict[tuple[str, str], onnx.FunctionProto] = {} # Fetch local function protos. E.g., local functions representing module calls. for ( sub_graph_name, @@ -893,7 +893,7 @@ def add_op_call( onnx_op_schema: onnx.defs.OpSchema, onnx_inputs: Sequence[ValidInputType], onnx_attributes: Mapping[str, ValidArgumentType], - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + ) -> Union[TorchScriptTensor, tuple[TorchScriptTensor, ...]]: # Compute outputs from the onnx_op op schema n_outputs = evaluator.compute_num_outputs(onnx_op_schema, onnx_inputs, onnx_attributes) result = self._add_torchscript_op_call( @@ -911,7 +911,7 @@ def add_function_call( onnx_function: onnxscript.OnnxFunction, onnx_inputs: Sequence[ValidInputType], onnx_attributes: Mapping[str, ValidArgumentType], - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + ) -> Union[TorchScriptTensor, tuple[TorchScriptTensor, ...]]: identifier = (onnx_function.name, onnx_function.function_ir.domain) self._function_store[identifier] = onnx_function @@ -931,7 +931,7 @@ def add_module_call( name: str, sub_torch_script_graph: TorchScriptGraph, onnx_inputs: Sequence[ValidInputType], - ) -> Union[TorchScriptTensor, Tuple[TorchScriptTensor, ...]]: + ) -> Union[TorchScriptTensor, tuple[TorchScriptTensor, ...]]: self._sub_torch_script_graphs[name] = sub_torch_script_graph domain_name = sub_torch_script_graph.domain_name assert domain_name is not None @@ -948,7 +948,7 @@ def add_module_call( def generate_function_value_info_proto( self, function_op_type: str ) -> Mapping[str, onnx.ValueInfoProto]: - named_value_info: Dict[str, onnx.ValueInfoProto] = {} + named_value_info: dict[str, onnx.ValueInfoProto] = {} function_id = _function_id(self.domain_name, function_op_type) for torch_value, tensor in self._value_to_tensor.items(): if (value_info := tensor.value_info()) is None: @@ -960,7 +960,7 @@ def generate_function_value_info_proto( return named_value_info @runtime_typing.checked - def generate_subgraphs_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: + def generate_subgraphs_value_info_proto(self) -> dict[str, onnx.ValueInfoProto]: """Unique naming strategies for values inside subgraphs, i.e. local functions. {function_domain::function_op_type}/{value_name} @@ -970,15 +970,15 @@ def generate_subgraphs_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: the `value_info` carried in `TorchScriptTensor` represents the general compatible shape and type. """ - named_value_info: Dict[str, onnx.ValueInfoProto] = {} + named_value_info: dict[str, onnx.ValueInfoProto] = {} for name, sub_graph in self._sub_torch_script_graphs.items(): named_value_info.update(sub_graph.generate_function_value_info_proto(name)) return named_value_info @runtime_typing.checked - def generate_maingraph_value_info_proto(self) -> Dict[str, onnx.ValueInfoProto]: + def generate_maingraph_value_info_proto(self) -> dict[str, onnx.ValueInfoProto]: """Returns value info proto for values in the main graph.""" - named_value_info: Dict[str, onnx.ValueInfoProto] = {} + named_value_info: dict[str, onnx.ValueInfoProto] = {} for torch_value, tensor in self._value_to_tensor.items(): if (value_info := tensor.value_info()) is None: continue @@ -1034,10 +1034,10 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func def to_model_proto( self, opset_version: int, include_initializers: bool = True ) -> onnx.ModelProto: - function_proto_dict: Mapping[Tuple[str, str], onnx.FunctionProto] = ( + function_proto_dict: Mapping[tuple[str, str], onnx.FunctionProto] = ( self.fetch_function_proto_dict(opset_version) ) - unique_custom_domains: Dict[str, int] = {} + unique_custom_domains: dict[str, int] = {} for function_proto in function_proto_dict.values(): # TODO(BowenBao): All local function domain versions are hardcoded as 1. diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e0bfa74297..e0776d271c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -13,7 +13,7 @@ import math from collections.abc import Sequence -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Union import numpy as np import torch @@ -595,7 +595,7 @@ def _adjust_args_for_arange_int_dtype( start: TRealUnlessFloat16OrInt8, end: TRealUnlessFloat16OrInt8, step: TRealUnlessFloat16OrInt8, -) -> Tuple[FLOAT, FLOAT, FLOAT]: +) -> tuple[FLOAT, FLOAT, FLOAT]: zero = op.Cast(0.0, to=FLOAT.dtype) start = op.Cast(start, to=FLOAT.dtype) end = op.Cast(end, to=FLOAT.dtype) @@ -2958,7 +2958,7 @@ def aten_embedding_bag( sparse: bool = False, per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)""" # assert(rank(indices) in [1,2]) @@ -2986,7 +2986,7 @@ def _aten_embedding_bag_onnx( mode: int, per_sample_weights: TFloat, include_last_offset: bool, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat]: neg_1 = op.Constant(value_ints=[-1]) # Assume indices is shape(5,2), indices_1d is shape(10,) indices_1d = op.Reshape(indices, neg_1) @@ -3093,7 +3093,7 @@ def aten_embedding_bag_padding_idx( per_sample_weights: Optional[TFloat] = None, include_last_offset: bool = False, padding_idx: int = -1, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat]: """embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) We add default values for the attributes to accommodate _embedding_bag as well: @@ -3127,7 +3127,7 @@ def _aten_embedding_bag_1d_padding_idx_onnx( per_sample_weights: TFloat, include_last_offset: bool, padding_idx: int, -) -> Tuple[TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat]: neg_1 = op.Constant(value_ints=[-1]) # Get weight out according to indices, # e.g. indices=[3,1,4,5,3] means get weight[[3,1,4,5,3]] @@ -5290,7 +5290,7 @@ def aten_max(self: TReal) -> TReal: @torch_op("aten::max.dim", trace_only=True) -def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, INT64]: +def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> tuple[TReal, INT64]: """max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)""" if len(self.shape) == 0: @@ -5357,7 +5357,7 @@ def aten_min(self: TReal) -> TReal: @torch_op("aten::min.dim", trace_only=True) -def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, TInt]: +def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> tuple[TReal, TInt]: """min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)""" if len(self.shape) == 0: result = self @@ -5892,7 +5892,7 @@ def aten__native_batch_norm_no_training( running_var: Optional[TFloat] = None, momentum: float = 0.9, eps: float = 1e-05, -) -> Tuple[TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat]: """_native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)""" return aten_native_batch_norm( @@ -5908,7 +5908,7 @@ def aten__native_batch_norm_no_stats( training: bool = False, momentum: float = 0.9, eps: float = 1e-05, -) -> Tuple[TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat]: """_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)""" return aten_native_batch_norm(input, weight, bias, None, None, training, momentum, eps) @@ -5924,7 +5924,7 @@ def aten_native_batch_norm( training: bool = False, momentum: float = 0.9, eps: float = 1e-05, -) -> Tuple[TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat]: """native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)""" if weight is None: # Set to 1.0 as default @@ -5992,7 +5992,7 @@ def _aten_native_batch_norm_training_onnx( axes: INT64, momentum: float, eps: float, -) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: """Batch normalization training mode. NOTE: momentum in PyTorch is 1.0-momentum in ONNX. @@ -6043,7 +6043,7 @@ def _aten_native_batch_norm_inference_onnx( running_var: TFloat, momentum: float, eps: float, -) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: """Batch normalization inference mode. NOTE: momentum in PyTorch is 1.0-momentum in ONNX. @@ -6083,7 +6083,7 @@ def aten__native_batch_norm_legit_functional( training: bool = False, momentum: float = 0.9, eps: float = 1e-05, -) -> Tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat, TFloat, TFloat]: if weight is None: # Set to 1.0 as default weight = op.Expand(op.Constant(value_floats=[1.0]), op.Shape(input, start=1, end=2)) @@ -6169,7 +6169,7 @@ def aten_native_channel_shuffle(self: TensorType, groups: int) -> TensorType: @torch_op("aten::native_dropout", trace_only=True) -def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> Tuple[TFloat, BOOL]: +def aten_native_dropout(input: TFloat, p: float, train: bool = True) -> tuple[TFloat, BOOL]: """native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)""" result, mask = op.Dropout(input, p, train) @@ -6194,7 +6194,7 @@ def aten_native_group_norm( HxW: Optional[INT64] = None, group: int = 1, eps: float = 1e-05, -) -> Tuple[TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat]: """native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)""" # Actually we don't need N,C,HxW value because the input tensor has that information @@ -6216,7 +6216,7 @@ def _aten_native_group_norm_onnx( bias: TFloat, group: INT64, eps: float, -) -> Tuple[TFloat, TFloat, TFloat]: +) -> tuple[TFloat, TFloat, TFloat]: # Because onnx.GroupNorm() need size=group for weight and bias # But the torch's aten function's input need size=channel, the size mismatched # So we have to use onnx.InstanceNorm() to simulate @@ -6286,7 +6286,7 @@ def aten_native_layer_norm( weight: Optional[TReal] = None, bias: Optional[TReal] = None, eps: float = 1e-05, -) -> Tuple[TReal, TReal, TReal]: +) -> tuple[TReal, TReal, TReal]: """native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)""" # https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html#torch.nn.LayerNorm @@ -8114,7 +8114,7 @@ def aten_std_correction( # std_mean is decomposed by PyTroch -def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: +def aten_std_mean(self: TReal, unbiased: bool = True) -> tuple[TReal, TReal]: """std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" # Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction" @@ -8126,7 +8126,7 @@ def aten_std_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: # std_mean is decomposed by PyTroch def aten_std_mean_dim( self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: """std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)""" # Although dim is Optional in signature, but we assume it must have value for this overload @@ -8144,7 +8144,7 @@ def aten_std_mean_correction( dim: Optional[int] = None, correction: Optional[float] = None, keepdim: bool = False, -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: """std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)""" if correction is None: @@ -8444,7 +8444,7 @@ def aten_to_sparse_csr(self: TensorType) -> TensorType: @torch_op("aten::topk", trace_only=True) def aten_topk( self: TReal, k: int, dim: int = -1, largest: bool = True, sorted: bool = True -) -> Tuple[TReal, INT64]: +) -> tuple[TReal, INT64]: """topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)""" # We do not handle scalar inputs for topk @@ -8909,7 +8909,7 @@ def _aten_var_dim_onnx( # var_mean is decomposed by PyTroch -def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: +def aten_var_mean(self: TReal, unbiased: bool = True) -> tuple[TReal, TReal]: """var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)""" # Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction" @@ -8920,7 +8920,7 @@ def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]: # var_mean is decomposed by PyTroch def aten_var_mean_dim( self: TReal, dim: Sequence[int], unbiased: bool = True, keepdim: bool = False -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: """var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)""" # Although dim is Optional in signature, but we assume it must have value for this overload @@ -8935,7 +8935,7 @@ def aten_var_mean_correction( dim: Optional[int] = None, correction: Optional[float] = None, keepdim: bool = False, -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: """var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)""" if correction is None: @@ -8953,7 +8953,7 @@ def aten_var_mean_correction( # var_mean is decomposed by PyTroch def _aten_var_mean_onnx( self: TReal, correction: float = 1.0, keepdim: bool = False -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: # Compute mean and var mean = op.ReduceMean(self, keepdims=keepdim) sub_mean = op.Sub(self, mean) @@ -8973,7 +8973,7 @@ def _aten_var_mean_onnx( # var_mean is decomposed by PyTroch def _aten_var_mean_dim_onnx( self: TReal, dims: Sequence[int], correction: float, keepdim: bool = False -) -> Tuple[TReal, TReal]: +) -> tuple[TReal, TReal]: dims = op.Reshape(dims, op.Constant(value_ints=[-1])) # Computer mean and var mean = op.ReduceMean(self, dims, keepdims=keepdim) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 0931ac533a..34def04423 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -24,8 +24,8 @@ import typing from collections import OrderedDict from collections.abc import Collection, Hashable, Iterable, Iterator, Sequence +from collections.abc import Set as AbstractSet from typing import ( - AbstractSet, Any, Generic, NamedTuple, diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index e83e5ac825..46dede393a 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -4,7 +4,7 @@ from __future__ import annotations import abc -from typing import ClassVar, Optional, Tuple, Union +from typing import ClassVar, Optional, Union import onnx import onnx.helper @@ -13,7 +13,7 @@ _DType = onnxscript.ir.DataType _DimType = Union[int, str, type(None)] -_ShapeType = Union[Tuple[_DimType, ...], _DimType, type(Ellipsis)] +_ShapeType = Union[tuple[_DimType, ...], _DimType, type(Ellipsis)] _tensor_type_shape_cache: dict[_DType, TensorType] = {} tensor_type_registry: dict[_DType, TensorType] = {} From ed5e1b88be696d3734e8b0d47ea8df376f2008ac Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 27 Apr 2025 19:03:42 -0700 Subject: [PATCH 06/24] lint --- onnxscript/function_libs/torch_lib/ops/nn.py | 24 ++++++++++---------- onnxscript/ir/_protocols.py | 8 ++----- onnxscript/ir/_tape.py | 8 ++----- onnxscript/ir/passes/common/inliner.py | 5 ++-- onnxscript/ir/serde.py | 4 ++-- onnxscript/rewriter/pattern.py | 3 +-- onnxscript/type_annotation_test.py | 4 ++-- 7 files changed, 23 insertions(+), 33 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 93bed5a51c..96ab15020a 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -16,7 +16,7 @@ import math from collections.abc import Sequence -from typing import Optional, Tuple, TypeVar, Union +from typing import Optional, TypeVar, Union import onnx @@ -92,7 +92,7 @@ def _adjust_attributes_of_avg_pool( kernel_size: Sequence[int], stride: Sequence[int], padding: Sequence[int], -) -> Tuple[Sequence[int], Sequence[int], Sequence[int]]: +) -> tuple[Sequence[int], Sequence[int], Sequence[int]]: """Adjust attributes of avg_pool to match ONNX specification.""" if isinstance(kernel_size, int): @@ -897,7 +897,7 @@ def aten_max_pool1d_with_indices( padding: Sequence[int] = (0,), dilation: Sequence[int] = (1,), ceil_mode: bool = False, -) -> Tuple[TFloatOrUInt8, INT64]: +) -> tuple[TFloatOrUInt8, INT64]: """max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)""" # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly. @@ -928,7 +928,7 @@ def _adjust_attributes_of_max_pool( stride: Sequence[int], padding: Sequence[int], dilation: Sequence[int], -) -> Tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: +) -> tuple[Sequence[int], Sequence[int], Sequence[int], Sequence[int]]: if isinstance(dilation, int): dilations = [dilation] * expand_size else: @@ -1050,7 +1050,7 @@ def aten_max_pool2d_with_indices( padding: Sequence[int] = (0, 0), dilation: Sequence[int] = (1, 1), ceil_mode: bool = False, -) -> Tuple[TFloatOrUInt8, INT64]: +) -> tuple[TFloatOrUInt8, INT64]: """max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)""" # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly. @@ -1098,7 +1098,7 @@ def aten_max_pool3d_with_indices( padding: Sequence[int] = (0, 0, 0), dilation: Sequence[int] = (1, 1, 1), ceil_mode: bool = False, -) -> Tuple[TFloatOrUInt8, INT64]: +) -> tuple[TFloatOrUInt8, INT64]: """max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)""" # Torch prefers to use single number x for kernel, stride, pad and dilation on both sides implicitly. @@ -1134,7 +1134,7 @@ def _aten_max_pool_with_indices_onnx( n_dims_one: Sequence[int], n_dims_zero: Sequence[int], n_dims_axes: Sequence[int], -) -> Tuple[TFloatOrUInt8, INT64]: +) -> tuple[TFloatOrUInt8, INT64]: self_rank_is_unbatched_rank = Rank(self) == unbatched_rank if self_rank_is_unbatched_rank: self = op.Unsqueeze(self, axes=[0]) @@ -1794,7 +1794,7 @@ def aten_scaled_dot_product_attention( def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs( query: TFloat, -) -> Tuple[FLOAT, INT64, INT64, FLOAT]: +) -> tuple[FLOAT, INT64, INT64, FLOAT]: query_first_three_dims = op.Slice( op.Shape(query), op.Constant(value_ints=[0]), op.Constant(value_ints=[3]) ) @@ -1823,7 +1823,7 @@ def aten__scaled_dot_product_flash_attention( is_causal: bool = False, return_debug_mask: bool = False, scale: Optional[float] = None, -) -> Tuple[TFloat, FLOAT, INT64, INT64, INT64, INT64, INT64, INT64, FLOAT]: +) -> tuple[TFloat, FLOAT, INT64, INT64, INT64, INT64, INT64, INT64, FLOAT]: """_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) One of the implementations of scaled_dot_product_attention. @@ -1862,7 +1862,7 @@ def aten__scaled_dot_product_flash_attention( def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( query: TFloat, compute_log_sumexp: bool, -) -> Tuple[FLOAT, INT64]: +) -> tuple[FLOAT, INT64]: """_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)""" query = op.Transpose(query, perm=[0, 2, 1, 3]) @@ -1901,7 +1901,7 @@ def aten__scaled_dot_product_flash_attention_for_cpu( is_causal: bool = False, attn_mask: Optional[TFloat] = None, scale: Optional[float] = None, -) -> Tuple[TFloat, FLOAT]: +) -> tuple[TFloat, FLOAT]: """_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)""" result = aten_scaled_dot_product_attention( query, @@ -1933,7 +1933,7 @@ def aten__scaled_dot_product_efficient_attention( dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, -) -> Tuple[TFloat, FLOAT, INT64, INT64]: +) -> tuple[TFloat, FLOAT, INT64, INT64]: """_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)""" result = aten_scaled_dot_product_attention( diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index a68761f48d..073f5b7959 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -41,11 +41,7 @@ MutableSequence, Sequence, ) -from typing import ( - Any, - Protocol, - Tuple, -) +from typing import Any, Protocol from onnxscript.ir import _enums @@ -54,7 +50,7 @@ from typing_extensions import TypeAlias # An identifier that will uniquely identify an operator. E.g (domain, op_type, overload) -OperatorIdentifier: TypeAlias = Tuple[str, str, str] +OperatorIdentifier: TypeAlias = tuple[str, str, str] @typing.runtime_checkable diff --git a/onnxscript/ir/_tape.py b/onnxscript/ir/_tape.py index b3472949a5..9180cb5b38 100644 --- a/onnxscript/ir/_tape.py +++ b/onnxscript/ir/_tape.py @@ -5,17 +5,13 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import ( - Any, - Optional, - Tuple, -) +from typing import Any, Optional from onnxscript import ir from onnxscript.ir import _convenience # A type representing the domains/versions used in creating nodes in IR. -UsedOpsets = set[Tuple[str, Optional[int]]] +UsedOpsets = set[tuple[str, Optional[int]]] class Tape: diff --git a/onnxscript/ir/passes/common/inliner.py b/onnxscript/ir/passes/common/inliner.py index 23034a1531..61229d02fd 100644 --- a/onnxscript/ir/passes/common/inliner.py +++ b/onnxscript/ir/passes/common/inliner.py @@ -10,7 +10,6 @@ from collections import defaultdict from collections.abc import Iterable, Sequence -from typing import List, Tuple import onnxscript.ir.convenience as _ir_convenience from onnxscript import ir @@ -18,13 +17,13 @@ # A replacement for a node specifies a list of nodes that replaces the original node, # and a list of values that replaces the original node's outputs. -NodeReplacement = Tuple[Sequence[ir.Node], Sequence[ir.Value]] +NodeReplacement = tuple[Sequence[ir.Node], Sequence[ir.Value]] # A call stack is a list of identifiers of call sites, where the first element is the # outermost call site, and the last element is the innermost call site. This is used # primarily for generating unique names for values in the inlined functions. CallSiteId = str -CallStack = List[CallSiteId] +CallStack = list[CallSiteId] def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument diff --git a/onnxscript/ir/serde.py b/onnxscript/ir/serde.py index 8be63125e3..c6d97d464a 100644 --- a/onnxscript/ir/serde.py +++ b/onnxscript/ir/serde.py @@ -62,7 +62,7 @@ import logging import os from collections.abc import Mapping, Sequence -from typing import Any, Callable, List +from typing import Any, Callable import numpy as np import onnx @@ -741,7 +741,7 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function: name=proto.name, overload=getattr(proto, "overload", ""), graph=graph, - attributes=typing.cast(List[_core.Attr], attributes), + attributes=typing.cast("list[_core.Attr]", attributes), ) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 690fb0f069..9363dc31a1 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -15,7 +15,6 @@ Any, Callable, Protocol, - Tuple, TypeVar, Union, ) @@ -532,7 +531,7 @@ def __str__(self) -> str: inputs_and_attributes = f"{inputs}, {attributes}" if attributes else inputs return f"{outputs} = {qualified_op} ({inputs_and_attributes})" - def op_identifier(self) -> Tuple[str, str, str] | None: + def op_identifier(self) -> tuple[str, str, str] | None: return self._op_identifier @property diff --git a/onnxscript/type_annotation_test.py b/onnxscript/type_annotation_test.py index 715dfc2cae..b2e5b27916 100644 --- a/onnxscript/type_annotation_test.py +++ b/onnxscript/type_annotation_test.py @@ -3,7 +3,7 @@ import unittest from collections.abc import Sequence -from typing import Any, List, Optional, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import parameterized @@ -213,7 +213,7 @@ class TypeConversionFunctionsTest(unittest.TestCase): ), ] ) - def test_pytype_to_type_strings(self, _, pytype: Any, expected: List[str]): + def test_pytype_to_type_strings(self, _, pytype: Any, expected: list[str]): self.assertEqual(type_annotation.pytype_to_type_strings(pytype), expected) @parameterized.parameterized.expand( From 442b9c5638627902fbaa166f10feed3b418e5a80 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Sun, 27 Apr 2025 19:07:04 -0700 Subject: [PATCH 07/24] lint --- opgen/onnx_opset_builder.py | 6 +++--- opgen/pygen.py | 15 ++++++--------- pyproject.toml | 2 +- tests/function_libs/torch_lib/extra_opinfo.py | 4 ++-- tests/function_libs/torch_lib/ops_test.py | 4 ++-- 5 files changed, 14 insertions(+), 17 deletions(-) diff --git a/opgen/onnx_opset_builder.py b/opgen/onnx_opset_builder.py index c0ae8d5158..d94cba244f 100644 --- a/opgen/onnx_opset_builder.py +++ b/opgen/onnx_opset_builder.py @@ -8,7 +8,7 @@ from collections.abc import Iterable from pathlib import Path from textwrap import dedent -from typing import Annotated, Any, Optional, Set, TextIO +from typing import Annotated, Any, Optional, TextIO import pygen as cg from onnx.defs import ( @@ -158,8 +158,8 @@ def __init__( *, module_base_name: str, min_default_opset_version: int, - include_opsets: Optional[Set[OpsetId]] = None, - exclude_opsets: Optional[Set[OpsetId]] = None, + include_opsets: Optional[set[OpsetId]] = None, + exclude_opsets: Optional[set[OpsetId]] = None, ): self.module_base_name = module_base_name self.min_default_opset_version = min_default_opset_version diff --git a/opgen/pygen.py b/opgen/pygen.py index 730b624573..a0228b0937 100644 --- a/opgen/pygen.py +++ b/opgen/pygen.py @@ -16,10 +16,7 @@ Callable, Generic, Optional, - Set, TextIO, - Tuple, - Type, TypeVar, Union, ) @@ -30,7 +27,7 @@ NoneType = type(None) -def _assert_instance(instance, expected_type: Union[Type, Tuple[Type, ...]]): +def _assert_instance(instance, expected_type: Union[type, tuple[type, ...]]): if not isinstance(instance, expected_type): raise TypeError(f"expected: {expected_type!r}; actual: {instance!r}") @@ -71,7 +68,7 @@ class NodePredicate: def __init__( self, role: Optional[Role] = None, - type_: Optional[Type[TNode]] = None, + type_: Optional[type[TNode]] = None, func: Optional[Callable[[Node], bool]] = None, ): _assert_instance(role, (Role, NoneType)) @@ -164,7 +161,7 @@ def get_children_in_role(self, role: Role): _assert_instance(role, Role) return self.get_children(NodePredicate(role=role)) - def get_children_of_type(self, type_: Type[TNode]) -> Iterable[TNode]: + def get_children_of_type(self, type_: type[TNode]) -> Iterable[TNode]: _assert_instance(type_, type) return self.get_children(NodePredicate(type_=type_)) @@ -183,7 +180,7 @@ def get_ancestors_in_role(self, role: Role, and_self=False): _assert_instance(role, Role) return self.get_ancestors(NodePredicate(role=role), and_self=and_self) - def get_ancestors_of_type(self, type_: Type[TNode], and_self=False) -> Iterable[TNode]: + def get_ancestors_of_type(self, type_: type[TNode], and_self=False) -> Iterable[TNode]: _assert_instance(type_, type) return self.get_ancestors(NodePredicate(type_=type_), and_self=and_self) @@ -1131,7 +1128,7 @@ def __init__(self, predicate: NodePredicate): super().__init__() _assert_instance(predicate, NodePredicate) self._predicate = predicate - self.names: Set[str] = set() + self.names: set[str] = set() def leave(self, node: Node) -> Optional[bool]: if self._predicate.matches(node) and hasattr(node, "name"): @@ -1141,7 +1138,7 @@ def leave(self, node: Node) -> Optional[bool]: class ImportAdjuster(FixupVisitor): def __init__(self): super().__init__() - self.naming_conflicts: Set[str] = set() + self.naming_conflicts: set[str] = set() def enter(self, node: Node): if len(self.node_stack) == 0: diff --git a/pyproject.toml b/pyproject.toml index 27952cf4ae..84f6484357 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -214,7 +214,7 @@ ignore-init-module-imports = true "setup.py" = ["TID251"] # pathlib is allowed in supporting code "**/{examples,tests,docs,tools,utils,opgen,_framework_apis}/*" = ["TID251"] # pathlib is allowed in supporting code "**/*_test.py" = ["TID251"] # pathlib is allowed in tests -"onnxscript/onnx_opset/_impl/*" = ["UP035"] # Need to update opgen to use the new types +"onnxscript/onnx_opset/*" = ["UP035"] # Need to update opgen to use the new types [tool.ruff.lint.flake8-tidy-imports] # Disallow all relative imports. diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 26b75bf93b..850567fc10 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -7,7 +7,7 @@ import functools import itertools -from typing import Any, List +from typing import Any import torch import torchvision @@ -2150,7 +2150,7 @@ def __init__(self): # in ops_test_data.py and opinfo_core.OpInfo("unique_name", ...) # To avoid name duplication, it is possible to rename the OpInfo and specify # the `op` field explicitly. -OP_DB: List[opinfo_core.OpInfo] = [ +OP_DB: list[opinfo_core.OpInfo] = [ opinfo_core.OpInfo( "ops.aten.bernoulli.p", aten_name="bernoulli.p", diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index c652916311..e01deecf5c 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -28,7 +28,7 @@ import os import unittest from collections.abc import Sequence -from typing import Callable, Optional, Tuple +from typing import Callable, Optional import numpy as np import onnx @@ -72,7 +72,7 @@ def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]: def _should_skip_xfail_test_sample( op_name: str, sample, dtype: torch.dtype, device_type: str -) -> Tuple[Optional[str], Optional[str]]: +) -> tuple[Optional[str], Optional[str]]: """Returns a reason if a test sample should be skipped.""" if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS: return None, None From 4d6ea344de8930b7d09fa31398259f4ca0a6f574 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 09:59:12 -0700 Subject: [PATCH 08/24] Fix ast --- onnxscript/converter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 44220a423a..0a7ae3b603 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -430,7 +430,7 @@ def _is_constant_expr(self, node: ast.AST) -> None: ast.UnaryOp, ast.Compare, ast.Attribute, - ast.list, + ast.List, ast.Load, ast.Constant, ), @@ -981,7 +981,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: typeinfo = None var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo) self._bind(lhs, var) - elif isinstance(lhs, ast.tuple): + elif isinstance(lhs, ast.Tuple): # Assignments of the form "x, y, z = op.SomeOp(...)" if not isinstance(rhs, ast.Call): self.fail( @@ -1016,9 +1016,9 @@ def generate_onnx_name(x: ast.AST): self.fail(stmt, "Multi-assignment not supported.") lhs = targets[0] rhs = stmt.value - if isinstance(rhs, ast.tuple): + if isinstance(rhs, ast.Tuple): # Assignments of the form "... = Expression1, Expression2" - if not isinstance(lhs, ast.tuple): + if not isinstance(lhs, ast.Tuple): # Assignments of the form "single_var = Expression1, Expression2". # We do not support tuple-typed variables. self.fail(lhs, f"Left term must be a tuple not '{type(lhs)!r}'.") @@ -1067,7 +1067,7 @@ def ret(exp, i, suffix): val = stmt.value assert val is not None, "Return statement without return-value not supported." - if isinstance(val, ast.tuple): + if isinstance(val, ast.Tuple): check_num_outputs(len(val.elts)) return [ret(exp, i, str(i)) for i, exp in enumerate(val.elts)] check_num_outputs(1) From bbe99cb648e8a6e8688a2f0e674ecd67fe1ba6d6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 10:06:16 -0700 Subject: [PATCH 09/24] debugging --- onnxscript/ir/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 04b5574c0b..7a2e1935ee 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -2,6 +2,8 @@ # Licensed under the MIT License. """In-memory intermediate representation for ONNX graphs.""" +from __future__ import annotations + __all__ = [ # Modules "serde", From 2f8aab4b07ba9b571fbee09e35480810aafce4c3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 11:56:59 -0700 Subject: [PATCH 10/24] update --- onnxscript/ir/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 1a03529d3a..481ecda152 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -153,7 +153,9 @@ def __set_module() -> None: """Set the module of all functions in this module to this public module.""" global_dict = globals() for name in __all__: - global_dict[name].__module__ = __name__ + if hasattr(global_dict[name], "__module__"): + # Set the module of the function to this module + global_dict[name].__module__ = __name__ __set_module() From e29d8c6e7ab18246545bf73e254f0ec342745f9e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 12:00:00 -0700 Subject: [PATCH 11/24] GenericAlias --- onnxscript/ir/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 481ecda152..fd9b94f930 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -152,10 +152,13 @@ def __set_module() -> None: """Set the module of all functions in this module to this public module.""" global_dict = globals() + import types + for name in __all__: - if hasattr(global_dict[name], "__module__"): - # Set the module of the function to this module - global_dict[name].__module__ = __name__ + if type(global_dict[name]) is types.GenericAlias: + continue + # Set the module of the function to this module + global_dict[name].__module__ = __name__ __set_module() From 1fa8c11fa6fbfccae2279aadf59227221a405a43 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 12:00:19 -0700 Subject: [PATCH 12/24] test --- onnxscript/ir/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index fd9b94f930..1613b681d9 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -156,6 +156,7 @@ def __set_module() -> None: for name in __all__: if type(global_dict[name]) is types.GenericAlias: + # GenericAlias doesn't have a __module__ attribute continue # Set the module of the function to this module global_dict[name].__module__ = __name__ From bcecabe09d20bb912a25910af851a29ea6d4d393 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 12:02:31 -0700 Subject: [PATCH 13/24] type --- onnxscript/ir/__init__.py | 6 ------ onnxscript/ir/_protocols.py | 6 ++++-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/onnxscript/ir/__init__.py b/onnxscript/ir/__init__.py index 1613b681d9..1a03529d3a 100644 --- a/onnxscript/ir/__init__.py +++ b/onnxscript/ir/__init__.py @@ -152,13 +152,7 @@ def __set_module() -> None: """Set the module of all functions in this module to this public module.""" global_dict = globals() - import types - for name in __all__: - if type(global_dict[name]) is types.GenericAlias: - # GenericAlias doesn't have a __module__ attribute - continue - # Set the module of the function to this module global_dict[name].__module__ = __name__ diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index 073f5b7959..029b4e056b 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -41,7 +41,7 @@ MutableSequence, Sequence, ) -from typing import Any, Protocol +from typing import Any, Protocol, Tuple # noqa: UP035 from onnxscript.ir import _enums @@ -50,7 +50,9 @@ from typing_extensions import TypeAlias # An identifier that will uniquely identify an operator. E.g (domain, op_type, overload) -OperatorIdentifier: TypeAlias = tuple[str, str, str] +OperatorIdentifier: TypeAlias = Tuple[ # Requires Tuple because tuple[] does not have __module__ + str, str, str +] @typing.runtime_checkable From be33f0abe13b29711a6af41f0ff9c7ed0f557d86 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 13:07:44 -0700 Subject: [PATCH 14/24] format --- onnxscript/ir/_protocols.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/_protocols.py b/onnxscript/ir/_protocols.py index 029b4e056b..a2425b86dd 100644 --- a/onnxscript/ir/_protocols.py +++ b/onnxscript/ir/_protocols.py @@ -50,9 +50,11 @@ from typing_extensions import TypeAlias # An identifier that will uniquely identify an operator. E.g (domain, op_type, overload) -OperatorIdentifier: TypeAlias = Tuple[ # Requires Tuple because tuple[] does not have __module__ - str, str, str -] +OperatorIdentifier: TypeAlias = ( + Tuple[ # Requires Tuple because tuple[] does not have __module__ + str, str, str + ] +) @typing.runtime_checkable From 51cb8e1cbae6dcb1b39fa504313dd2dc32363102 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 13:34:27 -0700 Subject: [PATCH 15/24] GenericAlias --- onnxscript/values.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/values.py b/onnxscript/values.py index ecba1ca5f6..1b76780516 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -749,7 +749,10 @@ def __init__(self, info: sourceinfo.SourceInfo) -> None: class AttrRef(SymbolValue): def __init__( - self, attr_name: str, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo + self, + attr_name: str, + typeinfo: _GenericAlias | types.GenericAlias, + info: sourceinfo.SourceInfo, ) -> None: """Initializes AttrRef. @@ -762,8 +765,9 @@ def __init__( super().__init__(info) self.value = attr_name self.typeinfo = typeinfo - if not isinstance(typeinfo, (type, _GenericAlias)): + if not isinstance(typeinfo, (type, _GenericAlias, types.GenericAlias)): # typing._GenericAlias for List[int] and List[str], etc. + # types.GenericAlias for list[int] and tuple[int], etc. raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.") self.typeinfo = typeinfo From 8d158cc662a7cda6a130fb5120f9ab1876a1a07d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 13:34:58 -0700 Subject: [PATCH 16/24] err --- onnxscript/values.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/values.py b/onnxscript/values.py index 1b76780516..cf3093d0ef 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -768,7 +768,7 @@ def __init__( if not isinstance(typeinfo, (type, _GenericAlias, types.GenericAlias)): # typing._GenericAlias for List[int] and List[str], etc. # types.GenericAlias for list[int] and tuple[int], etc. - raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.") + raise TypeError(f"Expecting a type not {type(typeinfo)} for typeinfo.") self.typeinfo = typeinfo From 72d481ad3ad45a154e3bfa84b3b883a30384d5cb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 13:36:59 -0700 Subject: [PATCH 17/24] Sequence --- onnxscript/type_annotation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index e66288aa25..6593d8fcbc 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. from __future__ import annotations -import collections import inspect import typing from collections.abc import Sequence @@ -41,7 +40,7 @@ bool: onnx.AttributeProto.INTS, # experimental } -_LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence]) +_LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, Sequence]) # Map from ONNX AttributeProto type to its representation (in ONNX Script). _ATTRTYPE_TO_REPR = { From 23e74d57be51dd5f66da9b8eb0f5cfd3e7253129 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 13:39:01 -0700 Subject: [PATCH 18/24] update --- onnxscript/_internal/runtime_typing.py | 4 ---- onnxscript/type_annotation.py | 6 ++---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/onnxscript/_internal/runtime_typing.py b/onnxscript/_internal/runtime_typing.py index 3cf8a8db57..d636bf6381 100644 --- a/onnxscript/_internal/runtime_typing.py +++ b/onnxscript/_internal/runtime_typing.py @@ -22,10 +22,6 @@ checked = typing.cast(typing.Callable[[T], T], _beartype_decorator) - # Beartype warns when we import from typing because the types are deprecated - # in Python 3.9. But there will be a long time until we can move to using - # the native container types for type annotations (when 3.9 is the lowest - # supported version). So we silence the warning. warnings.filterwarnings( "ignore", category=_roar.BeartypeDecorHintPep585DeprecationWarning, diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 6593d8fcbc..3f4bb4ee19 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -76,10 +76,8 @@ def onnx_attr_type_to_onnxscript_repr(attr_type: onnx.AttributeProto.AttributeTy def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue: """Remove Annotated wrapper if present, otherwise return typeinfo as is.""" - if hasattr(typing, "Annotated"): - # Present in Python 3.9+ - if typing.get_origin(typeinfo) is typing.Annotated: - return typing.get_args(typeinfo)[0] + if typing.get_origin(typeinfo) is typing.Annotated: + return typing.get_args(typeinfo)[0] return typeinfo From 735e4a9342e87717abd640be3738437bcfc6f33e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 13:45:07 -0700 Subject: [PATCH 19/24] _is_tensor_type --- onnxscript/type_annotation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 3f4bb4ee19..659530d9bf 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -128,6 +128,10 @@ def base_type_is_bool(pytype: TypeAnnotationValue) -> bool: def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: if isinstance(typeinfo, onnx_types.TensorType): return True + if isinstance(typeinfo, typing.TypeVar): + # Special case the handle TypeVar for py310 because inspect.isclass(typeinfo) + # seems to return True for TypeVar + return False if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType): return True return False From 2173b6bebb8293bf78f31c222382b77fe3f24deb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 13:47:57 -0700 Subject: [PATCH 20/24] constraints --- onnxscript/type_annotation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 659530d9bf..afcace9f28 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -171,6 +171,10 @@ def is_value_type(typeinfo: TypeAnnotationValue) -> bool: if hasattr(typeinfo, "__bound__"): bound = typeinfo.__bound__ return is_value_type(bound) + if hasattr(typeinfo, "__constraints__"): + constraints = typeinfo.__constraints__ + if constraints: + return any(is_value_type(x) for x in constraints) raise ValueError(f"Unsupported type annotation {typeinfo}") From bdde8bcdc8d46e32e150ed1af723fc63a2424ecd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 13:51:54 -0700 Subject: [PATCH 21/24] Fix --- .../tools/torch_lib/generate_prims_signatures.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py index 76c975f4fb..f5ca893a9d 100644 --- a/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py +++ b/onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py @@ -27,7 +27,7 @@ def create_list_type(arg: torchgen.model.Argument) -> cg.TypeRef: inner_arg_type = arg.type if not arg.type.is_nullable() else arg.type.elem - assert isinstance(inner_arg_type, torchgen.model.listType), f"arg: {arg}" + assert isinstance(inner_arg_type, torchgen.model.ListType), f"arg: {arg}" arg_type = arg_type_to_str(arg.type) if type_is_builtin(arg_type): @@ -75,7 +75,7 @@ def get_argument_type(arg: torchgen.model.Argument) -> cg.TypeRef: """Returns the Python type for the given argument.""" inner_arg_type = arg.type if not arg.type.is_nullable() else arg.type.elem - if isinstance(inner_arg_type, torchgen.model.listType): + if isinstance(inner_arg_type, torchgen.model.ListType): inner_node = create_list_type(arg) else: arg_type_str = arg_type_to_str(inner_arg_type) @@ -130,7 +130,7 @@ def parse_default_value(arg: torchgen.model.Argument) -> Any: else: if isinstance(value, int): # Expand the value to a tuple if the type is a list. - if isinstance(arg.type, torchgen.model.listType): + if isinstance(arg.type, torchgen.model.ListType): if arg.type.size is not None: return (value,) * arg.type.size return (value,) From bb048397822fd7c5857403a6d97b942d36eefed4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 14:05:48 -0700 Subject: [PATCH 22/24] abc --- onnxscript/type_annotation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index afcace9f28..3ae7e609c8 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -128,9 +128,9 @@ def base_type_is_bool(pytype: TypeAnnotationValue) -> bool: def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: if isinstance(typeinfo, onnx_types.TensorType): return True - if isinstance(typeinfo, typing.TypeVar): - # Special case the handle TypeVar for py310 because inspect.isclass(typeinfo) - # seems to return True for TypeVar + if type(typeinfo) is onnx_types.TensorType: + # Special case the handle when typeinfo is TensorType. + # It seems abc.ABC in py39 has issues with issubclass return False if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType): return True From 18dc8863b201fd8211331638dba2e10333eeb1ce Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 14:05:59 -0700 Subject: [PATCH 23/24] true --- onnxscript/type_annotation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 3ae7e609c8..66ec221f0b 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -131,7 +131,7 @@ def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: if type(typeinfo) is onnx_types.TensorType: # Special case the handle when typeinfo is TensorType. # It seems abc.ABC in py39 has issues with issubclass - return False + return True if inspect.isclass(typeinfo) and issubclass(typeinfo, onnx_types.TensorType): return True return False From e383fd772bc47e4a41e91276c760485ad6371385 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 28 Apr 2025 14:06:15 -0700 Subject: [PATCH 24/24] typeinfo --- onnxscript/type_annotation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/type_annotation.py b/onnxscript/type_annotation.py index 66ec221f0b..667b6b46c2 100644 --- a/onnxscript/type_annotation.py +++ b/onnxscript/type_annotation.py @@ -128,7 +128,7 @@ def base_type_is_bool(pytype: TypeAnnotationValue) -> bool: def _is_tensor_type(typeinfo: TypeAnnotationValue) -> bool: if isinstance(typeinfo, onnx_types.TensorType): return True - if type(typeinfo) is onnx_types.TensorType: + if typeinfo is onnx_types.TensorType: # Special case the handle when typeinfo is TensorType. # It seems abc.ABC in py39 has issues with issubclass return True