In [32]:
from __future__ import print_function, division
import os
import glob
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.utils.data as data
from math import log10

from os import listdir
from os.path import join
from PIL import Image
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode


In [33]:


def load_img(filepath):
    """Load image and convert to YCbCr color space.
    Returns Y, Cb, and Cr channels separately."""
    img = Image.open(filepath).convert('YCbCr')
    y, cb, cr = img.split()
    return y, cb, cr
    
class SuperResolutionDataset(Dataset):
    """Dataset for super-resolution that processes Y channel through the network
    and upscales Cb/Cr channels using bicubic interpolation."""
    
    def __init__(self, root_dir, input_transform=None, target_transform=None):
        self.input_transform = input_transform
        self.target_transform = target_transform
        self.root_dir = root_dir
        self.img_list = [join(root_dir, x) for x in listdir(root_dir)]
        
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Load all three channels
        y, cb, cr = load_img(self.img_list[idx])
        
        # Create input (downsampled Y channel)
        input_y = y.copy()
        if self.input_transform:
            input_y = self.input_transform(input_y)
        
        # Create target (full resolution Y channel)
        target_y = y
        if self.target_transform:
            target_y = self.target_transform(target_y)
        
        return input_y, target_y

In [34]:

def get_training_set(root_dir):
    train_dir = join(root_dir, "train")

    return SuperResolutionDataset(train_dir,
                             input_transform=transforms.Compose([
                                transforms.Resize((540,960)),
                                transforms.ToTensor()]),
                             target_transform=transforms.ToTensor())


def get_test_set(root_dir):
    test_dir = join(root_dir, "test")

    return SuperResolutionDataset(test_dir,
                             input_transform=transforms.Compose([
                                transforms.Resize((540,960)),
                                transforms.ToTensor()]),
                             target_transform=transforms.ToTensor())

In [35]:
import torch.nn as nn
import torch.nn.init as init


class Net(nn.Module):
    def __init__(self, upscale_factor):
        super(Net, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)

In [36]:
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
import torch.optim as optim
model = Net(upscale_factor=2).to(device)
criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.01)

train_set = get_training_set('./1080Dataset/')
test_set = get_test_set('./1080Dataset/')



In [37]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=10)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=10)

In [38]:
def train(epoch):
    epoch_loss = 0
    for iteration, batch in enumerate(train_loader, 1):
        input, target = batch[0].to(device), batch[1].to(device)

        optimizer.zero_grad()
        loss = criterion(model(input), target)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
 
        print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(train_loader), loss.item()))

    print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(train_loader)))

def test():
    avg_psnr = 0
    with torch.no_grad():
        for batch in test_loader:
            input, target = batch[0].to(device), batch[1].to(device)

            prediction = model(input)
            mse = criterion(prediction, target)
            psnr = 10 * log10(1 / mse.item())
            avg_psnr += psnr
    print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(test_loader)))


def checkpoint(epoch):
    model_out_path = "model_epoch_{}.pth".format(epoch)
    torch.save(model, model_out_path)
    print("Checkpoint saved to {}".format(model_out_path))
    

In [None]:
def reconstruct_color_image(y_channel, cb_channel, cr_channel):
    """Reconstruct a color image from Y, Cb, and Cr channels.
    
    Args:
        y_channel: Super-resolved Y channel (tensor or PIL Image)
        cb_channel: Cb channel (should be upscaled to match Y size)
        cr_channel: Cr channel (should be upscaled to match Y size)
    
    Returns:
        RGB PIL Image
    """
    # Convert tensors to PIL Images if needed
    if torch.is_tensor(y_channel):
        y_channel = transforms.ToPILImage()(y_channel.cpu())
    if torch.is_tensor(cb_channel):
        cb_channel = transforms.ToPILImage()(cb_channel.cpu())
    if torch.is_tensor(cr_channel):
        cr_channel = transforms.ToPILImage()(cr_channel.cpu())
    
    # Ensure all channels have the same size
    target_size = y_channel.size
    if cb_channel.size != target_size:
        cb_channel = cb_channel.resize(target_size, Image.BICUBIC)
    if cr_channel.size != target_size:
        cr_channel = cr_channel.resize(target_size, Image.BICUBIC)
    
    # Merge channels and convert to RGB
    ycbcr_image = Image.merge('YCbCr', [y_channel, cb_channel, cr_channel])
    rgb_image = ycbcr_image.convert('RGB')
    
    return rgb_image


In [None]:
for epoch in range(1, 100):
    train(epoch)
    test()
    checkpoint(epoch)