**Install & imports**

In [None]:
!pip install -q gradio albumentations opencv-python-headless torch torchvision

import os, cv2, torch, numpy as np, gradio as gr, torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2

**Install & imports**

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# EDIT THIS if your path differs:
MODEL_PATH = "/content/drive/MyDrive/UMR_GAN/results/umr_pix2pix_inpaint_ssim_G.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
print("Weights:", MODEL_PATH, "exists:", os.path.exists(MODEL_PATH))


Mounted at /content/drive
Device: cpu
Weights: /content/drive/MyDrive/UMR_GAN/results/umr_pix2pix_inpaint_ssim_G.pth exists: True


**Model (UNetGenerator in_c=2 → out_c=1) and load weights**

In [None]:
class UNetGenerator(nn.Module):
    def __init__(self, in_c=2, out_c=1):
        super().__init__()
        def down(i,o,bn=True):
            L=[nn.Conv2d(i,o,4,2,1,bias=False)]
            if bn: L.append(nn.BatchNorm2d(o))
            L.append(nn.LeakyReLU(0.2,True))
            return nn.Sequential(*L)
        def up(i,o,drop=False):
            L=[nn.ConvTranspose2d(i,o,4,2,1,bias=False), nn.BatchNorm2d(o), nn.ReLU(True)]
            if drop: L.append(nn.Dropout(0.5))
            return nn.Sequential(*L)
        self.d1=down(in_c,64,False); self.d2=down(64,128)
        self.d3=down(128,256); self.d4=down(256,512)
        self.d5=down(512,512); self.d6=down(512,512)
        self.d7=down(512,512); self.b=down(512,512, bn=False)  # no BN at 1×1
        self.u1=up(512,512,True); self.u2=up(1024,512,True)
        self.u3=up(1024,512,True); self.u4=up(1024,512)
        self.u5=up(1024,256); self.u6=up(512,128)
        self.u7=up(256,64)
        self.out=nn.Sequential(nn.ConvTranspose2d(128,out_c,4,2,1), nn.Tanh())

    def forward(self,x):
        d1=self.d1(x); d2=self.d2(d1); d3=self.d3(d2); d4=self.d4(d3)
        d5=self.d5(d4); d6=self.d6(d5); d7=self.d7(d6); b=self.b(d7)
        u1=self.u1(b); u2=self.u2(torch.cat([u1,d7],1))
        u3=self.u3(torch.cat([u2,d6],1)); u4=self.u4(torch.cat([u3,d5],1))
        u5=self.u5(torch.cat([u4,d4],1)); u6=self.u6(torch.cat([u5,d3],1))
        u7=self.u7(torch.cat([u6,d2],1))
        return self.out(torch.cat([u7,d1],1))

# Load weights (handles state_dict or wrapped dict)
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
state = ckpt.get("G", ckpt.get("gen", ckpt.get("generator", ckpt.get("state_dict", ckpt))))
G = UNetGenerator(in_c=2, out_c=1).to(DEVICE)
missing, unexpected = G.load_state_dict(state, strict=False)
print("Loaded G. missing:", missing, "| unexpected:", unexpected)
G.eval()

Loaded G. missing: [] | unexpected: []


UNetGenerator(
  (d1): Sequential(
    (0): Conv2d(2, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (d2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (d3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (d4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (d5): Sequential(
    (0): Conv2d(512, 512, kernel_si

**Pre/Post (resize→normalize to [-1,1] and denorm to [0,1])**

In [None]:
TF = A.Compose([
    A.Resize(256,256),
    A.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0),
    ToTensorV2()
])

def to_gray(arr_rgb):
    return cv2.cvtColor(arr_rgb, cv2.COLOR_RGB2GRAY)

def denorm01(t):
    return (t.squeeze().cpu().numpy()*0.5 + 0.5).clip(0,1)


**Gradio app (Denoise or Inpaint with sketch mask)**

In [None]:
# --- UMR-GAN Gradio UI (denoise + inpaint) with noise match, median filter, TTA ---
import gradio as gr, numpy as np, cv2, torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Use the same preprocessing as training/eval
TF = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0),
    ToTensorV2()
])

def _editor_to_mask(sketch, H=256, W=256):
    """Parse ImageEditor output to a binary mask (H,W) in {0,1}."""
    if sketch is None:
        return np.zeros((H, W), dtype=np.uint8)
    arr = sketch
    if isinstance(arr, dict):                         # some gradio versions return dicts
        for key in ["composite", "image", "background"]:
            if key in arr and arr[key] is not None:
                arr = arr[key]; break
    arr = np.array(arr)
    if arr.ndim == 3 and arr.shape[2] == 4:          # RGBA -> alpha
        m = (arr[..., 3] > 0).astype(np.uint8)
    elif arr.ndim == 3:                               # RGB -> brightness
        m = (cv2.cvtColor(arr, cv2.COLOR_RGB2GRAY) > 0).astype(np.uint8)
    else:                                             # grayscale
        m = (arr > 0).astype(np.uint8)
    return cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST)

def _denorm01(t):   # [-1,1] -> [0,1] numpy
    return (t.squeeze().cpu().numpy()*0.5 + 0.5).clip(0,1)

@torch.no_grad()
def _run_G(x_cond):
    return G(x_cond).cpu()

def restore(img_rgb, mode, noise_sigma, use_median, use_tta, sketch):
    # 1) to gray + resize to 256
    gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
    gray = cv2.resize(gray, (256, 256), interpolation=cv2.INTER_AREA)

    # 2) add Gaussian noise (match training distribution); set to 0 if input already noisy
    if noise_sigma and noise_sigma > 0:
        noisy = np.clip(gray + np.random.normal(0, noise_sigma, gray.shape), 0, 255).astype(np.uint8)
    else:
        noisy = gray.copy()

    # 3) mask (zero for denoise; drawn for inpaint)
    if mode == "Denoise (zero mask)":
        M = np.zeros_like(noisy, dtype=np.uint8)
    else:
        M = _editor_to_mask(sketch, 256, 256)  # {0,1}

    # 4) preprocess to [-1,1] tensors
    noisy_t = TF(image=noisy)["image"]                  # (1,256,256)
    mask_t  = torch.from_numpy(M).float().unsqueeze(0)  # (1,256,256)

    # blank masked pixels to -1 (training convention)
    x_img = noisy_t.clone()
    x_img[mask_t.bool()] = -1.0

    # 5) condition tensor (B=1, C=2)
    x = torch.cat([x_img, mask_t], dim=0).unsqueeze(0).to(DEVICE)

    # 6) optional TTA (flip-average)
    if use_tta:
        xs = [x, torch.flip(x, dims=[3]), torch.flip(x, dims=[2]), torch.flip(x, dims=[2,3])]
        outs = []
        for k, xi in enumerate(xs):
            yi = _run_G(xi)
            if k == 1: yi = torch.flip(yi, dims=[3])
            if k == 2: yi = torch.flip(yi, dims=[2])
            if k == 3: yi = torch.flip(yi, dims=[2,3])
            outs.append(yi)
        out = torch.mean(torch.stack(outs, dim=0), dim=0)
    else:
        out = _run_G(x)

    # 7) optional median filter to suppress pepper dots
    out_np = _denorm01(out[0])
    if use_median:
        out_np = cv2.medianBlur((out_np*255).astype(np.uint8), 3) / 255.0

    return (_denorm01(x_img), M.astype(float), out_np)

with gr.Blocks() as demo:
    gr.Markdown("## UMR-GAN (pix2pix) — MRI Denoising & Inpainting")
    with gr.Row():
        with gr.Column(scale=1):
            img_in = gr.Image(type="numpy", label="Upload MRI slice (auto-resized to 256×256)")
            mode   = gr.Radio(choices=["Denoise (zero mask)", "Inpaint (use drawn mask)"],
                              value="Denoise (zero mask)", label="Mode")
            noise_sigma = gr.Slider(0, 30, value=15, step=1, label="Add Gaussian noise σ (match training)")
            use_median  = gr.Checkbox(value=True,  label="Post-process: median 3×3 (reduce dots)")
            use_tta     = gr.Checkbox(value=False, label="TTA (flip-average)")
            sketch      = gr.ImageEditor(label="Draw Mask (paint white where missing)")
            btn         = gr.Button("Restore")
        with gr.Column(scale=1):
            out1 = gr.Image(label="Input (masked/noisy)", image_mode="L", width=256, height=256)
            out2 = gr.Image(label="Mask",               image_mode="L", width=256, height=256)
            out3 = gr.Image(label="Restored",           image_mode="L", width=256, height=256)
    btn.click(fn=restore, inputs=[img_in, mode, noise_sigma, use_median, use_tta, sketch],
              outputs=[out1, out2, out3])

demo.launch(share=True)


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://038ef6ea7de14ab509.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


