Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ rather than message. Each subclass also derives from `ValueError`, so existing

### Fixed

- **Degenerate inputs no longer crash extraction**: a 1x1 image, a `palette_size`
larger than the number of distinct colors, or any case with fewer pixels than
the requested palette size previously raised raw sklearn/numpy `ValueError`s.
- **Reshape bug with alpha masking**: Extractors reshaped the pixel array to
`(height * width, n_channels)`, but after alpha masking the valid-pixel count
can be smaller than `height * width`. This caused the reshape to either raise
Expand Down
23 changes: 23 additions & 0 deletions pylette/src/color_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,29 @@ def extract_colors(
Returns:
Palette: A palette of the extracted colors.

Guarantees:
The returned palette satisfies these invariants for every extraction
method (pinned by the property suite in ``tests/integration/test_invariants.py``):

- ``len(palette) <= palette_size``. Fewer colors are returned when the
image has fewer distinct colors than requested.
- The color frequencies sum to ``1.0``.
- Every channel of every ``Color.rgb`` is an ``int`` in ``[0, 255]``.
- Extraction is deterministic: the same image and arguments always
produce the same palette.
- Colors are ordered by ``sort_mode`` — ascending ``luminance`` or, by
default, descending ``frequency`` — and that ordering is stable.
- Degenerate inputs (a solid color, a 1x1 image, ``palette_size``
greater than the number of distinct colors, a partial alpha mask) are
handled without error. The one expected failure is an image with no
pixels left to sample (e.g. a fully alpha-masked image), which raises
:class:`~pylette.NoValidPixelsError`.

Raises:
InvalidImageError: If the image cannot be loaded or its type is unsupported.
NoValidPixelsError: If no pixels remain after alpha masking.
UnknownExtractionMethodError: If ``mode`` is not a known extraction method.

Examples:
Colors can be extracted from a variety of sources, including local files, byte streams, URLs, and numpy arrays.

Expand Down
7 changes: 5 additions & 2 deletions pylette/src/extractors/k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ def extract(self, arr: NDArray[NP_T], palette_size: int) -> list[Color]:

from sklearn.cluster import KMeans

arr = np.squeeze(arr)
model = KMeans(n_clusters=palette_size, n_init="auto", init="k-means++", random_state=2024)
arr = self._reshape_array(arr)
# Never request more clusters than there are pixels (degenerate inputs
# like a 1x1 image); KMeans requires n_clusters <= n_samples.
n_colors = min(palette_size, arr.shape[0])
model = KMeans(n_clusters=n_colors, n_init="auto", init="k-means++", random_state=2024)
labels = model.fit_predict(arr)
palette = np.array(model.cluster_centers_, dtype=int)
color_count = np.bincount(labels)
Expand Down
8 changes: 7 additions & 1 deletion pylette/src/extractors/median_cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ def extract(self, arr: NDArray[NP_T], palette_size: int) -> list[Color]:
valid_pixel_count = arr.shape[0]
boxes = [ColorBox(arr)]
while len(boxes) < palette_size:
largest_box_idx = np.argmax([box.size for box in boxes])
# Only boxes with at least 2 pixels can be split; a 1-pixel box would
# produce an empty box. Stop once nothing is splittable (e.g. there
# are fewer distinct pixels than the requested palette size).
splittable = [i for i, box in enumerate(boxes) if box.pixel_count >= 2]
if not splittable:
break
largest_box_idx = splittable[int(np.argmax([boxes[i].size for i in splittable]))]
boxes = boxes[:largest_box_idx] + boxes[largest_box_idx].split() + boxes[largest_box_idx + 1 :]
return [Color(tuple(map(int, box.average)), box.pixel_count / valid_pixel_count) for box in boxes]
9 changes: 6 additions & 3 deletions pylette/src/extractors/oklab.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,22 @@ def extract(self, arr: NDArray[NP_T], palette_size: int) -> list[Color]:
# sRGB8 -> OKLab
lab = linear_srgb_to_oklab(srgb_to_linear(rgb8 / 255.0))

model = KMeans(n_clusters=palette_size, n_init="auto", init="k-means++", random_state=2024)
# Never request more clusters than there are pixels (degenerate inputs
# like a 1x1 image); KMeans requires n_clusters <= n_samples.
n_clusters = min(palette_size, len(pixels))
model = KMeans(n_clusters=n_clusters, n_init="auto", init="k-means++", random_state=2024)
labels = model.fit_predict(lab)
centers_lab = np.asarray(model.cluster_centers_)

# OKLab centroids -> float sRGB in [0, 1], kept pre-quantization so the
# Color stores full precision; out-of-gamut values are clamped.
centers_srgb = np.clip(linear_to_srgb(oklab_to_linear_srgb(centers_lab)), 0.0, 1.0)

counts = np.bincount(labels, minlength=palette_size)
counts = np.bincount(labels, minlength=n_clusters)
total = float(counts.sum())

colors: list[Color] = []
for i in range(palette_size):
for i in range(n_clusters):
if counts[i] == 0:
continue
mean_alpha = float(alpha[labels == i].mean()) / 255.0
Expand Down
4 changes: 4 additions & 0 deletions pylette/src/palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def __init__(self, colors: list[Color], metadata: PaletteMetaData | None = None)

Parameters:
colors (list[Color]): A list of Color objects.

Note:
For a palette produced by :func:`~pylette.extract_colors`,
``frequencies`` are the per-color relative weights and sum to ``1.0``.
"""

self.colors = colors
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dev = [
"requests-mock>=1.12.1",
"ruff>=0.5.0",
"pyright>=1.1.0",
"hypothesis>=6.0",
]

[project.scripts]
Expand Down
132 changes: 132 additions & 0 deletions tests/integration/test_invariants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
These property tests correspond to the "Guarantees" section documented on ``extract_colors``.

The tests are parametrized over ``available_methods()`` so a newly registered
extractor is covered automatically. A Hypothesis property test fuzzes arbitrary
small images to catch degenerate inputs.
"""

import numpy as np
import pytest
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays
from PIL import Image

from pylette import NoValidPixelsError, Palette, extract_colors
from pylette.src.extractors import available_methods

METHODS = available_methods()

# Degenerate images make KMeans/OKLab emit sklearn ConvergenceWarnings; that is
# expected here and not what these tests are about.
pytestmark = pytest.mark.filterwarnings("ignore::UserWarning")


def _assert_palette_invariants(palette: Palette, palette_size: int) -> None:
# len(palette) <= palette_size
assert len(palette) <= palette_size
if len(palette) == 0:
return
# sum(frequencies) ~= 1.0
assert sum(palette.frequencies) == pytest.approx(1.0)
for color in palette.colors:
# every channel in-gamut, plain Python ints
assert all(isinstance(channel, int) and 0 <= channel <= 255 for channel in color.rgb)


@pytest.mark.parametrize("mode", METHODS)
@pytest.mark.parametrize("palette_size", [1, 3, 5])
@pytest.mark.parametrize("resize", [True, False])
def test_solid_image_is_handled(mode: str, palette_size: int, resize: bool) -> None:
img = Image.new("RGB", (8, 8), (12, 200, 75))
palette = extract_colors(img, palette_size=palette_size, mode=mode, resize=resize)
_assert_palette_invariants(palette, palette_size)
assert len(palette) >= 1


@pytest.mark.parametrize("mode", METHODS)
def test_one_by_one_image_is_handled(mode: str) -> None:
img = Image.fromarray(np.array([[[10, 20, 30]]], dtype=np.uint8), "RGB")
palette = extract_colors(img, palette_size=5, mode=mode, resize=False)
_assert_palette_invariants(palette, 5)
assert len(palette) >= 1


@pytest.mark.parametrize("mode", METHODS)
def test_palette_size_exceeds_distinct_colors(mode: str) -> None:
arr = np.array([[[0, 0, 0], [255, 255, 255]], [[255, 0, 0], [0, 0, 255]]], dtype=np.uint8)
img = Image.fromarray(arr, "RGB")
palette = extract_colors(img, palette_size=10, mode=mode, resize=False)
_assert_palette_invariants(palette, 10)


@pytest.mark.parametrize("mode", METHODS)
def test_partial_alpha_mask_is_handled(mode: str) -> None:
arr = np.zeros((16, 16, 4), dtype=np.uint8)
arr[..., :3] = np.random.default_rng(0).integers(0, 256, (16, 16, 3))
arr[::2, :, 3] = 255 # half opaque, half transparent
img = Image.fromarray(arr, "RGBA")
palette = extract_colors(img, palette_size=5, mode=mode, resize=False, alpha_mask_threshold=0)
_assert_palette_invariants(palette, 5)


@pytest.mark.parametrize("mode", METHODS)
def test_total_alpha_mask_raises_typed_error(mode: str) -> None:
arr = np.zeros((16, 16, 4), dtype=np.uint8) # alpha = 0 everywhere
img = Image.fromarray(arr, "RGBA")
with pytest.raises(NoValidPixelsError):
extract_colors(img, palette_size=5, mode=mode, resize=False, alpha_mask_threshold=0)


@pytest.mark.parametrize("mode", METHODS)
def test_deterministic_under_fixed_random_state(mode: str) -> None:
arr = np.random.default_rng(7).integers(0, 256, (20, 20, 3), dtype=np.uint8)
img = Image.fromarray(arr, "RGB")
a = extract_colors(img, palette_size=5, mode=mode)
b = extract_colors(img, palette_size=5, mode=mode)
assert [c.rgb for c in a.colors] == [c.rgb for c in b.colors]
assert [c.frequency for c in a.colors] == [c.frequency for c in b.colors]


@pytest.mark.parametrize("mode", METHODS)
@pytest.mark.parametrize(
"sort_mode, key, reverse",
[
("luminance", lambda c: c.luminance, False),
("frequency", lambda c: c.frequency, True),
],
)
def test_sort_order_is_stable_and_idempotent(mode, sort_mode, key, reverse) -> None: # type: ignore[no-untyped-def]
arr = np.random.default_rng(3).integers(0, 256, (24, 24, 3), dtype=np.uint8)
img = Image.fromarray(arr, "RGB")
palette = extract_colors(img, palette_size=6, mode=mode, sort_mode=sort_mode)
colors = palette.colors
# The returned palette is already in sort order, and re-sorting is a no-op.
resorted = sorted(colors, key=key, reverse=reverse)
assert [c.rgb for c in resorted] == [c.rgb for c in colors]


_image_arrays = arrays(
dtype=np.uint8,
shape=st.tuples(st.integers(1, 12), st.integers(1, 12), st.sampled_from([3, 4])),
)


@settings(max_examples=40, deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(
arr=_image_arrays,
palette_size=st.integers(1, 8),
mode=st.sampled_from(METHODS),
sort_mode=st.sampled_from([None, "luminance", "frequency"]),
resize=st.booleans(),
)
def test_property_invariants_hold_for_arbitrary_images(arr, palette_size, mode, sort_mode, resize) -> None: # type: ignore[no-untyped-def]
mode_str = "RGB" if arr.shape[-1] == 3 else "RGBA"
img = Image.fromarray(arr, mode_str)
try:
palette = extract_colors(img, palette_size=palette_size, mode=mode, sort_mode=sort_mode, resize=resize)
except NoValidPixelsError:
# A fully alpha-masked image is an expected, typed failure (P4).
return
_assert_palette_invariants(palette, palette_size)
24 changes: 24 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading