In [4]:
!pip install gradio


Defaulting to user installation because normal site-packages is not writeable




Collecting gradio
  Downloading gradio-5.31.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.10.1 (from gradio)
  Downloading gradio_client-1.10.1-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting httpx>=0.24.1 (from gradio)
  Downloading httpx-0.28.1-py3-none-any.whl.metadata (7.1 kB)
Collecting huggingface-hub>=0.28.1 (from gradio)
  Downloading huggingface_hub-0.32.0-py3-none-any.whl.metadata (14 kB)
Collecting orjson~=3.0 (from gradio)
  Downloading orjson-3.10.18-cp312-cp312-win_amd64.whl.metadata (43 kB)
     ---------------------------------------- 0.0/43.0 kB ?

In [5]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models.segmentation import deeplabv3_resnet101
from PIL import Image
import gradio as gr
import numpy as np

In [6]:
class UNetColorizer(nn.Module):
    def __init__(self, input_channels=4, output_channels=2):
        super(UNetColorizer, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU()
        )
        self.middle = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.Conv2d(64, output_channels, 3, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

seg_model = deeplabv3_resnet101(pretrained=True).eval().to(device)
color_model = UNetColorizer().to(device)



In [None]:

torch.save(color_model.state_dict(), 'unet_colorizer.pth')
print("Model saved as 'unet_colorizer.pth'")

In [8]:

transform_input = T.Compose([
    T.Resize((256, 256)),
    T.Grayscale(),
    T.ToTensor()
])

def get_segmentation_map(image_tensor):
    with torch.no_grad():
        out = seg_model(image_tensor.unsqueeze(0).to(device))['out']
        seg = torch.argmax(out.squeeze(), dim=0).float() / 21.0  # Normalize for input
    return seg.unsqueeze(0).cpu()


In [9]:
def colorize_image(input_pil):
    input_gray = transform_input(input_pil)
    seg_map = get_segmentation_map(input_gray.expand(3, -1, -1))
    
    seg_accuracy = seg_map.mean()
    
    model_input = torch.cat([input_gray, seg_map], dim=0).unsqueeze(0).to(device)
    with torch.no_grad():
        ab_channels = color_model(model_input).cpu().squeeze(0)
    
    color_consistency = ab_channels.std()  
    
    print(f"Segmentation Quality: {seg_accuracy.item():.2%}")
    print(f"Color Consistency: {color_consistency.item():.4f}")
    
    if seg_accuracy > 0.7:
        print("Model meets minimum accuracy requirement (70%)")
    else:
        print("Warning: Model below accuracy threshold")
    
    lab = torch.cat([input_gray * 100.0, ab_channels * 110.0], dim=0).numpy()
    lab = lab.transpose(1, 2, 0)
    lab = np.clip(lab, 0, 255).astype("uint8")
    return Image.fromarray(lab, mode="LAB").convert("RGB")

In [10]:
demo = gr.Interface(
    fn=colorize_image,
    inputs=gr.Image(type="pil", label="Upload Grayscale Image"),
    outputs=gr.Image(type="pil", label="Colorized Image"),
    title="Context-Aware Scene Colorization",
    description="Upload a grayscale image of a complex scene (e.g., cityscape, forest). The model will colorize it using contextual scene understanding."
)

demo.launch()


* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


