From 43ce6ce63eeb5f73b67e0b27176c012e8da663af Mon Sep 17 00:00:00 2001 From: Advait Dixit Date: Thu, 5 Dec 2024 21:43:13 -0800 Subject: [PATCH] [mypyc] Fixing iteration over NamedTuple objects. --- mypyc/irbuild/builder.py | 22 +++++++++++++++++----- mypyc/test-data/run-loops.test | 26 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index a0837ba2bfc7..d8c85c6002b7 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -60,6 +60,7 @@ Type, TypedDictType, TypeOfAny, + TypeVarLikeType, UninhabitedType, UnionType, get_proper_type, @@ -910,11 +911,22 @@ def get_sequence_type_from_type(self, target_type: Type) -> RType: return RUnion.make_simplified_union( [self.get_sequence_type_from_type(item) for item in target_type.items] ) - assert isinstance(target_type, Instance), target_type - if target_type.type.fullname == "builtins.str": - return str_rprimitive - else: - return self.type_to_rtype(target_type.args[0]) + elif isinstance(target_type, Instance): + if target_type.type.fullname == "builtins.str": + return str_rprimitive + else: + return self.type_to_rtype(target_type.args[0]) + # This elif-blocks are needed for iterating over classes derived from NamedTuple. + elif isinstance(target_type, TypeVarLikeType): + return self.get_sequence_type_from_type(target_type.upper_bound) + elif isinstance(target_type, TupleType): + # Tuple might have elements of different types. + rtypes = {self.mapper.type_to_rtype(item) for item in target_type.items} + if len(rtypes) == 1: + return rtypes.pop() + else: + return RUnion.make_simplified_union(list(rtypes)) + assert False, target_type def get_dict_base_type(self, expr: Expression) -> list[Instance]: """Find dict type of a dict-like expression. diff --git a/mypyc/test-data/run-loops.test b/mypyc/test-data/run-loops.test index 76fbb06200a3..3cbb07297e6e 100644 --- a/mypyc/test-data/run-loops.test +++ b/mypyc/test-data/run-loops.test @@ -545,3 +545,29 @@ def test_range_object() -> None: r4 = range(4, 12, 0) except ValueError as e: assert "range() arg 3 must not be zero" in str(e) + +[case testNamedTupleLoop] +from collections.abc import Iterable +from typing import NamedTuple, Any +from typing_extensions import Self + + +class Vector2(NamedTuple): + x: int + y: float + + @classmethod + def from_iter(cls, iterable: Iterable[Any]) -> Self: + return cls(*iter(iterable)) + + def __neg__(self) -> Self: + return self.from_iter(-c for c in self) + +[file driver.py] +import native +print(-native.Vector2(2, -3.1)) +print([x for x in native.Vector2(4, -5.2)]) + +[out] +Vector2(x=-2, y=3.1) +\[4, -5.2]