Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions dspy/adapters/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
import io
import os
from typing import Union
import requests
from urllib.parse import urlparse

import pydantic
import requests

try:
from PIL import Image as PILImage

PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False


class Image(pydantic.BaseModel):
url: str

@pydantic.model_validator(mode="before")
@classmethod
def validate_input(cls, values):
Expand All @@ -34,25 +37,32 @@ def validate_input(cls, values):
@classmethod
def from_url(cls, url: str, download: bool = False):
return cls(url=encode_image(url, download))

@classmethod
def from_file(cls, file_path: str):
return cls(url=encode_image(file_path))

@classmethod
def from_PIL(cls, pil_image):
import PIL

return cls(url=encode_image(PIL.Image.open(pil_image)))

def __repr__(self):
len_base64 = len(self.url.split("base64,")[1])
return f"Image(url = {self.url.split('base64,')[0]}base64,<IMAGE_BASE_64_ENCODED({str(len_base64)})>)"


def is_url(string: str) -> bool:
"""Check if a string is a valid URL."""
try:
result = urlparse(string)
return all([result.scheme in ('http', 'https'), result.netloc])
return all([result.scheme in ("http", "https"), result.netloc])
except ValueError:
return False

def encode_image(image: Union[str, bytes, 'PILImage.Image', dict], download_images: bool = False) -> str:

def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_images: bool = False) -> str:
"""
Encode an image to a base64 data URI.

Expand Down Expand Up @@ -100,39 +110,44 @@ def encode_image(image: Union[str, bytes, 'PILImage.Image', dict], download_imag
else:
raise ValueError(f"Unsupported image type: {type(image)}")


def _encode_image_from_file(file_path: str) -> str:
"""Encode an image from a file path to a base64 data URI."""
with open(file_path, "rb") as image_file:
image_data = image_file.read()
file_extension = _get_file_extension(file_path)
encoded_image = base64.b64encode(image_data).decode('utf-8')
encoded_image = base64.b64encode(image_data).decode("utf-8")
return f"data:image/{file_extension};base64,{encoded_image}"


def _encode_image_from_url(image_url: str) -> str:
"""Encode an image from a URL to a base64 data URI."""
response = requests.get(image_url)
response.raise_for_status()
content_type = response.headers.get('Content-Type', '')
if content_type.startswith('image/'):
file_extension = content_type.split('/')[-1]
content_type = response.headers.get("Content-Type", "")
if content_type.startswith("image/"):
file_extension = content_type.split("/")[-1]
else:
# Fallback to file extension from URL or default to 'png'
file_extension = _get_file_extension(image_url) or 'png'
encoded_image = base64.b64encode(response.content).decode('utf-8')
file_extension = _get_file_extension(image_url) or "png"
encoded_image = base64.b64encode(response.content).decode("utf-8")
return f"data:image/{file_extension};base64,{encoded_image}"

def _encode_pil_image(image: 'Image.Image') -> str:

def _encode_pil_image(image: "Image.Image") -> str:
"""Encode a PIL Image object to a base64 data URI."""
buffered = io.BytesIO()
file_extension = (image.format or 'PNG').lower()
file_extension = (image.format or "PNG").lower()
image.save(buffered, format=file_extension)
encoded_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/{file_extension};base64,{encoded_image}"


def _get_file_extension(path_or_url: str) -> str:
"""Extract the file extension from a file path or URL."""
extension = os.path.splitext(urlparse(path_or_url).path)[1].lstrip('.').lower()
return extension or 'png' # Default to 'png' if no extension found
extension = os.path.splitext(urlparse(path_or_url).path)[1].lstrip(".").lower()
return extension or "png" # Default to 'png' if no extension found


def is_image(obj) -> bool:
"""Check if the object is an image or a valid image reference."""
Expand Down
15 changes: 15 additions & 0 deletions tests/primitives/test_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import dspy
from dspy import Example


Expand Down Expand Up @@ -49,6 +50,20 @@ def test_example_len():
assert len(example) == 2


def test_example_repr_str_img():
example = Example(
img=dspy.Image(url="data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7")
)
assert (
repr(example)
== "Example({'img': Image(url = data:image/gif;base64,<IMAGE_BASE_64_ENCODED(56)>)}) (input_keys=None)"
)
assert (
str(example)
== "Example({'img': Image(url = data:image/gif;base64,<IMAGE_BASE_64_ENCODED(56)>)}) (input_keys=None)"
)


def test_example_repr_str():
example = Example(a=1)
assert repr(example) == "Example({'a': 1}) (input_keys=None)"
Expand Down
Loading