From e2dc28860ce2a0c1baf7d9f3a63479a3c656190c Mon Sep 17 00:00:00 2001 From: Michael Jones Date: Mon, 25 Nov 2024 09:10:58 +0000 Subject: [PATCH 1/3] feat(dspy): add datamodel-code-generator to dev reqs --- requirements-dev.txt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 98d89e732d..23984fa074 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,7 @@ black==24.2.0 +datamodel-code-generator==0.26.3 +litellm[proxy]==1.51.0 +pillow==10.4.0 pre-commit==3.7.0 pytest==8.3.3 pytest-env==1.1.3 @@ -6,5 +9,3 @@ pytest-mock==3.12.0 ruff==0.3.0 torch==2.2.1 transformers==4.38.2 -pillow==10.4.0 -litellm[proxy]==1.51.0 From 19dbeebe2b621c0767a6bd14b9e2b9b8bab98ea7 Mon Sep 17 00:00:00 2001 From: Michael Jones Date: Mon, 25 Nov 2024 09:12:39 +0000 Subject: [PATCH 2/3] fix(dspy): fix signature replace for pydantic v2.10 --- dspy/signatures/signature.py | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 88812be0c7..b02038abab 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -1,5 +1,6 @@ import ast import inspect +import logging import re import types import typing @@ -11,8 +12,9 @@ from pydantic.fields import FieldInfo import dsp -from dspy.signatures.field import InputField, OutputField, new_to_old_field from dspy.adapters.image_utils import Image +from dspy.signatures.field import InputField, OutputField, new_to_old_field + def signature_to_template(signature, adapter=None) -> dsp.Template: """Convert from new to legacy format.""" @@ -242,7 +244,7 @@ class Signature(BaseModel, metaclass=SignatureMeta): @classmethod @contextmanager def replace( - cls: "Signature", + cls, new_signature: "Signature", validate_new_signature: bool = True, ) -> typing.Generator[None, None, None]: @@ -262,16 +264,28 @@ def replace( f"Field '{field}' is missing from the updated signature '{new_signature.__class__}.", ) - class OldSignature(cls, Signature): + class OldSignature(cls): pass - replace_fields = ["__doc__", "model_fields", "model_extra", "model_config"] - for field in replace_fields: - setattr(cls, field, getattr(new_signature, field)) + replace_attrs = ["__doc__", "__pydantic_fields__", "model_fields", "model_extra", "model_config"] + for attr in replace_attrs: + try: + setattr(cls, attr, getattr(new_signature, attr)) + except AttributeError as exc: + if attr == "model_fields": + logging.debug("Model attribute model_fields not replaced, expected with pydantic > 2.10") + else: + raise exc cls.model_rebuild(force=True) yield - for field in replace_fields: - setattr(cls, field, getattr(OldSignature, field)) + for attr in replace_attrs: + try: + setattr(cls, attr, getattr(OldSignature, attr)) + except AttributeError as exc: + if attr == "model_fields": + logging.debug("Model attribute model_fields not replaced, expected with pydantic > 2.10") + else: + raise exc cls.model_rebuild(force=True) @@ -383,7 +397,7 @@ def _parse_type_node(node, names=None) -> Any: without using structural pattern matching introduced in Python 3.10. """ - + if names is None: names = typing.__dict__ @@ -401,7 +415,7 @@ def _parse_type_node(node, names=None) -> Any: id_ = node.id if id_ in names: return names[id_] - + for type_ in [int, str, float, bool, list, tuple, dict, Image]: if type_.__name__ == id_: return type_ @@ -420,7 +434,7 @@ def _parse_type_node(node, names=None) -> Any: keys = [kw.arg for kw in node.keywords] values = [kw.value.value for kw in node.keywords] return Field(**dict(zip(keys, values))) - + if isinstance(node, ast.Attribute) and node.attr == "Image": return Image From 15b090bdd67507fd1643104f1292de432cd75554 Mon Sep 17 00:00:00 2001 From: Michael Jones Date: Mon, 25 Nov 2024 09:36:09 +0000 Subject: [PATCH 3/3] fix(dspy): fix signature replace for pydantic v2.10 --- dspy/signatures/signature.py | 43 +++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index b02038abab..387b9d771d 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -245,7 +245,7 @@ class Signature(BaseModel, metaclass=SignatureMeta): @contextmanager def replace( cls, - new_signature: "Signature", + new_signature: "Type[Signature]", validate_new_signature: bool = True, ) -> typing.Generator[None, None, None]: """Replace the signature with an updated version. @@ -267,25 +267,32 @@ def replace( class OldSignature(cls): pass - replace_attrs = ["__doc__", "__pydantic_fields__", "model_fields", "model_extra", "model_config"] - for attr in replace_attrs: - try: - setattr(cls, attr, getattr(new_signature, attr)) - except AttributeError as exc: - if attr == "model_fields": - logging.debug("Model attribute model_fields not replaced, expected with pydantic > 2.10") - else: - raise exc + def swap_attributes(source: Type[Signature]): + unhandled = {} + + for attr in ["__doc__", "__pydantic_fields__", "model_fields", "model_extra", "model_config"]: + try: + setattr(cls, attr, getattr(source, attr)) + except AttributeError as exc: + if attr in ("__pydantic_fields__", "model_fields"): + version = "< 2.10" if attr == "__pydantic_fields__" else ">= 2.10" + logging.debug(f"Model attribute {attr} not replaced, expected with pydantic {version}") + unhandled[attr] = exc + else: + raise exc + + # if neither of the attributes were replaced, raise an error to prevent silent failures + if set(unhandled.keys()) >= {"model_fields", "__pydantic_fields__"}: + raise ValueError("Failed to replace either model_fields or __pydantic_fields__") from ( + unhandled.get("model_fields") or unhandled.get("__pydantic_fields__") + ) + + swap_attributes(new_signature) cls.model_rebuild(force=True) + yield - for attr in replace_attrs: - try: - setattr(cls, attr, getattr(OldSignature, attr)) - except AttributeError as exc: - if attr == "model_fields": - logging.debug("Model attribute model_fields not replaced, expected with pydantic > 2.10") - else: - raise exc + + swap_attributes(OldSignature) cls.model_rebuild(force=True)