In [1]:
import numpy as np
import cv2
import tensorflow as tf
import os
from tensorflow import keras



In [2]:
# Function to convert color bins to image
def convert_color_bins_to_image(pred_binary_map):
    bin_width = 256 / len(pred_binary_map)
    colorized_img = np.zeros((256, 256, 3), dtype=np.float32)
    
    for i in range(len(pred_binary_map)):
        mask = pred_binary_map[..., i] == 1
        bin_min = -128 + i * bin_width
        bin_max = bin_min + bin_width
        
        colorized_img[..., 0][mask] = 0  # Set L channel to 0
        colorized_img[..., 1][mask] = (bin_min + bin_max) / 2
        colorized_img[..., 2][mask] = bin_min + bin_width / 2
    
    # Convert Lab image to RGB
    colorized_img = cv2.cvtColor(colorized_img.astype(np.float32), cv2.COLOR_LAB2RGB)
    
    return colorized_img


# Load and preprocess the input data
def prepareInputData(path, h, w):
    X = []
    y = []
    for imageDir in os.listdir(path):
        try:
            img = cv2.imread(path + imageDir)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

            img = img.astype(np.float32)
            img_lab = cv2.cvtColor(img, cv2.COLOR_RGB2Lab)
            img_lab_rs = cv2.resize(img_lab, (w, h))
            img_l = img_lab_rs[:, :, 0]
            img_ab = img_lab_rs[:, :, 1:]
            img_ab = img_ab / 128.0

            X.append(img_l)
            y.append(img_ab)
        except:
            pass

    X = np.array(X)
    y = np.array(y)
    
    return X, y


# Build the model architecture
def buildModel(input_shape):
    model = keras.Sequential([
        keras.layers.Input(shape=input_shape),
        keras.layers.Reshape((input_shape[0], input_shape[1], 1)),
        keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same'),
        keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        keras.layers.Conv2D(2, (1, 1), activation='tanh', padding='same')
    ])
    return model

In [3]:
# Set up data and model
image_path = "photos/train/train_data/"
testPath = "photos/test/test_data/"
image_height = 256
image_width = 256
num_color_bins = 15

# Prepare input data
X_train, y_train = prepareInputData(image_path, image_height, image_width)
X_test, y_test = prepareInputData(testPath, image_height, image_width)

In [6]:
np.unique(X_train)

array([  0., 100.], dtype=float32)

In [4]:
# Create the model
input_shape = (image_height, image_width, 1)
model = buildModel(input_shape)

# Compile and train the model
model.compile(optimizer='adam', loss='mse')
model.fit(X_train, y_train, epochs=10, steps_per_epoch=5, validation_data=(X_test, y_test), validation_steps=10)


Epoch 1/10
Epoch 2/10
Epoch 3/10

KeyboardInterrupt: 

In [None]:
def convert_color_bins_to_image(pred_binary_map):
    bin_width = 256 / pred_binary_map.shape[-1]
    colorized_img = np.zeros((256, 256, 3), dtype=np.float32)

    # Reshape pred_binary_map to match colorized_img dimensions
    pred_binary_map = np.reshape(pred_binary_map, (256, 256, -1))

    for i in range(pred_binary_map.shape[-1]):
        mask = pred_binary_map[..., i] == 1
        bin_min = -128 + i * bin_width
        bin_max = bin_min + bin_width

        colorized_img[..., 0][mask] = 0  # Set L channel to 0
        colorized_img[..., 1][mask] = (bin_min + bin_max) / 2
        colorized_img[..., 2][mask] = bin_min + bin_width / 2

    # Convert Lab image to RGB
    colorized_img = cv2.cvtColor(colorized_img.astype(np.float32), cv2.COLOR_LAB2RGB)

    return colorized_img



In [None]:
# Generate colorized images using the trained model
sample_image = X_train[0]  # Select a sample image for demonstration
sample_image = np.expand_dims(sample_image, axis=0)
pred_binary_map = model.predict(sample_image)
colorized_img = convert_color_bins_to_image(pred_binary_map)

# Display the colorized image
cv2.imshow("Colorized Image", colorized_img)



In [None]:
# Display the colorized image
cv2.imshow("Colorized Image", colorized_img)
#cv2.waitKey(0)
#cv2.destroyAllWindows()
