From d3d391ec0a51adda40efb40c6e1d9e718813a242 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Lesp=C3=A9rance?= Date: Wed, 13 Nov 2024 16:45:53 -0500 Subject: [PATCH 1/3] fix: No longer crashing notebooks & spamming console w image urls --- dspy/adapters/image_utils.py | 5 +++++ dspy/primitives/example.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/dspy/adapters/image_utils.py b/dspy/adapters/image_utils.py index 68eeb9b882..ff4c1903fa 100644 --- a/dspy/adapters/image_utils.py +++ b/dspy/adapters/image_utils.py @@ -44,6 +44,11 @@ def from_PIL(cls, pil_image): import PIL return cls(url=encode_image(PIL.Image.open(pil_image))) + def __str__(self): + len_base64 = len(self.url.split("base64,")[1]) + return f"Image(url = {self.url.split('base64,')[0]}base64,)" + + def is_url(string: str) -> bool: """Check if a string is a valid URL.""" try: diff --git a/dspy/primitives/example.py b/dspy/primitives/example.py index 0d809482ab..9935652065 100644 --- a/dspy/primitives/example.py +++ b/dspy/primitives/example.py @@ -46,7 +46,7 @@ def __len__(self): def __repr__(self): # return f"Example({self._store})" + f" (input_keys={self._input_keys}, demos={self._demos})" - d = {k: v for k, v in self._store.items() if not k.startswith("dspy_")} + d = {k: str(v) for k, v in self._store.items() if not k.startswith("dspy_")} return f"Example({d})" + f" (input_keys={self._input_keys})" def __str__(self): From e2cc2ac421a2ea84696d1bd2abcac91cf60f4d5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Lesp=C3=A9rance?= Date: Wed, 13 Nov 2024 17:50:14 -0500 Subject: [PATCH 2/3] test: Added test for dspy.Image repr / str w Example --- tests/primitives/test_example.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/primitives/test_example.py b/tests/primitives/test_example.py index 8900f564a3..62a94c74be 100644 --- a/tests/primitives/test_example.py +++ b/tests/primitives/test_example.py @@ -1,4 +1,5 @@ import pytest +import dspy from dspy import Example @@ -49,6 +50,20 @@ def test_example_len(): assert len(example) == 2 +def test_example_repr_str_img(): + example = Example( + img=dspy.Image(url="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7") + ) + assert ( + repr(example) + == "Example({'img': Image(url = data:image/gif;base64,)}) (input_keys=None)" + ) + assert ( + str(example) + == "Example({'img': Image(url = data:image/gif;base64,)}) (input_keys=None)" + ) + + def test_example_repr_str(): example = Example(a=1) assert repr(example) == "Example({'a': 1}) (input_keys=None)" From b3daeb1d39ec6f478c5167502ffe281fd89a85a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Lesp=C3=A9rance?= Date: Wed, 13 Nov 2024 17:51:03 -0500 Subject: [PATCH 3/3] fix: using repr instead of str to simplify handling images as when printing examples --- dspy/adapters/image_utils.py | 46 ++++++++++++++++++++++-------------- dspy/primitives/example.py | 2 +- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/dspy/adapters/image_utils.py b/dspy/adapters/image_utils.py index ff4c1903fa..5dee95b78e 100644 --- a/dspy/adapters/image_utils.py +++ b/dspy/adapters/image_utils.py @@ -2,19 +2,22 @@ import io import os from typing import Union -import requests from urllib.parse import urlparse + import pydantic +import requests try: from PIL import Image as PILImage + PIL_AVAILABLE = True except ImportError: PIL_AVAILABLE = False + class Image(pydantic.BaseModel): url: str - + @pydantic.model_validator(mode="before") @classmethod def validate_input(cls, values): @@ -34,17 +37,18 @@ def validate_input(cls, values): @classmethod def from_url(cls, url: str, download: bool = False): return cls(url=encode_image(url, download)) - + @classmethod def from_file(cls, file_path: str): return cls(url=encode_image(file_path)) - + @classmethod def from_PIL(cls, pil_image): import PIL + return cls(url=encode_image(PIL.Image.open(pil_image))) - def __str__(self): + def __repr__(self): len_base64 = len(self.url.split("base64,")[1]) return f"Image(url = {self.url.split('base64,')[0]}base64,)" @@ -53,11 +57,12 @@ def is_url(string: str) -> bool: """Check if a string is a valid URL.""" try: result = urlparse(string) - return all([result.scheme in ('http', 'https'), result.netloc]) + return all([result.scheme in ("http", "https"), result.netloc]) except ValueError: return False -def encode_image(image: Union[str, bytes, 'PILImage.Image', dict], download_images: bool = False) -> str: + +def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_images: bool = False) -> str: """ Encode an image to a base64 data URI. @@ -105,39 +110,44 @@ def encode_image(image: Union[str, bytes, 'PILImage.Image', dict], download_imag else: raise ValueError(f"Unsupported image type: {type(image)}") + def _encode_image_from_file(file_path: str) -> str: """Encode an image from a file path to a base64 data URI.""" with open(file_path, "rb") as image_file: image_data = image_file.read() file_extension = _get_file_extension(file_path) - encoded_image = base64.b64encode(image_data).decode('utf-8') + encoded_image = base64.b64encode(image_data).decode("utf-8") return f"data:image/{file_extension};base64,{encoded_image}" + def _encode_image_from_url(image_url: str) -> str: """Encode an image from a URL to a base64 data URI.""" response = requests.get(image_url) response.raise_for_status() - content_type = response.headers.get('Content-Type', '') - if content_type.startswith('image/'): - file_extension = content_type.split('/')[-1] + content_type = response.headers.get("Content-Type", "") + if content_type.startswith("image/"): + file_extension = content_type.split("/")[-1] else: # Fallback to file extension from URL or default to 'png' - file_extension = _get_file_extension(image_url) or 'png' - encoded_image = base64.b64encode(response.content).decode('utf-8') + file_extension = _get_file_extension(image_url) or "png" + encoded_image = base64.b64encode(response.content).decode("utf-8") return f"data:image/{file_extension};base64,{encoded_image}" -def _encode_pil_image(image: 'Image.Image') -> str: + +def _encode_pil_image(image: "Image.Image") -> str: """Encode a PIL Image object to a base64 data URI.""" buffered = io.BytesIO() - file_extension = (image.format or 'PNG').lower() + file_extension = (image.format or "PNG").lower() image.save(buffered, format=file_extension) - encoded_image = base64.b64encode(buffered.getvalue()).decode('utf-8') + encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8") return f"data:image/{file_extension};base64,{encoded_image}" + def _get_file_extension(path_or_url: str) -> str: """Extract the file extension from a file path or URL.""" - extension = os.path.splitext(urlparse(path_or_url).path)[1].lstrip('.').lower() - return extension or 'png' # Default to 'png' if no extension found + extension = os.path.splitext(urlparse(path_or_url).path)[1].lstrip(".").lower() + return extension or "png" # Default to 'png' if no extension found + def is_image(obj) -> bool: """Check if the object is an image or a valid image reference.""" diff --git a/dspy/primitives/example.py b/dspy/primitives/example.py index 9935652065..0d809482ab 100644 --- a/dspy/primitives/example.py +++ b/dspy/primitives/example.py @@ -46,7 +46,7 @@ def __len__(self): def __repr__(self): # return f"Example({self._store})" + f" (input_keys={self._input_keys}, demos={self._demos})" - d = {k: str(v) for k, v in self._store.items() if not k.startswith("dspy_")} + d = {k: v for k, v in self._store.items() if not k.startswith("dspy_")} return f"Example({d})" + f" (input_keys={self._input_keys})" def __str__(self):