# INPAINTING & SUPER-RESOLUTION

In this project, I address two related problems in satellite imagery: inpainting (filling missing or corrupted regions) and single-image super-resolution (recovering fine details from low-resolution inputs). I first explore U-Net-based inpainting with adversarial training, which is a common and effective approach for structured image completion, before switching to diffusion-based inpainting when large-scale semantic consistency becomes the bottleneck. For resolution enhancement, I use a U-Net generator trained on bicubic upsampling residuals and a PatchGAN discriminator to encourage sharper textures, following established SRGAN-style methods.

For input only a folder with satellite images is required named train_images. For google Colab GPU support, the folder can be added to drive, but this isn't required. Furthermore, a unet_inpaint_residual_ckpt135.pt could be downloaded for the inpainting model that starts at epoch 135, but this could be restarted from epoch 0.

Key-sources:

https://www.mdpi.com/2073-8994/18/1/94

https://arxiv.org/pdf/1611.07004

https://arxiv.org/pdf/1609.04802

In [None]:
!pip install --quiet imagecodecs # Needed for reading TIFF files

In [None]:
import os
import glob
import tifffile
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from diffusers import StableDiffusionInpaintPipeline
from torch.utils.data import DataLoader, Dataset
from google.colab import drive, files
import shutil
from PIL import Image

In [None]:
""" Helper functions for converting the (H, W, 12) TIFF images to (H, W, 3) RGB """

def normalize_band(band):
    """Normalize a band to 0-1 range for display.

    Uses 2-98 percentile stretch.
    NaN pixels are replaced with 0.
    """

    band = np.nan_to_num(band, nan=0.0)
    vmin, vmax = np.percentile(band[band > 0], [2, 98]) if (band > 0).any() else (0, 1)
    if vmax == vmin:
        return np.zeros_like(band)
    return np.clip((band - vmin) / (vmax - vmin), 0, 1)

def prepare_rgb(img_tif):
    """
    img_tif: (H, W, 12) float
    returns: (H, W, 3) uint8
    """
    rgb = np.stack([
        normalize_band(img_tif[:, :, 3]),  # R
        normalize_band(img_tif[:, :, 2]),  # G
        normalize_band(img_tif[:, :, 1]),  # B
    ], axis=-1)
    return (rgb * 255).astype(np.uint8)

def load_rgb_uint8(path):
    """Load RGB image as uint8 (H, W, 3)."""
    return np.array(Image.open(path).convert("RGB"), dtype=np.uint8)

def rgb_to_tensor(rgb_uint8):

    """Convert uint8 RGB to float tensor (3, H, W) in [0,1]."""
    x = torch.from_numpy(rgb_uint8).float() / 255.0
    return x.permute(2, 0, 1)

def create_inpaint_mask(rgb_uint8, threshold=1):
    """
    Hole mask: (H, W) between {0,255} where 255 means hole.
    """
    hole = (
        (rgb_uint8[:, :, 0] <= threshold) &
        (rgb_uint8[:, :, 1] <= threshold) &
        (rgb_uint8[:, :, 2] <= threshold)
    )
    return hole.astype(np.uint8) * 255

def valid_pixel_mask_uint8(rgb_uint8, threshold=1):
    """
    Boolean where True means 'known pixel'.
    """
    return (
        (rgb_uint8[:, :, 0] > threshold) |
        (rgb_uint8[:, :, 1] > threshold) |
        (rgb_uint8[:, :, 2] > threshold)
    )



In [None]:
drive.mount('/content/drive') # Drive support for Colab GPU, optional
input_folder = "/content/drive/MyDrive/train_images/train_images"
output_folder = "/content/rgb_images"
os.makedirs(output_folder, exist_ok=True)

In [None]:
"""Loops through to all images in input folder and converts them to (H, W, 3) png values"""
tif_files = sorted(glob.glob(os.path.join(input_folder, "train_*.tif")))
for file_path in tif_files:
    img_tif = tifffile.imread(file_path)
    rgb = prepare_rgb(img_tif)

    # Save as PNG
    filename = os.path.basename(file_path).replace(".tif", ".png")
    plt.imsave(os.path.join(output_folder, filename), rgb)

print(f"All {len(tif_files)} images converted and saved to {output_folder}!")

Some of the png images consist of black holes, therefore we have to inpaint these images first. But first we save all the rgb images with only 3 channels to a local file, such that we omit hyperspectral data. From now on, we only look at the rgb images.

In [None]:
shutil.make_archive("/content/rgb_images", 'zip', "/content/rgb_images") # Creates a zip file of the output folder rgb_images to save locally

files.download("/content/rgb_images.zip") # Downloads the local folder from drive


In [None]:
input_folder = "/content/drive/MyDrive/rgb_images" # New input folder

img20_path = glob.glob(os.path.join(input_folder, "train_20.png"))[0]
img25_path = glob.glob(os.path.join(input_folder, "train_25.png"))[0]

img20 = load_rgb_uint8(img20_path)
img25 = load_rgb_uint8(img25_path)

fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(img20); axes[0].set_title("train_20 (RGB input)"); axes[0].axis("off")
axes[1].imshow(img25); axes[1].set_title("train_25 (RGB input)"); axes[1].axis("off")
plt.tight_layout()
plt.show()

It shows that some images like image 20 show black invalid regions.
Before super-resolution, we must inpaint these holes or SR will hallucinate these edges.

# Inpainting pipeline
To train the inpainting model in a supervised way, we synthetically generate hole-like mask structures that mimic the missing regions observed in the satellite images. These masks are constructed as single connected blobs confined to valid pixels, allowing the U-Net to learn how to reconstruct realistic content from surrounding context rather than isolated pixel noise. Many modern inpainting methods explicitly handle irregular holes so models learn to fill in large, non-rectangular missing regions based on inference from neighbourhood pixels
https://www.mdpi.com/2073-8994/18/1/94

In [None]:
def conv_plus_conv(in_channels: int, out_channels: int):
    """
     Makes UNet block
    :param in_channels: input channels
    :param out_channels: output channels
    :return: UNet block
    :source: https://www.kaggle.com/code/evgenia12/unet-ipynb
    """
    return nn.Sequential(
          nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
          nn.InstanceNorm2d(out_channels, affine=True),
          nn.LeakyReLU(0.2, inplace=True),
          nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
          nn.InstanceNorm2d(out_channels, affine=True),
          nn.LeakyReLU(0.2, inplace=True),
      )

def bilinear_upsample_concat_conv(x, skip, conv_block):

 # Bilinear upsample x to match skip's (H,W) to concatenate
        #x:         (N, Cx, Hx, Wx)   decoder feature map
        #skip:      (N, Cs, Hs, Ws)   encoder skip feature map
        #conv_block: input channels Cx+Cs

    x_up = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
    x_cat = torch.cat([x_up, skip], dim=1)
    return conv_block(x_cat)

class U_net_generator(nn.Module):
  def __init__(self, in_channels = 4):
    # Input channels is 4, 3 RGB channels + corrupted mask.
    super().__init__()
    # Encoder
    self.down1 = conv_plus_conv(in_channels, 32)
    self.down2 = conv_plus_conv(32, 64)
    self.down3 = conv_plus_conv(64, 128)
    self.down4 = conv_plus_conv(128, 256)
    # Bottleneck
    self.bottleneck = conv_plus_conv(256, 512)
    # Decoder
    self.up4 = conv_plus_conv(512 + 256, 256)
    self.up3 = conv_plus_conv(256 + 128, 128)
    self.up2 = conv_plus_conv(128 + 64, 64)
    self.up1 = conv_plus_conv(64 + 32, 32)
    # Output prediction is 3 RGB channels with removed mask
    self.out = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1)
    #Reduce feature map size between up/downsampling
    self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)

  def forward(self, x):
      x1 = self.down1(x)
      p1 = self.downsample(x1)
      x2 = self.down2(p1)
      p2 = self.downsample(x2)
      x3 = self.down3(p2)
      p3 = self.downsample(x3)
      x4 = self.down4(p3)
      p4 = self.downsample(x4)
      b = self.bottleneck(p4)

      u4 = bilinear_upsample_concat_conv(b,  x4, self.up4)
      u3 = bilinear_upsample_concat_conv(u4, x3, self.up3)
      u2 = bilinear_upsample_concat_conv(u3, x2, self.up2)
      u1 = bilinear_upsample_concat_conv(u2, x1, self.up1)
      out = self.out(u1)
      return out

We use a PatchGAN discriminator, which has been shown to work effectively together with U-Net generators. Instead of producing a single real/fake score, the discriminator outputs patched logits, encouraging the generator to match local texture statistics rather than only global structure. This is well suited for satellite inpainting, where corrupted holes are best represented by "patches" of their neighbourhood.
https://arxiv.org/pdf/1611.07004


In [None]:
class PatchDiscriminator(nn.Module):
    """
    PatchGAN discriminator (pix2pix-style).
    Operates on local patches instead of a single global score.
    """
    def __init__(self, in_channels=3, base=64):
        super().__init__()

        # Basic conv block, same naming style as U-Net
        def conv_block(in_channels: int, out_channels: int, stride: int):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1),
                nn.InstanceNorm2d(out_channels, affine=True),
                nn.LeakyReLU(0.2, inplace=True),
            )

        # First layer: no normalization (standard for GANs)
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, base, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.layer2 = conv_block(base, base * 2, stride=2)
        self.layer3 = conv_block(base * 2, base * 4, stride=2)
        self.layer4 = conv_block(base * 4, base * 8, stride=1)

        # Output patch logits
        self.out = nn.Conv2d(base * 8, 1, kernel_size=4, stride=1, padding=1)

    def forward(self, x):
        # x: (B, 3, H, W) RGB image
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return self.out(x)


def d_hinge_loss(d_real, d_fake):
    # Discriminator hinge loss
    return (
        F.relu(1.0 - d_real).mean() +
        F.relu(1.0 + d_fake).mean()
    )


def g_hinge_loss(d_fake):
    # Generator hinge loss
    return (-d_fake).mean()


Since we train the U-Net in a supervised way, we need both an input image and a known target. To achieve this, we deliberately corrupt the original image by removing a connected region and ask the network to reconstruct it. The hole mask is generated using a simple random walk that stays inside valid pixels, producing realistic missing regions for the model to learn from.

In [None]:
def build_connected_hole_mask(
    valid_mask_1hw,
    blob_steps=8000,
    blob_radius=12,
    target_cov=(0.03, 0.08),
    max_tries=5,
):
    """
    Builds one connected hole mask inside valid pixels
    with controlled coverage.
    """

    _, H, W = valid_mask_1hw.shape
    device = valid_mask_1hw.device

    for _ in range(max_tries):
        blob_hw = connected_blob_mask(
            H, W,
            valid_mask_1hw,
            steps=blob_steps,
            radius=blob_radius,
            device=device
        )[0]

        # restrict to valid pixels that fall inside the mask
        blob_hw = blob_hw * valid_mask_1hw[0]
        cov = blob_hw.mean().item()

        if target_cov[0] <= cov <= target_cov[1]:
            return blob_hw.unsqueeze(0)

    # fallback: return last blob even if coverage is not between target
    return blob_hw.unsqueeze(0)

@torch.no_grad()
def connected_blob_mask(
    valid_mask_1hw: torch.Tensor,     # (1,H,W) float {0,1}
    target_cov=(0.03, 0.08),          # fraction of the whole image (3%–8%)
    grow_steps=256,                   # number of growth iterations (tensor ops, fast)
    p_add=0.35,                       # probability of adding each candidate pixel
    dilate_radius=6,                  # final thickness control (like your radius)
):
    """
    Builds ONE connected blob inside valid pixels.
    No Python per-step random walk. Uses vectorized growth (fast).
    Returns: (1,H,W) float {0,1}
    """
    vm = valid_mask_1hw[0]  # (H,W)
    H, W = vm.shape
    device = vm.device

    # choose target coverage inside band
    tgt = float(torch.empty(1, device=device).uniform_(target_cov[0], target_cov[1]).item())
    desired = tgt * H * W

    ys, xs = torch.where(vm > 0.5)
    if ys.numel() == 0:
        raise RuntimeError("No valid pixels to start blob")

    # seed at one valid pixel
    i = torch.randint(0, ys.numel(), (1,), device=device).item()
    y0, x0 = ys[i].item(), xs[i].item()

    blob = torch.zeros((H, W), device=device, dtype=torch.float32)
    blob[y0, x0] = 1.0 # initialized at 1.0

    # growth loop (random walk)
    for _ in range(grow_steps):
        # discrete diffusion
        neigh = F.max_pool2d(blob[None,None], kernel_size=3, stride=1, padding=1)[0,0]
        frontier = (neigh - blob).clamp(0, 1)          # neighbors not already in blob
        cand = frontier * vm                            # keep only valid pixels

        if cand.sum().item() == 0:
            break

        # stochastic addition (adding noise)
        add = (torch.rand((H, W), device=device) < p_add).float() * cand
        blob = (blob + add).clamp(0, 1)

        if blob.sum().item() >= desired:
            break

    # final thickness
    if dilate_radius > 0:
        k = 2 * dilate_radius + 1
        # add dilation
        blob = F.max_pool2d(blob[None,None], kernel_size=k, stride=1, padding=dilate_radius)[0,0]
        blob = (blob > 0).float()

    return (blob[None] * valid_mask_1hw).clamp(0, 1)   # (1,H,W)


To train the U-Net efficiently, we cache the dataset in memory instead of loading images from disk every iteration. An earlier uncached version worked the same way conceptually but made each epoch significantly about 20 seconds slower due to having to load it into memory. This class prepares supervised training pairs by returning the original image together with a valid-pixel mask, while corruption is applied later during training.

In [None]:
class SatelliteInpaintDataCached:
    def __init__(self, image_paths, threshold=1, min_valid_frac=0.5):
        self.paths = image_paths
        self.threshold = threshold
        self.min_valid_frac = min_valid_frac

        self.rgb_u8 = [load_rgb_uint8(p) for p in self.paths]

        self.valid_masks = []
        for rgb in self.rgb_u8:
            vb = valid_pixel_mask_uint8(rgb, threshold=self.threshold)
            vm = torch.from_numpy(vb.astype(np.float32)).unsqueeze(0)  # (1,H,W) CPU
            self.valid_masks.append(vm)


    def corrupt_with_mask(self, x_chw, mask_1hw, mode="zero"):
        # Apply corruption only inside the hole mask
        if mode == "zero":
            fill = torch.zeros_like(x_chw)
        elif mode == "blur":
          # Apply corruption only inside the hole mask
          # This was more effective in early epoch stages, since the model had no ground truth how to inpaint on black holes yet.
            x = x_chw.unsqueeze(0)
            fill = F.avg_pool2d(x, kernel_size=31, stride=1, padding=15)[0]
        else:
            raise ValueError(mode)
        # Keep original pixels outside the mask, replace inside
        return x_chw * (1.0 - mask_1hw) + fill * mask_1hw

    def prepare_sample(self, idx):
        # Loads cached image and its real mask (corruption happens later)
        rgb = self.rgb_u8[idx]
        valid_mask = self.valid_masks[idx]  # CPU
        # Prepare another random sample if image too corrupted already naturally
        if valid_mask.mean().item() < self.min_valid_frac:
            new_idx = torch.randint(0, len(self.rgb_u8), (1,)).item()
            return self.prepare_sample(new_idx)

        x_gt = rgb_to_tensor(rgb)
        return x_gt, valid_mask


In [None]:
class DataAdapter(torch.utils.data.Dataset):
    def __init__(self, data_obj):
        self.data = data_obj
    def __len__(self):
        return len(self.data.rgb_u8)
    def __getitem__(self, idx):
        return self.data.prepare_sample(idx)

# Loss functions
Early experiments with simpler losses and single-image overfitting were discarded, as they failed to capture irregular hole geometry and edge structure.

For the loss function of our inpainting model, we use the following components:

**Masked L1 loss**: ensures the model is only penalized inside the missing regions, so learning focuses on reconstructing the hole rather than altering known pixels.

**Gradient (x/y) loss**: encourages edge and structure preservation, since pixel-wise losses alone tend to blur sharp features.

**Inner boundary loss**: penalizes reconstructed pixels near the hole boundary if they do not resemble their immediate known neighbors.

**Residual total loss:** a weighted sum of the above terms, allowing to effectively choose lambda values of which above loss functions should take higher weight.


In [None]:
def masked_l1_sum(pred, target, mask, eps=1e-8):
    # L1 loss applied only inside the hole region
    m = mask.repeat(1, pred.shape[1], 1, 1)  # (B,3,H,W)
    return ((pred - target).abs() * m).sum() / (m.sum() + eps)

def grad_xy_loss_sum(pred, target, mask, eps=1e-8):
    # Gradient loss in x and y directions
    pdx = pred[:, :, :, 1:] - pred[:, :, :, :-1]
    pdy = pred[:, :, 1:, :] - pred[:, :, :-1, :]
    tdx = target[:, :, :, 1:] - target[:, :, :, :-1]
    tdy = target[:, :, 1:, :] - target[:, :, :-1, :]

    mx = mask[:, :, :, 1:]
    my = mask[:, :, 1:, :]

    # normalize by masked pixel-count (per-channel)
    C = pred.shape[1]
    dx = ((pdx - tdx).abs() * mx).sum() / (mx.sum() * C + eps)
    dy = ((pdy - tdy).abs() * my).sum() / (my.sum() * C + eps)
    return dx + dy

def inner_boundary(mask, radius=3):
    # Extracts a thin ring inside the hole boundary
    k = 2 * radius + 1
    eroded = 1.0 - F.max_pool2d(1.0 - mask, kernel_size=k, stride=1, padding=radius)
    return (mask - eroded).clamp(0, 1)

def residual_inpaint_loss(filled, x_gt, mask, w_l1=1.0, w_grad=0.4, w_ring=0.9, ring_radius=3):
    l1 = masked_l1_sum(filled, x_gt, mask)
    g  = grad_xy_loss_sum(filled, x_gt, mask)
    ring = inner_boundary(mask, radius=ring_radius)
    g_ring = grad_xy_loss_sum(filled, x_gt, ring)
    # Final weighted objective
    total = w_l1*l1 + w_grad*g + w_ring*g_ring
    return total, l1, g, g_ring



In [None]:
def build_batch_masks(valid_mask_b1hw):
    """One connected blob per image (stable ~3–7% coverage)."""
    B = valid_mask_b1hw.shape[0]
    masks = []
    for b in range(B):
        m = connected_blob_mask(
            valid_mask_b1hw[b],
            target_cov=(0.03, 0.07),
            grow_steps=256,
            p_add=0.35,
            dilate_radius=9
        )  # (1,H,W)
        masks.append(m)
    return torch.stack(masks, dim=0)

In [None]:
images = sorted(glob.glob("/content/drive/MyDrive/rgb_images/train_*.png"))  # all 175 images loaded
print(len(images))

data = SatelliteInpaintDataCached(
    images, threshold=1, min_valid_frac=0.5
)
ds = DataAdapter(data)

dl = torch.utils.data.DataLoader(
    ds,
    batch_size=4,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True
) # For GPU support
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [None]:
# get one sample
x_gt, valid_mask = data.prepare_sample(0)   # (3,H,W), (1,H,W)

# move to GPU
x_gt = x_gt.to(device)
valid_mask = valid_mask.to(device)

# build one connected hole mask
with torch.no_grad():
    train_mask = build_batch_masks(valid_mask.unsqueeze(0))[0]  # (1,H,W)

# corrupt using the same blur fill as training
fill = F.avg_pool2d(
    x_gt.unsqueeze(0),
    kernel_size=31,
    stride=1,
    padding=15
)[0]

# corrupts the image. Pixels not affected by the hole keep their value, while inside the hole pixels are replaced by the corrupted value
x_corrupt = x_gt * (1.0 - train_mask) + fill * train_mask

gt  = (x_gt.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
msk = (train_mask[0].cpu().numpy() * 255).astype(np.uint8)
cor = (x_corrupt.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)

plt.figure(figsize=(18,6))
plt.subplot(1,3,1); plt.imshow(gt);  plt.title("GT"); plt.axis("off")
plt.subplot(1,3,2); plt.imshow(msk, cmap="gray"); plt.title(f"Mask cov={msk.mean()/255:.3f}"); plt.axis("off")
plt.subplot(1,3,3); plt.imshow(cor); plt.title("Corrupted (blur fill)"); plt.axis("off")
plt.show()


The example shows that the blurred corruption creates a realistic missing region that closely resembles the actual holes present in the data. With this verified, we proceed to train the GAN so the U-Net learns to reconstruct these regions from realistic corrupted inputs.

Training is resumed from a fixed checkpoint (unet_inpaint_residual_ckpt.pt). Since GPU availability during development was limited, the model was saved intermediately and is now simply reloaded to continue training for a fixed number of additional epochs. For the final run, we always start from this checkpoint and extend training rather than retraining from scratch.

In [None]:
# Always resume from the latest inpainting checkpoint
ckpt_path = "/content/drive/MyDrive/unet_inpaint_residual_ckpt135.pt"
ckpt = torch.load(ckpt_path, map_location=device)

# Generator (weights from old run)
net = U_net_generator(in_channels=4).to(device)
net.load_state_dict(ckpt["model_state"])
optG = torch.optim.Adam(net.parameters(), lr=1e-3)

# Discriminator
D = PatchDiscriminator(in_channels=3, base=64).to(device)
optD = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.0, 0.9))

# AMP
use_amp = (device.type == "cuda")
scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

# Training state + config
start_epoch = ckpt.get("epoch", -1) + 1
global_step = ckpt.get("global_step", 0)

cfg = ckpt.get("cfg", {})
w_l1 = cfg.get("w_l1", 1.0)
w_grad = cfg.get("w_grad", 0.4)
w_ring = cfg.get("w_ring", 0.9)
ring_radius = cfg.get("ring_radius", 3)
w_gan = cfg.get("w_gan", 0.05)

# Train a bit more
epochs_more = 12
epochs_total = start_epoch + epochs_more

print(f"Resuming G from epoch {start_epoch} (+{epochs_more} epochs). D is reinitialized.")


For every epoch:

1.   the data is prepared by creating a hole mask and corrupting the image with a blurred fill,
2. the discriminator is trained to tell real images from inpainted ones,
3. the generator is trained to reconstruct the missing regions, and
4. the model is saved and visualized to track progress.

Even though the original images contain black pixels, blur-based corruption is used during training so the model learns to use surrounding context instead of relying on pixel intensity alone.

In [None]:
for epoch in range(start_epoch, epochs_total):
    net.train()
    D.train()

    run_G = run_rec = run_l1 = run_g = run_ring = run_gan = run_D = 0.0
    n_batches = 0

    # slight LR drop after a few epochs for stability
    if epoch == start_epoch:
        for pg in optG.param_groups:
            pg["lr"] *= 0.5

    pbar = tqdm(dl, desc=f"Epoch {epoch+1}/{epochs_total}", leave=True)

    for x_gt, valid_mask in pbar:
        # move to GPU
        x_gt       = x_gt.to(device, non_blocking=True)        # (B,3,H,W)
        valid_mask = valid_mask.to(device, non_blocking=True)  # (B,1,H,W)

        # build mask on GPU
        mask = build_batch_masks(valid_mask)                   # (B,1,H,W)

        # blur-fill corruption
        fill = F.avg_pool2d(x_gt, kernel_size=31, stride=1, padding=15)
        x_corrupt = x_gt * (1.0 - mask) + fill * mask
        # generator input = corrupted image + mask
        x_in = torch.cat([x_corrupt, mask], dim=1)             # (B,4,H,W)


        # Discriminator
        optD.zero_grad(set_to_none=True)

        with torch.no_grad():
            with torch.amp.autocast("cuda", enabled=use_amp):
                delta  = net(x_in)
                filled = (x_corrupt + delta * mask).clamp(0, 1)

        with torch.amp.autocast("cuda", enabled=use_amp):
            d_real = D(x_gt)
            d_fake = D(filled.detach())
            d_loss = d_hinge_loss(d_real, d_fake)

        scaler.scale(d_loss).backward()
        scaler.step(optD)
        scaler.update()

        # Generator
        optG.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda", enabled=use_amp):
            delta  = net(x_in)
            filled = (x_corrupt + delta * mask).clamp(0, 1)
             # reconstruction + structure losses
            rec_total, l1, g, g_ring = residual_inpaint_loss(
                filled, x_gt, mask,
                w_l1=w_l1, w_grad=w_grad, w_ring=w_ring, ring_radius=ring_radius
            )
            # adversarial loss
            gan = g_hinge_loss(D(filled))
            total_g = rec_total + w_gan * gan

        scaler.scale(total_g).backward()
        scaler.step(optG)
        scaler.update()
        # outcomes
        run_G    += float(total_g.item())
        run_rec  += float(rec_total.item())
        run_l1   += float(l1.item())
        run_g    += float(g.item())
        run_ring += float(g_ring.item())
        run_gan  += float(gan.item())
        run_D    += float(d_loss.item())
        n_batches += 1
        global_step += 1

        pbar.set_postfix({
            "G":    f"{total_g.item():.3f}",
            "rec":  f"{rec_total.item():.3f}",
            "gan":  f"{gan.item():.3f}",
            "D":    f"{d_loss.item():.3f}",
            "cov":  f"{mask.mean().item():.3f}",
        })

    print(
        f"Epoch {epoch+1}: "
        f"G={run_G/n_batches:.4f} rec={run_rec/n_batches:.4f} "
        f"l1={run_l1/n_batches:.4f} g={run_g/n_batches:.4f} ring={run_ring/n_batches:.4f} "
        f"gan={run_gan/n_batches:.4f} D={run_D/n_batches:.4f}"
    )

    # # save checkpoint every epoch
    ckpt_path = "/content/drive/MyDrive/unet_inpaint_residual_ckpt147.pt"
    torch.save({
        "epoch": epoch,
        "global_step": global_step,
        "model_state": net.state_dict(),
        "optim_state": optG.state_dict(),
        "disc_state": D.state_dict(),
        "optD_state": optD.state_dict(),
        "scaler_state": scaler.state_dict() if use_amp else None,
        "cfg": {
            "in_channels": 4,
            "arch": "U_net_generator",
            "w_l1": w_l1, "w_grad": w_grad, "w_ring": w_ring, "ring_radius": ring_radius,
            "w_gan": w_gan,
            "threshold": getattr(data, "threshold", None),
        },
    }, ckpt_path)
    print("Saved:", ckpt_path)

    # quick visual every 3 epochs
    if (epoch + 1) % 3 == 0:
        net.eval()
        with torch.no_grad():
            x_gt_v, valid_mask_v = next(iter(dl))
            x_gt_v       = x_gt_v.to(device)
            valid_mask_v = valid_mask_v.to(device)

            mask_v = build_batch_masks(valid_mask_v)
            fill_v = F.avg_pool2d(x_gt_v, kernel_size=31, stride=1, padding=15)
            x_corrupt_v = x_gt_v * (1.0 - mask_v) + fill_v * mask_v

            delta_v = net(torch.cat([x_corrupt_v, mask_v], dim=1))
            filled_v = (x_corrupt_v + delta_v * mask_v).clamp(0, 1)

            gt_img   = (x_gt_v[0].permute(1,2,0).cpu().numpy()*255).astype("uint8")
            msk_img  = (mask_v[0,0].cpu().numpy()*255).astype("uint8")
            cor_img  = (x_corrupt_v[0].permute(1,2,0).cpu().numpy()*255).astype("uint8")
            fill_img = (filled_v[0].permute(1,2,0).cpu().numpy()*255).astype("uint8")

        plt.figure(figsize=(24,6))
        plt.subplot(1,4,1); plt.imshow(gt_img); plt.title("GT"); plt.axis("off")
        plt.subplot(1,4,2); plt.imshow(msk_img, cmap="gray"); plt.title("Mask"); plt.axis("off")
        plt.subplot(1,4,3); plt.imshow(cor_img); plt.title("Corrupted"); plt.axis("off")
        plt.subplot(1,4,4); plt.imshow(fill_img); plt.title(f"Filled @ epoch {epoch+1}"); plt.axis("off")
        plt.show()
        net.train()

During inference, the trained model is applied to an image with observed holes by providing both the image and its hole mask. Since the real data contains zero-valued holes, inference is performed on zero-filled inputs; the network predicts a residual that fills only the missing regions, which is then composited back onto the observed image.

In [None]:
@torch.no_grad()
def infer_inpaint_with_mask(
    net,
    rgb_uint8: np.ndarray,     # image with holes
    hole_mask_u8: np.ndarray,  # 255 = hole
    device,
    corrupt_mode="zero", #real data has black holes
):
    net.eval()

    # input image
    x_obs = torch.from_numpy(rgb_uint8).float().permute(2,0,1) / 255.0

    # mask to {0,1}
    hm = hole_mask_u8.astype(np.float32)
    if hm.max() > 1.0:
        hm = hm / 255.0
    mask = torch.from_numpy(hm)[None].clamp(0,1)

    # fill hole area (zeros match corrupted output)
    if corrupt_mode == "zero":
        fill = torch.zeros_like(x_obs)
    elif corrupt_mode == "blur":
        fill = F.avg_pool2d(x_obs.unsqueeze(0), kernel_size=31, stride=1, padding=15)[0]
    else:
        raise ValueError("corrupt_mode must be 'blur' or 'zero'")

    x_corrupt = x_obs * (1.0 - mask) + fill * mask

    x_in = torch.cat([x_corrupt, mask], dim=0).unsqueeze(0).to(device)
    delta = net(x_in)[0].cpu()

    # apply prediction only inside hole
    hole_pred = (x_corrupt + delta * mask).clamp(0,1)

    # paste back into original image
    filled = x_obs * (1.0 - mask) + hole_pred * mask

    filled_u8 = (filled.permute(1,2,0).numpy() * 255).astype(np.uint8)
    pred_u8   = (hole_pred.permute(1,2,0).numpy() * 255).astype(np.uint8)
    return filled_u8, pred_u8


In [None]:
input_folder = "/content/drive/MyDrive/rgb_images"

# load example image
img20_path = glob.glob(os.path.join(input_folder, "train_20.png"))[0]
rgb20 = load_rgb_uint8(img20_path)                  # (H,W,3)
mask20_u8 = create_inpaint_mask(rgb20, threshold=1) # (H,W) {0,255}

# run inference
filled_u8, pred_u8 = infer_inpaint_with_mask(
    net, rgb20, mask20_u8, device, corrupt_mode="zero"
)

# plot results
plt.figure(figsize=(18,6))
plt.subplot(1,3,1)
plt.imshow(rgb20)
plt.title("Input")
plt.axis("off")

plt.subplot(1,3,2)
plt.imshow(mask20_u8, cmap="gray")
plt.title("Hole mask")
plt.axis("off")

plt.subplot(1,3,3)
plt.imshow(filled_u8)
plt.title("Inpainted result")
plt.axis("off")

plt.tight_layout()
plt.show()


# Inpainting results and moving forward
The inpainting results shown above are not what I initially expected. Even after more than 100 epochs of training, the blurred regions inside the holes do not fully disappear and remain visible. While U-Net–based inpainting has been shown to work well for similar image restoration tasks, in this case the required training complexity and model capacity make it difficult to achieve high-quality results within the scope of this project.

For this reason, I switch to a diffusion-based inpainting pipeline, which can more easily generate plausible content for larger missing regions without requiring extensive task-specific training.

Importantly, the work on the U-Net generator and discriminator is not lost. These models are well suited for single-image super-resolution, where the task is more localized and aligns strongly with their strengths. As a result, they are reused in the next stage of the pipeline. Since the diffusion inpainter operates at a fixed resolution, the images are first resized before applying super-resolution.

In [None]:
class SRSatelliteCached(Dataset):
    def __init__(self, image_paths, lr_size=128, hr_size=512):
        self.paths = image_paths
        self.lr_size = lr_size
        self.hr_size = hr_size
        self.rgb_u8 = [load_rgb_uint8(p) for p in self.paths]  # cache in RAM

    def __len__(self): return len(self.rgb_u8)

    def __getitem__(self, idx):
        rgb = self.rgb_u8[idx]  # (H,W,3) uint8
        # resize to target high-resolution size
        img = Image.fromarray(rgb).resize((self.hr_size, self.hr_size), Image.BICUBIC)
        hr = TF.to_tensor(img)
        # create low-resolution version
        lr = img.resize((self.lr_size, self.lr_size), Image.BICUBIC) # synthetic LR
        # bicubic upsample back to HR size (baseline input)
        up = lr.resize((self.hr_size, self.hr_size), Image.BICUBIC)
        up = TF.to_tensor(up)

        return up, hr

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
).to(device)


input_dir  = "/content/drive/MyDrive/rgb_images"
output_dir = "/content/drive/MyDrive/inpainted_images"
os.makedirs(output_dir, exist_ok=True)

In [None]:
SD_SIZE        = 256 # inpaint resolution
STEPS          = 30
GUIDANCE       = 6 # Balance between faithful to input image and prompt
SEED           = 42 # reproducable images

PROMPT = "satellite photo, realistic terrain, consistent lighting, natural textures"
NEG_PROMPT = "blurry, oversmooth, repeating patterns, artifacts, cartoon"
paths = sorted(glob.glob(os.path.join(input_dir, "train_*.png"))) #

for p in tqdm(paths):
    rgb = load_rgb_uint8(p)
    mask_u8 = create_inpaint_mask(rgb, threshold=1)

    # convert to PIL
    img_pil  = Image.fromarray(rgb)
    mask_pil = Image.fromarray(mask_u8, mode="L")

    # resize to 256 for SD
    img_256  = img_pil.resize((SD_SIZE, SD_SIZE), resample=Image.BICUBIC)
    mask_256 = mask_pil.resize((SD_SIZE, SD_SIZE), resample=Image.NEAREST)

    # if no holes, no need to inpaint, therefore just copy image
    if np.array(mask_256).mean() < 1e-3:
        out_img = img_pil
    else:
        gen = torch.Generator(device=device).manual_seed(SEED)

        out_256 = pipe(
            prompt=PROMPT,
            negative_prompt=NEG_PROMPT,
            image=img_256,
            mask_image=mask_256,
            guidance_scale=GUIDANCE,
            num_inference_steps=STEPS,
            strength=1.0,
            generator=gen,
        ).images[0]
        # resize back to original resolution
        out_img = out_256.resize(img_pil.size, resample=Image.BICUBIC)

    out_name = os.path.basename(p).replace("train_", "inpaint_")
    out_path = os.path.join(output_dir, out_name)
    out_img.save(out_path)

print("Inpainting complete. Saved to:", output_dir)


In [None]:
input_dir  = "/content/drive/MyDrive/rgb_images"
output_dir = "/content/drive/MyDrive/inpainted_images"

orig20 = load_rgb_uint8(os.path.join(input_dir, "train_20.png"))
orig25 = load_rgb_uint8(os.path.join(input_dir, "train_25.png"))

inp20  = load_rgb_uint8(os.path.join(output_dir, "inpaint_20.png"))
inp25  = load_rgb_uint8(os.path.join(output_dir, "inpaint_25.png"))

plt.figure(figsize=(12, 8))

plt.subplot(2, 2, 1)
plt.imshow(orig20)
plt.title("train_20 (original rgb)")
plt.axis("off")

plt.subplot(2, 2, 2)
plt.imshow(inp20)
plt.title("inpaint_20 (inpainted)")
plt.axis("off")

plt.subplot(2, 2, 3)
plt.imshow(orig25)
plt.title("train_25 (original rgb)")
plt.axis("off")

plt.subplot(2, 2, 4)
plt.imshow(inp25)
plt.title("inpaint_25 (inpainted)")
plt.axis("off")

plt.tight_layout()
plt.show()


These results look better than my inpainting pipeline. Therefore, we will be able to apply single-image superresolution on the inpainted_images. Firstly, we can reuse large parts of the U-net generator and Patch discriminator classes. For each image I create a synthetic low-resolution version by bicubic downsampling, then bicubic upsample it back to 512x512. The U-Net generator learns refinements of the upsampled image while the PatchGAN discriminator encourages more realistic textures.

In [None]:
def conv_plus_conv(in_channels: int, out_channels: int):
    """
    Same structure as before, without normalization.
    Lim et al. (and many more) have shown that batchNorm actually worsens results for SR: https://arxiv.org/pdf/1707.02921
    """
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.ReLU(inplace=True),
    )

def bilinear_upsample_concat_conv(x, skip, conv_block):
 # Bilinear upsample x to match skip's (H,W) to concatenate
        #x:         (N, Cx, Hx, Wx)   decoder feature map
        #skip:      (N, Cs, Hs, Ws)   encoder skip feature map
        #conv_block: nn.Module that expects input channels Cx+Cs
    x_up = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
    x_cat = torch.cat([x_up, skip], dim=1)
    return conv_block(x_cat)

class UNetSRGenerator(nn.Module):
    """
    U-Net generator for single-image super-resolution.
    Input  : bicubic-upsampled LR image
    Output : residual added on top of bicubic baseline
    """
    def __init__(self, base=32):
        super().__init__()

        # encoder
        self.down1 = conv_plus_conv(3, base)
        self.down2 = conv_plus_conv(base, base * 2)
        self.down3 = conv_plus_conv(base * 2, base * 4)

        self.pool = nn.MaxPool2d(2)

        # bottleneck
        self.bottleneck = conv_plus_conv(base * 4, base * 8)

        # decoder
        self.up3 = conv_plus_conv(base * 8 + base * 4, base * 4)
        self.up2 = conv_plus_conv(base * 4 + base * 2, base * 2)
        self.up1 = conv_plus_conv(base * 2 + base, base)

        # predict residual (same resolution as input)
        self.out = nn.Conv2d(base, 3, kernel_size=1)

    def forward(self, x):
        # encoder
        x1 = self.down1(x)
        x2 = self.down2(self.pool(x1))
        x3 = self.down3(self.pool(x2))

        # bottleneck
        b = self.bottleneck(self.pool(x3))

        # decoder
        u3 = bilinear_upsample_concat_conv(b,  x3, self.up3)
        u2 = bilinear_upsample_concat_conv(u3, x2, self.up2)
        u1 = bilinear_upsample_concat_conv(u2, x1, self.up1)

        # residual prediction
        delta = self.out(u1)
        return delta


In [None]:
def grad_xy_loss(pred, target):
    # Gradient loss in x and y directions
    pdx = pred[:, :, :, 1:] - pred[:, :, :, :-1]
    pdy = pred[:, :, 1:, :] - pred[:, :, :-1, :]
    tdx = target[:, :, :, 1:] - target[:, :, :, :-1]
    tdy = target[:, :, 1:, :] - target[:, :, :-1, :]

    return (pdx - tdx).abs().mean() + (pdy - tdy).abs().mean()

def mixge_loss(sr, hr, lambda_g=0.1):
    # Mean-squared error
    mse = F.mse_loss(sr, hr)
    g   = grad_xy_loss(sr, hr)
    return mse + lambda_g * g, mse, g


In [None]:
images = sorted(glob.glob("/content/drive/MyDrive/inpainted_images/*.png"))
print("Num images:", len(images))

dataset = SRSatelliteCached(
    images,
    lr_size=128,   # 4× SR
    hr_size=512
)

dl = torch.utils.data.DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
def to_u8(t3hw):
    # Converter in order to be read by plt.imread() who expects (H,W,C)
    return (t3hw.clamp(0,1).permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)

def corner_crop(t3hw, size=220, corner="tr"):
    # Zoom in on a corner in order to see clearer resolution change
    _, H, W = t3hw.shape
    if corner == "tl":
        y0, x0 = 0, 0
    elif corner == "tr":
        y0, x0 = 0, W - size
    elif corner == "bl":
        y0, x0 = H - size, 0
    elif corner == "br":
        y0, x0 = H - size, W - size
        #Zoomed in image
    return t3hw[:, y0:y0+size, x0:x0+size]

In [None]:
use_amp = (device.type == "cuda")

net = UNetSRGenerator(base=32).to(device)
D   = PatchDiscriminator(in_channels=3, base=64).to(device)

optG = torch.optim.Adam(net.parameters(), lr=1e-4, betas=(0.0, 0.9))
optD = torch.optim.Adam(D.parameters(),   lr=2e-4, betas=(0.0, 0.9))

scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

lambda_g = 0.1          # gradient weight in MixGE
warmup_epochs = 10      # train only reconstruction first
w_gan_after = 0.003     # turn GAN on after warmup
epochs = 50

ckpt_path = "/content/drive/MyDrive/srgan_unet_final.pt"

# pick a different batch for visualization every 5 epochs
visualisation = {}
for e in range(5, epochs + 1, 5):
    visualisation[e] = next(iter(dl))

In [None]:
for epoch in range(epochs):
    net.train()
    D.train()

    w_gan = 0.0 if epoch < warmup_epochs else w_gan_after

    run = {"G":0.0, "D":0.0, "mse":0.0, "grad":0.0, "gan":0.0}
    n_batches = 0

    pbar = tqdm(dl, desc=f"Epoch {epoch+1}/{epochs} | w_gan={w_gan:.3f}", leave=True)

    for up, hr in pbar:
        up = up.to(device, non_blocking=True)  # bicubic upsampled LR (input)
        hr = hr.to(device, non_blocking=True)  # HR target

        # forward generator once (reuse for D + G)
        with torch.amp.autocast("cuda", enabled=use_amp):
            delta = net(up)
            sr = (up + delta).clamp(0, 1)
        # Patch Discriminator
        if w_gan > 0:
            optD.zero_grad(set_to_none=True)
            with torch.amp.autocast("cuda", enabled=use_amp):
                d_real = D(hr)
                d_fake = D(sr.detach())
                d_loss = d_hinge_loss(d_real, d_fake)

            scaler.scale(d_loss).backward()
            scaler.step(optD)
        else:
            d_loss = hr.new_tensor(0.0)

        # U-net generator
        optG.zero_grad(set_to_none=True)

        with torch.amp.autocast("cuda", enabled=use_amp):
            rec, mse, g = mixge_loss(sr, hr, lambda_g=lambda_g)
            # gan loss
            gan = g_hinge_loss(D(sr)) if w_gan > 0 else sr.new_tensor(0.0)
            total_g = rec + w_gan * gan

        scaler.scale(total_g).backward()
        scaler.step(optG)
        scaler.update()

        run["G"]    += float(total_g.item())
        run["D"]    += float(d_loss.item())
        run["mse"]  += float(mse.item())
        run["grad"] += float(g.item())
        run["gan"]  += float(gan.item())
        n_batches += 1

        pbar.set_postfix({
            "G": f"{total_g.item():.3f}",
            "D": f"{d_loss.item():.3f}",
            "mse": f"{mse.item():.3f}",
            "grad": f"{g.item():.3f}",
            "gan": f"{gan.item():.3f}",
        })

    for k in run:
        run[k] /= max(1, n_batches)

    print(
        f"Epoch {epoch+1}: "
        f"G={run['G']:.4f} D={run['D']:.4f} "
        f"mse={run['mse']:.4f} grad={run['grad']:.4f} gan={run['gan']:.4f}"
    )

    # plot an example every 5 epochs
    if (epoch + 1) % 5 == 0:
        net.eval()
        up_v, hr_v = visualisation[epoch + 1]
        up_v = up_v.to(device, non_blocking=True)
        hr_v = hr_v.to(device, non_blocking=True)

        with torch.no_grad():
            with torch.amp.autocast("cuda", enabled=use_amp):
                sr_v = (up_v + net(up_v)).clamp(0, 1)

        # zoom in image
        up0 = up_v[0].float().cpu()
        hr0 = hr_v[0].float().cpu()
        sr0 = sr_v[0].float().cpu()

        # Synthetic LR view
        lr_vis = F.interpolate(
            up0.unsqueeze(0), size=(128, 128),
            mode="bicubic", align_corners=False
        )[0].cpu()

        corner = "tr"
        zoom_size = 220
        up_z = corner_crop(up0, size=zoom_size, corner=corner)
        hr_z = corner_crop(hr0, size=zoom_size, corner=corner)
        sr_z = corner_crop(sr0, size=zoom_size, corner=corner)

        diff = (sr_z - up_z).abs()
        diff = diff / (diff.max() + 1e-8)

        plt.figure(figsize=(22,7))

        plt.subplot(2,4,1); plt.imshow(to_u8(hr0));    plt.title("HR (target)"); plt.axis("off")
        plt.subplot(2,4,2); plt.imshow(to_u8(lr_vis)); plt.title("Synthetic LR (downsampled)"); plt.axis("off")
        plt.subplot(2,4,3); plt.imshow(to_u8(up0));    plt.title("Bicubic upsample (input)"); plt.axis("off")
        plt.subplot(2,4,4); plt.imshow(to_u8(sr0));    plt.title(f"SR output (epoch {epoch+1})"); plt.axis("off")

        plt.subplot(2,4,5); plt.imshow(to_u8(hr_z));   plt.title(f"Zoom {corner}: HR"); plt.axis("off")
        plt.subplot(2,4,6); plt.imshow(to_u8(up_z));   plt.title(f"Zoom {corner}: Bicubic"); plt.axis("off")
        plt.subplot(2,4,7); plt.imshow(to_u8(sr_z));   plt.title(f"Zoom {corner}: SR"); plt.axis("off")
        plt.subplot(2,4,8); plt.imshow(to_u8(diff));   plt.title("|SR - Bicubic| (zoom)"); plt.axis("off")

        plt.tight_layout()
        plt.show()

    # Checkpoints and save backup every 5 epochs or at end
    if (epoch + 1) % 5 == 0 or (epoch + 1) == epochs:
        torch.save({
            "epoch": epoch,
            "model_state": net.state_dict(),
            "disc_state": D.state_dict(),
            "optG_state": optG.state_dict(),
            "optD_state": optD.state_dict(),
            "scaler_state": scaler.state_dict() if use_amp else None,
            "cfg": {
                "lambda_g": lambda_g,
                "w_gan_after": w_gan_after,
                "warmup_epochs": warmup_epochs,
                "base": 32,
                "lr_size": 128,
                "hr_size": 512,
            }
        }, ckpt_path)
        print("Checkpoint saved:", ckpt_path)


In [None]:
net.eval()

input_folder = "/content/drive/MyDrive/inpainted_images"
paths = [
    os.path.join(input_folder, "inpaint_20.png"),
    os.path.join(input_folder, "inpaint_25.png"),
]

corner = "bl"
zoom_size = 220

for p in paths:
    # HR "target" (the image we compare to)
    hr0 = TF.to_tensor(
        Image.open(p).convert("RGB").resize((512, 512), Image.BICUBIC)
    )

    # Synthetic LR view (128x128)
    lr_pil = TF.to_pil_image(hr0).resize((128, 128), Image.BICUBIC)
    lr_vis = TF.to_tensor(lr_pil)
    # upsampled
    up0 = TF.to_tensor(
        lr_pil.resize((512, 512), Image.BICUBIC)
    )

    with torch.no_grad():
        with torch.amp.autocast("cuda", enabled=use_amp):
            up_b = up0.unsqueeze(0).to(device, non_blocking=True)
            sr0 = (up_b + net(up_b)).clamp(0, 1)[0].cpu()

    # zoom + diff
    up_z = corner_crop(up0, size=zoom_size, corner=corner)
    hr_z = corner_crop(hr0, size=zoom_size, corner=corner)
    sr_z = corner_crop(sr0, size=zoom_size, corner=corner)

    diff = (sr_z - up_z).abs()
    diff = diff / (diff.max() + 1e-8)

    plt.figure(figsize=(22,7))

    plt.subplot(2,4,1); plt.imshow(to_u8(hr0));    plt.title("HR (target)"); plt.axis("off")
    plt.subplot(2,4,2); plt.imshow(to_u8(lr_vis)); plt.title("Synthetic LR (downsampled)"); plt.axis("off")
    plt.subplot(2,4,3); plt.imshow(to_u8(up0));    plt.title("Bicubic upsample (input)"); plt.axis("off")
    plt.subplot(2,4,4); plt.imshow(to_u8(sr0));    plt.title("SR output"); plt.axis("off")

    plt.subplot(2,4,5); plt.imshow(to_u8(hr_z));   plt.title(f"Zoom {corner}: HR"); plt.axis("off")
    plt.subplot(2,4,6); plt.imshow(to_u8(up_z));   plt.title(f"Zoom {corner}: Bicubic"); plt.axis("off")
    plt.subplot(2,4,7); plt.imshow(to_u8(sr_z));   plt.title(f"Zoom {corner}: SR"); plt.axis("off")
    plt.subplot(2,4,8); plt.imshow(to_u8(diff));   plt.title("|SR - Bicubic| (zoom)"); plt.axis("off")

    plt.suptitle(os.path.basename(p), y=1.02)
    plt.tight_layout()
    plt.show()


At full resolution, the SR output looks very similar to the bicubic input, so improvements are subtle in the overall image. When I zoom into a corner, the SR result shows slightly sharper edges and more defined small structures. This is also visible in |SR - Bicubic|, where most updates are focussed around edges rather than flat regions.