Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 78 additions & 8 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import dsp

from .field import Field
from .field import Field, InputField, OutputField
import threading

class SignatureMeta(type):
Expand Down Expand Up @@ -34,22 +34,92 @@ def __new__(cls, name, bases, class_dict):
return new_class

def __call__(cls, *args, **kwargs):
if len(args) == 1 and isinstance(args[0], str):
instance = super(SignatureMeta, cls).__call__(*args, **kwargs)
return instance
#old
return cls._template(*args, **kwargs)

def __getattr__(cls, attr):
# Redirect attribute access to the template object when accessed on the class directly
return getattr(cls._template, attr)

class Signature(metaclass=SignatureMeta):
def __init__(self, signature: str = "", instructions: str = ""):
self.signature = signature
self.instructions = instructions
self.fields = {}
self.parse_structure()

def __repr__(cls):
s = []
@property
def kwargs(self):
return {k: v for k, v in self.fields.items()}

def parse_structure(self):
inputs_str, outputs_str = self.signature.split("->")
for name in inputs_str.split(","):
self.add_field(name.strip(), InputField())
for name in outputs_str.split(","):
self.add_field(name.strip(), OutputField())

def attach(self, **kwargs):
for key, (prefix, desc) in kwargs.items():
field_type = self.fields.get(key)
if not field_type:
raise ValueError(f"{key} does not exist in this signature")
field_map = {
InputField: InputField(prefix=prefix, desc=desc),
OutputField: OutputField(prefix=prefix, desc=desc)
}
self.fields[key] = field_map.get(type(field_type))
return self

def add_field(self, field_name: str, field_type, position="append"):
if field_name in self.fields:
raise ValueError(f"{field_name} already exists in fields.")
if isinstance(field_type, (InputField, OutputField)):
field_instance = field_type
else:
raise ValueError(f"non-existent {field_type}.")
if isinstance(field_instance, InputField) and position == "append":
input_fields = self.input_fields()
if input_fields:
last_input_key = list(input_fields.keys())[-1]
index = list(self.fields.keys()).index(last_input_key) + 1
self.fields = {**dict(list(self.fields.items())[:index]), field_name: field_instance, **dict(list(self.fields.items())[index:])}
else:
self.fields[field_name] = field_instance
elif isinstance(field_instance, OutputField) and position == "prepend":
output_fields = self.output_fields()
if output_fields:
first_output_key = list(output_fields.keys())[0]
index = list(self.fields.keys()).index(first_output_key)
self.fields = {**dict(list(self.fields.items())[:index]), field_name: field_instance, **dict(list(self.fields.items())[index:])}
else:
self.fields[field_name] = field_instance
elif position == "prepend":
self.fields = {field_name: field_instance, **self.fields}
elif position == "append":
self.fields[field_name] = field_instance
else:
raise ValueError(f"invalid field addition. Please verify that your field name: {field_name}, field_type: {field_type}, and expected position: {position} are correct.")

for name, field in cls.signature.__dict__.items():
s.append(f"- {name} = {field}")

return f'{cls.__name__}\n' + '\n'.join(s)
def input_fields(self):
return {k: v for k, v in self.fields.items() if isinstance(v, InputField)}

def output_fields(self):
return {k: v for k, v in self.fields.items() if isinstance(v, OutputField)}

def __repr__(self):
s = []
for name, _ in self.fields.items():
value = getattr(self, name, None)
if value:
s.append(f"- {name} = {value}")
else:
s.append(f"- {name} = [field not attached]")
return f'{self.__class__.__name__}\n' + '\n'.join(s)

class Signature(metaclass=SignatureMeta):
def __eq__(self, __value: object) -> bool:
return self._template == __value._template

Expand Down