**Setup and Imports**

In [None]:
!pip install torch torchvision albumentations torchmetrics lpips matplotlib tqdm

Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics, lpips
Successfully installed lightning-utilities-0.15.2 lpips-0.1.4 torchmetrics-1.8.2


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch, torch.nn as nn  # in case nn wasn't in scope
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [None]:
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2, os, glob, numpy as np, matplotlib.pyplot as plt, json
from tqdm import tqdm
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

**Dataset Loader**

In [None]:
import glob, os, cv2, numpy as np, torch, random
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2

class MRIDataset(Dataset):
    def __init__(self, root, inpaint_prob=0.5):
        self.files = glob.glob(os.path.join(root, "**", "*.png"), recursive=True)
        if len(self.files) == 0:
            self.files = glob.glob(os.path.join(root, "**", "*.jpg"), recursive=True)
        print(f"Found {len(self.files)} images in {root}")

        self.inpaint_prob = inpaint_prob  # % of samples that get a mask

        # Normalize to [-1,1] to match Tanh
        self.tf = A.Compose([
            A.Resize(256,256),
            A.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0),
            ToTensorV2()
        ])

    def _random_mask(self, h, w):
        M = np.zeros((h, w), dtype=np.uint8)
        # rectangles
        for _ in range(random.randint(1, 3)):
            rh, rw = random.randint(h//8, h//3), random.randint(w//8, w//3)
            ry, rx = random.randint(0, h-rh), random.randint(0, w-rw)
            M[ry:ry+rh, rx:rx+rw] = 1
        # short strokes (optional)
        for _ in range(random.randint(10, 20)):
            y, x = random.randint(0, h-1), random.randint(0, w-1)
            M[max(0,y-2):min(h,y+2), max(0,x-8):min(w,x+8)] = 1
        return M

    def __len__(self): return len(self.files)

    def __getitem__(self, idx):
        clean_u8 = cv2.imread(self.files[idx], cv2.IMREAD_GRAYSCALE)

        # add Gaussian noise like before
        noisy_u8 = clean_u8 + np.random.normal(0, 15, clean_u8.shape)
        noisy_u8 = np.clip(noisy_u8, 0, 255).astype(np.uint8)

        clean = self.tf(image=clean_u8)["image"]   # (1,H,W) in [-1,1]
        noisy = self.tf(image=noisy_u8)["image"]   # (1,H,W) in [-1,1]

        # optional inpainting mask
        if random.random() < self.inpaint_prob:
            H, W = clean.shape[1], clean.shape[2]
            M_u8 = self._random_mask(H, W)                      # 0/1
            M = torch.from_numpy(M_u8).float().unsqueeze(0)     # (1,H,W)
            masked_noisy = noisy.clone()
            masked_noisy[M.bool()] = -1.0                       # fill holes with -1
        else:
            M = torch.zeros_like(noisy)                         # no holes
            masked_noisy = noisy

        # condition = [masked_noisy, mask]; target = clean
        x_cond = torch.cat([masked_noisy, M], dim=0)            # (2,H,W)
        y_tgt  = clean                                          # (1,H,W)
        return x_cond, y_tgt


**Loading Training Dataset**

In [None]:
train_path = "/content/drive/MyDrive/Training"

train_ds = MRIDataset(train_path, inpaint_prob=0.5)
train_dl = DataLoader(train_ds, batch_size=1, shuffle=True)

Found 5721 images in /content/drive/MyDrive/Training


**Generator (U-Net 256)**

In [None]:
import torch.nn as nn

class UNetGenerator(nn.Module):
    def __init__(self, in_c=2, out_c=1):  # <- 2-channel condition now
        super().__init__()
        def down(i,o,bn=True):
            layers=[nn.Conv2d(i,o,4,2,1,bias=False)]
            if bn: layers.append(nn.BatchNorm2d(o))
            layers.append(nn.LeakyReLU(0.2,True))
            return nn.Sequential(*layers)
        def up(i,o,drop=False):
            layers=[nn.ConvTranspose2d(i,o,4,2,1,bias=False),
                    nn.BatchNorm2d(o), nn.ReLU(True)]
            if drop: layers.append(nn.Dropout(0.5))
            return nn.Sequential(*layers)
        self.d1=down(in_c,64,False); self.d2=down(64,128)
        self.d3=down(128,256); self.d4=down(256,512)
        self.d5=down(512,512); self.d6=down(512,512)
        self.d7=down(512,512); self.b=down(512,512, bn=False)  # 1×1 no BN
        self.u1=up(512,512,True); self.u2=up(1024,512,True)
        self.u3=up(1024,512,True); self.u4=up(1024,512)
        self.u5=up(1024,256); self.u6=up(512,128)
        self.u7=up(256,64)
        self.out=nn.Sequential(nn.ConvTranspose2d(128,out_c,4,2,1), nn.Tanh())
    def forward(self,x):
        d1=self.d1(x); d2=self.d2(d1); d3=self.d3(d2); d4=self.d4(d3)
        d5=self.d5(d4); d6=self.d6(d5); d7=self.d7(d6); b=self.b(d7)
        u1=self.u1(b); u2=self.u2(torch.cat([u1,d7],1))
        u3=self.u3(torch.cat([u2,d6],1)); u4=self.u4(torch.cat([u3,d5],1))
        u5=self.u5(torch.cat([u4,d4],1)); u6=self.u6(torch.cat([u5,d3],1))
        u7=self.u7(torch.cat([u6,d2],1))
        return self.out(torch.cat([u7,d1],1))

class PatchDiscriminator(nn.Module):
    def __init__(self, in_c=3):  # 2-ch cond + 1-ch target
        super().__init__()
        def block(i,o,norm=True):
            layers=[nn.Conv2d(i,o,4,2,1,bias=False)]
            if norm: layers.append(nn.BatchNorm2d(o))
            layers.append(nn.LeakyReLU(0.2,True))
            return layers
        self.net=nn.Sequential(
            *block(in_c,64,False),
            *block(64,128),
            *block(128,256),
            nn.Conv2d(256,1,4,1,1)
        )
    def forward(self, cond, y):  # cond: (B,2,H,W)
        return self.net(torch.cat([cond, y], dim=1))


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
G, D = UNetGenerator(in_c=2, out_c=1).to(device), PatchDiscriminator(in_c=3).to(device)

**Discriminator (PatchGAN)**

In [None]:
class PatchDiscriminator(nn.Module):
    def __init__(self, in_c=2):
        super().__init__()
        def block(i,o,norm=True):
            layers=[nn.Conv2d(i,o,4,2,1)]
            if norm: layers.append(nn.BatchNorm2d(o))
            layers.append(nn.LeakyReLU(0.2,True))
            return layers
        self.net=nn.Sequential(
            *block(in_c,64,False),
            *block(64,128),
            *block(128,256),
            nn.Conv2d(256,1,4,1,1))
    def forward(self,x,y):
        return self.net(torch.cat([x,y],1))


**Initialize Models & Losses**

In [None]:
import torch.optim as optim
from torchmetrics.image import StructuralSimilarityIndexMeasure

optG = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5,0.999))
optD = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))
bce, l1 = nn.BCEWithLogitsLoss(), nn.L1Loss()

# SSIM expects data_range for normalized images; [-1,1] ⇒ range=2.0
ssim_metric = StructuralSimilarityIndexMeasure(data_range=2.0).to(device)

λ_l1   = 100.0
λ_ssim = 10.0   # start with 10; you can tune to 5–20

**Training loop**

In [None]:
EPOCHS = 20
for epoch in range(EPOCHS):
    G.train(); D.train()
    g_losses, d_losses = [], []

    for x_cond, y_tgt in tqdm(train_dl):     # <- was (noisy, clean)
        x_cond, y_tgt = x_cond.to(device), y_tgt.to(device)
        fake = G(x_cond)

        # --- Discriminator ---
        D.zero_grad()
        real_out = D(x_cond, y_tgt)
        fake_out = D(x_cond, fake.detach())
        d_loss = (bce(real_out, torch.ones_like(real_out)) +
                  bce(fake_out, torch.zeros_like(fake_out))) * 0.5
        d_loss.backward(); optD.step()

        # --- Generator ---
        G.zero_grad()
        fake_out = D(x_cond, fake)
        g_adv  = bce(fake_out, torch.ones_like(fake_out))
        g_l1   = l1(fake, y_tgt)
        g_ssim = 1.0 - ssim_metric(fake, y_tgt)   # lower is better
        g_loss = g_adv + λ_l1 * g_l1 + λ_ssim * g_ssim
        g_loss.backward(); optG.step()

        g_losses.append(g_loss.item()); d_losses.append(d_loss.item())

    print(f"Epoch {epoch+1}/{EPOCHS} | G_loss={np.mean(g_losses):.3f} | D_loss={np.mean(d_losses):.3f}")


100%|██████████| 5721/5721 [29:05<00:00,  3.28it/s]


Epoch 1/20 | G_loss=7.962 | D_loss=0.497


100%|██████████| 5721/5721 [05:55<00:00, 16.10it/s]


Epoch 2/20 | G_loss=6.991 | D_loss=0.529


100%|██████████| 5721/5721 [05:55<00:00, 16.11it/s]


Epoch 3/20 | G_loss=6.502 | D_loss=0.563


100%|██████████| 5721/5721 [05:57<00:00, 16.01it/s]


Epoch 4/20 | G_loss=6.290 | D_loss=0.572


100%|██████████| 5721/5721 [05:54<00:00, 16.13it/s]


Epoch 5/20 | G_loss=6.059 | D_loss=0.577


100%|██████████| 5721/5721 [05:55<00:00, 16.07it/s]


Epoch 6/20 | G_loss=5.958 | D_loss=0.577


100%|██████████| 5721/5721 [05:58<00:00, 15.95it/s]


Epoch 7/20 | G_loss=5.841 | D_loss=0.577


100%|██████████| 5721/5721 [05:58<00:00, 15.98it/s]


Epoch 8/20 | G_loss=5.833 | D_loss=0.567


100%|██████████| 5721/5721 [05:58<00:00, 15.96it/s]


Epoch 9/20 | G_loss=5.768 | D_loss=0.589


100%|██████████| 5721/5721 [05:57<00:00, 15.99it/s]


Epoch 10/20 | G_loss=5.840 | D_loss=0.591


100%|██████████| 5721/5721 [05:55<00:00, 16.11it/s]


Epoch 11/20 | G_loss=5.726 | D_loss=0.591


100%|██████████| 5721/5721 [05:55<00:00, 16.10it/s]


Epoch 12/20 | G_loss=5.646 | D_loss=0.594


100%|██████████| 5721/5721 [05:54<00:00, 16.15it/s]


Epoch 13/20 | G_loss=5.610 | D_loss=0.591


100%|██████████| 5721/5721 [05:54<00:00, 16.12it/s]


Epoch 14/20 | G_loss=5.589 | D_loss=0.583


100%|██████████| 5721/5721 [06:00<00:00, 15.86it/s]


Epoch 15/20 | G_loss=5.642 | D_loss=0.581


100%|██████████| 5721/5721 [06:00<00:00, 15.88it/s]


Epoch 16/20 | G_loss=5.679 | D_loss=0.577


100%|██████████| 5721/5721 [06:00<00:00, 15.89it/s]


Epoch 17/20 | G_loss=5.724 | D_loss=0.572


100%|██████████| 5721/5721 [06:01<00:00, 15.83it/s]


Epoch 18/20 | G_loss=5.658 | D_loss=0.570


100%|██████████| 5721/5721 [06:00<00:00, 15.88it/s]


Epoch 19/20 | G_loss=5.672 | D_loss=0.570


100%|██████████| 5721/5721 [05:57<00:00, 16.02it/s]

Epoch 20/20 | G_loss=5.726 | D_loss=0.566





In [None]:
results_dir = "/content/drive/MyDrive/UMR_GAN/results"
os.makedirs(results_dir, exist_ok=True)
torch.save(G.state_dict(), f"{results_dir}/umr_pix2pix_inpaint_ssim_G.pth")
torch.save(D.state_dict(), f"{results_dir}/umr_pix2pix_inpaint_ssim_D.pth")
print("Saved to", results_dir)

Saved to /content/drive/MyDrive/UMR_GAN/results
