In [None]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import cv2
import ipywidgets as widgets
from IPython.display import display
from basic_SVD import find_SVD, recombine_SVD

np.seterr(all='log')

In [112]:
def assertion_check(color_image_arr:np.ndarray):
    assert color_image_arr.ndim == 3
    assert color_image_arr.shape[2] == 3





def interweave_encode(color_image_arr:np.ndarray):
    assertion_check(color_image_arr)

    R = color_image_arr[:, :, 0] / 255# Red channel
    G = color_image_arr[:, :, 1] / 255 # Green channel
    B = color_image_arr[:, :, 2] / 255 # Blue channel
    stacked = np.stack((R, G, B), axis=1)  # Shape: (m, 3, n)
    interwoven = stacked.reshape(-1, R.shape[1])  # Shape: (m * 3, n)
    return interwoven

def interweave_decode(interweaved_arr:np.ndarray):
    interweaved_arr = interweaved_arr
    # Extract each channel
    R = interweaved_arr[0::3] * 255 # Rows for Red channel
    G = interweaved_arr[1::3] * 255 # Rows for Green channel
    B = interweaved_arr[2::3] * 255 # Rows for Blue channel

    # Stack channels back into RGB
    rgb_matrix = np.stack((R, G, B), axis=-1).astype(np.uint8)  # Shape: (m, n, 3)
    return rgb_matrix




def hsv_encode(color_image_arr:np.ndarray):
    assertion_check(color_image_arr)
    color_image_arr = mcolors.rgb_to_hsv(color_image_arr)
    H = color_image_arr[:, :, 0] # alraedy normalized to 1
    S = color_image_arr[:, :, 1] # already normalized to 1
    V = color_image_arr[:, :, 2] / 255 # normalize value to 1

    # Stack channels back into RGB
    stacked = np.stack((H, S, V), axis=1)  # Shape: (m, 3, n)
    interwoven = stacked.reshape(-1, H.shape[1])  # Shape: (m * 3, n)
    return interwoven

def hsv_decode(interweaved_arr:np.ndarray):
    # Extract each channel
    H = interweaved_arr[0::3]
    S = interweaved_arr[1::3]
    V = interweaved_arr[2::3] * 255

    # Stack channels back into RGB
    hsv_matrix = np.stack((H, S, V), axis=-1)  # Shape: (m, n, 3)
    rgb_matrix = np.clip(mcolors.hsv_to_rgb(hsv_matrix), a_max=255, a_min=0.01).astype(np.uint8)
    return rgb_matrix





def ycr_encode(color_image_arr:np.ndarray):
    assertion_check(color_image_arr)
    color_image_arr = cv2.cvtColor(color_image_arr, cv2.COLOR_RGB2YCrCb)
    Y = color_image_arr[:, :, 0] / 255
    CR = color_image_arr[:, :, 1] / 255
    CB = color_image_arr[:, :, 2] / 255

    # Stack channels back into RGB
    stacked = np.stack((Y, CR, CB), axis=1)  # Shape: (m, 3, n)
    interwoven = stacked.reshape(-1, Y.shape[1])  # Shape: (m * 3, n)
    return interwoven

def ycr_decode(interweaved_arr:np.ndarray):
    # Extract each channel
    Y = interweaved_arr[0::3] * 255
    CR = interweaved_arr[1::3] * 255
    CB = interweaved_arr[2::3] * 255

    # Stack channels back into RGB
    YcRcB_matrix = np.stack((Y, CR, CB), axis=-1).astype(np.uint8)  # Shape: (m, n, 3)
    rgb_matrix = cv2.cvtColor(YcRcB_matrix, cv2.COLOR_YCrCb2RGB)
    return rgb_matrix




rgb_interweave = (interweave_encode, interweave_decode) #objectively worse than HSV: worse image quality and longer runtime (8s), but significantly improved in quality and speed by dividing by 255 in encoding and multipling back 255 in decoding to all RGB values
hsv_interweave = (hsv_encode, hsv_decode) # best results, 2.7s
ycr_interweave = (ycr_encode, ycr_decode) # white artifacts occasinaolly, 2.1s

rgb_coder = hsv_interweave

In [None]:
img_path = "shanghai.png"
out_path = "output"
image_array = np.array(Image.open(img_path))
image_color_encoded = rgb_coder[0](image_array)
image_grayscale = np.dot(image_array[..., :3], [0.2989, 0.5870, 0.1140])
# plt.imshow(image_grayscale, cmap='gray')
plt.imshow(image_array)

In [114]:
U, S, V = find_SVD(image_grayscale) # m^2n or mn^2 
U_encoded, S_encoded, V_encoded = find_SVD(image_color_encoded) # why does RGB interweae take longer than HSV interweave

m = image_array.shape[0]
n = image_array.shape[1]
m_encoded = image_color_encoded.shape[0]
n_encoded = image_color_encoded.shape[1]
# reduced_image = recombine_SVD(U, S, V, 50)
# plt.imshow(reduced_image, cmap="gray")

In [None]:
# Create a slider
slider = widgets.IntSlider(
    value=50,          # Initial value
    min=1,             # Minimum value
    max=min(m,n),           # Maximum value
    step=1,            # Step size
    description='Value:',  # Label for the slider
    continuous_update=False, # Update the value continuously as the slider moves
    layout=widgets.Layout(width='60%')
)

checkbox = widgets.Checkbox(
    value=False,  # Default value (unchecked)
    description='Colored',  # Label for the checkbox
    disabled=False  # If True, the checkbox will be grayed out
)

save_button = widgets.Button(
    description='Save'    
)

def show_reduced_gray_image(k):
    image = recombine_SVD(U, S, V, k)
    plt.imshow(image, cmap='gray')
    plt.axis('off')  # Turn off the axes
    plt.show()

def show_reduced_color_image(k):
    image_encoded_reduced = recombine_SVD(U_encoded, S_encoded, V_encoded, k)
    image_reduced = rgb_coder[1](image_encoded_reduced)
    plt.imshow(image_reduced)
    plt.axis('off')  # Turn off the axes
    plt.show()

def show_reduced_image(k, isColor):
    if isColor:
        show_reduced_color_image(k)
        print(f"matrices U:{m_encoded}x{k}, S:{k}, V:{n_encoded}x{k}")
        print(f"matrices size {m_encoded*k + n_encoded*k + k} numbers")
        print(f"original size {m_encoded*n_encoded} numbers")
        print(f"{((m*k + n*k + k) / (m*n) * 100):.2f}% of original size")
    else:
        show_reduced_gray_image(k)
        # this is wonrg
        print(f"matrices U:{m}x{k}, S:{k}, V:{n}x{k}")
        print(f"matrices size {m*k + n*k + k} numbers")
        print(f"original size {m*n} numbers")
        print(f"{((m*k + n*k + k) / (m*n) * 100):.2f}% of original size")

# Display the slider
interactive_plot = widgets.interactive(show_reduced_image, k=slider, isColor=checkbox)
display(interactive_plot)
display(save_button)