# Histogram equalization for RGB images

- In this assignment, you will apply histogram equalization to RGB color images.
- You will try two methods: application to each channel individually and usage of common equalization based on grayscale image for all channels.
- Also, there is an optional task to implement a function for creating an arbitrary intensity transformation look-up table (LUT).

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import skimage

import sys; sys.path.append('..')
from tests import test_histogram_equalization

In [None]:
rgb = skimage.io.imread('../data/fruits.jpg')
gray = skimage.util.img_as_ubyte(skimage.color.rgb2gray(rgb))

In [None]:
fig, axes = plt.subplots(1, 2, figsize=plt.figaspect(0.5), layout='constrained')
axes[0].imshow(rgb)
axes[1].imshow(gray, cmap='gray', vmin=0, vmax=255);

# Task 0: prepare functions


Prepare two following functions:
1. `create_histogram_equalization_lut`
2. `histogram_equalization`

Hints:
- You can just copy&paste `create_histogram_equalization_lut` from the [intensity_transformations](../lectures/intensity_transformations.ipynb) lecture notebook.
- Note the usage of the `cdf` function.

In [None]:
def create_histogram_equalization_lut(h: np.ndarray) -> np.ndarray:
    ########################################
    # TODO: implement

    raise NotImplementedError

    ########################################

    return lut

In [None]:
test_histogram_equalization.TestCreateHistogramEqualizationLut.eval(create_histogram_equalization_lut_fn=create_histogram_equalization_lut)

Prepare a function called `equalize_histogram`, that will take in a grayscale image (2D matrix) and output an equalized version of it.

In [None]:
def equalize_histogram(gray: np.ndarray) -> np.ndarray:
    ########################################
    # TODO: implement

    raise NotImplementedError

    ########################################

    return equalized

In [None]:
test_histogram_equalization.TestEqualizeHistogram.eval(equalize_histogram_fn=equalize_histogram)

The following should show the original grayscale image `gray` and its equalized version.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=plt.figaspect(0.5), layout='constrained')
axes[0].imshow(gray, cmap='gray', vmin=0, vmax=255)
axes[1].imshow(equalize_histogram(gray), cmap='gray', vmin=0, vmax=255);

# Task 1: histogram equalization for R, G and B channels independently

Implement the function `equalize_histogram_rgb_indep`. The function should perform the following steps:
1. Split the R, G and B channels from `rgb`.
2. Equalize each channel independently.
3. Stack them together to form the new equalized `rgb_indep` image.

In [None]:
def equalize_histogram_rgb_indep(rgb: np.ndarray) -> np.ndarray:
    ########################################
    # TODO: implement

    raise NotImplementedError

    # ENDTODO
    ########################################

In [None]:
test_histogram_equalization.TestEqualizeHistogramRGBIndep.eval(equalize_histogram_rgb_indep_fn=equalize_histogram_rgb_indep)

In [None]:
rgb_equ_indep = equalize_histogram_rgb_indep(rgb)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=plt.figaspect(0.5), layout='constrained')
axes[0].imshow(rgb)
axes[1].imshow(rgb_equ_indep);

# Task 2: equalize using transformation LUT computed from grayscale image

Implement the function `equalize_histogram_rgb_gray`. The function should perform the following steps:
1. Convert `rgb` to grayscale (use the `gray` image from above).
2. Compute the histogram equalization lookup table from the grayscale image.
3. Apply this same LUT to every channel of the `rgb` image to form the new `rgb_equ_gray` equalized image.

In [None]:
def equalize_histogram_rgb_gray(rgb: np.ndarray) -> np.ndarray:
    ########################################
    # TODO: implement

    raise NotImplementedError

    # ENDTODO
    ########################################

In [None]:
test_histogram_equalization.TestEqualizeHistogramRGBGray.eval(equalize_histogram_rgb_gray_fn=equalize_histogram_rgb_gray)

In [None]:
rgb_equ_gray = equalize_histogram_rgb_gray(rgb)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=plt.figaspect(0.5), layout='constrained')
axes[0].imshow(rgb)
axes[1].imshow(rgb_equ_gray);

# Task 3 (optional): implement `create_lut` function

Implement the function `create_lut` that will create a custom intensity transformation lookup table based on user's input. It's inputs is
- `points`: list of `(a, b)` pairs such that brightness `a` maps to `b` in the resulting table

Returns:
- `q_enh`: lookup table as a vector of 256 values of type `np.ndarray` with shape `(256,)`

Notes:
- User specifies arbitrarily-sized sparse `list` of `(p, q)` pairs of which input brightness `a` should map to which output brightness `b`.
- Everything in-between should be interpolated.
- For example `create_lut([(0, 255), (255, 0)])` should return the same lookup table as `lut_neg` from the lecture notebook, because in this example `a=0` maps to `b=255`, `a=255` maps to `b=0` and everything inbetween will be linearly interpolated.
- If the user's list starts from a pair in which `a > 0` or ends with a pair in which `a < 255`, then all values less than the smallest `a` will be mapped to a value of the `b` from the same pair.
- For example, if the input list is `[(100, 100), (200, 200)]`, all values of `0 <= a < 100` will map to `b=100` and all values `200 < a < 256` will map to `b=200`.

In [None]:
def create_lut(points: list[tuple[int, int]]) -> np.ndarray:
    """
    lookup table for arbitrary brightness transformation
    """

    ########################################
    # TODO: implement

    raise NotImplementedError
    
    # ENDTODO
    ########################################
    
    return lut

In [None]:
test_histogram_equalization.TestCreateLUT.eval(create_lut_fn=create_lut)

In [None]:
lut_neg = create_lut([(0, 255), (255, 0)])  # lookup table for negative
lut_enh = create_lut([(0, 0), (50, 10), (200, 245), (255, 255)])  # lookup table for contrast enhancement
lut_thr = create_lut([(0, 0), (100, 0), (101, 255), (255, 255)])  # lookup table for thresholding
lut_arb = create_lut([(50, 150), (128, 30), (200, 255)])

In [None]:
with sns.axes_style(style='darkgrid'):
    fig, axes = plt.subplots(1, 4, figsize=plt.figaspect(0.25), layout='constrained')
    axes[0].plot(lut_neg);
    axes[1].plot(lut_enh);
    axes[2].plot(lut_thr);
    axes[3].plot(lut_arb);

In [None]:
fig, axes = plt.subplots(1, 4, figsize=plt.figaspect(0.25), layout='constrained')
axes[0].imshow(lut_neg[gray], cmap='gray', vmin=0, vmax=255);
axes[1].imshow(lut_enh[gray], cmap='gray', vmin=0, vmax=255);
axes[2].imshow(lut_thr[gray], cmap='gray', vmin=0, vmax=255);
axes[3].imshow(lut_arb[gray], cmap='gray', vmin=0, vmax=255);