Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Optimize classmethod calls via cls #14789

Merged
merged 6 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,7 @@ def deserialize(cls, data: JsonDict) -> Decorator:

VAR_FLAGS: Final = [
"is_self",
"is_cls",
"is_initialized_in_class",
"is_staticmethod",
"is_classmethod",
Expand Down Expand Up @@ -935,6 +936,7 @@ class Var(SymbolNode):
"type",
"final_value",
"is_self",
"is_cls",
"is_ready",
"is_inferred",
"is_initialized_in_class",
Expand Down Expand Up @@ -967,6 +969,8 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
self.type: mypy.types.Type | None = type # Declared or inferred type, or None
# Is this the first argument to an ordinary method (usually "self")?
self.is_self = False
# Is this the first argument to a classmethod (typically "cls")?
self.is_cls = False
self.is_ready = True # If inferred, is the inferred type available?
self.is_inferred = self.type is None
# Is this initialized explicitly to a non-None value in class body?
Expand Down
7 changes: 5 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,8 +1369,11 @@ def analyze_function_body(self, defn: FuncItem) -> None:
# The first argument of a non-static, non-class method is like 'self'
# (though the name could be different), having the enclosing class's
# instance type.
if is_method and not defn.is_static and not defn.is_class and defn.arguments:
defn.arguments[0].variable.is_self = True
if is_method and not defn.is_static and defn.arguments:
if not defn.is_class:
defn.arguments[0].variable.is_self = True
else:
defn.arguments[0].variable.is_cls = True

defn.body.accept(self)
self.function_stack.pop()
Expand Down
7 changes: 6 additions & 1 deletion mypyc/ir/class_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def __init__(
self.base_mro: list[ClassIR] = [self]

# Direct subclasses of this class (use subclasses() to also include non-direct ones)
# None if separate compilation prevents this from working
# None if separate compilation prevents this from working.
#
# Often it's better to use has_no_subclasses() or subclasses() instead.
self.children: list[ClassIR] | None = []

# Instance attributes that are initialized in the class body.
Expand Down Expand Up @@ -301,6 +303,9 @@ def get_method(self, name: str, *, prefer_method: bool = False) -> FuncIR | None
def has_method_decl(self, name: str) -> bool:
return any(name in ir.method_decls for ir in self.mro)

def has_no_subclasses(self) -> bool:
return self.children == [] and not self.allow_interpreted_subclasses

def subclasses(self) -> set[ClassIR] | None:
"""Return all subclasses of this class, both direct and indirect.

Expand Down
8 changes: 7 additions & 1 deletion mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,11 @@ def load_final_literal_value(self, val: int | str | bytes | float | bool, line:
else:
assert False, "Unsupported final literal value"

def get_assignment_target(self, lvalue: Lvalue, line: int = -1) -> AssignmentTarget:
def get_assignment_target(
self, lvalue: Lvalue, line: int = -1, *, for_read: bool = False
) -> AssignmentTarget:
if line == -1:
line = lvalue.line
if isinstance(lvalue, NameExpr):
# If we are visiting a decorator, then the SymbolNode we really want to be looking at
# is the function that is decorated, not the entire Decorator node itself.
Expand All @@ -578,6 +582,8 @@ def get_assignment_target(self, lvalue: Lvalue, line: int = -1) -> AssignmentTar
# New semantic analyzer doesn't create ad-hoc Vars for special forms.
assert lvalue.is_special_form
symbol = Var(lvalue.name)
if not for_read and isinstance(symbol, Var) and symbol.is_cls:
self.error("Cannot assign to the first argument of classmethod", line)
if lvalue.kind == LDEF:
if symbol not in self.symtables[-1]:
# If the function is a generator function, then first define a new variable
Expand Down
65 changes: 40 additions & 25 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
from mypy.types import Instance, ProperType, TupleType, TypeType, get_proper_type
from mypyc.common import MAX_SHORT_INT
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD
from mypyc.ir.ops import (
Assign,
Expand Down Expand Up @@ -174,7 +175,7 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
)
return obj
else:
return builder.read(builder.get_assignment_target(expr), expr.line)
return builder.read(builder.get_assignment_target(expr, for_read=True), expr.line)

return builder.load_global(expr)

Expand Down Expand Up @@ -336,30 +337,7 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr
# Call a method via the *class*
assert isinstance(callee.expr.node, TypeInfo)
ir = builder.mapper.type_to_ir[callee.expr.node]
decl = ir.method_decl(callee.name)
args = []
arg_kinds, arg_names = expr.arg_kinds[:], expr.arg_names[:]
# Add the class argument for class methods in extension classes
if decl.kind == FUNC_CLASSMETHOD and ir.is_ext_class:
args.append(builder.load_native_type_object(callee.expr.node.fullname))
arg_kinds.insert(0, ARG_POS)
arg_names.insert(0, None)
args += [builder.accept(arg) for arg in expr.args]

if ir.is_ext_class:
return builder.builder.call(decl, args, arg_kinds, arg_names, expr.line)
else:
obj = builder.accept(callee.expr)
return builder.gen_method_call(
obj,
callee.name,
args,
builder.node_type(expr),
expr.line,
expr.arg_kinds,
expr.arg_names,
)

return call_classmethod(builder, ir, expr, callee)
elif builder.is_module_member_expr(callee):
# Fall back to a PyCall for non-native module calls
function = builder.accept(callee)
Expand All @@ -368,6 +346,17 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr
function, args, expr.line, arg_kinds=expr.arg_kinds, arg_names=expr.arg_names
)
else:
if isinstance(callee.expr, RefExpr):
node = callee.expr.node
if isinstance(node, Var) and node.is_cls:
typ = get_proper_type(node.type)
if isinstance(typ, TypeType) and isinstance(typ.item, Instance):
class_ir = builder.mapper.type_to_ir.get(typ.item.type)
if class_ir and class_ir.is_ext_class and class_ir.has_no_subclasses():
# Call a native classmethod via cls that can be statically bound,
# since the class has no subclasses.
return call_classmethod(builder, class_ir, expr, callee)

receiver_typ = builder.node_type(callee.expr)

# If there is a specializer for this method name/type, try calling it.
Expand All @@ -389,6 +378,32 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr
)


def call_classmethod(builder: IRBuilder, ir: ClassIR, expr: CallExpr, callee: MemberExpr) -> Value:
decl = ir.method_decl(callee.name)
args = []
arg_kinds, arg_names = expr.arg_kinds[:], expr.arg_names[:]
# Add the class argument for class methods in extension classes
if decl.kind == FUNC_CLASSMETHOD and ir.is_ext_class:
args.append(builder.load_native_type_object(ir.fullname))
arg_kinds.insert(0, ARG_POS)
arg_names.insert(0, None)
args += [builder.accept(arg) for arg in expr.args]

if ir.is_ext_class:
return builder.builder.call(decl, args, arg_kinds, arg_names, expr.line)
else:
obj = builder.accept(callee.expr)
return builder.gen_method_call(
obj,
callee.name,
args,
builder.node_type(expr),
expr.line,
expr.arg_kinds,
expr.arg_names,
)


def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: SuperExpr) -> Value:
if callee.info is None or (len(callee.call.args) != 0 and len(callee.call.args) != 2):
return translate_call(builder, expr, callee)
Expand Down
69 changes: 69 additions & 0 deletions mypyc/test-data/irbuild-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,75 @@ L0:
r3 = CPyTagged_Add(r0, r2)
return r3

[case testCallClassMethodViaCls]
class C:
@classmethod
def f(cls, x: int) -> int:
return cls.g(x)

@classmethod
def g(cls, x: int) -> int:
return x

class D:
@classmethod
def f(cls, x: int) -> int:
# TODO: This could aso be optimized, since g is not ever overridden
return cls.g(x)

@classmethod
def g(cls, x: int) -> int:
return x

class DD(D):
pass
[out]
def C.f(cls, x):
cls :: object
x :: int
r0 :: object
r1 :: int
L0:
r0 = __main__.C :: type
r1 = C.g(r0, x)
return r1
def C.g(cls, x):
cls :: object
x :: int
L0:
return x
def D.f(cls, x):
cls :: object
x :: int
r0 :: str
r1, r2 :: object
r3 :: int
L0:
r0 = 'g'
r1 = box(int, x)
r2 = CPyObject_CallMethodObjArgs(cls, r0, r1, 0)
r3 = unbox(int, r2)
return r3
def D.g(cls, x):
cls :: object
x :: int
L0:
return x

[case testCannotAssignToClsArgument]
from typing import Any, cast

class C:
@classmethod
def m(cls) -> None:
cls = cast(Any, D) # E: Cannot assign to the first argument of classmethod
cls, x = cast(Any, D), 1 # E: Cannot assign to the first argument of classmethod
cls, x = cast(Any, [1, 2]) # E: Cannot assign to the first argument of classmethod
cls.m()

class D:
pass

[case testSuper1]
class A:
def __init__(self, x: int) -> None:
Expand Down
107 changes: 86 additions & 21 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -662,42 +662,107 @@ Traceback (most recent call last):
AttributeError: attribute 'x' of 'X' undefined

[case testClassMethods]
MYPY = False
if MYPY:
from typing import ClassVar
from typing import ClassVar, Any
from typing_extensions import final
from mypy_extensions import mypyc_attr

from interp import make_interpreted_subclass

class C:
lurr: 'ClassVar[int]' = 9
lurr: ClassVar[int] = 9
@staticmethod
def foo(x: int) -> int: return 10 + x
def foo(x: int) -> int:
return 10 + x
@classmethod
def bar(cls, x: int) -> int: return cls.lurr + x
def bar(cls, x: int) -> int:
return cls.lurr + x
@staticmethod
def baz(x: int, y: int = 10) -> int: return y - x
def baz(x: int, y: int = 10) -> int:
return y - x
@classmethod
def quux(cls, x: int, y: int = 10) -> int: return y - x
def quux(cls, x: int, y: int = 10) -> int:
return y - x
@classmethod
def call_other(cls, x: int) -> int:
return cls.quux(x, 3)

class D(C):
def f(self) -> int:
return super().foo(1) + super().bar(2) + super().baz(10) + super().quux(10)

def test1() -> int:
def ctest1() -> int:
return C.foo(1) + C.bar(2) + C.baz(10) + C.quux(10) + C.quux(y=10, x=9)
def test2() -> int:

def ctest2() -> int:
c = C()
return c.foo(1) + c.bar(2) + c.baz(10)
[file driver.py]
from native import *
assert C.foo(10) == 20
assert C.bar(10) == 19
c = C()
assert c.foo(10) == 20
assert c.bar(10) == 19

assert test1() == 23
assert test2() == 22
CAny: Any = C

def test_classmethod_using_any() -> None:
assert CAny.foo(10) == 20
assert CAny.bar(10) == 19

def test_classmethod_on_instance() -> None:
c = C()
assert c.foo(10) == 20
assert c.bar(10) == 19
assert c.call_other(1) == 2

def test_classmethod_misc() -> None:
assert ctest1() == 23
assert ctest2() == 22
assert C.call_other(2) == 1

def test_classmethod_using_super() -> None:
d = D()
assert d.f() == 22

d = D()
assert d.f() == 22
@final
class F1:
@classmethod
def f(cls, x: int) -> int:
return cls.g(x)

@classmethod
def g(cls, x: int) -> int:
return x + 1

class F2: # Implicitly final (no subclasses)
@classmethod
def f(cls, x: int) -> int:
return cls.g(x)

@classmethod
def g(cls, x: int) -> int:
return x + 1

def test_classmethod_of_final_class() -> None:
assert F1.f(5) == 6
assert F2.f(7) == 8

@mypyc_attr(allow_interpreted_subclasses=True)
class CI:
@classmethod
def f(cls, x: int) -> int:
return cls.g(x)

@classmethod
def g(cls, x: int) -> int:
return x + 1

def test_classmethod_with_allow_interpreted() -> None:
assert CI.f(4) == 5
sub = make_interpreted_subclass(CI)
assert sub.f(4) == 7

[file interp.py]
def make_interpreted_subclass(base):
class Sub(base):
@classmethod
def g(cls, x: int) -> int:
return x + 3
return Sub

[case testSuper]
from mypy_extensions import trait
Expand Down