From 31b9dbe9f8f535d8ab4dd23623181a97570810de Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Mon, 29 Apr 2024 13:39:34 -0700 Subject: [PATCH 1/2] revert named_parameters() to original behavior: DFS order and no recursion in lists --- dspy/primitives/module.py | 37 +++++++++++++++++++++++++++++--- tests/primitives/test_program.py | 4 ++-- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/dspy/primitives/module.py b/dspy/primitives/module.py index 33690c9242..5cf0b22315 100644 --- a/dspy/primitives/module.py +++ b/dspy/primitives/module.py @@ -4,17 +4,48 @@ import ujson +# NOTE: Note: It's important (temporary decision) to maintain named_parameters that's different in behavior from +# named_sub_modules for the time being. class BaseModule: def __init__(self): pass def named_parameters(self): - """Unlike PyTorch, handles lists of parameters too.""" + """ + Unlike PyTorch, handles (non-recursive) lists of parameters too. + """ + + import dspy from dspy.predict.parameter import Parameter - # Remove the 'self.' prefix from the names - return [(name[5:], param) for name, param in self.named_sub_modules(Parameter)] + visited = set() + named_parameters = [] + + def add_parameter(param_name, param_value): + if isinstance(param_value, Parameter) and id(param_value) not in visited: + visited.add(id(param_value)) + named_parameters.append((param_name, param_value)) + + for name, value in self.__dict__.items(): + if isinstance(value, Parameter): + add_parameter(name, value) + + elif isinstance(value, dspy.Module): + # When a sub-module is pre-compiled, keep it frozen. + if not getattr(value, "_compiled", False): + for sub_name, param in value.named_parameters(): + add_parameter(f"{name}.{sub_name}", param) + + elif isinstance(value, (list, tuple)): + for idx, item in enumerate(value): + add_parameter(f"{name}[{idx}]", item) + + elif isinstance(value, dict): + for key, item in value.items(): + add_parameter(f"{name}['{key}']", item) + + return named_parameters def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]: """Find all sub-modules in the module, as well as their names. diff --git a/tests/primitives/test_program.py b/tests/primitives/test_program.py index 3f86005154..2fe08741bc 100644 --- a/tests/primitives/test_program.py +++ b/tests/primitives/test_program.py @@ -126,9 +126,9 @@ def test_complex_module_traversal(): "self", "self.sub_module", "self.sub_module.nested_list[0]", - "self.sub_module.nested_list[1][key]", + # "self.sub_module.nested_list[1][key]", # NEW: NO recursive structures "self.sub_module.nested_tuple[0]", - "self.sub_module.nested_tuple[1][0]", + # "self.sub_module.nested_tuple[1][0]", # NEW: NO recursive structures # "self.sub_module.nested_tuple[1][1]", This should not be included, as it's the same module as the previous one } found_names = {name for name, _ in root.named_sub_modules()} From 6e4714bafcbf2250c1752cc579218403097c5b8e Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Mon, 29 Apr 2024 13:44:59 -0700 Subject: [PATCH 2/2] correction to test_program.py --- tests/primitives/test_program.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/primitives/test_program.py b/tests/primitives/test_program.py index 2fe08741bc..219ad52594 100644 --- a/tests/primitives/test_program.py +++ b/tests/primitives/test_program.py @@ -126,9 +126,9 @@ def test_complex_module_traversal(): "self", "self.sub_module", "self.sub_module.nested_list[0]", - # "self.sub_module.nested_list[1][key]", # NEW: NO recursive structures + "self.sub_module.nested_list[1][key]", # NOTE: named_sub_modules allows recursive structures "self.sub_module.nested_tuple[0]", - # "self.sub_module.nested_tuple[1][0]", # NEW: NO recursive structures + "self.sub_module.nested_tuple[1][0]", # NEW: named_sub_modules allows recursive structures, but named_prameters does not # "self.sub_module.nested_tuple[1][1]", This should not be included, as it's the same module as the previous one } found_names = {name for name, _ in root.named_sub_modules()}