# Satellite Image Segmentation â€“ Gradio Deployment

This notebook deploys the trained UNet segmentation model using Gradio for interactive inference.

In [None]:
import torch
import segmentation_models_pytorch as smp
import numpy as np
import cv2
import gradio as gr

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

In [None]:
NUM_CLASSES = 5
ENCODER = "resnet34"

def build_model():
    return smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=None,
        in_channels=3,
        classes=NUM_CLASSES,
    ).to(device)

In [None]:
model = build_model()
model.load_state_dict(torch.load("../models/best_model.pth", map_location=device))
model.eval()

In [None]:
def preprocess_image(img):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (256, 256))
    img = img.astype(np.float32) / 255.0
    x = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
    return img, x.to(device)

In [None]:
COLORS = np.array([
    [60, 16, 152],    # class 0
    [132, 41, 246],   # class 1
    [110, 193, 228],  # class 2
    [254, 221, 58],   # class 3
    [226, 169, 41],   # class 4
], dtype=np.uint8)

In [None]:
def decode_mask(mask):
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for cls in range(NUM_CLASSES):
        color_mask[mask == cls] = COLORS[cls]
    return color_mask

In [None]:
def predict(image):
    img, x = preprocess_image(image)
    with torch.no_grad():
        logits = model(x)
        pred = torch.argmax(logits, dim=1).squeeze().cpu().numpy()
    color_mask = decode_mask(pred)
    overlay = (0.6 * img * 255 + 0.4 * color_mask).astype(np.uint8)
    return overlay

In [None]:
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy", label="Input Satellite Image"),
    outputs=gr.Image(type="numpy", label="Segmentation Output"),
    title="Satellite Image Semantic Segmentation",
    description="Upload a satellite image to get pixel-wise land-use segmentation."
)

In [None]:
demo.launch()