In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import cv2
import gradio as gr
import pandas as pd

# Residual block for feature learning
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        out = self.relu(out)
        return out

# Complex U-Net with skip connections and residual blocks
class ComplexColorizationNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(4, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            ResidualBlock(64)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            ResidualBlock(128)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            ResidualBlock(256)
        )
        self.enc4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            ResidualBlock(512)
        )
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            ResidualBlock(512)
        )
        # Decoder
        self.up4 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.dec4 = nn.Sequential(
            nn.BatchNorm2d(512),  # 256 + 256 channels after concat
            nn.ReLU(inplace=True),
            ResidualBlock(512)
        )
        self.up3 = nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1)
        self.dec3 = nn.Sequential(
            nn.BatchNorm2d(256),  # 128 + 128
            nn.ReLU(inplace=True),
            ResidualBlock(256)
        )
        self.up2 = nn.ConvTranspose2d(256, 64, kernel_size=4, stride=2, padding=1)
        self.dec2 = nn.Sequential(
            nn.BatchNorm2d(128),  # 64 + 64
            nn.ReLU(inplace=True),
            ResidualBlock(128)
        )
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.out_conv = nn.Conv2d(32, 2, kernel_size=1)  # ab channels output

    def forward(self, x):
        e1 = self.enc1(x)      # (B,64,H,W)
        e2 = self.enc2(e1)     # (B,128,H/2,W/2)
        e3 = self.enc3(e2)     # (B,256,H/4,W/4)
        e4 = self.enc4(e3)     # (B,512,H/8,W/8)

        b = self.bottleneck(e4)  # (B,512,H/8,W/8)

        d4 = self.up4(b)          # (B,256,H/4,W/4)
        d4 = torch.cat([d4, e3], dim=1)  # (B,512,H/4,W/4)
        d4 = self.dec4(d4)        # (B,512,H/4,W/4)

        d3 = self.up3(d4)         # (B,128,H/2,W/2)
        d3 = torch.cat([d3, e2], dim=1)  # (B,256,H/2,W/2)
        d3 = self.dec3(d3)        # (B,256,H/2,W/2)

        d2 = self.up2(d3)         # (B,64,H,W)
        d2 = torch.cat([d2, e1], dim=1)  # (B,128,H,W)
        d2 = self.dec2(d2)        # (B,128,H,W)

        d1 = self.dec1(d2)        # (B,32,H,W)
        out = torch.tanh(self.out_conv(d1))  # (B,2,H,W)
        return out

model = ComplexColorizationNet()
model.eval()

def preprocess_image(image):
    image = image.convert('RGB').resize((256, 256))
    gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
    gray_norm = gray.astype(np.float32) / 255.0
    return gray_norm, image

def prepare_input_tensor(gray, hint_points):
    h, w = gray.shape
    L = torch.from_numpy(gray).unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    ab_hint = torch.zeros(1, 2, h, w)
    mask = torch.zeros(1, 1, h, w)
    for (x, y, ab) in hint_points:
        x = min(max(0, x), w-1)
        y = min(max(0, y), h-1)
        ab_hint[0, :, y, x] = torch.tensor(ab)
        mask[0, 0, y, x] = 1
    input_tensor = torch.cat([L, ab_hint, mask], dim=1)  # (1,4,H,W)
    return input_tensor

def postprocess_output(L, ab):
    ab_np = ab.detach().cpu().numpy().squeeze()
    h, w = L.shape
    ab_resized = np.zeros((2, h, w), dtype=np.float32)
    for i in range(2):
        ab_resized[i] = cv2.resize(ab_np[i], (w, h))
    L_np = (L * 100).astype(np.float32)
    lab = np.zeros((h, w, 3), dtype=np.float32)
    lab[:, :, 0] = L_np
    lab[:, :, 1:] = ab_resized.transpose(1, 2, 0) * 110
    lab = lab.astype(np.float32)
    rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
    rgb = np.clip(rgb, 0, 1)
    rgb_img = (rgb * 255).astype(np.uint8)
    return Image.fromarray(rgb_img)

user_hints = []

def clear_hints():
    global user_hints
    user_hints = []

def add_user_hint(x, y, r, g, b):
    global user_hints
    rgb = np.array([[[r / 255.0, g / 255.0, b / 255.0]]], dtype=np.float32)
    lab = cv2.cvtColor(rgb, cv2.COLOR_RGB2LAB)[0, 0]
    a = (lab[1] - 128) / 128
    b_ = (lab[2] - 128) / 128
    user_hints.append((int(x), int(y), (a, b_)))

def model_colorize(image, hint_points_df):
    gray_norm, _ = preprocess_image(image)
    clear_hints()
    if hint_points_df is not None and not hint_points_df.empty:
        for row in hint_points_df.values.tolist():
            if len(row) != 3:
                continue
            x, y, color_hex = row
            try:
                x, y = int(x), int(y)
                c = color_hex.strip()
                if not c.startswith("#"):
                    c = "#" + c
                r, g, b = tuple(int(c.lstrip('#')[i:i + 2], 16) for i in (0, 2, 4))
                add_user_hint(x, y, r, g, b)
            except Exception:
                continue
    input_tensor = prepare_input_tensor(gray_norm, user_hints)
    with torch.no_grad():
        output_ab = model(input_tensor)
    return postprocess_output(gray_norm, output_ab)

def get_image_metrics(image):
    gray = image.convert("L")
    np_gray = np.array(gray)
    return {
        "Resolution": f"{image.width}x{image.height}",
        "Unique Gray Levels": np.unique(np_gray).size,
        "Hint Points Count": len(user_hints)
    }

def add_hint_row(hint_points):
    if hint_points is None:
        hint_points = pd.DataFrame(columns=["x", "y", "color"])
    new_row = pd.DataFrame([[0, 0, "#ff0000"]], columns=hint_points.columns)
    return pd.concat([hint_points, new_row], ignore_index=True)

def reset_hints():
    return pd.DataFrame(columns=["x", "y", "color"])

with gr.Blocks() as demo:
    gr.Markdown("## Complex Interactive User-Guided Image Colorization")

    with gr.Row():
        img_in = gr.Image(label="Upload grayscale image", type="pil")
        color_out = gr.Image(label="Colorized Output")

    with gr.Row():
        point_input = gr.Dataframe(
            headers=["x", "y", "color"],
            datatype=["number", "number", "str"],
            label="Hint Locations and Colors",
            max_height=250,
            interactive=True
        )
        add_hint_btn = gr.Button("Add Hint")
        reset_btn = gr.Button("Reset Hints")

    add_hint_btn.click(add_hint_row, inputs=[point_input], outputs=[point_input])
    reset_btn.click(reset_hints, outputs=[point_input])

    run_btn = gr.Button("Colorize")
    resolution_lbl = gr.Label(value="Resolution: N/A")
    unique_gray_lbl = gr.Label(value="Unique Gray Levels: N/A")
    hint_count_lbl = gr.Label(value="Hint Points Count: 0")

    def run_colorization(image, points_df):
        out_img = model_colorize(image, points_df)
        metrics = get_image_metrics(image)
        return out_img, metrics["Resolution"], metrics["Unique Gray Levels"], metrics["Hint Points Count"]

    run_btn.click(
        fn=run_colorization,
        inputs=[img_in, point_input],
        outputs=[color_out, resolution_lbl, unique_gray_lbl, hint_count_lbl]
    )

demo.launch(share=True)


* Running on local URL:  http://127.0.0.1:7862
* Running on public URL: https://b3642216aa95a9bba7.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


