In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import colour
import numpy as np
import cv2
from colour_checker_detection import detect_colour_checkers_segmentation

FRAMES_DIRECTORY = os.path.join(os.environ['HOME'], 'Videos', 'colorchecks')
FRAME_SETS = [file for file in os.listdir(FRAMES_DIRECTORY) if file.endswith('.npy')]
FRAME_SETS.sort()

FRAME_SET_PATHS = [os.path.join(FRAMES_DIRECTORY, file) for file in FRAME_SETS]

In [None]:
CAPTURE_FRAMESET_PATHS = [path for path in FRAME_SET_PATHS if 'rwg-log3g10' in path]
YUYV_FRAMESET_PATHS = [path for path in FRAME_SET_PATHS if path.endswith('v4l2-yuyv.npy')]

print(CAPTURE_FRAMESET_PATHS)
print(YUYV_FRAMESET_PATHS)

In [None]:
ACTIVE_FRAMESET_PATHS = CAPTURE_FRAMESET_PATHS
ACTIVE_FRAMESET_PATH = ACTIVE_FRAMESET_PATHS[0]
ORIGINAL_FRAMES = np.load(ACTIVE_FRAMESET_PATH)[::10]
ORIGINAL_FRAMESETS = [np.load(path)[::10] for path in ACTIVE_FRAMESET_PATHS]

frames = ORIGINAL_FRAMES.copy()
framesets = [f.copy() for f in ORIGINAL_FRAMESETS]
print(ACTIVE_FRAMESET_PATH)
print(frames.shape)

for i, frameset in enumerate(framesets):
    # Save the first frame of each frameset to a PNG file
    npy_path = ACTIVE_FRAMESET_PATHS[i]
    png_path = npy_path.replace('.npy', '.png')
    cv2.imwrite(png_path, frameset[5])

In [None]:
# Convert from BGR to RGB
import cv2

frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]
for i, frameset in enumerate(framesets):
    framesets[i] = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frameset]

In [None]:
# Flip frame upside down if the path is '-bgr' or '-mjpg'
if '-bgr' in ACTIVE_FRAMESET_PATH or '-mjpg' in ACTIVE_FRAMESET_PATH:
    frames = [cv2.flip(frame, 0) for frame in frames]

for i, frameset in enumerate(framesets):
    if '-bgr' in ACTIVE_FRAMESET_PATHS[i] or '-mjpg' in ACTIVE_FRAMESET_PATHS[i]:
        framesets[i] = [cv2.flip(frame, 0) for frame in frameset]

In [None]:
sample_indices = np.linspace(0, len(frames) - 1, 5, dtype=int)  # Adjust the number of samples as needed

fig, axs = plt.subplots(1, len(sample_indices), figsize=(15, 5))

for ax, idx in zip(axs, sample_indices):
    ax.imshow(frames[idx])
    ax.set_title(f"Frame {idx}")
    ax.axis('off')
print(ACTIVE_FRAMESET_PATH)
plt.show()

In [None]:
sample_indices = np.linspace(0, len(frames) - 1, 5, dtype=int)

def log3g10_to_linear(log3g10):
    """
    Convert Log3G10 encoded values to linear light values.
    This formula is based on the typical Log3G10 curve used by RED cameras.
    Adjust the formula according to the specific details if necessary.
    """
    linear = np.where(log3g10 <= 0.0, 0.0, 10**((log3g10 - 0.091)/0.45) - 0.01)
    return linear

def rwg_to_xyz(rwg):
    """
    Convert REDWideGamutRGB linear light values to XYZ.
    The input rwg is expected to be a NumPy array of shape (..., 3),
    where the last dimension represents the R, G, and B components.
    """
    # REDWideGamutRGB to XYZ matrix
    rwg_to_xyz_matrix = np.array([
        [0.735275, 0.068609, 0.146571],
        [0.286694, 0.842979, -0.129673],
        [-0.079681, -0.347343, 1.516081]
    ])
    # Apply the matrix transformation
    xyz = np.dot(rwg, rwg_to_xyz_matrix.T)
    return xyz

def image_to_calibration_matrix(image: np.ndarray) -> np.ndarray:
    """Return the 3x3 color calibration matrix given an image in linear color
    space containing a Macbeth chart.
    """
    swatches = detect_colour_checkers_segmentation(image)
    if len(swatches) == 0:
        raise ValueError("No color charts detected in image!")

    REFERENCE_COLOUR_CHECKER = colour.CCS_COLOURCHECKERS[
        "ColorChecker24 - After November 2014"
    ]
    REFERENCE_SWATCHES = colour.XYZ_to_RGB(
        colour.xyY_to_XYZ(list(REFERENCE_COLOUR_CHECKER.data.values())),
        "sRGB",
        REFERENCE_COLOUR_CHECKER.illuminant,
    ).astype("float32")

    return np.linalg.pinv(swatches[0]) @ REFERENCE_SWATCHES

def color_correct_with_matrix(image: np.ndarray, matrix: np.ndarray) -> np.ndarray:
    """Return the color corrected image given the color calibration matrix."""
    return np.power((image @ matrix).clip(0, 1), 1 / 2.2) * 255.0

# Display a series of frames at various stages of color correction / detection
# The first frame is the original frame
# The second frame is the frame after linearization
# The third frame is the frame after RWG to XYZ
# The fourth frame is the frame after detection of chart
# The fifth frame is the frame after color correction
sample_indices = np.array([0, 11])
fig, axs = plt.subplots(5, len(sample_indices), figsize=(15, 15))

for i, idx in enumerate(sample_indices):
    original_frame = frames[i]
    axs[0, i].imshow(original_frame)
    axs[0, i].set_title(f"Frame {idx}")
    axs[0, i].axis('off')

    linear_frame = log3g10_to_linear(original_frame / 255.0)
    axs[1, i].imshow(linear_frame * 255.0)
    axs[1, i].set_title(f"Linear Frame {idx}")
    axs[1, i].axis('off')

    xyz_frame = rwg_to_xyz(original_frame / 255.0)
    axs[2, i].imshow(xyz_frame)
    axs[2, i].set_title(f"XYZ Frame {idx}")
    axs[2, i].axis('off')

    # detection_frame = detect_colour_checkers_segmentation(xyz_frame)
    # axs[3, i].imshow(detection_frame * 255.0)
    # axs[3, i].set_title(f"Detection Frame {idx}")
    # axs[3, i].axis('off')

    calibration_matrix = image_to_calibration_matrix(original_frame / 255.0)
    color_corrected_frame = color_correct_with_matrix(xyz_frame, calibration_matrix)
    axs[4, i].imshow(color_corrected_frame)
    axs[4, i].set_title(f"Color Corrected Frame {idx}")
    axs[4, i].axis('off')

In [None]:

def plot_color_histograms(frames, title):
    """
    Plots color histograms for a list of frames.
    """
    plt.figure(figsize=(12, 4))
    
    # Plot the frame
    plt.subplot(1, 2, 1)
    plt.imshow(frames[4])
    plt.title(f'{title} Image')
    plt.axis('off')
    
    # Plot the color histogram
    plt.subplot(1, 2, 2)
    colors = ('r', 'g', 'b')
    for i, col in enumerate(colors):
        hist = cv2.calcHist(frames, [i], None, [256], [0, 256])
        plt.plot(hist, color=col)
    plt.xlim([0, 256])
    plt.title(f'{title} Color Histogram')

    plt.show()

for i, frameset in enumerate(framesets):
    plot_color_histograms(frameset[:5], f'{ACTIVE_FRAMESET_PATHS[i]}')