Skip to content
Closed
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions onnxscript/_internal/analysis.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,8 @@
from __future__ import annotations

import ast
from typing import Any, Optional, Sequence, Set
from collections.abc import Sequence
from typing import Any, Optional

from onnxscript import sourceinfo
from onnxscript._internal import ast_utils
@@ -15,7 +16,7 @@ def _get_loop_var(for_stmt: ast.For, formatter: sourceinfo.Formatter) -> str:
return for_stmt.target.id


def _used_vars(expr: Optional[ast.expr]) -> Set[str]:
def _used_vars(expr: Optional[ast.expr]) -> set[str]:
"""Return set of all variables used, including function names, in an expression."""
if expr is None:
return set()
@@ -35,7 +36,7 @@ def _used_vars(expr: Optional[ast.expr]) -> Set[str]:
return result


def _lhs_vars(lhs: ast.expr) -> Set[str]:
def _lhs_vars(lhs: ast.expr) -> set[str]:
"""Return set of assigned variables in the lhs of an assignment statement."""

def get_id(e):
@@ -49,12 +50,12 @@ def get_id(e):

def assigned_vars(
stmt: ast.stmt | list[ast.stmt], formatter: sourceinfo.Formatter
) -> Set[str]:
) -> set[str]:
"""Return the set of all variables that may be assigned to in an execution of input stmt
or sequence of statements.
"""

def assigned_in_block(block: Sequence[ast.stmt]) -> Set[str]:
def assigned_in_block(block: Sequence[ast.stmt]) -> set[str]:
result: set[Any] = set()
for s in block:
result = result | assigned_vars(s, formatter)
@@ -90,14 +91,14 @@ def do_liveness_analysis(fun: ast.FunctionDef, formatter: sourceinfo.Formatter):
and `s.live_out`.
"""

def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]:
stmt.live_out = live_out # type: ignore[attr-defined]
live = do_visit(stmt, live_out)
stmt.live_in = live # type: ignore[attr-defined]
return live

def do_visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
def do_visit(stmt: ast.stmt, live_out: set[str]) -> set[str]:
def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]:
for s in reversed(block):
live_out = visit(s, live_out)
return live_out
@@ -165,12 +166,12 @@ def exposed_uses(stmts: Sequence[ast.stmt], formatter: sourceinfo.Formatter):
(in the first statement). Hence x is included in the exposed_uses.
"""

def visitBlock(block: Sequence[ast.stmt], live_out: Set[str]) -> Set[str]:
def visitBlock(block: Sequence[ast.stmt], live_out: set[str]) -> set[str]:
for stmt in reversed(block):
live_out = visit(stmt, live_out)
return live_out

def visit(stmt: ast.stmt, live_out: Set[str]) -> Set[str]:
def visit(stmt: ast.stmt, live_out: set[str]) -> set[str]:
if isinstance(stmt, ast.Assign):
return live_out.difference(_lhs_vars(stmt.targets[0])) | _used_vars(stmt.value)
if isinstance(stmt, ast.AnnAssign):
3 changes: 2 additions & 1 deletion onnxscript/_internal/autocast.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable, Optional

import numpy as np
import onnx
4 changes: 3 additions & 1 deletion onnxscript/_internal/param_manipulation.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,9 @@
from __future__ import annotations

import collections
from typing import Any, OrderedDict, Sequence
from collections import OrderedDict
from collections.abc import Sequence
from typing import Any

from onnxscript import values

3 changes: 2 additions & 1 deletion onnxscript/_internal/utils.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,8 @@
from __future__ import annotations

import numbers
from typing import Optional, Sequence
from collections.abc import Sequence
from typing import Optional

import numpy as np
import onnx
3 changes: 2 additions & 1 deletion onnxscript/_internal/version_utils.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,8 @@
from __future__ import annotations

import warnings
from typing import Callable, Sequence
from collections.abc import Sequence
from typing import Callable

import packaging.version

6 changes: 3 additions & 3 deletions onnxscript/_legacy_ir/__init__.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@

import dataclasses
from collections import deque
from typing import List, Tuple, Union
from typing import Union

import numpy as np
import onnx
@@ -47,9 +47,9 @@ def __init__(self) -> None:
# TODO: Technically, SymbolicValue should be a recursive type to handle lists of lists of
# tensors, etc. However, we currently only handle lists of tensors.

SymbolicValue = Union[str, List[str]]
SymbolicValue = Union[str, list[str]]

FunctionId = Tuple[str, str, str]
FunctionId = tuple[str, str, str]


def get_function_id(function: onnx.FunctionProto) -> FunctionId:
3 changes: 2 additions & 1 deletion onnxscript/_legacy_ir/visitor.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,8 @@

import dataclasses
import logging
from typing import Any, Sequence
from collections.abc import Sequence
from typing import Any

import numpy as np
import onnx
2 changes: 1 addition & 1 deletion onnxscript/_thirdparty/asciichartpy.py
Original file line number Diff line number Diff line change
@@ -32,8 +32,8 @@

from __future__ import annotations

from collections.abc import Mapping
from math import ceil, floor, isnan
from typing import Mapping

black = "\033[30m"
red = "\033[31m"
2 changes: 1 addition & 1 deletion onnxscript/backend/onnx_backend.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@

import os
import textwrap
from typing import Iterator
from collections.abc import Iterator

import numpy as np
import onnx
3 changes: 2 additions & 1 deletion onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,8 @@
# Licensed under the MIT License.
from __future__ import annotations

from typing import Any, Optional, Sequence
from collections.abc import Sequence
from typing import Any, Optional

import numpy
import onnx
2 changes: 1 addition & 1 deletion onnxscript/backend/onnx_export_test.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import re
import sys
import unittest
from typing import Pattern
from re import Pattern

import onnx
import onnxruntime as ort
45 changes: 15 additions & 30 deletions onnxscript/converter.py
Original file line number Diff line number Diff line change
@@ -4,15 +4,12 @@

import ast
import logging
from collections.abc import Sequence
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Union,
)

@@ -178,11 +175,11 @@ def __init__(
self.default_opset_ = default_opset

# States initialized by `_init_function_translation`
self._outer: List[irbuilder.IRFunction] = []
self._outer: list[irbuilder.IRFunction] = []
self._current_fn: irbuilder.IRFunction = None
self._nextvar: int = 0
self._used_vars: set[str] = set()
self._locals: List[Dict[str, LocalSymValue]] = [{}]
self._locals: list[dict[str, LocalSymValue]] = [{}]

@property
def default_opset(self) -> values.Opset:
@@ -230,7 +227,7 @@ def _init_function_translation(self) -> None:
self._current_fn: Optional[irbuilder.IRFunction] = None
self._nextvar = 0
self._used_vars = set()
self._locals: List[Dict[str, LocalSymValue]] = [{}]
self._locals: list[dict[str, LocalSymValue]] = [{}]

def _source_of(self, node: ast.AST) -> sourceinfo.SourceInfo:
return sourceinfo.SourceInfo(node, self.source, self._current_fn.name)
@@ -269,7 +266,7 @@ def _exit_scope(self) -> irbuilder.IRFunction:
self._locals.pop(0)
return graph

def _current_scope(self) -> Dict[str, LocalSymValue]:
def _current_scope(self) -> dict[str, LocalSymValue]:
return self._locals[0]

def _bind(self, name: str, val: LocalSymValue) -> None:
@@ -433,7 +430,7 @@ def _is_constant_expr(self, node: ast.AST) -> None:
ast.UnaryOp,
ast.Compare,
ast.Attribute,
ast.List,
ast.list,
ast.Load,
ast.Constant,
),
@@ -528,12 +525,7 @@ def _translate_attr(
return attr

def _translate_docstring(self, node: ast.Expr) -> None:
if hasattr(node.value, "value"):
# python 3.8+
return self.ir_builder.add_docstring(self._current_fn, node.value.value)
raise TypeError(
f"Unexpected type {type(node)!r} for node. Unsupoorted version of python."
)
return self.ir_builder.add_docstring(self._current_fn, node.value.value)

def _translate_expr(
self, node: ast.AST, target: Optional[PreferredName] = None
@@ -697,9 +689,9 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]:

# As the first step, we partition the index elements into four kinds: Slice (eg., 1:5:2),
# known-to-be-scalar (eg., 2), other-tensor (eg., I), skip/no-op (that is, just ":")
sliced_indices: List[Tuple[int, ast.expr]] = []
scalar_indices: List[Tuple[int, ast.expr]] = []
non_scalar_indices: List[Tuple[int, ast.expr]] = []
sliced_indices: list[tuple[int, ast.expr]] = []
scalar_indices: list[tuple[int, ast.expr]] = []
non_scalar_indices: list[tuple[int, ast.expr]] = []
for axis, elt in enumerate(indices):
if isinstance(elt, ast.Slice):
# Add to sliced_indices, unless it is "::", which is a no-op.
@@ -870,14 +862,7 @@ def _translate_unary_op_expr(self, node):
# should intercept this call and replace node
# by node.operand.
# This mechanism does not handle somthing like `(-(-5))`.
if hasattr(node.operand, "value"):
# python 3.8+
val = node.operand.value
else:
raise TypeError(
f"Unable to guess constant value from type {type(node.operand)!r} "
f"and attributes {dir(node.operand)!r}."
)
val = node.operand.value
if op == ast.USub:
cst = ast.Constant(-val, lineno=node.lineno, col_offset=node.col_offset)
return self._translate_expr(cst)
@@ -996,7 +981,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None:
typeinfo = None
var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo)
self._bind(lhs, var)
elif isinstance(lhs, ast.Tuple):
elif isinstance(lhs, ast.tuple):
# Assignments of the form "x, y, z = op.SomeOp(...)"
if not isinstance(rhs, ast.Call):
self.fail(
@@ -1031,9 +1016,9 @@ def generate_onnx_name(x: ast.AST):
self.fail(stmt, "Multi-assignment not supported.")
lhs = targets[0]
rhs = stmt.value
if isinstance(rhs, ast.Tuple):
if isinstance(rhs, ast.tuple):
# Assignments of the form "... = Expression1, Expression2"
if not isinstance(lhs, ast.Tuple):
if not isinstance(lhs, ast.tuple):
# Assignments of the form "single_var = Expression1, Expression2".
# We do not support tuple-typed variables.
self.fail(lhs, f"Left term must be a tuple not '{type(lhs)!r}'.")
@@ -1082,7 +1067,7 @@ def ret(exp, i, suffix):

val = stmt.value
assert val is not None, "Return statement without return-value not supported."
if isinstance(val, ast.Tuple):
if isinstance(val, ast.tuple):
check_num_outputs(len(val.elts))
return [ret(exp, i, str(i)) for i, exp in enumerate(val.elts)]
check_num_outputs(1)
3 changes: 1 addition & 2 deletions onnxscript/evaluator.py
Original file line number Diff line number Diff line change
@@ -5,13 +5,12 @@
import abc
import contextlib
import pprint
from collections.abc import Mapping, Sequence
from typing import (
Any,
Callable,
Mapping,
Optional,
Protocol,
Sequence,
TypeVar,
Union,
runtime_checkable,
21 changes: 11 additions & 10 deletions onnxscript/function_libs/tools/torch_lib/deduce_type_constraints.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,8 @@
import copy
import dataclasses
import logging
from typing import Dict, Mapping, Optional, Sequence, Set
from collections.abc import Mapping, Sequence
from typing import Optional

import onnx
import onnx.defs
@@ -47,10 +48,10 @@ class TypeConstraint:
"""Type constraint shared by multiple values."""

name: str
type_strs: Set[str]
values: Set[Value]
type_strs: set[str]
values: set[Value]

def __init__(self, name: str, type_strs: Set[str]):
def __init__(self, name: str, type_strs: set[str]):
self.name = name
self.type_strs = type_strs
self.values = set()
@@ -125,9 +126,9 @@ def __repr__(self) -> str:

@dataclasses.dataclass
class OnnxFunctionTypeConstraints:
input_type_constraints: Dict[str, Optional[TypeConstraint]]
output_type_constraints: Dict[str, Optional[TypeConstraint]]
intermediate_type_constraints: Dict[str, Optional[TypeConstraint]]
input_type_constraints: dict[str, Optional[TypeConstraint]]
output_type_constraints: dict[str, Optional[TypeConstraint]]
intermediate_type_constraints: dict[str, Optional[TypeConstraint]]

def __repr__(self):
repr_strs = [
@@ -190,7 +191,7 @@ def __repr__(self):
class TypeConstraintDeducer:
def __init__(self, onnx_function: onnxscript.OnnxFunction):
self.onnx_function = onnx_function
self.values: Dict[str, Value] = {}
self.values: dict[str, Value] = {}

def type_constraints(self, signature_only: bool = True) -> OnnxFunctionTypeConstraints:
"""Retrieve deduced type constraints for the ONNX function."""
@@ -210,7 +211,7 @@ def type_constraints(self, signature_only: bool = True) -> OnnxFunctionTypeConst
)

# Rename type constraints to T0, T1, T2, ...
seen_type_constraints: Set[TypeConstraint] = set()
seen_type_constraints: set[TypeConstraint] = set()
for type_constraint in (
*input_type_constraints.values(),
*output_type_constraints.values(),
@@ -250,7 +251,7 @@ def _bind_signature(
node: onnx.NodeProto,
param_names: Sequence[str],
param_schemas: Sequence[onnx.defs.OpSchema.FormalParameter],
op_type_constraints: Dict[str, TypeConstraint],
op_type_constraints: dict[str, TypeConstraint],
is_output: bool = False,
):
param_schemas = list(param_schemas)
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.