In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from skimage.color import rgb2lab, lab2rgb
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
import numpy as np
import matplotlib.pyplot as plt

# Preprocessing function
def preprocess_data(images):
    lab_images = rgb2lab(images)  # Convert RGB to LAB
    l_channel = lab_images[..., 0:1] / 100.0  # Normalize L channel
    ab_channels = lab_images[..., 1:] / 128.0  # Normalize AB channels
    return l_channel, ab_channels

# U-Net Generator
def build_generator(input_shape):
    inputs = tf.keras.Input(shape=input_shape)

    # Encoder
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    p3 = layers.MaxPooling2D((2, 2))(c3)

    # Bottleneck
    b = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)

    # Decoder
    u1 = layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same')(b)
    u1 = layers.concatenate([u1, c3], axis=-1)

    u2 = layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(u1)
    u2 = layers.concatenate([u2, c2], axis=-1)

    u3 = layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(u2)
    u3 = layers.concatenate([u3, c1], axis=-1)

    outputs = layers.Conv2D(2, (1, 1), activation='tanh', padding='same')(u3)

    return models.Model(inputs, outputs)

# Custom Metrics Callback for PSNR and SSIM
class CustomMetrics(tf.keras.callbacks.Callback):
    def __init__(self, l_val, ab_val):
        super().__init__()
        self.l_val = l_val
        self.ab_val = ab_val

    def on_epoch_end(self, epoch, logs=None):
        avg_psnr, avg_ssim = evaluate_model(self.model, self.l_val, self.ab_val)
        print(f"  PSNR: {avg_psnr:.2f}, SSIM: {avg_ssim:.4f}")

# Evaluation Function for PSNR and SSIM
def evaluate_model(generator, l_test, ab_test):
    predictions = generator.predict(l_test)
    psnr_scores, ssim_scores = [], []

    for i in range(len(predictions)):
        pred_ab = predictions[i] * 128.0
        true_ab = ab_test[i] * 128.0
        l_channel = l_test[i] * 100.0

        pred_lab = np.concatenate((l_channel, pred_ab), axis=-1)
        true_lab = np.concatenate((l_channel, true_ab), axis=-1)

        pred_rgb = lab2rgb(pred_lab)
        true_rgb = lab2rgb(true_lab)

        psnr_scores.append(psnr(true_rgb, pred_rgb))
        ssim_scores.append(ssim(true_rgb, pred_rgb, multichannel=True, win_size=3, data_range=1))  # Specify data_range

    avg_psnr = np.mean(psnr_scores)
    avg_ssim = np.mean(ssim_scores)
    return avg_psnr, avg_ssim



# Training Function
def train_model(generator, l_train, ab_train, l_val, ab_val, epochs=10, batch_size=32):
    # Prepare TensorFlow datasets
    train_dataset = tf.data.Dataset.from_tensor_slices((l_train, ab_train)).batch(batch_size)
    val_dataset = tf.data.Dataset.from_tensor_slices((l_val, ab_val)).batch(batch_size)

    # Compile the Generator
    generator.compile(
        optimizer=optimizers.Adam(1e-4),
        loss='mean_squared_error',
        metrics=['accuracy']
    )

    # Train the model with custom metrics callback
    history = generator.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=[CustomMetrics(l_val, ab_val)]
    )

    return history

# Visualization Function
def visualize_results(generator, l_test, ab_test, num_samples=5):
    predictions = generator.predict(l_test[:num_samples])

    for i in range(num_samples):
        pred_ab = predictions[i] * 128.0
        true_ab = ab_test[i] * 128.0
        l_channel = l_test[i] * 100.0

        pred_lab = np.concatenate((l_channel, pred_ab), axis=-1)
        true_lab = np.concatenate((l_channel, true_ab), axis=-1)

        pred_rgb = lab2rgb(pred_lab)
        true_rgb = lab2rgb(true_lab)

        plt.figure(figsize=(10, 5))
        plt.subplot(1, 3, 1)
        plt.title('Grayscale')
        plt.imshow(l_channel[..., 0], cmap='gray')

        plt.subplot(1, 3, 2)
        plt.title('True Color')
        plt.imshow(true_rgb)

        plt.subplot(1, 3, 3)
        plt.title('Predicted Color')
        plt.imshow(pred_rgb)
        plt.show()

# Plot Training History
def plot_training_history(history):
    plt.figure(figsize=(12, 6))

    # Loss
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Val Loss')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Val Accuracy')
    plt.title('Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()


# Main
if __name__ == "__main__":
    # Load Dataset
    (x_train, _), (x_test, _) = tf.keras.datasets.cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    l_train, ab_train = preprocess_data(x_train)
    l_test, ab_test = preprocess_data(x_test)

    # Build Generator
    input_shape = (32, 32, 1)  # For CIFAR-10
    generator = build_generator(input_shape)

    # Train Model
    history = train_model(generator, l_train, ab_train, l_test[:1000], ab_test[:1000], epochs=10, batch_size=32)

    # Plot Training History
    plot_training_history(history)

    # Visualize Results
    visualize_results(generator, l_test, ab_test, num_samples=5)
