In [4]:
import os
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
import torch.nn.functional as F

class DIV2KPatchDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, patch_size=33, scale=4):
        self.hr_files = sorted(glob.glob(os.path.join(hr_dir, "*.*")))
        self.lr_files = sorted(glob.glob(os.path.join(lr_dir, "*.*")))
        self.patch_size = patch_size
        self.scale = scale

    def __len__(self):
        return len(self.hr_files)

    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_files[idx]).convert("RGB")
        lr_img = Image.open(self.lr_files[idx]).convert("RGB")

        hr_img = transforms.ToTensor()(hr_img)
        lr_img = transforms.ToTensor()(lr_img)

        _, h, w = lr_img.shape
        top = torch.randint(0, h - self.patch_size + 1, (1,)).item()
        left = torch.randint(0, w - self.patch_size + 1, (1,)).item()
        lr_patch = lr_img[:, top:top+self.patch_size, left:left+self.patch_size]
        hr_patch = hr_img[:, top*self.scale:(top+self.patch_size)*self.scale,
                          left*self.scale:(left+self.patch_size)*self.scale]
        lr_patch_up = F.interpolate(lr_patch.unsqueeze(0), scale_factor=self.scale, mode='bicubic', align_corners=False)
        lr_patch_up = lr_patch_up.squeeze(0)

        return lr_patch_up, hr_patch

class SRCNN(nn.Module):
    def __init__(self):
        super(SRCNN, self).__init__()
        self.layer1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.layer2 = nn.Conv2d(64, 32, kernel_size=1, padding=0)
        self.layer3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.relu(self.layer1(x))
        x = self.relu(self.layer2(x))
        x = self.layer3(x)
        return x

base_path = r"D:\DL Project"
train_hr = os.path.join(base_path, "DIV2K_train_HR")
train_lr = os.path.join(base_path, "DIV2K_train_LR_x2")
valid_hr = os.path.join(base_path, "DIV2K_valid_HR")
valid_lr = os.path.join(base_path, "DIV2K_valid_LR_x2")

batch_size = 4
lr = 1e-4
num_epochs = 20
patch_size = 33
scale = 2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

train_dataset = DIV2KPatchDataset(train_hr, train_lr, patch_size=patch_size, scale=scale)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = DIV2KPatchDataset(valid_hr, valid_lr, patch_size=patch_size, scale=scale)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

model = SRCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

def psnr(pred, target):
    mse = nn.functional.mse_loss(pred, target)
    return 10 * torch.log10(1 / mse)

for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0
    for lr_patches, hr_patches in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
        lr_patches = lr_patches.to(device)
        hr_patches = hr_patches.to(device)

        optimizer.zero_grad()
        outputs = model(lr_patches)
        loss = criterion(outputs, hr_patches)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * lr_patches.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch}/{num_epochs}] Loss: {epoch_loss:.6f}")

    model.eval()
    psnr_total = 0.0
    with torch.no_grad():
        for lr_patches, hr_patches in valid_loader:
            lr_patches = lr_patches.to(device)
            hr_patches = hr_patches.to(device)
            outputs = model(lr_patches)
            psnr_total += psnr(outputs, hr_patches).item()
    avg_psnr = psnr_total / len(valid_loader)
    print(f"Validation PSNR: {avg_psnr:.2f} dB")
    torch.save(model.state_dict(), os.path.join(base_path, f"srcnn_epoch{epoch}.pth"))

print("Training complete!")


Using device: cuda


Epoch 1/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1/20: 100%|██████████| 200/200 [02:01<00:00,  1.65it/s]


Epoch [1/20] Loss: 0.042270
Validation PSNR: 21.01 dB


Epoch 2/20: 100%|██████████| 200/200 [02:00<00:00,  1.66it/s]


Epoch [2/20] Loss: 0.006748
Validation PSNR: 24.40 dB


Epoch 3/20: 100%|██████████| 200/200 [02:00<00:00,  1.66it/s]


Epoch [3/20] Loss: 0.004258
Validation PSNR: 26.37 dB


Epoch 4/20: 100%|██████████| 200/200 [01:46<00:00,  1.88it/s]


Epoch [4/20] Loss: 0.003221
Validation PSNR: 26.54 dB


Epoch 5/20: 100%|██████████| 200/200 [01:43<00:00,  1.93it/s]


Epoch [5/20] Loss: 0.002903
Validation PSNR: 28.65 dB


Epoch 6/20: 100%|██████████| 200/200 [01:44<00:00,  1.91it/s]


Epoch [6/20] Loss: 0.002716
Validation PSNR: 28.45 dB


Epoch 7/20: 100%|██████████| 200/200 [01:46<00:00,  1.88it/s]


Epoch [7/20] Loss: 0.002375
Validation PSNR: 29.20 dB


Epoch 8/20: 100%|██████████| 200/200 [01:45<00:00,  1.90it/s]


Epoch [8/20] Loss: 0.002325
Validation PSNR: 29.78 dB


Epoch 9/20: 100%|██████████| 200/200 [01:44<00:00,  1.91it/s]


Epoch [9/20] Loss: 0.002150
Validation PSNR: 29.94 dB


Epoch 10/20: 100%|██████████| 200/200 [01:44<00:00,  1.90it/s]


Epoch [10/20] Loss: 0.002057
Validation PSNR: 29.04 dB


Epoch 11/20: 100%|██████████| 200/200 [01:44<00:00,  1.91it/s]


Epoch [11/20] Loss: 0.002000
Validation PSNR: 30.11 dB


Epoch 12/20: 100%|██████████| 200/200 [01:44<00:00,  1.91it/s]


Epoch [12/20] Loss: 0.002042
Validation PSNR: 30.00 dB


Epoch 13/20: 100%|██████████| 200/200 [01:44<00:00,  1.91it/s]


Epoch [13/20] Loss: 0.001778
Validation PSNR: 30.63 dB


Epoch 14/20: 100%|██████████| 200/200 [01:45<00:00,  1.90it/s]


Epoch [14/20] Loss: 0.001889
Validation PSNR: 29.76 dB


Epoch 15/20: 100%|██████████| 200/200 [01:44<00:00,  1.91it/s]


Epoch [15/20] Loss: 0.001577
Validation PSNR: 31.06 dB


Epoch 16/20: 100%|██████████| 200/200 [01:45<00:00,  1.90it/s]


Epoch [16/20] Loss: 0.001682
Validation PSNR: 30.54 dB


Epoch 17/20: 100%|██████████| 200/200 [01:45<00:00,  1.90it/s]


Epoch [17/20] Loss: 0.001566
Validation PSNR: 31.83 dB


Epoch 18/20: 100%|██████████| 200/200 [01:45<00:00,  1.90it/s]


Epoch [18/20] Loss: 0.001683
Validation PSNR: 31.23 dB


Epoch 19/20: 100%|██████████| 200/200 [01:44<00:00,  1.91it/s]


Epoch [19/20] Loss: 0.001627
Validation PSNR: 30.79 dB


Epoch 20/20: 100%|██████████| 200/200 [01:44<00:00,  1.91it/s]


Epoch [20/20] Loss: 0.001591
Validation PSNR: 31.27 dB
Training complete!
