Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dsp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .modules import *
from .primitives import *
from .templates import *
from .utils import settings
from .modules import * # noqa
from .primitives import * # noqa
from .adapters import * # noqa
from .utils import settings # noqa

"""
TODO:
Expand Down
4 changes: 4 additions & 0 deletions dsp/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base_template import * # noqa
from .template import * # noqa
from .experimental_adapter import * # noqa
from .utils import * # noqa
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from collections import namedtuple
from typing import Callable

from dsp.templates import Field, TemplateV2, format_answers, passages2text
from .utils import format_answers, passages2text

Field = namedtuple("Field", "name separator input_variable output_variable description")


class Type:
Expand All @@ -19,7 +22,7 @@ def __eq__(self, __value: object) -> bool:
return isinstance(__value, Type) and self.__dict__ == __value.__dict__


class Template(TemplateV2):
class BaseTemplate:
"""A template datatype that represents the structure of communicate with the LM."""

def __init__(self, instructions: str, **kwargs):
Expand Down Expand Up @@ -61,9 +64,7 @@ def __eq__(self, other):
v1, v2 = self.kwargs[k], other.kwargs[k]
if not v1 == v2:
print(k, v1, v2)


# print("here?", self.instructions == other.instructions, self.kwargs == other.kwargs)
return self.instructions == other.instructions and self.kwargs == other.kwargs

def __str__(self) -> str:
Expand Down
202 changes: 202 additions & 0 deletions dsp/adapters/experimental_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from typing import Any, Union

import dsp
from dsp.primitives.demonstrate import Example

from .base_template import BaseTemplate


class ExperimentalAdapter(BaseTemplate):
def query(self, example: Example, is_demo: bool = False) -> str:
"""Retrieves the input variables from the example and formats them into a query string."""
result: list[str] = []

# If not a demo, find the last field that doesn't have a value set in `example` and set it to ""
# This creates the "Output:" prefix at the end of the prompt.
if not is_demo:
has_value = [
field.input_variable in example
and example[field.input_variable] is not None
and example[field.input_variable] != ""
for field in self.fields
]

if not any(has_value):
assert False, "No input variables found in the example"

for i in range(1, len(has_value)):
if has_value[i - 1] and not any(has_value[i:]):
example[self.fields[i].input_variable] = ""
break

for field in self.fields:
if field.input_variable in example and example[field.input_variable] is not None:
if field.input_variable in self.format_handlers:
format_handler = self.format_handlers[field.input_variable]
else:
def format_handler(x):
return str(x).strip()

formatted_value = format_handler(example[field.input_variable])
separator = "\n" if field.separator == " " and "\n" in formatted_value else field.separator

result.append(f"{field.name}{separator}{formatted_value}",)

return "\n\n".join([r for r in result if r])

def guidelines(self, show_guidelines=True) -> str:
"""Returns the task guidelines as described in the lm prompt"""
if (not show_guidelines) or (hasattr(dsp.settings, "show_guidelines") and not dsp.settings.show_guidelines):
return ""

result = "Follow the following format.\n\n"

example = dsp.Example()
for field in self.fields:
example[field.input_variable] = field.description
example.augmented = True

result += self.query(example)
return result

def extract(
self,
example: Union[Example, dict[str, Any]],
raw_pred: str,
) -> Example:
"""Extracts the answer from the LM raw prediction using the template structure

Args:
example (Union[Example, dict[str, Any]]): Contains the input variables that raw_pred was completed on.
raw_pred (str): LM generated string

Returns:
Example: The example with the output variables filled in
"""
example = dsp.Example(example)

raw_pred = raw_pred.strip()
parts = raw_pred.split('\n')
adjusted_parts = []
for part in parts:
trimmed_part = part.strip()
if trimmed_part:
if adjusted_parts:
adjusted_parts.append('\n' + trimmed_part)
else:
adjusted_parts.append(trimmed_part)
raw_pred = '\n'.join(adjusted_parts)

idx = 0
while idx < len(self.fields):
if self.fields[idx].input_variable not in example or example[self.fields[idx].input_variable] is None:
break
idx += 1

import dspy

idx = min(idx, len(self.fields) - 1)
while raw_pred != "" and idx < len(self.fields):
if idx < len(self.fields) - 1:
next_field_name = "\n" + self.fields[idx + 1].name
offset = raw_pred.find(next_field_name)

if offset >= 0:
if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip("---").strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip("---").strip()
else:
field_name_parts = self.fields[idx].name.split()
start_pos = 0
for part in field_name_parts:
pos = raw_pred.find(part.strip())
if pos != -1:
start_pos = pos + len(part)
else:
break

example[self.fields[idx].output_variable] = raw_pred[start_pos:offset].strip().rstrip("---").strip()
raw_pred = raw_pred[offset + len(next_field_name) :].strip()
idx += 1
else:
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()

raw_pred = ""
idx += 1
break

else:
assert idx == len(self.fields) - 1, (idx, len(self.fields))

if dspy.settings.release >= 20231003:
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
else:
field_name_parts = self.fields[idx].name.split()
start_pos = 0
for part in field_name_parts:
pos = raw_pred.find(part.strip())
if pos != -1:
start_pos = pos + len(part)
else:
break
example[self.fields[idx].output_variable] = raw_pred[start_pos:].strip()

break

return example

def __call__(self, example, show_guidelines=True) -> str:
example = dsp.Example(example)
output_fields = []
for i in range(len(self.fields)):
if self.fields[i].input_variable not in example:
output_field = self.fields[i].input_variable
if output_field not in output_fields:
output_fields.append(self.fields[i].name.split(':')[0])

if hasattr(dsp.settings, "query_only") and dsp.settings.query_only:
return self.query(example)

# The training data should not contain the output variable
assert self.fields[-1].input_variable not in example, f"Output variable {self.fields[-1].input_variable} should not be supplied for querying the LM."
# del example[self.fields[-1].input_variable]

rdemos = [
self.query(demo, is_demo=True)
for demo in example.demos
if (
(not demo.get("augmented", False))
and ( # validate that the training example has the same primitive input var as the template
self.fields[-1].input_variable in demo and demo[self.fields[-1].input_variable] is not None
)
)
]

ademos = [self.query(demo, is_demo=True) for demo in example.demos if demo.get("augmented", False)]

# Move the rdemos to ademos if rdemo has all the fields filled in
rdemos_ = []
new_ademos = []
for rdemo in rdemos:
if all((field.name in rdemo) for field in self.fields if field.input_variable in example):
new_ademos.append(rdemo)
else:
rdemos_.append(rdemo)

ademos = new_ademos + ademos
rdemos = rdemos_

example["augmented"] = True

query = self.query(example)
parts = [self.instructions, *rdemos, self.guidelines(show_guidelines), *ademos, query,]

prompt = "\n\n---\n\n".join([p.strip() for p in parts if p])
prompt_ = prompt[: prompt.rfind("\n")].strip()

s_or_not = "s" if len(output_fields) > 1 else ""
only_or_not = "only " if len(output_fields) == 1 else ""

prompt_ += f"\n\nPlease provide the output field{s_or_not} {', '.join(output_fields[:-1]) + (', then ' if len(output_fields) > 2 else ' then ') + output_fields[-1] if len(output_fields) > 1 else output_fields[0]}. Do so immediately, without additional content before or after, and precisely as the format above shows. Begin with {only_or_not}the field {output_fields[0]}."
return prompt_.strip()

65 changes: 2 additions & 63 deletions dsp/templates/template_v2.py → dsp/adapters/template.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,12 @@
import re
from collections import namedtuple
from typing import Any, Union

import dsp
from dsp.primitives.demonstrate import Example

from .utils import format_answers, passages2text
from .base_template import BaseTemplate

Field = namedtuple("Field", "name separator input_variable output_variable description")

# TODO: de-duplicate with dsp/templates/template.py


class TemplateV2:
def __init__(
self,
template,
format_handlers={
"passages": passages2text,
"context": passages2text,
"answer": format_answers,
"answers": format_answers,
},
):
self.format_handlers = format_handlers

template = template.strip()

self.instructions = re.search("(.*)\n", template).group(1)
template = template[len(self.instructions) :].strip()

self.fields = []
while len(template) > 0:
match = re.search("(.*)(\s){(.*)}\s(.*\${.*})", template)
if match is not None:
name = match.group(1)
separator = match.group(2)
variable = match.group(3)
description = match.group(4)
else:
match = re.search("(.*)(\s){(.*)}", template)
if match is not None:
name = match.group(1)
separator = match.group(2)
variable = match.group(3)
description = None
else:
raise ValueError("Could not parse template")

var_match = re.match("(.*) -> (.*)", variable)
if var_match is not None:
input_variable = var_match.group(1)
output_variable = var_match.group(2)
else:
input_variable = variable
output_variable = variable

self.fields.append(
Field(
name=name,
separator=separator,
input_variable=input_variable,
output_variable=output_variable,
description=description,
),
)

template = template[len(match.group(0)) :].strip()

class Template(BaseTemplate):
def query(self, example: Example, is_demo: bool = False) -> str:
"""Retrieves the input variables from the example and formats them into a query string."""
result: list[str] = []
Expand Down
38 changes: 19 additions & 19 deletions dsp/templates/utils.py → dsp/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,29 @@ def passages2text(passages: Union[str, list, tuple]) -> str:
return "\n".join([f"[{idx+1}] «{txt}»" for idx, txt in enumerate(passages)])


def passages2textV2(passages: Union[str, list, tuple]) -> str:
"""Formats the given one or more passages into a single structured string."""
if isinstance(passages, str):
return passages

assert type(passages) in [list, tuple]

def psg2text(psg):
try:
title, snippet = psg.split("|", 1)
return f"Title: {title.strip()} | Snippet: «{snippet.strip()}»"
except Exception:
pass
# def passages2textV2(passages: Union[str, list, tuple]) -> str:
# """Formats the given one or more passages into a single structured string."""
# if isinstance(passages, str):
# return passages

# assert type(passages) in [list, tuple]

# def psg2text(psg):
# try:
# title, snippet = psg.split("|", 1)
# return f"Title: {title.strip()} | Snippet: «{snippet.strip()}»"
# except Exception:
# pass

return f"«{psg}»"
# return f"«{psg}»"

if len(passages) == 0:
return "N/A"
# if len(passages) == 0:
# return "N/A"

if len(passages) == 1:
return psg2text(passages[0])
# if len(passages) == 1:
# return psg2text(passages[0])

return "\n".join([f"[{idx+1}] {psg2text(txt)}" for idx, txt in enumerate(passages)])
# return "\n".join([f"[{idx+1}] {psg2text(txt)}" for idx, txt in enumerate(passages)])


def format_answers(answers: Union[str, list]) -> Optional[str]:
Expand Down
Loading