In [15]:
import numpy as np

from PIL import Image

from scipy.fft import dctn, idctn
from matplotlib import pyplot as plt
from matplotlib import pylab as pylab
import colours
import oklab
from dqt import *

%matplotlib inline
pylab.rcParams['figure.figsize'] = (12, 12)

cspace = colours.OKLAB

In [16]:
def im_rgb_to_cspace(im_r, im_g, im_b):
    im_r = np.asarray(im_r)
    im_g = np.asarray(im_g)
    im_b = np.asarray(im_b)

    rgb = np.zeros([im_r.size, 3])

    rgb[:, 0] = im_r.ravel()
    rgb[:, 1] = im_g.ravel()
    rgb[:, 2] = im_b.ravel()

    xyz = colours.SRGB.to_xyz(rgb)
    csp = cspace.from_xyz(xyz)

    l, m, s = csp[:, 0], csp[:, 1], csp[:, 2]

    l = l - cspace.l_min
    m = m - cspace.m_min
    s = s - cspace.s_min

    l = l.reshape(im_r.shape)
    m = m.reshape(im_r.shape)
    s = s.reshape(im_r.shape)

    return (l, m, s)

def im_cspace_to_rgb(im_l, im_m, im_s):
    l = np.asarray(im_l)
    m = np.asarray(im_m)
    s = np.asarray(im_s)

    l = l + cspace.l_min
    m = m + cspace.m_min
    s = s + cspace.s_min

    csp = np.zeros([im_l.size, 3])

    csp[:, 0] = im_l.ravel()
    csp[:, 1] = im_m.ravel()
    csp[:, 2] = im_s.ravel()

    xyz = cspace.to_xyz(csp)
    rgb = colours.SRGB.from_xyz(xyz)

    r, g, b = rgb[:, 0], rgb[:, 1], rgb[:, 2]

    r = r.reshape(im_l.shape)
    g = g.reshape(im_l.shape)
    b = b.reshape(im_l.shape)

    return (r, g, b)

In [17]:
def im_rgb_to_oklab(im_r, im_g, im_b):
    im_r = np.asarray(im_r)
    im_g = np.asarray(im_g)
    im_b = np.asarray(im_b)

    l, m, s = oklab.lsrgb_to_oklab(im_r, im_g, im_b)

    l = l - oklab.l_min
    m = m - oklab.m_min
    s = s - oklab.s_min

    return (l, m, s)

def im_oklab_to_rgb(im_l, im_m, im_s):
    l = np.asarray(im_l)
    m = np.asarray(im_m)
    s = np.asarray(im_s)

    l = l + oklab.l_min
    m = m + oklab.m_min
    s = s + oklab.s_min

    r, g, b = oklab.oklab_to_lsrgb(l, m, s)

    return (r, g, b)

def read_image(name, size):
    image = Image.open(f'../../images/{name}/{name}-{size}.png').convert('RGB')

    image = np.asarray(image)

    r, g, b = image[:, :, 0], image[:, :, 1], image[:, :, 2]

    r = r[:r.shape[0] - r.shape[0]%8, :r.shape[1] - r.shape[1]%8]
    g = g[:g.shape[0] - g.shape[0]%8, :g.shape[1] - g.shape[1]%8]
    b = b[:b.shape[0] - b.shape[0]%8, :b.shape[1] - b.shape[1]%8]

    image_cropped = np.zeros([r.shape[0], r.shape[1], 3], dtype=np.uint8)

    image_cropped[:, :, 0] = r
    image_cropped[:, :, 1] = g
    image_cropped[:, :, 2] = b

    return Image.fromarray(image_cropped)


image = read_image('earth', 1024)

image.save('original.png')

image = np.asarray(image)

r, g, b = image[:, :, 0], image[:, :, 1], image[:, :, 2], 

l, m, s = im_rgb_to_oklab(r, g, b)

In [18]:
dqt_90_dct_lum = np.array([
     3,   2,   2,   3,   5,   8,  10,  12,
     2,   2,   3,   4,   5,  12,  12,  11,
     3,   3,   3,   5,   8,  11,  14,  11,
     3,   3,   4,   6,  10,  17,  16,  12,
     4,   4,   7,  11,  14,  22,  21,  15,
     5,   7,  11,  13,  16,  12,  23,  18,
    10,  13,  16,  17,  21,  24,  24,  21,
    14,  18,  19,  20,  22,  20,  20,  20,
], dtype=np.float64)

dqt_50_dct_lum = np.array([
    16,  11,  10,  16,  24,  40,  51,  61,
    12,  12,  14,  19,  26,  58,  60,  55,
    14,  13,  16,  24,  40,  57,  69,  56,
    14,  17,  22,  29,  51,  87,  80,  62,
    18,  22,  37,  56,  68, 109, 103,  77,
    24,  35,  55,  64,  81, 104, 113,  92,
    49,  64,  78,  87, 103, 121, 120, 101,
    72,  92,  95,  98, 112, 100, 103,  99,
], dtype=np.float64)

dqt_90_dct_oklab_l = 2 * dqt_90_dct_lum * oklab.l_range/255
dqt_90_dct_oklab_m = 4 * dqt_50_dct_lum * oklab.m_range/255
dqt_90_dct_oklab_s = 4 * dqt_50_dct_lum * oklab.s_range/255

In [19]:
def process(im, dqt):
    im = np.array(im)
    im = im[:im.shape[0] - im.shape[0]%8, :im.shape[1] - im.shape[1]%8]

    im = im.astype(np.float64)

    dct = np.zeros(im.shape, dtype=np.float64)

    # Do 8x8 DCT on image (in-place)
    for i in np.r_[:im.shape[0]:8]:
        for j in np.r_[:im.shape[1]:8]:
            dim = dctn(im[i:(i+8), j:(j+8)], axes=[0, 1], norm='ortho')

            dct[i:(i+8), j:(j+8)] = (np.rint(dim.ravel() / dqt) * dqt).reshape((8, 8))

    image_dct = np.zeros(im.shape)

    # Do 8x8 IDCT on image (in-place)
    for i in np.r_[:im.shape[0]:8]:
        for j in np.r_[:im.shape[1]:8]:
            image_dct[i:(i+8),j:(j+8)] = idctn(dct[i:(i+8), j:(j+8)], axes=[0, 1], norm='ortho')
    

    return image_dct


In [20]:
im_l = process(l, dqt_90_dct_oklab_l)
im_m = process(m, dqt_90_dct_oklab_m)
im_s = process(s, dqt_90_dct_oklab_s)

In [21]:
im_r, im_g, im_b = im_cspace_to_rgb(im_l, im_m, im_s)

im_r = np.clip(np.rint(im_r), 0, 255).astype(np.uint8)
im_g = np.clip(np.rint(im_g), 0, 255).astype(np.uint8)
im_b = np.clip(np.rint(im_b), 0, 255).astype(np.uint8)


img_rec = np.zeros_like(image)

img_rec[:, :, 0] = im_r
img_rec[:, :, 1] = im_g
img_rec[:, :, 2] = im_b

im_rec = Image.fromarray(img_rec)

im_rec.save('recovered.png')