From af77f84e1c1003553d10b1325aac2669b5905765 Mon Sep 17 00:00:00 2001 From: Isaac Miller Date: Sat, 2 Nov 2024 22:11:44 -0700 Subject: [PATCH 1/6] Fix broken inspect_history and broken prompt cache --- dspy/__init__.py | 10 +-- dspy/adapters/chat_adapter.py | 7 +- dspy/clients/base_lm.py | 18 +++-- dspy/clients/lm.py | 4 +- dspy/utils/dummies.py | 1 + dspy/utils/inspect_global_history.py | 4 + tests/utils/test_inspect_global_history.py | 86 ++++++++++++++++++++++ 7 files changed, 112 insertions(+), 18 deletions(-) create mode 100644 dspy/utils/inspect_global_history.py create mode 100644 tests/utils/test_inspect_global_history.py diff --git a/dspy/__init__.py b/dspy/__init__.py index 84d3600655..77a92b77ef 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -12,7 +12,7 @@ from dspy.clients import * # isort: skip from dspy.adapters import * # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging - +from dspy.utils.inspect_global_history import inspect_history settings = dsp.settings configure_dspy_loggers(__name__) @@ -70,10 +70,4 @@ BootstrapRS = dspy.teleprompt.BootstrapFewShotWithRandomSearch COPRO = dspy.teleprompt.COPRO MIPROv2 = dspy.teleprompt.MIPROv2 -Ensemble = dspy.teleprompt.Ensemble - - -# TODO: Consider if this should access settings.lm *or* a list that's shared across all LMs in the program. -def inspect_history(*args, **kwargs): - from dspy.clients.lm import GLOBAL_HISTORY, _inspect_history - return _inspect_history(GLOBAL_HISTORY, *args, **kwargs) \ No newline at end of file +Ensemble = dspy.teleprompt.Ensemble \ No newline at end of file diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 27ef87ecd7..432a850616 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -201,7 +201,10 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text= for field, field_value in fields_with_values.items(): formatted_field_value = _format_field_value(field_info=field.info, value=field_value, assume_text=assume_text) if assume_text: - output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}") + output.append(f"[[ ## {field.name} ## ]]") + # This conditional is specifically for the completed field, which is always the last field + if field_value: + output[-1] += f"\n{formatted_field_value}" else: output.append({"type": "text", "text": f"[[ ## {field.name} ## ]]\n"}) if isinstance(formatted_field_value, dict) and formatted_field_value.get("type") == "image_url": @@ -396,7 +399,7 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]): parts.append(format_signature_fields_for_instructions(signature.input_fields)) parts.append(format_signature_fields_for_instructions(signature.output_fields)) parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}, assume_text=True)) - + print(format_fields({BuiltInCompletedOutputFieldInfo: ""}, assume_text=True)) instructions = textwrap.dedent(signature.instructions) objective = ("\n" + " " * 8).join([""] + instructions.splitlines()) parts.append(f"In adhering to this structure, your objective is: {objective}") diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index b6f13d0ca5..5bbaf3232a 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +GLOBAL_HISTORY = [] class BaseLM(ABC): def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, cache=True, **kwargs): @@ -13,8 +14,11 @@ def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, c def __call__(self, prompt=None, messages=None, **kwargs): pass - def inspect_history(self, n: int = 1): - _inspect_history(self, n) + def inspect_history(self, n: int = 1, skip: int = 0): + _inspect_history(self.history, n, skip) + + def update_global_history(self, entry): + GLOBAL_HISTORY.append(entry) def _green(text: str, end: str = "\n"): @@ -25,10 +29,14 @@ def _red(text: str, end: str = "\n"): return "\x1b[31m" + str(text) + "\x1b[0m" + end -def _inspect_history(lm, n: int = 1): +def _inspect_history(history, n: int = 1, skip: int = 0): """Prints the last n prompts and their completions.""" - - for item in reversed(lm.history[-n:]): + if skip < 0: + raise ValueError("skip must be non-negative integers") + elif n <= 0: + raise ValueError("n must be a positive integer") + history_slice = history[-n-skip:-skip] if skip > 0 else history[-n:] + for item in reversed(history_slice): messages = item["messages"] or [{"role": "user", "content": item["prompt"]}] outputs = item["outputs"] diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 22e37019f4..c8eba2d377 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -23,8 +23,6 @@ if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" -GLOBAL_HISTORY = [] - logger = logging.getLogger(__name__) class LM(BaseLM): @@ -109,7 +107,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): model_type=self.model_type, ) self.history.append(entry) - GLOBAL_HISTORY.append(entry) + self.update_global_history(entry) return outputs diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index 90f5def66a..2a581932fc 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -205,6 +205,7 @@ def format_answer_fields(field_names_and_values: Dict[str, Any]): entry = dict(**entry, outputs=outputs, usage=0) entry = dict(**entry, cost=0) self.history.append(entry) + self.update_global_history(entry) return outputs diff --git a/dspy/utils/inspect_global_history.py b/dspy/utils/inspect_global_history.py new file mode 100644 index 0000000000..52055be2e7 --- /dev/null +++ b/dspy/utils/inspect_global_history.py @@ -0,0 +1,4 @@ +# TODO: Consider if this should access settings.lm *or* a list that's shared across all LMs in the program. +def inspect_history(*args, **kwargs): + from dspy.clients.base_lm import GLOBAL_HISTORY, _inspect_history + return _inspect_history(GLOBAL_HISTORY, *args, **kwargs) \ No newline at end of file diff --git a/tests/utils/test_inspect_global_history.py b/tests/utils/test_inspect_global_history.py new file mode 100644 index 0000000000..cede61c60a --- /dev/null +++ b/tests/utils/test_inspect_global_history.py @@ -0,0 +1,86 @@ +import pytest +from dspy.utils.inspect_global_history import inspect_history +from dspy.utils.dummies import DummyLM +from dspy.clients.base_lm import GLOBAL_HISTORY +import dspy + +@pytest.fixture(autouse=True) +def clear_history(): + GLOBAL_HISTORY.clear() + yield + +def test_inspect_history_basic(capsys): + # Configure a DummyLM with some predefined responses + lm = DummyLM([{"response": "Hello"}, {"response": "How are you?"}]) + dspy.settings.configure(lm=lm) + + # Make some calls to generate history + predictor = dspy.Predict("query: str -> response: str") + predictor(query="Hi") + predictor(query="What's up?") + + # Test inspecting all history + history = GLOBAL_HISTORY + print(capsys) + assert len(history) > 0 + assert isinstance(history, list) + assert all(isinstance(entry, dict) for entry in history) + assert all("messages" in entry for entry in history) + +def test_inspect_history_with_n(capsys): + lm = DummyLM([{"response": "One"}, {"response": "Two"}, {"response": "Three"}]) + dspy.settings.configure(lm=lm) + + # Generate some history + predictor = dspy.Predict("query: str -> response: str") + predictor(query="First") + predictor(query="Second") + predictor(query="Third") + + inspect_history(n=2) + # Test getting last 2 entries + out, err = capsys.readouterr() + assert not "First" in out + assert "Second" in out + assert "Third" in out + +def test_inspect_empty_history(capsys): + # Configure fresh DummyLM + lm = DummyLM([]) + dspy.settings.configure(lm=lm) + + # Test inspecting empty history + inspect_history() + history = GLOBAL_HISTORY + assert len(history) == 0 + assert isinstance(history, list) + +def test_inspect_history_with_invalid_n(capsys): + lm = DummyLM([{"response": "Test"}]) + dspy.settings.configure(lm=lm) + + predictor = dspy.Predict("query: str -> response: str") + predictor(query="Test query") + + # Test with negative n + with pytest.raises(ValueError): + inspect_history(n=-1) + + # Test with n=0 + with pytest.raises(ValueError): + inspect_history(n=0) + out, err = capsys.readouterr() + assert out.strip() == "" + +def test_inspect_history_n_larger_than_history(capsys): + lm = DummyLM([{"response": "First"}, {"response": "Second"}]) + dspy.settings.configure(lm=lm) + + predictor = dspy.Predict("query: str -> response: str") + predictor(query="Query 1") + predictor(query="Query 2") + + # Request more entries than exist + inspect_history(n=5) + history = GLOBAL_HISTORY + assert len(history) == 2 # Should return all available entries From a2846fc8c9fcf755270ef8a463a48f44ca0b896d Mon Sep 17 00:00:00 2001 From: Isaac Miller Date: Sat, 2 Nov 2024 22:16:33 -0700 Subject: [PATCH 2/6] Remove errant print statement --- dspy/adapters/chat_adapter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 432a850616..8d9d6d6f52 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -399,7 +399,6 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]): parts.append(format_signature_fields_for_instructions(signature.input_fields)) parts.append(format_signature_fields_for_instructions(signature.output_fields)) parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}, assume_text=True)) - print(format_fields({BuiltInCompletedOutputFieldInfo: ""}, assume_text=True)) instructions = textwrap.dedent(signature.instructions) objective = ("\n" + " " * 8).join([""] + instructions.splitlines()) parts.append(f"In adhering to this structure, your objective is: {objective}") From 1e16d5bdc1edacfbea1052cd5f28195db8ddeced Mon Sep 17 00:00:00 2001 From: Isaac Miller Date: Sun, 3 Nov 2024 08:57:57 -0800 Subject: [PATCH 3/6] Move global inspect history into base_lm --- dspy/__init__.py | 1 - dspy/clients/__init__.py | 2 +- dspy/clients/base_lm.py | 26 ++++++++++++------- dspy/utils/inspect_global_history.py | 4 --- examples/temp.py | 9 +++++++ .../test_inspect_global_history.py | 24 +++-------------- 6 files changed, 30 insertions(+), 36 deletions(-) delete mode 100644 dspy/utils/inspect_global_history.py create mode 100644 examples/temp.py rename tests/{utils => clients}/test_inspect_global_history.py (76%) diff --git a/dspy/__init__.py b/dspy/__init__.py index 77a92b77ef..f80c8237fd 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -12,7 +12,6 @@ from dspy.clients import * # isort: skip from dspy.adapters import * # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging -from dspy.utils.inspect_global_history import inspect_history settings = dsp.settings configure_dspy_loggers(__name__) diff --git a/dspy/clients/__init__.py b/dspy/clients/__init__.py index 6a63509f5b..0056db0464 100644 --- a/dspy/clients/__init__.py +++ b/dspy/clients/__init__.py @@ -1,2 +1,2 @@ from .lm import LM -from .base_lm import BaseLM \ No newline at end of file +from .base_lm import BaseLM, inspect_history diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 5bbaf3232a..981f2338ed 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -28,19 +28,21 @@ def _green(text: str, end: str = "\n"): def _red(text: str, end: str = "\n"): return "\x1b[31m" + str(text) + "\x1b[0m" + end +def _blue(text: str, end: str = "\n"): + return "\x1b[34m" + str(text) + "\x1b[0m" + end -def _inspect_history(history, n: int = 1, skip: int = 0): + +def _inspect_history(history, n: int = 1): """Prints the last n prompts and their completions.""" - if skip < 0: - raise ValueError("skip must be non-negative integers") - elif n <= 0: - raise ValueError("n must be a positive integer") - history_slice = history[-n-skip:-skip] if skip > 0 else history[-n:] - for item in reversed(history_slice): + + for item in history[-n:]: messages = item["messages"] or [{"role": "user", "content": item["prompt"]}] outputs = item["outputs"] + timestamp = item.get("timestamp", "Unknown time") print("\n\n\n") + print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n") + for msg in messages: print(_red(f"{msg['role'].capitalize()} message:")) if isinstance(msg["content"], str): @@ -51,11 +53,13 @@ def _inspect_history(history, n: int = 1, skip: int = 0): if c["type"] == "text": print(c["text"].strip()) elif c["type"] == "image_url": + image_str = "" if "base64" in c["image_url"].get("url", ""): len_base64 = len(c["image_url"]["url"].split("base64,")[1]) - print(f"<{c['image_url']['url'].split('base64,')[0]}base64,") + image_str = f"<{c['image_url']['url'].split('base64,')[0]}base64," else: - print(f"") + image_str = f"" + print(_blue(image_str.strip())) print("\n") print(_red("Response:")) @@ -66,3 +70,7 @@ def _inspect_history(history, n: int = 1, skip: int = 0): print(_red(choices_text, end="")) print("\n\n\n") + +def inspect_history(*args, **kwargs): + """The global history shared across all LMs.""" + return _inspect_history(GLOBAL_HISTORY, *args, **kwargs) \ No newline at end of file diff --git a/dspy/utils/inspect_global_history.py b/dspy/utils/inspect_global_history.py deleted file mode 100644 index 52055be2e7..0000000000 --- a/dspy/utils/inspect_global_history.py +++ /dev/null @@ -1,4 +0,0 @@ -# TODO: Consider if this should access settings.lm *or* a list that's shared across all LMs in the program. -def inspect_history(*args, **kwargs): - from dspy.clients.base_lm import GLOBAL_HISTORY, _inspect_history - return _inspect_history(GLOBAL_HISTORY, *args, **kwargs) \ No newline at end of file diff --git a/examples/temp.py b/examples/temp.py new file mode 100644 index 0000000000..152344bac6 --- /dev/null +++ b/examples/temp.py @@ -0,0 +1,9 @@ +import dspy + +lm = dspy.LM("gpt-4o-mini") +dspy.settings.configure(lm=lm) + +predictor = dspy.Predict("query: str, image: Image -> response: str") +predictor(query="What is this dog?", image=dspy.Image.from_url("https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg")) + +dspy.inspect_history() \ No newline at end of file diff --git a/tests/utils/test_inspect_global_history.py b/tests/clients/test_inspect_global_history.py similarity index 76% rename from tests/utils/test_inspect_global_history.py rename to tests/clients/test_inspect_global_history.py index cede61c60a..f3bcf210d2 100644 --- a/tests/utils/test_inspect_global_history.py +++ b/tests/clients/test_inspect_global_history.py @@ -1,5 +1,4 @@ import pytest -from dspy.utils.inspect_global_history import inspect_history from dspy.utils.dummies import DummyLM from dspy.clients.base_lm import GLOBAL_HISTORY import dspy @@ -37,7 +36,7 @@ def test_inspect_history_with_n(capsys): predictor(query="Second") predictor(query="Third") - inspect_history(n=2) + dspy.inspect_history(n=2) # Test getting last 2 entries out, err = capsys.readouterr() assert not "First" in out @@ -50,28 +49,11 @@ def test_inspect_empty_history(capsys): dspy.settings.configure(lm=lm) # Test inspecting empty history - inspect_history() + dspy.inspect_history() history = GLOBAL_HISTORY assert len(history) == 0 assert isinstance(history, list) -def test_inspect_history_with_invalid_n(capsys): - lm = DummyLM([{"response": "Test"}]) - dspy.settings.configure(lm=lm) - - predictor = dspy.Predict("query: str -> response: str") - predictor(query="Test query") - - # Test with negative n - with pytest.raises(ValueError): - inspect_history(n=-1) - - # Test with n=0 - with pytest.raises(ValueError): - inspect_history(n=0) - out, err = capsys.readouterr() - assert out.strip() == "" - def test_inspect_history_n_larger_than_history(capsys): lm = DummyLM([{"response": "First"}, {"response": "Second"}]) dspy.settings.configure(lm=lm) @@ -81,6 +63,6 @@ def test_inspect_history_n_larger_than_history(capsys): predictor(query="Query 2") # Request more entries than exist - inspect_history(n=5) + dspy.inspect_history(n=5) history = GLOBAL_HISTORY assert len(history) == 2 # Should return all available entries From 7edcf52297bea00cc71596e15a3f303fa3ee4c82 Mon Sep 17 00:00:00 2001 From: Isaac Miller Date: Sun, 3 Nov 2024 09:12:36 -0800 Subject: [PATCH 4/6] Remove skip parameter --- dspy/clients/base_lm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 981f2338ed..d71d384b6b 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -14,8 +14,8 @@ def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, c def __call__(self, prompt=None, messages=None, **kwargs): pass - def inspect_history(self, n: int = 1, skip: int = 0): - _inspect_history(self.history, n, skip) + def inspect_history(self, n: int = 1): + _inspect_history(self.history, n) def update_global_history(self, entry): GLOBAL_HISTORY.append(entry) @@ -71,6 +71,6 @@ def _inspect_history(history, n: int = 1): print("\n\n\n") -def inspect_history(*args, **kwargs): +def inspect_history(n: int = 1): """The global history shared across all LMs.""" - return _inspect_history(GLOBAL_HISTORY, *args, **kwargs) \ No newline at end of file + return _inspect_history(GLOBAL_HISTORY, n) \ No newline at end of file From 88ed4e700f160aeb924fdb86a41cdae5ce9e9a29 Mon Sep 17 00:00:00 2001 From: Isaac Miller <17116851+isaacbmiller@users.noreply.github.com> Date: Sun, 3 Nov 2024 11:21:43 -0600 Subject: [PATCH 5/6] Delete examples/temp.py --- examples/temp.py | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 examples/temp.py diff --git a/examples/temp.py b/examples/temp.py deleted file mode 100644 index 152344bac6..0000000000 --- a/examples/temp.py +++ /dev/null @@ -1,9 +0,0 @@ -import dspy - -lm = dspy.LM("gpt-4o-mini") -dspy.settings.configure(lm=lm) - -predictor = dspy.Predict("query: str, image: Image -> response: str") -predictor(query="What is this dog?", image=dspy.Image.from_url("https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg")) - -dspy.inspect_history() \ No newline at end of file From b6d9c97e4a57c8d70c72c9c731702bba7bcb553a Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Sun, 3 Nov 2024 13:44:34 -0800 Subject: [PATCH 6/6] Minor adjustment to make adapters go back to original behavior --- dspy/adapters/chat_adapter.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 8d9d6d6f52..6727cd2458 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -201,10 +201,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text= for field, field_value in fields_with_values.items(): formatted_field_value = _format_field_value(field_info=field.info, value=field_value, assume_text=assume_text) if assume_text: - output.append(f"[[ ## {field.name} ## ]]") - # This conditional is specifically for the completed field, which is always the last field - if field_value: - output[-1] += f"\n{formatted_field_value}" + output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}") else: output.append({"type": "text", "text": f"[[ ## {field.name} ## ]]\n"}) if isinstance(formatted_field_value, dict) and formatted_field_value.get("type") == "image_url": @@ -212,7 +209,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text= else: output[-1]["text"] += formatted_field_value["text"] if assume_text: - return "\n\n".join(output) + return "\n\n".join(output).strip() else: return output