Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions src/python_template/utils/split_args_for_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def call_init_chain_respecting_super( # pylint: disable=too-many-locals,too-com
"""Call init methods of super-respecting classes via `super()` chain, and manually call
those that don't use `super()`.
"""
called = set()
# called = set()

# Step 1: Identify super-respecting classes:
# (crude guess: look at if "super" is in source)
Expand All @@ -379,15 +379,23 @@ def call_init_chain_respecting_super( # pylint: disable=too-many-locals,too-com
# except Exception:
# pass

# Step 2: Call init for non-super-respecting classes:
for base in cls.__mro__:
if base in (object, stop_at, skip_class) or base in called:
continue
if base not in super_respecting and base in split:
args = split[base].get("args", [])
kwargs = split[base].get("kwargs", {})
base.__init__(self, *args, **kwargs) # type: ignore[misc] # pylint: disable=unnecessary-dunder-call
called.add(base)
# Maybe skip this next step, since the classes will be called again
# as part of the super-respecting chains (as individual single-link
# chains)?
# # Step 2: Call init for non-super-respecting classes:
# for base in cls.__mro__:
# if base in (object, stop_at, skip_class) or base in called:
# continue
# if base not in super_respecting and base in split:
# args = split[base].get("args", [])
# kwargs = split[base].get("kwargs", {})
# if DEBUG_PRINTS:
# print(
# "[call_init_chain_respecting_super] Calling init of "
# f"non-super-respecting class {base.__name__}"
# )
# base.__init__(self, *args, **kwargs) # type: ignore[misc] # pylint: disable=unnecessary-dunder-call
# called.add(base)

# Step 3: Collect args and kwargs for super-respecting chain:
# first_bases = []
Expand Down Expand Up @@ -421,6 +429,11 @@ def call_init_chain_respecting_super( # pylint: disable=too-many-locals,too-com
for base in first_bases:
if base in (skip_class,):
continue
if DEBUG_PRINTS:
print(
"[call_init_chain_respecting_super] Calling init of "
f"super-respecting class {base.__name__}"
)
# super(base, self).__init__(*super_args[base], **super_kwargs[base])
base.__init__( # type: ignore[misc] # pylint: disable=unnecessary-dunder-call
self, *super_args[base], **super_kwargs[base]
Expand Down