In [2]:
import os
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
import torch.nn.functional as F

class DIV2KPatchDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, patch_size=48, scale=2):
        self.hr_files = sorted(glob.glob(os.path.join(hr_dir, "*.*")))
        self.lr_files = sorted(glob.glob(os.path.join(lr_dir, "*.*")))
        self.patch_size = patch_size
        self.scale = scale

    def __len__(self):
        return len(self.hr_files)

    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_files[idx]).convert("RGB")
        lr_img = Image.open(self.lr_files[idx]).convert("RGB")
        hr_img = transforms.ToTensor()(hr_img)
        lr_img = transforms.ToTensor()(lr_img)
        _, h, w = lr_img.shape
        top = torch.randint(0, h - self.patch_size + 1, (1,)).item()
        left = torch.randint(0, w - self.patch_size + 1, (1,)).item()
        lr_patch = lr_img[:, top:top+self.patch_size, left:left+self.patch_size]
        hr_patch = hr_img[:, top*self.scale:(top+self.patch_size)*self.scale,
                          left*self.scale:(left+self.patch_size)*self.scale]
        return lr_patch, hr_patch

class ResidualBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.PReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64)
        )
    def forward(self, x):
        return x + self.block(x)

class UpsampleBlock(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(64, 64 * scale * scale, 3, padding=1),
            nn.PixelShuffle(scale),
            nn.PReLU()
        )
    def forward(self, x):
        return self.block(x)

class SRResNet(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 9, padding=4)
        self.prelu = nn.PReLU()
        self.res_blocks = nn.Sequential(*[ResidualBlock() for _ in range(16)])
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.upsample = UpsampleBlock(scale)
        self.conv3 = nn.Conv2d(64, 3, 9, padding=4)

    def forward(self, x):
        x1 = self.prelu(self.conv1(x))
        x = self.res_blocks(x1)
        x = self.bn2(self.conv2(x))
        x = x + x1
        x = self.upsample(x)
        x = self.conv3(x)
        return x

base_path = r"D:\DL Project"
scale = 2
train_hr = os.path.join(base_path, "DIV2K_train_HR")
train_lr = os.path.join(base_path, f"DIV2K_train_LR_x{scale}")
valid_hr = os.path.join(base_path, "DIV2K_valid_HR")
valid_lr = os.path.join(base_path, f"DIV2K_valid_LR_x{scale}")

batch_size = 4
lr = 1e-4
epochs = 20
patch_size = 48

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = DIV2KPatchDataset(train_hr, train_lr, patch_size=patch_size, scale=scale)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataset = DIV2KPatchDataset(valid_hr, valid_lr, patch_size=patch_size, scale=scale)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

model = SRResNet(scale).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

def psnr(pred, target):
    mse = nn.functional.mse_loss(pred, target)
    return 10 * torch.log10(1 / mse)

for epoch in range(1, epochs + 1):
    model.train()
    running_loss = 0
    for lr_patches, hr_patches in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
        lr_patches = lr_patches.to(device)
        hr_patches = hr_patches.to(device)
        optimizer.zero_grad()
        outputs = model(lr_patches)
        loss = criterion(outputs, hr_patches)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * lr_patches.size(0)
    print(f"Epoch [{epoch}/{epochs}] Loss: {running_loss / len(train_loader.dataset):.6f}")
    model.eval()
    total = 0
    with torch.no_grad():
        for lr_patches, hr_patches in valid_loader:
            lr_patches = lr_patches.to(device)
            hr_patches = hr_patches.to(device)
            outputs = model(lr_patches)
            total += psnr(outputs, hr_patches).item()
    print(f"Validation PSNR: {total / len(valid_loader):.2f} dB")
    torch.save(model.state_dict(), os.path.join(base_path, f"srresnet_epoch{epoch}.pth"))

print("Training complete")


Epoch 1/20: 100%|██████████| 200/200 [02:07<00:00,  1.57it/s]


Epoch [1/20] Loss: 0.020150
Validation PSNR: 23.03 dB


Epoch 2/20: 100%|██████████| 200/200 [02:00<00:00,  1.66it/s]


Epoch [2/20] Loss: 0.006338
Validation PSNR: 25.03 dB


Epoch 3/20: 100%|██████████| 200/200 [02:00<00:00,  1.65it/s]


Epoch [3/20] Loss: 0.004416
Validation PSNR: 27.53 dB


Epoch 4/20: 100%|██████████| 200/200 [02:02<00:00,  1.63it/s]


Epoch [4/20] Loss: 0.003376
Validation PSNR: 27.29 dB


Epoch 5/20: 100%|██████████| 200/200 [02:03<00:00,  1.63it/s]


Epoch [5/20] Loss: 0.003183
Validation PSNR: 27.96 dB


Epoch 6/20: 100%|██████████| 200/200 [02:02<00:00,  1.63it/s]


Epoch [6/20] Loss: 0.002641
Validation PSNR: 28.25 dB


Epoch 7/20: 100%|██████████| 200/200 [02:02<00:00,  1.63it/s]


Epoch [7/20] Loss: 0.002432
Validation PSNR: 28.09 dB


Epoch 8/20: 100%|██████████| 200/200 [02:02<00:00,  1.63it/s]


Epoch [8/20] Loss: 0.002405
Validation PSNR: 30.59 dB


Epoch 9/20: 100%|██████████| 200/200 [02:02<00:00,  1.64it/s]


Epoch [9/20] Loss: 0.002052
Validation PSNR: 29.86 dB


Epoch 10/20: 100%|██████████| 200/200 [02:02<00:00,  1.63it/s]


Epoch [10/20] Loss: 0.002053
Validation PSNR: 30.48 dB


Epoch 11/20: 100%|██████████| 200/200 [02:02<00:00,  1.63it/s]


Epoch [11/20] Loss: 0.002060
Validation PSNR: 28.84 dB


Epoch 12/20: 100%|██████████| 200/200 [02:02<00:00,  1.63it/s]


Epoch [12/20] Loss: 0.001816
Validation PSNR: 30.05 dB


Epoch 13/20: 100%|██████████| 200/200 [02:02<00:00,  1.63it/s]


Epoch [13/20] Loss: 0.001824
Validation PSNR: 28.46 dB


Epoch 14/20: 100%|██████████| 200/200 [02:03<00:00,  1.62it/s]


Epoch [14/20] Loss: 0.001821
Validation PSNR: 30.93 dB


Epoch 15/20: 100%|██████████| 200/200 [02:02<00:00,  1.64it/s]


Epoch [15/20] Loss: 0.001666
Validation PSNR: 29.83 dB


Epoch 16/20: 100%|██████████| 200/200 [02:01<00:00,  1.65it/s]


Epoch [16/20] Loss: 0.001745
Validation PSNR: 29.65 dB


Epoch 17/20: 100%|██████████| 200/200 [02:01<00:00,  1.65it/s]


Epoch [17/20] Loss: 0.001725
Validation PSNR: 31.99 dB


Epoch 18/20: 100%|██████████| 200/200 [02:00<00:00,  1.66it/s]


Epoch [18/20] Loss: 0.001680
Validation PSNR: 30.21 dB


Epoch 19/20: 100%|██████████| 200/200 [01:59<00:00,  1.67it/s]


Epoch [19/20] Loss: 0.001593
Validation PSNR: 30.43 dB


Epoch 20/20: 100%|██████████| 200/200 [02:00<00:00,  1.66it/s]


Epoch [20/20] Loss: 0.001620
Validation PSNR: 30.95 dB
Training complete
