In [None]:
from image_functions import image_sampler, oklab_to_linear_srgb, linear_srgb_to_oklab
import skimage as ski
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.image as mpimg
from jax import vmap
import jax.random as jr
from IPython.display import HTML

matplotlib.rcParams["animation.embed_limit"] = 2**128
plt.rcParams["axes.spines.right"] = False
plt.rcParams["axes.spines.top"] = False



In [None]:
# read in original image
im = mpimg.imread("bluecat.png")

# process image. downscaling is important, as it requires simulation of less particles
# im = im / 255 # uncomment if reading jpeg
im = ski.color.rgba2rgb(im) # uncomment if reading png
im = ski.transform.rescale(im, 0.25, channel_axis=-1, anti_aliasing=True, order=3)

In [None]:
# define and run sampler
manager = image_sampler(
    im,
    num_particles=50000,
    loss_space="oklab",
    posterization_params={"posterizer": "oklab", "n_colors": 12},
    smoother_params={"kernel_size": 1, "kernel_std": 1.,},
    likelihood_params={"INF":1e2, 'scaled':True, 'box_bounds':(-6.,6.)},
    sampler_params={"lambd_range": (-1, 2.), "annealing_steps":50,'extra_steps':25},
)
out = manager.run(jr.key(4))

In [None]:
# view posterized image
plt.imshow(vmap(vmap(oklab_to_linear_srgb))(manager.palette))

In [None]:
# potentially smoothed reference image. smoothing largely obsolete, not used here
plt.imshow(vmap(vmap(oklab_to_linear_srgb))(manager.ref_img))

In [None]:
# render gif with white smoothing
ani = manager.draw_gif('cat_real.gif',render='img',start_frame=0, smoothing_params={'kernel_size':9, 'kernel_std':5.})
HTML(ani.to_jshtml())

In [None]:
# show loss for each particle color. as this plots negative loglikelihood, particles seek
# regions with low (dark) values.
manager.show_loss_map()