In [1]:
import os
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
import torch.nn.functional as F

class DIV2KPatchDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, patch_size=33, scale=2):
        self.hr_files = sorted(glob.glob(os.path.join(hr_dir, "*.*")))
        self.lr_files = sorted(glob.glob(os.path.join(lr_dir, "*.*")))
        self.patch_size = patch_size
        self.scale = scale

    def __len__(self):
        return len(self.hr_files)

    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_files[idx]).convert("RGB")
        lr_img = Image.open(self.lr_files[idx]).convert("RGB")
        hr_img = transforms.ToTensor()(hr_img)
        lr_img = transforms.ToTensor()(lr_img)

        _, h, w = lr_img.shape
        top = torch.randint(0, h - self.patch_size + 1, (1,)).item()
        left = torch.randint(0, w - self.patch_size + 1, (1,)).item()

        lr_patch = lr_img[:, top:top+self.patch_size, left:left+self.patch_size]
        hr_patch = hr_img[:, top*self.scale:(top+self.patch_size)*self.scale,
                          left*self.scale:(left+self.patch_size)*self.scale]
        return lr_patch, hr_patch

class ESPCN(nn.Module):
    def __init__(self, scale):
        super(ESPCN, self).__init__()
        self.layer1 = nn.Conv2d(3, 64, 5, padding=2)
        self.layer2 = nn.Conv2d(64, 32, 3, padding=1)
        self.layer3 = nn.Conv2d(32, 3 * scale * scale, 3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.layer3(x)
        x = self.pixel_shuffle(x)
        return x

base_path = r"D:\DL Project"
scale = 2
train_hr = os.path.join(base_path, "DIV2K_train_HR")
train_lr = os.path.join(base_path, f"DIV2K_train_LR_x{scale}")
valid_hr = os.path.join(base_path, "DIV2K_valid_HR")
valid_lr = os.path.join(base_path, f"DIV2K_valid_LR_x{scale}")

batch_size = 4
learning_rate = 1e-4
num_epochs = 20
patch_size = 33

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = DIV2KPatchDataset(train_hr, train_lr, patch_size=patch_size, scale=scale)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = DIV2KPatchDataset(valid_hr, valid_lr, patch_size=patch_size, scale=scale)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

model = ESPCN(scale).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

def psnr(pred, target):
    mse = nn.functional.mse_loss(pred, target)
    return 10 * torch.log10(1 / mse)

for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0
    for lr_patches, hr_patches in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
        lr_patches = lr_patches.to(device)
        hr_patches = hr_patches.to(device)
        optimizer.zero_grad()
        outputs = model(lr_patches)
        loss = criterion(outputs, hr_patches)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * lr_patches.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch}/{num_epochs}] Loss: {epoch_loss:.6f}")

    model.eval()
    psnr_total = 0
    with torch.no_grad():
        for lr_patches, hr_patches in valid_loader:
            lr_patches = lr_patches.to(device)
            hr_patches = hr_patches.to(device)
            outputs = model(lr_patches)
            psnr_total += psnr(outputs, hr_patches).item()

    print(f"Validation PSNR: {psnr_total / len(valid_loader):.2f} dB")
    torch.save(model.state_dict(), os.path.join(base_path, f"espcn_epoch{epoch}.pth"))

print("Training complete")


Epoch 1/20: 100%|██████████| 200/200 [02:34<00:00,  1.29it/s]


Epoch [1/20] Loss: 0.059943
Validation PSNR: 19.84 dB


Epoch 2/20: 100%|██████████| 200/200 [02:06<00:00,  1.57it/s]


Epoch [2/20] Loss: 0.009376
Validation PSNR: 23.54 dB


Epoch 3/20: 100%|██████████| 200/200 [02:00<00:00,  1.65it/s]


Epoch [3/20] Loss: 0.004707
Validation PSNR: 26.68 dB


Epoch 4/20: 100%|██████████| 200/200 [01:47<00:00,  1.87it/s]


Epoch [4/20] Loss: 0.003642
Validation PSNR: 27.49 dB


Epoch 5/20: 100%|██████████| 200/200 [01:51<00:00,  1.80it/s]


Epoch [5/20] Loss: 0.003090
Validation PSNR: 27.65 dB


Epoch 6/20: 100%|██████████| 200/200 [01:45<00:00,  1.89it/s]


Epoch [6/20] Loss: 0.002714
Validation PSNR: 28.72 dB


Epoch 7/20: 100%|██████████| 200/200 [01:46<00:00,  1.87it/s]


Epoch [7/20] Loss: 0.002613
Validation PSNR: 28.69 dB


Epoch 8/20: 100%|██████████| 200/200 [01:52<00:00,  1.78it/s]


Epoch [8/20] Loss: 0.002330
Validation PSNR: 29.17 dB


Epoch 9/20: 100%|██████████| 200/200 [01:56<00:00,  1.71it/s]


Epoch [9/20] Loss: 0.002231
Validation PSNR: 29.79 dB


Epoch 10/20: 100%|██████████| 200/200 [01:53<00:00,  1.77it/s]


Epoch [10/20] Loss: 0.002050
Validation PSNR: 29.35 dB


Epoch 11/20: 100%|██████████| 200/200 [01:51<00:00,  1.79it/s]


Epoch [11/20] Loss: 0.002121
Validation PSNR: 30.27 dB


Epoch 12/20: 100%|██████████| 200/200 [01:50<00:00,  1.81it/s]


Epoch [12/20] Loss: 0.001913
Validation PSNR: 30.63 dB


Epoch 13/20: 100%|██████████| 200/200 [01:45<00:00,  1.89it/s]


Epoch [13/20] Loss: 0.001856
Validation PSNR: 28.61 dB


Epoch 14/20: 100%|██████████| 200/200 [01:45<00:00,  1.90it/s]


Epoch [14/20] Loss: 0.001857
Validation PSNR: 30.11 dB


Epoch 15/20: 100%|██████████| 200/200 [01:43<00:00,  1.93it/s]


Epoch [15/20] Loss: 0.001690
Validation PSNR: 31.02 dB


Epoch 16/20: 100%|██████████| 200/200 [01:42<00:00,  1.95it/s]


Epoch [16/20] Loss: 0.001831
Validation PSNR: 31.46 dB


Epoch 17/20: 100%|██████████| 200/200 [01:44<00:00,  1.91it/s]


Epoch [17/20] Loss: 0.001702
Validation PSNR: 30.68 dB


Epoch 18/20: 100%|██████████| 200/200 [01:44<00:00,  1.92it/s]


Epoch [18/20] Loss: 0.001697
Validation PSNR: 31.46 dB


Epoch 19/20: 100%|██████████| 200/200 [01:42<00:00,  1.95it/s]


Epoch [19/20] Loss: 0.001565
Validation PSNR: 30.99 dB


Epoch 20/20: 100%|██████████| 200/200 [01:43<00:00,  1.93it/s]


Epoch [20/20] Loss: 0.001621
Validation PSNR: 31.84 dB
Training complete
