Skip to content

Commit

Permalink
Merge branch 'update-pre-commit' into py312
Browse files Browse the repository at this point in the history
  • Loading branch information
maresb committed Jan 9, 2024
2 parents e9efa5d + a66b277 commit 2cf7462
Show file tree
Hide file tree
Showing 33 changed files with 103 additions and 69 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ exclude: |
)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: debug-statements
exclude: |
Expand All @@ -20,23 +20,23 @@ repos:
)$
- id: check-merge-conflict
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py39-plus]
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.12.1
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-comprehensions
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
- repo: https://github.com/humitos/mirrors-autoflake.git
Expand All @@ -54,7 +54,7 @@ repos:
)$
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.0
rev: v1.8.0
hooks:
- id: mypy
language: python
Expand Down
4 changes: 2 additions & 2 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def _lessbroken_deepcopy(a):
else:
rval = copy.deepcopy(a)

assert type(rval) == type(a), (type(rval), type(a))
assert type(rval) is type(a), (type(rval), type(a))

if isinstance(rval, np.ndarray):
assert rval.dtype == a.dtype
Expand Down Expand Up @@ -1156,7 +1156,7 @@ def __str__(self):
return str(self.__dict__)

def __eq__(self, other):
rval = type(self) == type(other)
rval = type(self) is type(other)
if rval:
# nodes are not compared because this comparison is
# supposed to be true for corresponding events that happen
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(self, fn, itypes, otypes, infer_shape):
self.infer_shape = self._infer_shape

def __eq__(self, other):
return type(self) == type(other) and self.__fn == other.__fn
return type(self) is type(other) and self.__fn == other.__fn

def __hash__(self):
return hash(type(self)) ^ hash(self.__fn)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,8 +1084,8 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
viewof_change = []
# Use to track view_of changes

viewedby_add = defaultdict(lambda: [])
viewedby_remove = defaultdict(lambda: [])
viewedby_add = defaultdict(list)
viewedby_remove = defaultdict(list)
# Use to track viewed_by changes

for var in node.outputs:
Expand Down
56 changes: 44 additions & 12 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TypeVar,
Union,
cast,
overload,
)

import numpy as np
Expand Down Expand Up @@ -718,7 +719,7 @@ def __eq__(self, other):
return True

return (
type(self) == type(other)
type(self) is type(other)
and self.id == other.id
and self.type == other.type
)
Expand Down Expand Up @@ -1301,9 +1302,31 @@ def clone_get_equiv(
return memo


@overload
def general_toposort(
outputs: Iterable[T],
deps: None,
compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]],
deps_cache: Optional[dict[T, list[T]]],
clients: Optional[dict[T, list[T]]],
) -> list[T]:
...


@overload
def general_toposort(
outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, list[T]]],
compute_deps_cache: None,
deps_cache: None,
clients: Optional[dict[T, list[T]]],
) -> list[T]:
...


def general_toposort(
outputs: Iterable[T],
deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]],
compute_deps_cache: Optional[
Callable[[T], Optional[Union[OrderedSet, list[T]]]]
] = None,
Expand Down Expand Up @@ -1345,7 +1368,7 @@ def general_toposort(
if deps_cache is None:
deps_cache = {}

def _compute_deps_cache(io):
def _compute_deps_cache_(io):
if io not in deps_cache:
d = deps(io)

Expand All @@ -1363,6 +1386,8 @@ def _compute_deps_cache(io):
else:
return deps_cache[io]

_compute_deps_cache = _compute_deps_cache_

else:
_compute_deps_cache = compute_deps_cache

Expand Down Expand Up @@ -1451,15 +1476,14 @@ def io_toposort(
)
return order

compute_deps = None
compute_deps_cache = None
iset = set(inputs)
deps_cache: dict = {}

if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.

deps_cache: dict = {}

def compute_deps_cache(obj):
if obj in deps_cache:
return deps_cache[obj]
Expand All @@ -1478,6 +1502,14 @@ def compute_deps_cache(obj):
deps_cache[obj] = rval
return rval

topo = general_toposort(
outputs,
deps=None,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)

else:
# the inputs are used only here in the function that decides what
# 'predecessors' to explore
Expand All @@ -1494,13 +1526,13 @@ def compute_deps(obj):
assert not orderings.get(obj, None)
return rval

topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache,
clients=clients,
)
topo = general_toposort(
outputs,
deps=compute_deps,
compute_deps_cache=None,
deps_cache=None,
clients=clients,
)
return [o for o in topo if isinstance(o, Apply)]


Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/null_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def values_eq(self, a, b, force_same_dtype=True):
raise ValueError("NullType has no values to compare")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down
10 changes: 6 additions & 4 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,8 @@ class MetaNodeRewriter(NodeRewriter):

def __init__(self):
self.verbose = config.metaopt__verbose
self.track_dict = defaultdict(lambda: [])
self.tag_dict = defaultdict(lambda: [])
self.track_dict = defaultdict(list)
self.tag_dict = defaultdict(list)
self._tracks = []
self.rewriters = []

Expand Down Expand Up @@ -2406,13 +2406,15 @@ def importer(node):
if node is not current_node:
q.append(node)

chin = None
chin: Optional[Callable] = None
if self.tracks_on_change_inputs:

def chin(node, i, r, new_r, reason):
def chin_(node, i, r, new_r, reason):
if node is not current_node and not isinstance(node, str):
q.append(node)

chin = chin_

u = self.attach_updater(
fgraph, importer, None, chin=chin, name=getattr(self, "name", None)
)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/graph/rewriting/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def __new__(cls, constraint, token=None, prefix=""):
return obj

def __eq__(self, other):
if type(self) == type(other):
return self.token == other.token and self.constraint == other.constraint
if type(self) is type(other):
return self.token is other.token and self.constraint == other.constraint
return NotImplemented

def __hash__(self):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __hash__(self):
if "__eq__" not in dct:

def __eq__(self, other):
return type(self) == type(other) and tuple(
return type(self) is type(other) and tuple(
getattr(self, a) for a in props
) == tuple(getattr(other, a) for a in props)

Expand Down
2 changes: 1 addition & 1 deletion pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, n_outs, as_view=False, name=None):
self.name = name

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return False
if self.as_view != other.as_view:
return False
Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/c/params_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __hash__(self):

def __eq__(self, other):
return (
type(self) == type(other)
type(self) is type(other)
and self.__params_type__ == other.__params_type__
and all(
# NB: Params object should have been already filtered.
Expand Down Expand Up @@ -432,7 +432,7 @@ def __repr__(self):

def __eq__(self, other):
return (
type(self) == type(other)
type(self) is type(other)
and self.fields == other.fields
and self.types == other.types
)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/c/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def __hash__(self):

def __eq__(self, other):
return (
type(self) == type(other)
type(self) is type(other)
and self.ctype == other.ctype
and len(self) == len(other)
and len(self.aliases) == len(other.aliases)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def get_destroy_dependencies(fgraph: FunctionGraph) -> dict[Apply, list[Variable
in destroy_dependencies.
"""
order = fgraph.orderings()
destroy_dependencies = defaultdict(lambda: [])
destroy_dependencies = defaultdict(list)
for node in fgraph.apply_nodes:
for prereq in order.get(node, []):
destroy_dependencies[node].extend(prereq.outputs)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class ExceptionType(Generic):
def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -51,7 +51,7 @@ def __str__(self):
return f"CheckAndRaise{{{self.exc_type}({self.msg})}}"

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return False

if self.msg == other.msg and self.exc_type == other.exc_type:
Expand Down
6 changes: 3 additions & 3 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ def __call__(self, *types):
return [rval]

def __eq__(self, other):
return type(self) == type(other) and self.tbl == other.tbl
return type(self) is type(other) and self.tbl == other.tbl

def __hash__(self):
return hash(type(self)) # ignore hash of table
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def L_op(self, inputs, outputs, output_gradients):
return self.grad(inputs, output_gradients)

def __eq__(self, other):
test = type(self) == type(other) and getattr(
test = type(self) is type(other) and getattr(
self, "output_types_preference", None
) == getattr(other, "output_types_preference", None)
return test
Expand Down Expand Up @@ -4132,7 +4132,7 @@ def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
type(self) is not type(other)
or self.nin != other.nin
or self.nout != other.nout
):
Expand Down
10 changes: 5 additions & 5 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -675,7 +675,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -724,7 +724,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down
Loading

0 comments on commit 2cf7462

Please sign in to comment.