Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type inference for user code #25103

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ module = [
]
ignore_missing_imports = true

# pyright type checker settings:
[tool.pyright]
include = ["sympy"]

# This complains when self or cls is not used as the parameter name for an
# instance method or class method. Since this is used a lot in SymPy we disable
# this check.
reportSelfClsParameterName = false

[tool.slotscheck]
strict-imports = true
exclude-modules = '''
Expand Down
41 changes: 38 additions & 3 deletions sympy/core/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Mapping
from itertools import chain, zip_longest
from functools import cmp_to_key
from typing import TYPE_CHECKING

from .assumptions import _prepare_class_assumptions
from .cache import cacheit
Expand All @@ -21,6 +22,11 @@
from inspect import getmro


if TYPE_CHECKING:
from typing import ClassVar
from .assumptions import StdFactKB


def as_Basic(expr):
"""Return expr as a Basic instance using strict sympify
or raise a TypeError; this is just a wrapper to _sympify,
Expand Down Expand Up @@ -218,11 +224,40 @@ def __init_subclass__(cls):
is_Point = False
is_MatAdd = False
is_MatMul = False
is_real: bool | None
is_extended_real: bool | None
is_zero: bool | None

default_assumptions: ClassVar[StdFactKB]

is_composite: bool | None
is_noninteger: bool | None
is_extended_positive: bool | None
is_negative: bool | None
is_complex: bool | None
is_extended_nonpositive: bool | None
is_integer: bool | None
is_positive: bool | None
is_rational: bool | None
is_extended_nonnegative: bool | None
is_infinite: bool | None
is_antihermitian: bool | None
is_extended_negative: bool | None
is_extended_real: bool | None
is_finite: bool | None
is_polar: bool | None
is_imaginary: bool | None
is_transcendental: bool | None
is_extended_nonzero: bool | None
is_nonzero: bool | None
is_odd: bool | None
is_algebraic: bool | None
is_prime: bool | None
is_commutative: bool | None
is_nonnegative: bool | None
is_nonpositive: bool | None
is_hermitian: bool | None
is_irrational: bool | None
is_real: bool | None
is_zero: bool | None
is_even: bool | None

kind: Kind = UndefinedKind

Expand Down
26 changes: 19 additions & 7 deletions sympy/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,21 @@
dependencies, so that they can be easily imported anywhere in sympy/core.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

from functools import wraps
from .sympify import SympifyError, sympify


if TYPE_CHECKING:
from typing import Callable, TypeVar, Union
T1 = TypeVar('T1')
T2 = TypeVar('T2')
T3 = TypeVar('T3')


def _sympifyit(arg, retval=None):
"""
decorator to smartly _sympify function arguments
Expand Down Expand Up @@ -69,7 +80,8 @@ def __sympifyit_wrapper(a, b):
return __sympifyit_wrapper


def call_highest_priority(method_name):
def call_highest_priority(method_name: str
) -> Callable[[Callable[[T1, T2], T3]], Callable[[T1, T2], T3]]:
"""A decorator for binary special methods to handle _op_priority.

Explanation
Expand All @@ -95,12 +107,12 @@ def __mul__(self, other):
def __rmul__(self, other):
...
"""
def priority_decorator(func):
def priority_decorator(func: Callable[[T1, T2], T3]) -> Callable[[T1, T2], T3]:
@wraps(func)
def binary_op_wrapper(self, other):
def binary_op_wrapper(self: T1, other: T2) -> T3:
if hasattr(other, '_op_priority'):
if other._op_priority > self._op_priority:
f = getattr(other, method_name, None)
if other._op_priority > self._op_priority: # type: ignore
f: Union[Callable[[T1], T3], None] = getattr(other, method_name, None)
if f is not None:
return f(self)
return func(self, other)
Expand Down Expand Up @@ -187,8 +199,8 @@ def sympify_return(*args):
See the docstring of sympify_method_args for explanation.
'''
# Store a wrapper object for the decorated method
def wrapper(func):
return _SympifyWrapper(func, args)
def wrapper(func: Callable[[T1, T2], T3]) -> Callable[[T1, T2], T3]:
return _SympifyWrapper(func, args) # type: ignore
return wrapper


Expand Down
40 changes: 20 additions & 20 deletions sympy/core/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ def _add_handler(self):
def _mul_handler(self):
return Mul

def __pos__(self):
def __pos__(self) -> Expr:
return self

def __neg__(self):
def __neg__(self) -> Expr:
# Mul has its own __neg__ routine, so we just
# create a 2-args Mul with the -1 in the canonical
# slot 0.
Expand All @@ -195,32 +195,32 @@ def __abs__(self) -> Expr:

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__radd__')
def __add__(self, other):
def __add__(self, other) -> Expr:
return Add(self, other)

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__add__')
def __radd__(self, other):
def __radd__(self, other) -> Expr:
return Add(other, self)

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__rsub__')
def __sub__(self, other):
def __sub__(self, other) -> Expr:
return Add(self, -other)

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__sub__')
def __rsub__(self, other):
def __rsub__(self, other) -> Expr:
return Add(other, -self)

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__rmul__')
def __mul__(self, other):
def __mul__(self, other) -> Expr:
return Mul(self, other)

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__mul__')
def __rmul__(self, other):
def __rmul__(self, other) -> Expr:
return Mul(other, self)

@sympify_return([('other', 'Expr')], NotImplemented)
Expand All @@ -246,12 +246,12 @@ def __pow__(self, other, mod=None) -> Expr:

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__pow__')
def __rpow__(self, other):
def __rpow__(self, other) -> Expr:
return Pow(other, self)

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__rtruediv__')
def __truediv__(self, other):
def __truediv__(self, other) -> Expr:
denom = Pow(other, S.NegativeOne)
if self is S.One:
return denom
Expand All @@ -260,7 +260,7 @@ def __truediv__(self, other):

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__truediv__')
def __rtruediv__(self, other):
def __rtruediv__(self, other) -> Expr:
denom = Pow(self, S.NegativeOne)
if other is S.One:
return denom
Expand All @@ -269,40 +269,40 @@ def __rtruediv__(self, other):

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__rmod__')
def __mod__(self, other):
def __mod__(self, other) -> Expr:
return Mod(self, other)

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__mod__')
def __rmod__(self, other):
def __rmod__(self, other) -> Expr:
return Mod(other, self)

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__rfloordiv__')
def __floordiv__(self, other):
def __floordiv__(self, other) -> Expr:
from sympy.functions.elementary.integers import floor
return floor(self / other)

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__floordiv__')
def __rfloordiv__(self, other):
def __rfloordiv__(self, other) -> Expr:
from sympy.functions.elementary.integers import floor
return floor(other / self)


@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__rdivmod__')
def __divmod__(self, other):
def __divmod__(self, other) -> tuple[Expr, Expr]:
from sympy.functions.elementary.integers import floor
return floor(self / other), Mod(self, other)

@sympify_return([('other', 'Expr')], NotImplemented)
@call_highest_priority('__divmod__')
def __rdivmod__(self, other):
def __rdivmod__(self, other) -> tuple[Expr, Expr]:
from sympy.functions.elementary.integers import floor
return floor(other / self), Mod(other, self)

def __int__(self):
def __int__(self) -> int:
if not self.is_number:
raise TypeError("Cannot convert symbols to int")
r = self.round(2)
Expand All @@ -328,7 +328,7 @@ def __int__(self):
return i - (1 if i > 0 else -1)
return i

def __float__(self):
def __float__(self) -> float:
# Don't bother testing if it's a number; if it's not this is going
# to fail, and if it is we still need to check that it evalf'ed to
# a number.
Expand All @@ -339,7 +339,7 @@ def __float__(self):
raise TypeError("Cannot convert complex to float")
raise TypeError("Cannot convert expression to float")

def __complex__(self):
def __complex__(self) -> complex:
result = self.evalf()
re, im = result.as_real_imag()
return complex(float(re), float(im))
Expand Down
32 changes: 24 additions & 8 deletions sympy/core/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,15 @@ def _diff_wrt(self):
return False

@cacheit
def __new__(cls, *args, **options):
def __new__(cls, *args, **options) -> type[AppliedUndef]: # type: ignore
# Handle calls like Function('f')
if cls is Function:
return UndefinedFunction(*args, **options)
return UndefinedFunction(*args, **options) # type: ignore
else:
return cls._new_(*args, **options) # type: ignore

@classmethod
def _new_(cls, *args, **options) -> Expr:
n = len(args)

if not cls._valid_nargs(n):
Expand Down Expand Up @@ -809,6 +813,14 @@ def _eval_as_leading_term(self, x, logx=None, cdir=0):
return self


class DefinedFunction(Function):
"""Base class for defined functions like ``sin``, ``cos``, ..."""

@cacheit
def __new__(cls, *args, **options) -> Expr: # type: ignore
return cls._new_(*args, **options)


class AppliedUndef(Function):
"""
Base class for expressions resulting from the application of an undefined
Expand All @@ -817,13 +829,13 @@ class AppliedUndef(Function):

is_number = False

def __new__(cls, *args, **options):
args = list(map(sympify, args))
def __new__(cls, *args, **options) -> Expr: # type: ignore
args = tuple(map(sympify, args))
u = [a.name for a in args if isinstance(a, UndefinedFunction)]
if u:
raise TypeError('Invalid argument: expecting an expression, not UndefinedFunction%s: %s' % (
's'*(len(u) > 1), ', '.join(u)))
obj = super().__new__(cls, *args, **options)
obj: Expr = super().__new__(cls, *args, **options) # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe _new_ should be used directly from here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type: ignore here is because mypy rejects having a cls.__new__ method whose return hint does not imply that it returns an instance of cls:

from __future__ import annotations

class A:
    def __new__(cls) -> A:
        return super().__new__(cls)

class B(A):
    def __new__(cls) -> A:
        return super().__new__(cls)

reveal_type(B())

Here mypy rejects this although pyright accepts it:

$ mypy q.py 
q.py:8: error: Incompatible return type for "__new__" (returns "A", but must return a subtype of "B")  [misc]
q.py:11: note: Revealed type is "q.B"
Found 1 error in 1 file (checked 1 source file)
$ pyright q.py 
  q.py:11:13 - information: Type of "B()" is "A"
0 errors, 0 warnings, 1 information 

I could change this to say that the return type is AppliedUndef and then mypy would not complain. I am not sure if this should really be considered a bug in mypy though.

The return type Expr was needed for Function.__new__ because you can have e.g. sin(pi) -> 0 so we have no guarantee that sin(obj) is of type sin but it should always be of type Expr. With AppliedUndef there is no evaluation though so it would be accurate to say that it returns AppliedUndef. I'm just not sure if that is more useful to users than Expr (which is not incorrect since AppliedUndef is a subclass of Expr).

return obj

def _eval_as_leading_term(self, x, logx=None, cdir=0):
Expand Down Expand Up @@ -862,11 +874,15 @@ def __get__(self, ins, typ):

_undef_sage_helper = UndefSageHelper()


class UndefinedFunction(FunctionClass):
"""
The (meta)class of undefined functions.
"""
def __new__(mcl, name, bases=(AppliedUndef,), __dict__=None, **kwargs):
name: str
_sage_: UndefSageHelper

def __new__(mcl, name, bases=(AppliedUndef,), __dict__=None, **kwargs) -> type[AppliedUndef]:
from .symbol import _filter_assumptions
# Allow Function('f', real=True)
# and/or Function(Symbol('f', real=True))
Expand Down Expand Up @@ -894,10 +910,10 @@ def __new__(mcl, name, bases=(AppliedUndef,), __dict__=None, **kwargs):
__dict__.update({'_kwargs': kwargs})
# do this for pickling
__dict__['__module__'] = None
obj = super().__new__(mcl, name, bases, __dict__)
obj = super().__new__(mcl, name, bases, __dict__) # type: ignore
obj.name = name
obj._sage_ = _undef_sage_helper
return obj
return obj # type: ignore

def __instancecheck__(cls, instance):
return cls in type(instance).__mro__
Expand Down
4 changes: 2 additions & 2 deletions sympy/core/mod.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from .add import Add
from .exprtools import gcd_terms
from .function import Function
from .function import DefinedFunction
from .kind import NumberKind
from .logic import fuzzy_and, fuzzy_not
from .mul import Mul
from .numbers import equal_valued
from .singleton import S


class Mod(Function):
class Mod(DefinedFunction):
"""Represents a modulo operation on symbolic expressions.

Parameters
Expand Down
5 changes: 5 additions & 0 deletions sympy/core/tests/test_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,11 @@ def test_sympy__core__function__AppliedUndef():
assert _test_args(AppliedUndef(1, 2, 3))


def test_sympy__core__function__DefinedFunction():
from sympy.core.function import DefinedFunction
assert _test_args(DefinedFunction(1, 2, 3))


def test_sympy__core__function__Derivative():
from sympy.core.function import Derivative
assert _test_args(Derivative(2, x, y, 3))
Expand Down
Loading
Loading