In [4]:
pip install torch

Collecting torch
  Downloading torch-2.7.0-cp310-none-macosx_11_0_arm64.whl.metadata (29 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.3.2-py3-none-any.whl.metadata (11 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch)
  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torch-2.7.0-cp310-none-macosx_11_0_arm64.whl (68.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.6/68.6 MB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading filelock-3.18.0-py3-none-any.whl (16 kB)
Downloading fsspec-2025.3.2-py3-none-any

In [9]:
pip install torchvision

Collecting torchvision
  Downloading torchvision-0.22.0-cp310-cp310-macosx_11_0_arm64.whl.metadata (6.1 kB)
Downloading torchvision-0.22.0-cp310-cp310-macosx_11_0_arm64.whl (1.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchvision
Successfully installed torchvision-0.22.0
Note: you may need to restart the kernel to use updated packages.


In [55]:
import os
import glob
import random
import time
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from tqdm import tqdm

In [57]:
TRAIN_HR_DIR   = "/Users/snehabejugam/Downloads/div2k/DIV2K_train_HR/DIV2K_train_HR"
VAL_HR_DIR     = "/Users/snehabejugam/Downloads/div2k/DIV2K_valid_HR/DIV2K_valid_HR"
SCALE_FACTOR   = 2
HR_PATCH_SIZE  = 256
BATCH_SIZE     = 16
NUM_EPOCHS     = 10
LEARNING_RATE  = 1e-4
BETA1, BETA2   = 0.9, 0.999
NUM_WORKERS    = 0   
PIN_MEMORY     = False  
LAMBDA_CONTENT = 1.0
LAMBDA_ADV     = 0.001
VERBOSE = True

In [59]:
class DIV2KDataset_HR_Only(Dataset):
    def __init__(self, hr_dir, scale=4, patch_size=96, use_random_crop=True):
        super().__init__()
        self.hr_files = sorted(glob.glob(os.path.join(hr_dir, '*.png')))
        if not self.hr_files:
            raise FileNotFoundError(f"No PNG images found in {hr_dir}")
        if patch_size % scale != 0:
            raise ValueError("patch_size must be divisible by scale")
        self.scale = scale
        self.patch_size = patch_size
        self.lr_patch = patch_size // scale
        self.random_crop = use_random_crop
        print(f"Loaded {len(self.hr_files)} HR images from {hr_dir}")

    def __len__(self):
        return len(self.hr_files)

    def __getitem__(self, idx):
        path = self.hr_files[idx]
        try:
            hr = Image.open(path).convert('RGB')
        except:
            return None

        w, h = hr.size
        if self.random_crop:
            if w < self.patch_size or h < self.patch_size:
                hr_patch = TF.resize(hr, (self.patch_size, self.patch_size),
                                     interpolation=InterpolationMode.BICUBIC)
            else:
                i = random.randint(0, h - self.patch_size)
                j = random.randint(0, w - self.patch_size)
                hr_patch = TF.crop(hr, i, j, self.patch_size, self.patch_size)
        else:
            if w < self.patch_size or h < self.patch_size:
                hr_patch = TF.resize(hr, (self.patch_size, self.patch_size),
                                     interpolation=InterpolationMode.BICUBIC)
            else:
                hr_patch = TF.center_crop(hr, (self.patch_size, self.patch_size))

        # --- Generate LR patch ---
        lr_patch = TF.resize(hr_patch,
                             (self.lr_patch, self.lr_patch),
                             interpolation=InterpolationMode.BICUBIC)

        # --- To tensor & normalize to [-1,1] ---
        hr_t = TF.to_tensor(hr_patch)
        lr_t = TF.to_tensor(lr_patch)
        hr_t = TF.normalize(hr_t, [0.5]*3, [0.5]*3)
        lr_t = TF.normalize(lr_t, [0.5]*3, [0.5]*3)
        return lr_t, hr_t


def custom_collate(batch):
    batch = [b for b in batch if b is not None]
    return None if not batch else torch.utils.data.default_collate(batch)

In [61]:
class ResidualBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c, c, 3, padding=1), nn.BatchNorm2d(c), nn.ReLU(inplace=True),
            nn.Conv2d(c, c, 3, padding=1), nn.BatchNorm2d(c)
        )
    def forward(self, x):
        return x + self.net(x)


class Generator(nn.Module):
    def __init__(self, scale=4, n_res=5, c=64):
        super().__init__()
        self.conv_in  = nn.Conv2d(3, c, 9, padding=4)
        self.relu     = nn.ReLU(inplace=True)
        self.res_blocks = nn.Sequential(*(ResidualBlock(c) for _ in range(n_res)))
        self.conv_mid = nn.Conv2d(c, c, 3, padding=1)
        up = []
        for _ in range(int(np.log2(scale))):
            up += [nn.Conv2d(c, c*4, 3, padding=1),
                   nn.PixelShuffle(2),
                   nn.ReLU(inplace=True)]
        self.up       = nn.Sequential(*up)
        self.conv_out = nn.Conv2d(c, 3, 9, padding=4)
        self.tanh     = nn.Tanh()

    def forward(self, x):
        x1 = self.relu(self.conv_in(x))
        x2 = self.res_blocks(x1)
        x3 = self.conv_mid(x2) + x1
        x4 = self.up(x3)
        return self.tanh(self.conv_out(x4))


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        def block(in_c, out_c, bn=True):
            layers = [nn.Conv2d(in_c, out_c, 4, 2, 1, bias=not bn)]
            if bn: layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.net = nn.Sequential(
            *block(3, 64, bn=False),
            *block(64,128),
            *block(128,256),
            *block(256,512),
            nn.Conv2d(512,1,4,1,1)
        )
    def forward(self, x):
        return self.net(x)

In [63]:
def evaluate(gen, loader, device):
    gen.eval()
    tot_psnr = tot_ssim = n = 0
    with torch.no_grad():
        for batch in loader:
            if batch is None: continue
            lr, hr = batch
            lr, hr = lr.to(device), hr.to(device)
            sr = gen(lr)

            # take first image
            sr_img = ((sr[0].permute(1,2,0).cpu().clamp(-1,1)+1)/2).numpy()
            hr_img = ((hr[0].permute(1,2,0).cpu().clamp(-1,1)+1)/2).numpy()

            psnr = compare_psnr(hr_img, sr_img, data_range=1)
            ssim = compare_ssim(hr_img, sr_img, data_range=1, channel_axis=-1, win_size=7)
            tot_psnr += psnr
            tot_ssim += ssim
            n += 1
    gen.train()
    return (tot_psnr/n, tot_ssim/n) if n else (0,0)


In [65]:
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for d in (TRAIN_HR_DIR, VAL_HR_DIR):
        if not os.path.isdir(d):
            print(f"ERROR: directory not found: {d}")
            exit(1)

    train_ds = DIV2KDataset_HR_Only(
        TRAIN_HR_DIR, scale=SCALE_FACTOR,
        patch_size=HR_PATCH_SIZE, use_random_crop=True
    )
    val_ds = DIV2KDataset_HR_Only(
        VAL_HR_DIR, scale=SCALE_FACTOR,
        patch_size=HR_PATCH_SIZE, use_random_crop=False
    )

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        collate_fn=custom_collate
    )
    val_loader = DataLoader(
        val_ds, batch_size=1, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        collate_fn=custom_collate
    )

    print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

    G = Generator(scale=SCALE_FACTOR).to(device)
    D = Discriminator().to(device)

    content_loss = nn.L1Loss().to(device)
    adv_loss     = nn.BCEWithLogitsLoss().to(device)

    opt_g = optim.Adam(G.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))
    opt_d = optim.Adam(D.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))

    best_psnr    = 0.0
    best_epoch   = 0
    psnr_history = []
    ssim_history = []
    t0           = time.time()

    for epoch in range(1, NUM_EPOCHS + 1):
        G.train(); D.train()
        running_g = running_d = 0.0

        loop = train_loader
        if VERBOSE:
            loop = tqdm(train_loader,
                        desc=f"Epoch {epoch}/{NUM_EPOCHS}",
                        unit="batch")

        for batch in loop:
            if batch is None: 
                continue
            lr_b, hr_b = batch
            lr_b, hr_b = lr_b.to(device), hr_b.to(device)

           
            with torch.no_grad():
                out_shape = D(hr_b).shape
            real_lbl = torch.ones(out_shape, device=device)
            fake_lbl = torch.zeros(out_shape, device=device)

           
            opt_d.zero_grad()
            real_out  = D(hr_b)
            fake_img  = G(lr_b).detach()
            fake_out  = D(fake_img)
            loss_d    = 0.5 * (adv_loss(real_out, real_lbl) +
                              adv_loss(fake_out, fake_lbl))
            loss_d.backward(); opt_d.step()

       
            opt_g.zero_grad()
            gen_img = G(lr_b)
            loss_g  = (LAMBDA_CONTENT * content_loss(gen_img, hr_b) +
                       LAMBDA_ADV * adv_loss(D(gen_img), real_lbl))
            loss_g.backward(); opt_g.step()

            running_d += loss_d.item()
            running_g += loss_g.item()

        avg_d = running_d / len(train_loader)
        avg_g = running_g / len(train_loader)
        print(f"Epoch {epoch}/{NUM_EPOCHS} | D_loss: {avg_d:.4f} | G_loss: {avg_g:.4f}")

       
        psnr_val, ssim_val = evaluate(G, val_loader, device)
        print(f" → Val PSNR: {psnr_val:.4f} dB | SSIM: {ssim_val:.4f}")

        psnr_history.append(psnr_val)
        ssim_history.append(ssim_val)

        if psnr_val > best_psnr:
            best_psnr  = psnr_val
            best_epoch = epoch
            torch.save(G.state_dict(), "generator_best.pth")
            torch.save(D.state_dict(), "discriminator_best.pth")

  
    total_min  = (time.time() - t0) / 60
    avg_psnr   = sum(psnr_history) / len(psnr_history)
    avg_ssim   = sum(ssim_history) / len(ssim_history)

    print(f"\nTraining complete in {total_min:.1f} min")
    print(f"Average PSNR over {len(psnr_history)} epochs : {avg_psnr:.4f} dB")
    print(f"Average SSIM over {len(ssim_history)} epochs: {avg_ssim:.4f}")
    print(f"Highest PSNR was {best_psnr:.4f} dB at epoch {best_epoch}")


Loaded 800 HR images from /Users/snehabejugam/Downloads/div2k/DIV2K_train_HR/DIV2K_train_HR
Loaded 100 HR images from /Users/snehabejugam/Downloads/div2k/DIV2K_valid_HR/DIV2K_valid_HR
Train batches: 50, Val batches: 100


Epoch 1/10: 100%|████████████████████████████| 50/50 [13:58<00:00, 16.78s/batch]


Epoch 1/10 | D_loss: 0.5937 | G_loss: 0.2129
 → Val PSNR: 20.2583 dB | SSIM: 0.4750


Epoch 2/10: 100%|████████████████████████████| 50/50 [13:58<00:00, 16.77s/batch]


Epoch 2/10 | D_loss: 0.2037 | G_loss: 0.1397
 → Val PSNR: 22.2671 dB | SSIM: 0.5668


Epoch 3/10: 100%|████████████████████████████| 50/50 [14:00<00:00, 16.81s/batch]


Epoch 3/10 | D_loss: 0.0914 | G_loss: 0.1227
 → Val PSNR: 22.6553 dB | SSIM: 0.5987


Epoch 4/10: 100%|████████████████████████████| 50/50 [14:07<00:00, 16.94s/batch]


Epoch 4/10 | D_loss: 0.1033 | G_loss: 0.1129
 → Val PSNR: 23.3021 dB | SSIM: 0.6449


Epoch 5/10: 100%|████████████████████████████| 50/50 [14:05<00:00, 16.92s/batch]


Epoch 5/10 | D_loss: 0.1177 | G_loss: 0.1028
 → Val PSNR: 23.6722 dB | SSIM: 0.6703


Epoch 6/10: 100%|████████████████████████████| 50/50 [31:15<00:00, 37.51s/batch]


Epoch 6/10 | D_loss: 0.0988 | G_loss: 0.0984
 → Val PSNR: 24.1227 dB | SSIM: 0.6951


Epoch 7/10: 100%|████████████████████████████| 50/50 [16:05<00:00, 19.31s/batch]


Epoch 7/10 | D_loss: 0.1374 | G_loss: 0.0928
 → Val PSNR: 24.7016 dB | SSIM: 0.7153


Epoch 8/10: 100%|████████████████████████████| 50/50 [52:56<00:00, 63.54s/batch]


Epoch 8/10 | D_loss: 0.1034 | G_loss: 0.0907
 → Val PSNR: 24.8592 dB | SSIM: 0.7301


Epoch 9/10: 100%|████████████████████████████| 50/50 [18:43<00:00, 22.48s/batch]


Epoch 9/10 | D_loss: 0.2006 | G_loss: 0.0852
 → Val PSNR: 25.0509 dB | SSIM: 0.7454


Epoch 10/10: 100%|███████████████████████████| 50/50 [18:20<00:00, 22.00s/batch]


Epoch 10/10 | D_loss: 0.2527 | G_loss: 0.0812
 → Val PSNR: 25.6096 dB | SSIM: 0.7583

Training complete in 212.9 min
Average PSNR over 10 epochs : 23.6499 dB
Average SSIM over 10 epochs: 0.6600
Highest PSNR was 25.6096 dB at epoch 10
