In [None]:
import math
import json
import time
import torch
import os
import glob
import h5py
import random
import numpy as np
from PIL import Image
import importlib
from collections import OrderedDict

import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.autograd import Variable

import skimage.measure as measure
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import structural_similarity as ssim
import scipy.misc as misc

In [None]:
class MeanShift(nn.Module):
    def __init__(self, mean_rgb, sub):
        super(MeanShift, self).__init__()

        sign = -1 if sub else 1
        r = mean_rgb[0] * sign
        g = mean_rgb[1] * sign
        b = mean_rgb[2] * sign

        self.shifter = nn.Conv2d(3, 3, 1, 1, 0)
        self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.shifter.bias.data   = torch.Tensor([r, g, b])

        for params in self.shifter.parameters():
            params.requires_grad = False

    def forward(self, x):
        x = self.shifter(x)
        return x


class BuildingBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 ksize=3, stride=1, pad=1):
        super(BuildingBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, ksize, stride, pad),
            nn.ReLU(inplace=True)
        )

        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        return out


class ResidualBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
        )

        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        out = F.relu(out + x)
        return out


class EResidualBlock(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 group=1):
        super(EResidualBlock, self).__init__()

        self.body = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 1, 1, 0),
        )

        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        out = F.relu(out + x)
        return out


class UpsampleBlock(nn.Module):
    def __init__(self,
                 n_channels, scale, multi_scale,
                 group=1):
        super(UpsampleBlock, self).__init__()

        if multi_scale:
            self.up2 = Up_UpsampleBlock(n_channels, scale=2, group=group)
            self.up3 = Up_UpsampleBlock(n_channels, scale=3, group=group)
            self.up4 = Up_UpsampleBlock(n_channels, scale=4, group=group)
        else:
            self.up =  Up_UpsampleBlock(n_channels, scale=scale, group=group)

        self.multi_scale = multi_scale

    def forward(self, x, scale):
        if self.multi_scale:
            if scale == 2:
                return self.up2(x)
            elif scale == 3:
                return self.up3(x)
            elif scale == 4:
                return self.up4(x)
        else:
            return self.up(x)


class Up_UpsampleBlock(nn.Module):
    def __init__(self,
				 n_channels, scale,
				 group=1):
        super(Up_UpsampleBlock, self).__init__()

        modules = []
        if scale == 2 or scale == 4 or scale == 8:
            for _ in range(int(math.log(scale, 2))):
                modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
                modules += [nn.PixelShuffle(2)]
        elif scale == 3:
            modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]
            modules += [nn.PixelShuffle(3)]

        self.body = nn.Sequential(*modules)
        init_weights(self.modules)

    def forward(self, x):
        out = self.body(x)
        return out

class Block(nn.Module):
    def __init__(self,
                 in_channels, out_channels,
                 group=1):
        super(Block, self).__init__()

        self.b1 = ResidualBlock(64, 64)
        self.b2 = ResidualBlock(64, 64)
        self.b3 = ResidualBlock(64, 64)
        self.c1 = BuildingBlock(64*2, 64, 1, 1, 0)
        self.c2 = BuildingBlock(64*3, 64, 1, 1, 0)
        self.c3 = BuildingBlock(64*4, 64, 1, 1, 0)

    def forward(self, x):
        c0 = o0 = x

        b1 = self.b1(o0)
        c1 = torch.cat([c0, b1], dim=1)
        o1 = self.c1(c1)

        b2 = self.b2(o1)
        c2 = torch.cat([c1, b2], dim=1)
        o2 = self.c2(c2)

        b3 = self.b3(o2)
        c3 = torch.cat([c2, b3], dim=1)
        o3 = self.c3(c3)

        return o3

def init_weights(modules):
    pass

class CARN(nn.Module):
    def __init__(self, **kwargs):
        super(CARN, self).__init__()

        scale = kwargs.get("scale")
        multi_scale = kwargs.get("multi_scale")
        group = kwargs.get("group", 1)

        self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True)
        self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False)

        self.entry = nn.Conv2d(3, 64, 3, 1, 1)

        self.b1 = Block(64, 64)
        self.b2 = Block(64, 64)
        self.b3 = Block(64, 64)
        self.c1 = BuildingBlock(64*2, 64, 1, 1, 0)
        self.c2 = BuildingBlock(64*3, 64, 1, 1, 0)
        self.c3 = BuildingBlock(64*4, 64, 1, 1, 0)

        self.upsample = UpsampleBlock(64, scale=scale,
                                          multi_scale=multi_scale,
                                          group=group)
        self.exit = nn.Conv2d(64, 3, 3, 1, 1)

    def forward(self, x, scale):
        x = self.sub_mean(x)
        x = self.entry(x)
        c0 = o0 = x

        b1 = self.b1(o0)
        c1 = torch.cat([c0, b1], dim=1)
        o1 = self.c1(c1)

        b2 = self.b2(o1)
        c2 = torch.cat([c1, b2], dim=1)
        o2 = self.c2(c2)

        b3 = self.b3(o2)
        c3 = torch.cat([c2, b3], dim=1)
        o3 = self.c3(c3)

        out = self.upsample(o3, scale=scale)

        out = self.exit(out)
        out = self.add_mean(out)

        return out

In [None]:

def random_crop(hr, lr, size, scale):
    h, w = lr.shape[:-1]
    x = random.randint(0, w-size)
    y = random.randint(0, h-size)

    hsize = size*scale
    hx, hy = x*scale, y*scale

    crop_lr = lr[y:y+size, x:x+size].copy()
    crop_hr = hr[hy:hy+hsize, hx:hx+hsize].copy()
    return crop_hr, crop_lr


def random_flip_and_rotate(im1, im2):
    if random.random() < 0.5:
        im1 = np.flipud(im1)
        im2 = np.flipud(im2)

    if random.random() < 0.5:
        im1 = np.fliplr(im1)
        im2 = np.fliplr(im2)

    angle = random.choice([0, 1, 2, 3])
    im1 = np.rot90(im1, angle)
    im2 = np.rot90(im2, angle)
    return im1.copy(), im2.copy()

# Train Dataset

In [None]:
class Preprcess_Train_Dataset(data.Dataset):
    def __init__(self, path, size, scale):
        super(Preprcess_Train_Dataset, self).__init__()

        self.size = size
        h5f = h5py.File(path, "r")

        self.hr = [v[:] for v in h5f["HR"].values()]

        if scale == 0:
            self.scale = [2, 3, 4]
            self.lr = [[v[:] for v in h5f["X{}".format(i)].values()] for i in self.scale]
        else:
            self.scale = [scale]
            self.lr = [[v[:] for v in h5f["X{}".format(scale)].values()]]


        h5f.close()

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def __getitem__(self, index):
        size = self.size

        item = [(self.hr[index], self.lr[i][index]) for i, _ in enumerate(self.lr)]
        item = [random_crop(hr, lr, size, self.scale[i]) for i, (hr, lr) in enumerate(item)]
        item = [random_flip_and_rotate(hr, lr) for hr, lr in item]

        return [(self.transform(hr), self.transform(lr)) for hr, lr in item]

    def __len__(self):
        return len(self.hr)

# Test Dataset

In [None]:
class Preprcess_Test_Dataset(data.Dataset):
    def __init__(self, dirname, scale):
        super(Preprcess_Test_Dataset, self).__init__()

        self.name  = dirname.split("/")[-1]
        self.scale = scale

        if "DIV" in self.name:
            self.hr = glob.glob(os.path.join("{}_HR".format(dirname), "*.png"))
            self.lr = glob.glob(os.path.join("{}_LR_bicubic".format(dirname),
                                             "X{}/*.png".format(scale)))
        else:
            all_files = glob.glob(os.path.join(dirname, "x{}/*.png".format(scale)))
            self.hr = [name for name in all_files if "HR" in name]
            self.lr = [name for name in all_files if "LR" in name]
        # print("Reached Here")
        self.hr.sort()
        self.lr.sort()

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def __getitem__(self, index):
        hr = Image.open(self.hr[index])
        lr = Image.open(self.lr[index])
        # print("Reached Here 3.0")
        hr = hr.convert("RGB")
        lr = lr.convert("RGB")
        filename = self.hr[index].split("/")[-1]
        # print("Reached Here 4.0")
        return self.transform(hr), self.transform(lr), filename

    def __len__(self):
        return len(self.hr)


In [None]:
def evaluate(test_data_dir,carn,device,scale=2, num_step=0):
    mean_psnr = 0
    carn.eval()

    test_data   = Preprcess_Test_Dataset(test_data_dir, scale=scale)
    # print("Reached Here")
    test_loader = DataLoader(test_data,
                              batch_size=1,
                              num_workers=1,
                              shuffle=False)
    # print("Reached Here 2.0")
    for step, inputs in enumerate(test_loader):
        # print("Reached Here 5.0")
        hr = inputs[0].squeeze(0)
        lr = inputs[1].squeeze(0)
        # print(hr.shape)
        # print(lr.shape)
        name = inputs[2][0]

        h, w = lr.size()[1:]
        h_half, w_half = int(h/2), int(w/2)
        h_chop, w_chop = h_half + 20, w_half + 20

        lr_patch = torch.FloatTensor(4, 3, h_chop, w_chop)
        lr_patch[0].copy_(lr[:, 0:h_chop, 0:w_chop])
        lr_patch[1].copy_(lr[:, 0:h_chop, w-w_chop:w])
        lr_patch[2].copy_(lr[:, h-h_chop:h, 0:w_chop])
        lr_patch[3].copy_(lr[:, h-h_chop:h, w-w_chop:w])
        lr_patch = lr_patch.to(device)

        sr = carn(lr_patch, scale).data

        h, h_half, h_chop = h*scale, h_half*scale, h_chop*scale
        w, w_half, w_chop = w*scale, w_half*scale, w_chop*scale

        result = torch.FloatTensor(3, h, w).to(device)
        result[:, 0:h_half, 0:w_half].copy_(sr[0, :, 0:h_half, 0:w_half])
        result[:, 0:h_half, w_half:w].copy_(sr[1, :, 0:h_half, w_chop-w+w_half:w_chop])
        result[:, h_half:h, 0:w_half].copy_(sr[2, :, h_chop-h+h_half:h_chop, 0:w_half])
        result[:, h_half:h, w_half:w].copy_(sr[3, :, h_chop-h+h_half:h_chop, w_chop-w+w_half:w_chop])
        sr = result

        hr = hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
        sr = sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()

        im1 = hr[scale:-scale, scale:-scale]
        im2 = sr[scale:-scale, scale:-scale]
        mean_psnr += psnr(im1, im2) / len(test_data)

    return mean_psnr

def load(carn, path,step):
    carn.load_state_dict(torch.load(path))
    splited = path.split(".")[0].split("_")[-1]
    try:
        step = int(path.split(".")[0].split("_")[-1])
    except ValueError:
        step = 0
    print("Load pretrained {} model".format(path))

def save(ckpt_dir, ckpt_name,step):
    save_path = os.path.join(
        ckpt_dir, "{}_{}.pth".format(ckpt_name, step))
    torch.save(carn.state_dict(), save_path)

def decay_learning_rate(lr, step):
    lr = lr * (0.5 ** (step // 150000))
    return lr

def psnr(im1, im2):
    # def im2double(im):
    #     min_val, max_val = 0, 255
    #     out = (im.astype(np.float64)-min_val) / (max_val-min_val)
    #     return out

#     im1 = im2double(im1)
#     im2 = im2double(im2)
    psnr = peak_signal_noise_ratio(im1, im2, data_range=255)
    return psnr

def compute_ssim(im1, im2, win_size=None):
    im1 = im1.astype(float)
    im2 = im2.astype(float)
    data_range = max(im2.max() - im2.min(), 1)
    # print(im2.max(), im2.min())
    ssim_val = ssim(im1, im2, channel_axis=2, data_range=data_range)
    return ssim_val

### To accelerate training we first convert training images to h5 format

In [None]:
dataset_dir = "/home/desai.ven/CS7180/DIV2K"
dataset_type = "train"

f = h5py.File("DIV2K_{}.h5".format(dataset_type), "w")
dt = h5py.special_dtype(vlen=np.dtype('uint8'))

for subdir in ["HR", "X2", "X3", "X4"]:
    if subdir in ["HR"]:
        im_paths = glob.glob(os.path.join(dataset_dir,
                                          "DIV2K_{}_HR".format(dataset_type),
                                          "*.png"))

    else:
        im_paths = glob.glob(os.path.join(dataset_dir,
                                          "DIV2K_{}_LR_bicubic".format(dataset_type),
                                          subdir, "*.png"))
    im_paths.sort()
    grp = f.create_group(subdir)

    for i, path in enumerate(im_paths):
        im = misc.imread(path)
        print(path)
        grp.create_dataset(str(i), data=im)

## Training the Model with scale=2

In [None]:
scale=2
group=1
patch_size=64
batch_size=64
learning_rate=0.0001
max_steps = 200000
print_interval = 1000
clip=10
ckpt_name="chkpt_train_"
ckpt_dir="/home/desai.ven/CS7180/DIV2K"
train_data_path="/home/desai.ven/CS7180/DIV2K/DIV2K_train.h5"

carn = CARN(scale=scale, group=group)
loss_fn = nn.L1Loss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, carn.parameters()), lr=learning_rate)

train_data = Preprcess_Train_Dataset(train_data_path,
                                scale=scale,
                                size=patch_size)
train_loader = DataLoader(train_data,
                                batch_size=batch_size,
                                num_workers=1,
                                shuffle=True, drop_last=True)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
carn = carn.to(device)

steps = 0

os.makedirs(ckpt_dir, exist_ok=True)


while True:
    for inputs in train_loader:
        carn.train()
        hr, lr = inputs[-1][0], inputs[-1][1]

        hr = hr.to(device)
        lr = lr.to(device)

        sr = carn(lr, scale)
        loss = loss_fn(sr, hr)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(carn.parameters(), clip)
        optimizer.step()

        learning_rate = decay_learning_rate(learning_rate,steps)
        for param_group in optimizer.param_groups:
            param_group["lr"] = learning_rate
        # print("All Good")
        steps += 1
        if steps % print_interval == 0:
            # print("OK")
            psnr = evaluate(test_data_dir="/home/desai.ven/CS7180/DIV2K/DIV2K_valid", scale=scale, num_step=steps,carn=carn,device=device)
            save(ckpt_dir, ckpt_name,steps)

    if steps > max_steps: break

In [None]:
def save_image(tensor, filename):
    tensor = tensor.cpu()
    ndarr = tensor.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
    im = Image.fromarray(ndarr)
    im.save(filename)


def sample(net, device,dataset,ckpt_path, sample_dir, test_data_dir,scale):
    for step, (hr, lr, name) in enumerate(dataset):
        shave=20
        # print(step, hr,lr,name )
        t1 = time.time()
        lr = lr.unsqueeze(0).to(device)
        sr = net(lr, scale).detach().squeeze(0)
        lr = lr.squeeze(0)
        t2 = time.time()
        # print("Reached 3.0")
        model_name="Results"
        sr_dir = os.path.join(sample_dir,
                              model_name,
                              test_data_dir.split("/")[-1],
                              "x{}".format(scale),
                              "SR")
        hr_dir = os.path.join(sample_dir,
                              model_name,
                              test_data_dir.split("/")[-1],
                              "x{}".format(scale),
                              "HR")

        os.makedirs(sr_dir, exist_ok=True)
        os.makedirs(hr_dir, exist_ok=True)

        sr_im_path = os.path.join(sr_dir, "{}".format(name.replace("HR", "SR")))
        hr_im_path = os.path.join(hr_dir, "{}".format(name))
        # print(lr.shape,hr.shape)
        print("PSNR: ",psnr(hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy(), sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()))
        print("SSIM: ",compute_ssim(hr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy(), sr.cpu().mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()))
        save_image(sr, sr_im_path)
        save_image(hr, hr_im_path)
        print("Saved {} ({}x{} -> {}x{}, {:.3f}s)"
            .format(sr_im_path, lr.shape[1], lr.shape[2], sr.shape[1], sr.shape[2], t2-t1))


def main(scale,test_data_dir,):
    ckpt_path="/home/desai.ven/CS7180/DIV2K/chkpt/chkpt_train_200000.pth"
    sample_dir="/home/desai.ven/CS7180/DIV2K/Testing Dataset"
    net = CARN(multi_scale=True, group=1)
    print("Scale:",scale)
    state_dict = torch.load(ckpt_path)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        new_state_dict[name] = v

    net.load_state_dict(new_state_dict)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = net.to(device)
    # print("Entering dataset")
    dataset = Preprcess_Test_Dataset(test_data_dir, scale)
    # print("Entering sample")
    sample(net, device, dataset,ckpt_path,sample_dir, test_data_dir,scale)
    # print("Exit sample")

In [None]:
scale=2
test_data_dir= "/home/desai.ven/CS7180/DIV2K/Testing Dataset/Set14"
main(scale, test_data_dir)

Scale: 2
PSNR:  28.569910992005752
SSIM:  0.946403392237046
Saved /home/desai.ven/CS7180/DIV2K/Testing Dataset/Results/Set14/x2/SR/img_005_SRF_2_SR.png (180x125 -> 360x250, 0.694s)


## Cropping Results for visualization



In [15]:
import os
from PIL import Image

directory = "/content/Urban100"
dir="/content/Urban100/cropped"

left = 314
top = 186
right = 422
bottom = 260

for filename in os.listdir(directory):
    if filename.endswith(".png"):
        image = Image.open(os.path.join(directory, filename))
        cropped_image = image.crop((left, top, right, bottom))
        cropped_image.save(os.path.join(dir, "cropped_" + filename))

In [16]:
directory = "/content/Urban100/cropped"

!zip -r /content/Urban100/cropped.zip $directory

# Download the zip file
from google.colab import files
files.download("/content/Urban100/cropped.zip")

  adding: content/Urban100/cropped/ (stored 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_A+.png (stored 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_glasner.png (stored 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_HR.png (stored 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_ScSR.png (deflated 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_SelfExSR.png (stored 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_abhishek.png (deflated 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_Kim.png (stored 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_SRCNN.png (deflated 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_bicubic.png (deflated 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_Venky_SR.png (stored 0%)
  adding: content/Urban100/cropped/cropped_img_034_SRF_4_nearest.png (deflated 1%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [12]:
# !rm -r /content/Set14
# !rm -r /content/Urban100

# !mkdir /content/Set14
# !mkdir /content/Set14/cropped

# !mkdir /content/Urban100
# !mkdir /content/Urban100/cropped
