**Setup and Imports**

In [None]:
!pip install torch torchvision albumentations torchmetrics lpips matplotlib tqdm

Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics, lpips
Successfully installed lightning-utilities-0.15.2 lpips-0.1.4 torchmetrics-1.8.2


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch, torch.nn as nn  # in case nn wasn't in scope
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [None]:
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2, os, glob, numpy as np, matplotlib.pyplot as plt, json
from tqdm import tqdm
from torchmetrics import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure

**Dataset Loader**

In [None]:
class MRIDataset(Dataset):
    def __init__(self, root):
        # recursively find all images
        self.files = glob.glob(os.path.join(root, "**/*.png"), recursive=True)
        if len(self.files) == 0:
            self.files = glob.glob(os.path.join(root, "**/*.jpg"), recursive=True)
        print(f"Found {len(self.files)} images in {root}")

        self.tf = A.Compose([
        A.Resize(256,256),
        A.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0),  # maps to [-1, 1]
           ToTensorV2()
        ])

    def __len__(self): return len(self.files)

    def __getitem__(self, idx):
        clean = cv2.imread(self.files[idx], cv2.IMREAD_GRAYSCALE)
        # create synthetic noise for training
        noisy = clean + np.random.normal(0, 15, clean.shape)
        noisy = np.clip(noisy, 0, 255).astype(np.uint8)
        clean = self.tf(image=clean)["image"]
        noisy = self.tf(image=noisy)["image"]
        return noisy, clean


**Loading Training Dataset**

In [None]:
train_path = "/content/drive/MyDrive/Training"

train_ds = MRIDataset(train_path)
train_dl = DataLoader(train_ds, batch_size=1, shuffle=True)


Found 5721 images in /content/drive/MyDrive/Training


**Generator (U-Net 256)**

In [None]:
class UNetGenerator(nn.Module):
    def __init__(self, in_c=1, out_c=1):
        super().__init__()
        def down(i,o,bn=True):
            layers=[nn.Conv2d(i,o,4,2,1,bias=False)]
            if bn: layers.append(nn.BatchNorm2d(o))
            layers.append(nn.LeakyReLU(0.2,True))
            return nn.Sequential(*layers)
        def up(i,o,drop=False):
            layers=[nn.ConvTranspose2d(i,o,4,2,1,bias=False),
                    nn.BatchNorm2d(o),
                    nn.ReLU(True)]
            if drop: layers.append(nn.Dropout(0.5))
            return nn.Sequential(*layers)
        self.d1=down(in_c,64,False); self.d2=down(64,128)
        self.d3=down(128,256); self.d4=down(256,512)
        self.d5=down(512,512); self.d6=down(512,512)
        self.d7=down(512,512); self.b=down(512,512,bn=False)
        self.u1=up(512,512,True); self.u2=up(1024,512,True)
        self.u3=up(1024,512,True); self.u4=up(1024,512)
        self.u5=up(1024,256); self.u6=up(512,128)
        self.u7=up(256,64)
        self.out=nn.Sequential(nn.ConvTranspose2d(128,out_c,4,2,1), nn.Tanh())
    def forward(self,x):
        d1=self.d1(x); d2=self.d2(d1); d3=self.d3(d2); d4=self.d4(d3)
        d5=self.d5(d4); d6=self.d6(d5); d7=self.d7(d6); b=self.b(d7)
        u1=self.u1(b); u2=self.u2(torch.cat([u1,d7],1))
        u3=self.u3(torch.cat([u2,d6],1)); u4=self.u4(torch.cat([u3,d5],1))
        u5=self.u5(torch.cat([u4,d4],1)); u6=self.u6(torch.cat([u5,d3],1))
        u7=self.u7(torch.cat([u6,d2],1))
        return self.out(torch.cat([u7,d1],1))


**Discriminator (PatchGAN)**

In [None]:
class PatchDiscriminator(nn.Module):
    def __init__(self, in_c=2):
        super().__init__()
        def block(i,o,norm=True):
            layers=[nn.Conv2d(i,o,4,2,1)]
            if norm: layers.append(nn.BatchNorm2d(o))
            layers.append(nn.LeakyReLU(0.2,True))
            return layers
        self.net=nn.Sequential(
            *block(in_c,64,False),
            *block(64,128),
            *block(128,256),
            nn.Conv2d(256,1,4,1,1))
    def forward(self,x,y):
        return self.net(torch.cat([x,y],1))


**Initialize Models & Losses**

In [None]:
G, D = UNetGenerator().to(device), PatchDiscriminator().to(device)
optG = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5,0.999))
optD = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))
bce, l1 = nn.BCEWithLogitsLoss(), nn.L1Loss()

**Training loop**

In [None]:
EPOCHS, λ = 20, 100

for epoch in range(EPOCHS):
    G.train(); D.train()
    g_losses, d_losses = [], []

    for noisy, clean in tqdm(train_dl):
        noisy, clean = noisy.to(device), clean.to(device)
        fake = G(noisy)

        # --- Discriminator ---
        D.zero_grad()
        real_out = D(noisy, clean)
        fake_out = D(noisy, fake.detach())
        d_loss = (bce(real_out, torch.ones_like(real_out)) +
                  bce(fake_out, torch.zeros_like(fake_out))) * 0.5
        d_loss.backward(); optD.step()

        # --- Generator ---
        G.zero_grad()
        fake_out = D(noisy, fake)
        g_adv = bce(fake_out, torch.ones_like(fake_out))
        g_l1  = l1(fake, clean)
        g_loss = g_adv + λ * g_l1
        g_loss.backward(); optG.step()

        g_losses.append(g_loss.item()); d_losses.append(d_loss.item())

    print(f"Epoch {epoch+1}/{EPOCHS} | G_loss={np.mean(g_losses):.3f} | D_loss={np.mean(d_losses):.3f}")

# Save model to Drive
os.makedirs("/content/drive/MyDrive/UMR_GAN/results", exist_ok=True)
torch.save(G.state_dict(), "/content/drive/MyDrive/UMR_GAN/results/umr_pix2pix.pth")
print("Model saved to Drive!")


100%|██████████| 5721/5721 [35:46<00:00,  2.67it/s]


Epoch 1/20 | G_loss=6.850 | D_loss=0.392


100%|██████████| 5721/5721 [05:49<00:00, 16.37it/s]


Epoch 2/20 | G_loss=6.028 | D_loss=0.415


100%|██████████| 5721/5721 [05:53<00:00, 16.18it/s]


Epoch 3/20 | G_loss=5.611 | D_loss=0.443


100%|██████████| 5721/5721 [05:54<00:00, 16.15it/s]


Epoch 4/20 | G_loss=5.449 | D_loss=0.449


100%|██████████| 5721/5721 [05:51<00:00, 16.26it/s]


Epoch 5/20 | G_loss=5.325 | D_loss=0.452


100%|██████████| 5721/5721 [05:52<00:00, 16.25it/s]


Epoch 6/20 | G_loss=5.343 | D_loss=0.445


100%|██████████| 5721/5721 [05:50<00:00, 16.32it/s]


Epoch 7/20 | G_loss=5.600 | D_loss=0.428


100%|██████████| 5721/5721 [05:51<00:00, 16.25it/s]


Epoch 8/20 | G_loss=5.370 | D_loss=0.439


100%|██████████| 5721/5721 [05:52<00:00, 16.23it/s]


Epoch 9/20 | G_loss=5.273 | D_loss=0.449


100%|██████████| 5721/5721 [05:52<00:00, 16.25it/s]


Epoch 10/20 | G_loss=5.320 | D_loss=0.439


100%|██████████| 5721/5721 [05:52<00:00, 16.25it/s]


Epoch 11/20 | G_loss=5.172 | D_loss=0.453


100%|██████████| 5721/5721 [05:52<00:00, 16.24it/s]


Epoch 12/20 | G_loss=5.206 | D_loss=0.444


100%|██████████| 5721/5721 [05:52<00:00, 16.21it/s]


Epoch 13/20 | G_loss=5.405 | D_loss=0.420


100%|██████████| 5721/5721 [05:52<00:00, 16.21it/s]


Epoch 14/20 | G_loss=5.213 | D_loss=0.436


100%|██████████| 5721/5721 [05:50<00:00, 16.32it/s]


Epoch 15/20 | G_loss=5.400 | D_loss=0.414


100%|██████████| 5721/5721 [05:50<00:00, 16.33it/s]


Epoch 16/20 | G_loss=5.312 | D_loss=0.420


100%|██████████| 5721/5721 [05:53<00:00, 16.18it/s]


Epoch 17/20 | G_loss=5.387 | D_loss=0.405


100%|██████████| 5721/5721 [05:52<00:00, 16.25it/s]


Epoch 18/20 | G_loss=5.474 | D_loss=0.390


100%|██████████| 5721/5721 [05:51<00:00, 16.26it/s]


Epoch 19/20 | G_loss=5.677 | D_loss=0.369


100%|██████████| 5721/5721 [05:54<00:00, 16.13it/s]


Epoch 20/20 | G_loss=5.601 | D_loss=0.374
Model saved to Drive!
