diff --git a/dspy/primitives/module.py b/dspy/primitives/module.py index 33690c9242..5cf0b22315 100644 --- a/dspy/primitives/module.py +++ b/dspy/primitives/module.py @@ -4,17 +4,48 @@ import ujson +# NOTE: Note: It's important (temporary decision) to maintain named_parameters that's different in behavior from +# named_sub_modules for the time being. class BaseModule: def __init__(self): pass def named_parameters(self): - """Unlike PyTorch, handles lists of parameters too.""" + """ + Unlike PyTorch, handles (non-recursive) lists of parameters too. + """ + + import dspy from dspy.predict.parameter import Parameter - # Remove the 'self.' prefix from the names - return [(name[5:], param) for name, param in self.named_sub_modules(Parameter)] + visited = set() + named_parameters = [] + + def add_parameter(param_name, param_value): + if isinstance(param_value, Parameter) and id(param_value) not in visited: + visited.add(id(param_value)) + named_parameters.append((param_name, param_value)) + + for name, value in self.__dict__.items(): + if isinstance(value, Parameter): + add_parameter(name, value) + + elif isinstance(value, dspy.Module): + # When a sub-module is pre-compiled, keep it frozen. + if not getattr(value, "_compiled", False): + for sub_name, param in value.named_parameters(): + add_parameter(f"{name}.{sub_name}", param) + + elif isinstance(value, (list, tuple)): + for idx, item in enumerate(value): + add_parameter(f"{name}[{idx}]", item) + + elif isinstance(value, dict): + for key, item in value.items(): + add_parameter(f"{name}['{key}']", item) + + return named_parameters def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]: """Find all sub-modules in the module, as well as their names. diff --git a/tests/primitives/test_program.py b/tests/primitives/test_program.py index 3f86005154..219ad52594 100644 --- a/tests/primitives/test_program.py +++ b/tests/primitives/test_program.py @@ -126,9 +126,9 @@ def test_complex_module_traversal(): "self", "self.sub_module", "self.sub_module.nested_list[0]", - "self.sub_module.nested_list[1][key]", + "self.sub_module.nested_list[1][key]", # NOTE: named_sub_modules allows recursive structures "self.sub_module.nested_tuple[0]", - "self.sub_module.nested_tuple[1][0]", + "self.sub_module.nested_tuple[1][0]", # NEW: named_sub_modules allows recursive structures, but named_prameters does not # "self.sub_module.nested_tuple[1][1]", This should not be included, as it's the same module as the previous one } found_names = {name for name, _ in root.named_sub_modules()}