In [5]:
import os
import glob
import random
import time
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

from tqdm import tqdm

In [7]:
TRAIN_HR_DIR   = "/Users/nsvnathan/Downloads/div2k/DIV2K_train_HR/DIV2K_train_HR"
VAL_HR_DIR     = "/Users/nsvnathan/Downloads/div2k/DIV2K_valid_HR/DIV2K_valid_HR"

SCALE_FACTOR   = 2
HR_PATCH_SIZE  = 96
LR_PATCH_SIZE  = HR_PATCH_SIZE // SCALE_FACTOR

BATCH_SIZE     = 4
NUM_EPOCHS     = 10
LEARNING_RATE  = 1e-4
BETA1, BETA2   = 0.9, 0.999

NUM_WORKERS    = 0   
PIN_MEMORY     = False

LAMBDA_CONTENT = 1.0
LAMBDA_ADV     = 0.001

VERBOSE        = True  

In [9]:
class DIV2KDataset_HR_Only(Dataset):
    def __init__(self, hr_dir, patch_size=96, scale=4, random_crop=True):
        super().__init__()
        self.files = sorted(glob.glob(os.path.join(hr_dir, '*.png')))
        if not self.files:
            raise RuntimeError(f"No PNGs found in {hr_dir}")
        if patch_size % scale != 0:
            raise ValueError("patch_size must be divisible by scale")
        self.patch = patch_size
        self.lr_patch = patch_size // scale
        self.random_crop = random_crop
        print(f"Loaded {len(self.files)} images from {hr_dir}")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert('RGB')
        w, h = img.size

        # HR crop or resize
        if self.random_crop and w >= self.patch and h >= self.patch:
            i0 = random.randint(0, h - self.patch)
            j0 = random.randint(0, w - self.patch)
            hr = TF.crop(img, i0, j0, self.patch, self.patch)
        else:
            hr = (TF.center_crop(img, (self.patch, self.patch))
                  if w >= self.patch and h >= self.patch
                  else TF.resize(img, (self.patch, self.patch),
                                 interpolation=InterpolationMode.BICUBIC))

        # LR generation
        lr = TF.resize(hr, (self.lr_patch, self.lr_patch),
                       interpolation=InterpolationMode.BICUBIC)

        # To tensor & normalize to [-1,1]
        hr_t = TF.to_tensor(hr)
        lr_t = TF.to_tensor(lr)
        hr_t = TF.normalize(hr_t, [0.5]*3, [0.5]*3)
        lr_t = TF.normalize(lr_t, [0.5]*3, [0.5]*3)
        return lr_t, hr_t


def collate_fn(batch):
    batch = [b for b in batch if b is not None]
    return None if not batch else torch.utils.data.default_collate(batch)



In [11]:
def make_transformer(embed_dim, num_heads, mlp_ratio, dropout):
    layer = nn.TransformerEncoderLayer(
        d_model=embed_dim,
        nhead=num_heads,
        dim_feedforward=int(embed_dim * mlp_ratio),
        dropout=dropout,
        activation='gelu',
        batch_first=True
    )
    return nn.TransformerEncoder(layer, num_layers=6)



# TransGAN Generator
class TransGenerator(nn.Module):
    def __init__(self, 
                 in_ch=3, 
                 embed_dim=256, 
                 patch_size=2, 
                 mlp_ratio=4.0, 
                 num_heads=2, 
                 dropout=0.1):
        super().__init__()

        self.patch_size = patch_size
        # Patch embed: LR_PATCH_SIZE=24 → 12×12 patches if patch_size=2
        self.embed = nn.Conv2d(in_ch, embed_dim,
                               kernel_size=patch_size,
                               stride=patch_size)
        num_patches = (LR_PATCH_SIZE // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim))

        self.transformer = make_transformer(embed_dim, num_heads,
                                            mlp_ratio, dropout)

        # Unpatchify back to feature map
        self.depatch = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, embed_dim,
                               kernel_size=patch_size,
                               stride=patch_size),
            nn.ReLU(inplace=True)
        )

        # Upsample ×4 → ×2 + ×2
        ups = []
        for _ in range(int(np.log2(SCALE_FACTOR))):
            ups += [
                nn.Conv2d(embed_dim, embed_dim * 4, 3, padding=1),
                nn.PixelShuffle(2),
                nn.ReLU(inplace=True)
            ]
        self.upsample = nn.Sequential(*ups)

        self.conv_out = nn.Conv2d(embed_dim, 3, 9, padding=4)
        self.tanh     = nn.Tanh()

    def forward(self, x):
        # x: [B,3,24,24]
        B = x.size(0)
        # patchify + embed → [B, D, H', W']
        x = self.embed(x)
        # flatten → [B, N, D]
        x = x.flatten(2).transpose(1, 2)
        x = x + self.pos_embed
        # transformer → [B, N, D]
        x = self.transformer(x)
        # back to [B, D, H', W']
        D = x.size(-1)
        side = int(np.sqrt(x.size(1)))
        x = x.transpose(1, 2).view(B, D, side, side)
        x = self.depatch(x)
        x = self.upsample(x)
        x = self.conv_out(x)
        return self.tanh(x)



# TransGAN Discriminator
class TransDiscriminator(nn.Module):
    def __init__(self, 
                 in_ch=3, 
                 embed_dim=256, 
                 patch_size=4, 
                 mlp_ratio=4.0, 
                 num_heads=2, 
                 dropout=0.1):
        super().__init__()

        self.patch_size = patch_size
        # Patch embed for HR patches 96→24
        self.embed = nn.Conv2d(in_ch, embed_dim,
                               kernel_size=patch_size,
                               stride=patch_size)
        num_patches = (HR_PATCH_SIZE // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

        self.transformer = make_transformer(embed_dim, num_heads,
                                            mlp_ratio, dropout)
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, 1)

    def forward(self, x):
        B = x.size(0)
        # [B, D, 24,24]
        x = self.embed(x)
        # [B, D, N] → [B, N, D]
        x = x.flatten(2).transpose(1, 2)
        # prepend cls token
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed
        # transformer expects [B, N, D]
        x = self.transformer(x)
        # take cls token
        x = x[:, 0]
        x = self.norm(x)
        x = self.head(x)
        return x  # [B, 1]



In [13]:
def evaluate(gen, loader, device):
    gen.eval()
    tot_psnr = tot_ssim = n = 0
    with torch.no_grad():
        for batch in loader:
            if batch is None: 
                continue
            lr, hr = batch
            lr, hr = lr.to(device), hr.to(device)
            sr = gen(lr)

            # first image only
            sr_img = ((sr[0].permute(1,2,0).cpu().clamp(-1,1)+1)/2).numpy()
            hr_img = ((hr[0].permute(1,2,0).cpu().clamp(-1,1)+1)/2).numpy()

            tot_psnr += compare_psnr(hr_img, sr_img, data_range=1)
            tot_ssim += compare_ssim(hr_img, sr_img,
                                     data_range=1,
                                     channel_axis=-1,
                                     win_size=7)
            n += 1
    gen.train()
    return (tot_psnr/n, tot_ssim/n) if n else (0,0)

In [15]:
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #print("Using device:", device)

    for d in (TRAIN_HR_DIR, VAL_HR_DIR):
        if not os.path.isdir(d):
            raise RuntimeError(f"Directory not found: {d}")

    train_ds = DIV2KDataset_HR_Only(TRAIN_HR_DIR,
                                    patch_size=HR_PATCH_SIZE,
                                    scale=SCALE_FACTOR,
                                    random_crop=True)
    val_ds   = DIV2KDataset_HR_Only(VAL_HR_DIR,
                                    patch_size=HR_PATCH_SIZE,
                                    scale=SCALE_FACTOR,
                                    random_crop=False)

    train_loader = DataLoader(train_ds,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=NUM_WORKERS,
                              pin_memory=PIN_MEMORY,
                              collate_fn=collate_fn)
    val_loader = DataLoader(val_ds,
                            batch_size=1,
                            shuffle=False,
                            num_workers=NUM_WORKERS,
                            pin_memory=PIN_MEMORY,
                            collate_fn=collate_fn)

    print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

    G = TransGenerator().to(device)
    D = TransDiscriminator().to(device)

    content_loss = nn.L1Loss().to(device)
    adv_loss     = nn.BCEWithLogitsLoss().to(device)

    opt_g = optim.Adam(G.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))
    opt_d = optim.Adam(D.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))

    best_psnr = 0.0
    start = time.time()

    for epoch in range(1, NUM_EPOCHS+1):
        G.train(); D.train()
        running_g = running_d = 0.0

        loop = train_loader
        if VERBOSE:
            loop = tqdm(train_loader,
                        desc=f"Epoch {epoch}/{NUM_EPOCHS}",
                        unit="batch")

        for batch in loop:
            if batch is None: 
                continue
            lr_b, hr_b = batch
            lr_b, hr_b = lr_b.to(device), hr_b.to(device)

            # Discriminator step
            opt_d.zero_grad()
            real_out = D(hr_b)
            with torch.no_grad():
                fake = G(lr_b)
            fake_out = D(fake)
            real_lbl = torch.ones_like(real_out)
            fake_lbl = torch.zeros_like(fake_out)
            loss_d = 0.5 * (adv_loss(real_out, real_lbl) +
                            adv_loss(fake_out, fake_lbl))
            loss_d.backward()
            opt_d.step()

            # Generator step
            opt_g.zero_grad()
            sr_b = G(lr_b)
            adv_d = adv_loss(D(sr_b), real_lbl)
            cont  = content_loss(sr_b, hr_b)
            loss_g = LAMBDA_CONTENT * cont + LAMBDA_ADV * adv_d
            loss_g.backward()
            opt_g.step()

            running_d += loss_d.item()
            running_g += loss_g.item()

            if VERBOSE:
                loop.set_postfix(D=f"{loss_d.item():.4f}",
                                 G=f"{loss_g.item():.4f}")

    
        avg_d = running_d / len(train_loader)
        avg_g = running_g / len(train_loader)
        print(f"Epoch {epoch}/{NUM_EPOCHS}  D_loss: {avg_d:.4f}  G_loss: {avg_g:.4f}")

     
        psnr_val, ssim_val = evaluate(G, val_loader, device)
        print(f" → Val PSNR: {psnr_val:.4f} dB  SSIM: {ssim_val:.4f}")

      
        if psnr_val > best_psnr:
            best_psnr = psnr_val
            torch.save(G.state_dict(), "transgen_best.pth")
            torch.save(D.state_dict(), "transdisc_best.pth")

    total_min = (time.time() - start) / 60
    print(f"\nTraining complete in {total_min:.1f} min. Best PSNR: {best_psnr:.4f} dB")


Loaded 800 images from /Users/nsvnathan/Downloads/div2k/DIV2K_train_HR/DIV2K_train_HR
Loaded 100 images from /Users/nsvnathan/Downloads/div2k/DIV2K_valid_HR/DIV2K_valid_HR
Train batches: 200, Val batches: 100


Epoch 1/10: 100%|██████| 200/200 [10:05<00:00,  3.03s/batch, D=0.6547, G=0.1712]


Epoch 1/10  D_loss: 0.6840  G_loss: 0.1947
 → Val PSNR: 20.0949 dB  SSIM: 0.5093


Epoch 2/10: 100%|██████| 200/200 [10:07<00:00,  3.04s/batch, D=0.6671, G=0.1063]


Epoch 2/10  D_loss: 0.6716  G_loss: 0.1221
 → Val PSNR: 21.7878 dB  SSIM: 0.5568


Epoch 3/10: 100%|██████| 200/200 [10:10<00:00,  3.05s/batch, D=0.5349, G=0.0893]


Epoch 3/10  D_loss: 0.6331  G_loss: 0.1074
 → Val PSNR: 23.6374 dB  SSIM: 0.6463


Epoch 4/10: 100%|██████| 200/200 [09:58<00:00,  2.99s/batch, D=0.6065, G=0.0778]


Epoch 4/10  D_loss: 0.6128  G_loss: 0.0928
 → Val PSNR: 24.1893 dB  SSIM: 0.6842


Epoch 5/10: 100%|██████| 200/200 [10:11<00:00,  3.06s/batch, D=0.6264, G=0.1030]


Epoch 5/10  D_loss: 0.6096  G_loss: 0.0899
 → Val PSNR: 25.0596 dB  SSIM: 0.7186


Epoch 6/10: 100%|██████| 200/200 [10:16<00:00,  3.08s/batch, D=0.6538, G=0.0800]


Epoch 6/10  D_loss: 0.6386  G_loss: 0.0825
 → Val PSNR: 25.3156 dB  SSIM: 0.7321


Epoch 7/10: 100%|██████| 200/200 [10:19<00:00,  3.10s/batch, D=0.6976, G=0.0631]


Epoch 7/10  D_loss: 0.6522  G_loss: 0.0757
 → Val PSNR: 25.9979 dB  SSIM: 0.7528


Epoch 8/10: 100%|██████| 200/200 [10:18<00:00,  3.09s/batch, D=0.7445, G=0.1090]


Epoch 8/10  D_loss: 0.6592  G_loss: 0.0745
 → Val PSNR: 25.6842 dB  SSIM: 0.7629


Epoch 9/10: 100%|██████| 200/200 [21:34<00:00,  6.47s/batch, D=0.6751, G=0.0777]


Epoch 9/10  D_loss: 0.6684  G_loss: 0.0734
 → Val PSNR: 25.7926 dB  SSIM: 0.7671


Epoch 10/10: 100%|█████| 200/200 [10:16<00:00,  3.08s/batch, D=0.6218, G=0.0641]


Epoch 10/10  D_loss: 0.6730  G_loss: 0.0707
 → Val PSNR: 26.4782 dB  SSIM: 0.7748

Training complete in 115.6 min. Best PSNR: 26.4782 dB
