Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`register`, `get_extractor`, and `available_methods` from `Pylette.src.extractors`.
Extractors are now registered via the `@register(...)` decorator, making it
possible to plug in custom extraction methods.
- **OKLab extractor**: New `ExtractionMethod.OKLAB` extraction mode that runs
k-means in the perceptual [OKLab](https://bottosson.github.io/posts/oklab/)
color space, so clusters are grouped by perceived color difference. Pixels are
linearized before conversion.

### Changed

Expand Down
1 change: 1 addition & 0 deletions Pylette/src/extractors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Import for registration side-effect
from Pylette.src.extractors import k_means as _k_means # type: ignore # noqa: F401
from Pylette.src.extractors import median_cut as _median_cut # type: ignore # noqa: F401
from Pylette.src.extractors import oklab as _oklab # type: ignore # noqa: F401
from Pylette.src.extractors.registry import available_methods, get_extractor, register

__all__ = ["available_methods", "get_extractor", "register"]
133 changes: 133 additions & 0 deletions Pylette/src/extractors/oklab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""OKLab-based color extraction.

OKLab (Björn Ottosson, 2020 -- https://bottosson.github.io/posts/oklab/) is a
perceptual color space built so that Euclidean distance approximates perceived
color difference. Running k-means in OKLab therefore groups colors the way the
eye does.

Pipeline::

sRGB8 -> [0,1] -> linear sRGB -> OKLab -> k-means -> centroids (OKLab)
centroids -> linear sRGB -> sRGB8

The sRGB <-> linear step (the IEC 61966-2-1 transfer function) is not optional:
OKLab is defined on *linear* light, so skipping linearization distort the
kmeans clustering.

Two deliberate differences from the plain-RGB ``KMeans`` extractor:

* Alpha is not folded into the distance metric. Clustering happens on the three
OKLab channels; a representative alpha is recovered per cluster by averaging
the alpha of its member pixels.
* Empty clusters are dropped rather than emitted as spurious swatches.
"""

import numpy as np
from numpy.typing import NDArray
from typing_extensions import override

from Pylette.src.color import Color
from Pylette.src.extractors.protocol import NP_T, ColorExtractorBase
from Pylette.src.extractors.registry import register
from Pylette.src.types import ExtractionMethod, FloatArray


# sRGB <-> linear sRGB (IEC 61966-2-1 transfer function)
def srgb_to_linear(srgb: FloatArray) -> FloatArray:
"""Map gamma-encoded sRGB in ``[0, 1]`` to linear-light sRGB in ``[0, 1]``."""
return np.where(srgb <= 0.04045, srgb / 12.92, ((srgb + 0.055) / 1.055) ** 2.4)


def linear_to_srgb(linear: FloatArray) -> FloatArray:
"""Map linear-light sRGB in ``[0, 1]`` back to gamma-encoded sRGB in ``[0, 1]``."""
linear = np.clip(linear, 0.0, 1.0)
return np.where(linear <= 0.0031308, 12.92 * linear, 1.055 * np.power(linear, 1.0 / 2.4) - 0.055)


# linear sRGB <-> OKLab (Ottosson 2020)
#
# Forward matrices are the ones from the webpage. the inverse matrices are derived
# from them so the round-trip is numerically self-consistent (single source of
# truth) rather than two independently transcribed constant sets.

# linear sRGB -> LMS
_LRGB_TO_LMS = np.array(
[
[0.4122214708, 0.5363325363, 0.0514459929],
[0.2119034982, 0.6806995451, 0.1073969566],
[0.0883024619, 0.2817188376, 0.6299787005],
]
)
# nonlinear l'm's' -> OKLab
_LMS_TO_OKLAB = np.array(
[
[0.2104542553, 0.7936177850, -0.0040720468],
[1.9779984951, -2.4285922050, 0.4505937099],
[0.0259040371, 0.7827717662, -0.8086757660],
]
)
_OKLAB_TO_LMS = np.linalg.inv(_LMS_TO_OKLAB)
_LMS_TO_LRGB = np.linalg.inv(_LRGB_TO_LMS)


def linear_srgb_to_oklab(rgb: FloatArray) -> FloatArray:
"""Convert an ``(N, 3)`` array of linear sRGB to OKLab."""
lms = rgb @ _LRGB_TO_LMS.T
lms_nonlinear = np.cbrt(lms)
return lms_nonlinear @ _LMS_TO_OKLAB.T


def oklab_to_linear_srgb(lab: FloatArray) -> FloatArray:
"""Convert an ``(N, 3)`` array of OKLab back to linear sRGB."""
lms_nonlinear = lab @ _OKLAB_TO_LMS.T
lms = lms_nonlinear**3
return lms @ _LMS_TO_LRGB.T


@register(ExtractionMethod.OKLAB)
class OKLabKMeansExtractor(ColorExtractorBase):
"""K-means clustering performed in OKLab (perceptual) space."""

@override
def extract(self, arr: NDArray[NP_T], height: int, width: int, palette_size: int) -> list[Color]:
"""Extract a palette by clustering pixels in OKLab space.

Parameters:
arr: Pixel array of shape ``(..., C)`` with ``C >= 3``; RGB(A), uint8.
height: Image height (unused; retained for protocol compatibility).
width: Image width (unused; retained for protocol compatibility).
palette_size: Number of clusters / colors to extract.

Returns:
list[Color]: One color per non-empty cluster, with frequencies that
sum to 1.
"""
from sklearn.cluster import KMeans

pixels = np.asarray(arr).reshape(-1, arr.shape[-1])
rgb8 = pixels[:, :3].astype(np.float64)
has_alpha = pixels.shape[1] >= 4
alpha = pixels[:, 3].astype(np.float64) if has_alpha else np.full(len(pixels), 255.0)

# sRGB8 -> OKLab
lab = linear_srgb_to_oklab(srgb_to_linear(rgb8 / 255.0))

model = KMeans(n_clusters=palette_size, n_init="auto", init="k-means++", random_state=2024)
labels = model.fit_predict(lab)
centers_lab = np.asarray(model.cluster_centers_)

# OKLab centroids -> sRGB8
centers_srgb = linear_to_srgb(oklab_to_linear_srgb(centers_lab))
centers_rgb8 = np.clip(np.round(centers_srgb * 255.0), 0, 255).astype(int)

counts = np.bincount(labels, minlength=palette_size)
total = float(counts.sum())

colors: list[Color] = []
for i in range(palette_size):
if counts[i] == 0:
continue
mean_alpha = int(round(float(alpha[labels == i].mean())))
r, g, b = (int(c) for c in centers_rgb8[i])
colors.append(Color((r, g, b, mean_alpha), counts[i] / total))
return colors
1 change: 1 addition & 0 deletions Pylette/src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __array__(self) -> NDArray[np.uint8]: ...
class ExtractionMethod(str, Enum):
MC = "MedianCut"
KM = "KMeans"
OKLAB = "OKLab"


class ColorSpace(str, Enum):
Expand Down
73 changes: 73 additions & 0 deletions tests/integration/test_oklab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Tests for the OKLab extractor.
These focus on the OKLab color math (must match Ottosson's reference vectors and
round-trip near machine precision).
"""

import numpy as np
import pytest
from PIL import Image

from Pylette import extract_colors
from Pylette.src.extractors.oklab import (
linear_srgb_to_oklab,
linear_to_srgb,
oklab_to_linear_srgb,
srgb_to_linear,
)
from Pylette.types import ExtractionMethod


class TestOKLabTransforms:
# Reference vectors from Ottosson (linear sRGB -> OKLab).
REFERENCES = [
((1.0, 1.0, 1.0), (1.0, 0.0, 0.0)),
((1.0, 0.0, 0.0), (0.627955, 0.224863, 0.125846)),
((0.0, 1.0, 0.0), (0.866440, -0.233888, 0.179498)),
((0.0, 0.0, 1.0), (0.452014, -0.032457, -0.311528)),
]

@pytest.mark.parametrize("linear_rgb,expected_oklab", REFERENCES)
def test_forward_matches_reference(self, linear_rgb, expected_oklab):
got = linear_srgb_to_oklab(np.array([linear_rgb], dtype=float))[0]
np.testing.assert_allclose(got, expected_oklab, atol=1e-5)

def test_oklab_roundtrip_is_near_exact(self):
rng = np.random.default_rng(0)
lin = rng.random((5000, 3))
rt = oklab_to_linear_srgb(linear_srgb_to_oklab(lin))
np.testing.assert_allclose(lin, rt, atol=1e-10)

def test_full_srgb8_pipeline_roundtrip(self):
rng = np.random.default_rng(1)
srgb = rng.integers(0, 256, size=(5000, 3)).astype(float) / 255.0
back = linear_to_srgb(oklab_to_linear_srgb(linear_srgb_to_oklab(srgb_to_linear(srgb))))
assert np.max(np.abs(srgb - back)) * 255 < 1e-6


class TestOKLabExtraction:
@pytest.fixture
def gradient_image(self) -> Image.Image:
arr = np.zeros((64, 64, 3), dtype=np.uint8)
xs = np.arange(64)
arr[..., 0] = (xs[None, :] * 4) % 256
arr[..., 1] = (xs[:, None] * 4) % 256
arr[..., 2] = ((xs[None, :] + xs[:, None]) * 2) % 256
return Image.fromarray(arr, "RGB")

def test_respects_palette_size(self, gradient_image):
palette = extract_colors(gradient_image, palette_size=6, mode=ExtractionMethod.OKLAB)
assert len(palette) <= 6

def test_frequencies_sum_to_one(self, gradient_image):
palette = extract_colors(gradient_image, palette_size=5, mode=ExtractionMethod.OKLAB)
assert sum(palette.frequencies) == pytest.approx(1.0)

def test_colors_are_valid_rgb(self, gradient_image):
palette = extract_colors(gradient_image, palette_size=5, mode=ExtractionMethod.OKLAB)
for color in palette.colors:
assert all(0 <= channel <= 255 for channel in color.rgb)

def test_deterministic(self, gradient_image):
a = extract_colors(gradient_image, palette_size=5, mode=ExtractionMethod.OKLAB)
b = extract_colors(gradient_image, palette_size=5, mode=ExtractionMethod.OKLAB)
assert [c.rgb for c in a.colors] == [c.rgb for c in b.colors]