In [None]:
#Using codes from https://github.com/EemeliSaari/dmcnn-vd as reference
from  scipy import ndimage
import os
import scipy
import random
from google.colab.patches import cv2_imshow
from zipfile import ZipFile
import matplotlib.pyplot as plt
import cv2
from  scipy import ndimage
import numpy as np
import requests
import torch
import glob
import torch.utils.data
import imageio
from torchsummary import summary

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
def demosaicing_checkboard_bili(img):
  kernel = np.array(
          [[0, 1, 0],
            [1, 4, 1],
            [0, 1, 0]]) / 4
  R = scipy.ndimage.convolve(img[:,:,2], kernel)
  B = scipy.ndimage.convolve(img[:,:,0], kernel)
  G = scipy.ndimage.convolve(img[:,:,1], kernel)
  checkboard = np.dstack((B,G,R))
  return checkboard

def mosaicing_checkboard(img):
  x = np.ones((img.shape),dtype=int)
  x[1::2,::2,:] = 0
  x[::2,1::2,:] = 0
  mosaic = img * x
  return mosaic

In [None]:
class ImagePatchDataset(torch.utils.data.Dataset):

    def __init__(self, root, loader=None, sample_size=None,patch_size=(50, 50)):
        self.root = root
        self.transform = torch.from_numpy
        self.patch_size = patch_size
        self.loader = self._numpy_loader

        self.sample_size = sample_size
        files = os.listdir(root)

        self.files_ = list(map(lambda x: os.path.join(root, x), files))
        self.images_ = list(map(lambda x: np.array(self.loader(x)), self.files_))
        self.cfa_ = list(map(self._mosaic, self.images_))
        self.patches_ = self._compute_patches(self.images_)
        self.bilinears_ = list(map(self._bilin, self.cfa_))

    def __getitem__(self, idx):
        patch, img_id = self.patches_[idx]
        x, y = patch
        b0, b1 = self.patch_size
        truth = self.images_[img_id][x - b0:x, y - b1:y, :]
        cfa = self.cfa_[img_id][x - b0:x, y - b1:y].reshape((3, 50, 50)) / 255
        truth = truth.reshape((3, 50, 50)) / 255
        bilin = self.bilinears_[img_id][x - b0:x, y- b1:y, :].reshape((3, 50, 50)) / 255
        return cfa, truth, bilin
    def __len__(self):
        return len(self.patches_)
    def _compute_patches(self, images):
        patches = []
        for idx, img in enumerate(images):
            image_patch = []
            M, N, Z = img.shape
            b0, b1 = self.patch_size
            for i in range(b0, M-b0, 5):
                for j in range(b1, N-b1, 5):
                    image_patch.append(([i, j], idx))
            image_patch = random.sample(image_patch, self.sample_size)
            patches += image_patch
        return patches

    def _numpy_loader(self, path):
        return cv2.cvtColor(cv2.imread(path, 1), cv2.COLOR_BGR2RGB)

    def _mosaic(self, img):
        cfa = np.zeros(img.shape, np.uint8)
        mosaic = mosaicing_checkboard(img)
        for i in range(3):
            cfa[:, :, i] = mosaic[:, :, i]
        return cfa

    def _bilin(self, cfa):
        bilin = demosaicing_checkboard_bili(cfa)
        return bilin

In [None]:
class Residaul_CNN(torch.nn.Module):
    def __init__(self, n_layers=20):
        super(Residaul_CNN, self).__init__()
        self.n_layers = n_layers
        self.layer0 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 64, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.SELU()
        )
        for i in range(1, self.n_layers):
            setattr(self, f'layer{i}', self.conv_layer)
        self.residual = torch.nn.Sequential(
            torch.nn.Conv2d(64, 3, kernel_size=3, padding=1, bias=False),
            torch.nn.BatchNorm2d(3),
            torch.nn.SELU(inplace=True)
        )
        self.apply(self._msra_init)

    def forward(self, x):
        out = getattr(self, 'layer0')(x)
        for i in range(1, self.n_layers):
            out = getattr(self, f'layer{i}')(out)
        out = self.residual(out)
        return out

    def conv_layer(self):
        return torch.nn.Sequential(
            torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.SELU()
        )

    def n_params(self):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def _msra_init(self, m):
        if isinstance(m, torch.nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, np.sqrt(2./n))
        elif isinstance(m, torch.nn.BatchNorm2d):
            torch.nn.init.constant_(m.weight, 1)
            torch.nn.init.constant_(m.bias, 0)

In [None]:
model_vd = Residaul_CNN().cuda()
summary(model_vd, input_size=(3, 50, 50))
print(model_vd)

In [None]:
dataset = ImagePatchDataset(root="/content/drive/My Drive/Colab Notebooks/pristine_images", sample_size=100)
print(len(dataset))
data_loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

In [None]:
import glob
import math
demo_PSNR_mean_list = []
demo_PSNR_training_mean_list = []
def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if(mse == 0):
        return 100
    max_pixel = 255.0
    psnr = 20 * math.log10(max_pixel / math.sqrt(mse))
    return psnr

def cal_average(num):
    sum_num = 0
    if(len(num)>0):
      for t in num:
          sum_num = sum_num + t
      avg = sum_num / len(num)
      return avg

def testing_function(model_vd, images):
  bili_PSNR = []
  demo_PSNR = []
  bili_mse = []
  demo_mse = []
  bili_PSNR_mean = 0
  demo_PSNR_mean = 0
  for iii in range(len(images)):
    images[iii] = cv2.cvtColor(images[iii], cv2.COLOR_BGR2RGB)
    raw_image = images[iii]
    raw_image = raw_image[0:raw_image.shape[0],0:raw_image.shape[1],:]
    mosaic = mosaicing_checkboard(raw_image)
    bilin = demosaicing_checkboard_bili(mosaic).astype(np.int16)
    image_patches = []
    patches = []
    bilin_patches = []
    origin_pathces = []
    for i in range(50, mosaic.shape[0], 50):
        for j in range(50, mosaic.shape[1], 50):
            patch_mosaic = mosaic[i-50:i, j-50:j,:]
            patch = np.zeros((50, 50, 3), np.int16)
            for idx in range(3):
                patch[:, :, idx] = patch_mosaic[:,:,idx]
            patches.append((i, j))
            image_patches.append(patch)
            bilin_patches.append(bilin[i-50:i, j-50:j, :])
            origin_pathces.append(raw_image[i-50:i, j-50:j, :])
    tensor_patches = torch.from_numpy(np.stack(image_patches).reshape((len(image_patches), 3, 50, 50)) / 255.0).float().to(device)
    results = []
    with torch.no_grad():
        for i in range(0, tensor_patches.shape[0], 128):
            input_patch = tensor_patches[i:i+128, :, :, :]
            outputs = model_vd(input_patch) #+ input_patch
            results.append(outputs)
    results = torch.cat(results)
    demosaiced = np.zeros(raw_image.shape, np.int16)
    bili = np.zeros(raw_image.shape, np.int16)
    result = np.zeros(raw_image.shape, np.int16)
    for idx, patch in enumerate(patches):
        i, j = patch
        demosaic_patch = results[idx, :, :, :].reshape(50, 50, 3)
        demosaiced[i-50:i, j-50:j, :] = (np.array(demosaic_patch.tolist())*255 + bilin_patches[idx])
    result = np.zeros(raw_image.shape, np.float)
    import math
    for idx, patch in enumerate(patches):
        i, j = patch
        demosaic_patch = results[idx, :, :, :].reshape(50, 50, 3)
        demo = (np.array(demosaic_patch.tolist())*255)
        demosaiced[i-50:i, j-50:j, :] = demo + bilin_patches[idx]
    result[:,:,0] = demosaiced[:,:,0]
    result[:,:,1] = demosaiced[:,:,1]
    result[:,:,2] = demosaiced[:,:,2]
    result = result[0:raw_image.shape[0]-50,0:raw_image.shape[1]-50,:]
    ground_truth = raw_image[0:raw_image.shape[0]-50,0:raw_image.shape[1]-50,:]
    raw_image2 =PSNR(ground_truth.astype(np.int16),result.astype(np.int16))
    demo_PSNR.append(raw_image2)
  return cal_average(demo_PSNR)

In [None]:
criterion = torch.nn.MSELoss()
optimizer_vd = torch.optim.Adam(model_vd.parameters(), lr=1e-5)
n_epochs = 200
device = torch.device('cuda:0')
total_step = len(data_loader)
loss_list = []
demo_best = 30
for epoch in range(start_epoch, n_epochs):
    epoch_loss = []
    for idx, (cfa, target, bilin) in enumerate(data_loader):
        cfa = cfa.float().to(device)
        target = target.float().to(device)
        bilin = bilin.float().to(device)
        target.cuda()
        outputs = model_vd(cfa)

        #print(outputs.shape, target.shape)
        loss = criterion(outputs + bilin, target)
        epoch_loss.append(loss.item())

        optimizer_vd.zero_grad()
        loss.backward()
        optimizer_vd.step()

        if idx % 200 == 0:
            print(f'Epoch [{epoch}/{n_epochs}], Step [{idx}/{total_step}], Loss: {loss.item()}')
    epoch_stats = np.array(epoch_loss)
    psnr_epoch = 20 * math.log10(1.0 / math.sqrt(epoch_stats.mean()))
    print(f'\nFinished Epoch {epoch}, Loss --- mean: {epoch_stats.mean()}, std {epoch_stats.std()}\n')
    loss_list.append(epoch_stats.mean())
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12,8))
    ax1.imshow(np.array(outputs[-1].tolist()).reshape((50, 50, 3)) + np.array(bilin[-1].tolist()).reshape((50, 50, 3)))
    ax2.imshow(np.array(cfa[-1].tolist()).reshape((50, 50, 3)), cmap='gray')
    ax3.imshow(np.array(target[-1].tolist()).reshape((50, 50, 3)))
    demo_PSNR_mean = testing_function(model_vd, KodakImages)
    print("Current mean PSNR",demo_PSNR_mean)
    print("Current mean PSNR training", psnr_epoch)
    if demo_PSNR_mean > demo_best:
      print("Better!")
      demo_best = demo_PSNR_mean
      torch.save(model_vd.state_dict(), "/content/drive/My Drive/Colab Notebooks/20layers_100sample_50size_checkboard_best.pt")
    print("Best mean PSNR",demo_best)
    demo_PSNR_mean_list.append(demo_PSNR_mean)
    demo_PSNR_training_mean_list.append(psnr_epoch)
    plt.show()



In [None]:
import glob
from sklearn.metrics import mean_absolute_error
from skimage import data, img_as_float, io, color
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error
model_path = '/content/drive/My Drive/Colab Notebooks/20layers_100sample_50size_checkboard_best.pt'
device = torch.device('cuda:0')
model_vdd = Residaul_CNN().cuda()
model_vdd.load_state_dict(torch.load(model_path))
model_vd = model_vdd.eval()
#helper functions
def PSNR(original, compressed):
    psnr = [0,0,0]
    for i in range(3):
        original_mono = original[:,:,i]
        compressed_mono = compressed[:,:,i]
        mse = np.mean((original_mono.astype(float) - compressed_mono.astype(float)) ** 2)
        if(mse == 0):
            psnr[i]=100
        else:
            max_pixel = 255.0
            psnr[i] = 20 * math.log10(max_pixel / math.sqrt(mse))
    return psnr

def MSE(original, compressed):
    MSE = [0,0,0]
    for i in range(3):
        original_mono = original[:,:,i]
        compressed_mono = compressed[:,:,i]
        mse= np.mean((original_mono.astype(float) - compressed_mono.astype(float)) ** 2)
        MSE[i] = mse/np.mean(original_mono)
    return MSE

def SSIM(original, compressed,quant):
  ssim_const,tt = ssim(original, compressed, data_range=(compressed.max() - compressed.min())*1, channel_axis = -1, full=True)
  return 1-np.quantile(tt.ravel(), 1-quant)  #ssim_const



def delta_E(original, compressed,quant):
  lab1 = color.rgb2lab(original)
  lab2 = color.rgb2lab(compressed)
  deltaE_s = color.deltaE_ciede2000(lab1, lab2, channel_axis=-1)
  return np.quantile(deltaE_s.ravel(),quant)

In [None]:
import struct
bili_PSNR = []
demo_PSNR = []
bili_MSE = []
demo_MSE = []
SAE_demo = 0
MSE_demo = 0
SSIM_demo = 0
delataE_demo = 0
PSNR_demo = 0
SAE_bili = 0
MSE_bili = 0
SSIM_bili = 0
delataE_bili = 0
PSNR_bili = 0
bili_PSNR_RGB = []
demo_PSNR_RGB = []
demo_MSE_RGB = []
bili_MSE_RGB = []
demo_SSIM_RGB = []
bili_SSIM_RGB = []
demo_deltaE_RGB = []
bili_deltaE_RGB = []
images = [cv2.imread(file) for file in glob.glob("/content/drive/My Drive/Colab Notebooks/WED/*.bmp")]


quant = 0.95
for iii in range(len(images)):
  images[iii] = cv2.cvtColor(images[iii], cv2.COLOR_BGR2RGB)
  raw_image = images[iii]
  raw_image2 = images[iii]
  raw_image = np.zeros((raw_image2.shape[0]+50,raw_image2.shape[1]+50,raw_image2.shape[2]))
  raw_image[:,:,0] = np.pad(raw_image2[:,:,0], [(0, 50), (0, 50)], mode='constant')
  raw_image[:,:,1] = np.pad(raw_image2[:,:,1], [(0, 50), (0, 50)], mode='constant')
  raw_image[:,:,2] = np.pad(raw_image2[:,:,2], [(0, 50), (0, 50)], mode='constant')

  mosaic = mosaicing_checkboard(raw_image)
  bilin = demosaicing_checkboard_bili(mosaic)
  image_patches = []
  patches = []
  bilin_patches = []
  origin_pathces = []

  for i in range(50, mosaic.shape[0], 50):
      for j in range(50, mosaic.shape[1], 50):
          patch_mosaic = mosaic[i-50:i, j-50:j,:]
          patch = np.zeros((50, 50, 3), np.int16)
          for idx in range(3):
              patch[:, :, idx] = patch_mosaic[:,:,idx]
          patches.append((i, j))
          image_patches.append(patch)
          bilin_patches.append(bilin[i-50:i, j-50:j, :])
          origin_pathces.append(raw_image[i-50:i, j-50:j, :])
  tensor_patches = torch.from_numpy(np.stack(image_patches).reshape((len(image_patches), 3, 50, 50)) / 255.0).float().to(device)
  results = []
  with torch.no_grad():
      feature_maps = []
      for i in range(0, tensor_patches.shape[0],128):
          input_patch = tensor_patches[i:i+128, :, :, :]
          outputs = model_vd(input_patch)
          hook.remove()
          results.append(outputs)
  results = torch.cat(results)

  popo = results.cpu().detach().numpy()

  demosaiced = np.zeros(raw_image.shape, np.int16)
  last_layer = np.zeros(raw_image.shape,float)
  origin = np.zeros(raw_image.shape, np.int16)
  bili = np.zeros(raw_image.shape, np.int16)
  result = np.zeros(raw_image.shape, np.int16)

  for idx, patch in enumerate(patches):
      i, j = patch
      demosaic_patch = results[idx, :, :, :].reshape(50, 50, 3)
      demosaiced[i-50:i, j-50:j, :] = (np.array(demosaic_patch.tolist())*255 + bilin_patches[idx])
  bili = np.zeros(raw_image.shape, float)
  result = np.zeros(raw_image.shape, float)
  import math
  bili_PSNR_patch = []
  demo_PSNR_patch = []

  for idx, patch in enumerate(patches):
      i, j = patch
      demosaic_patch = results[idx, :, :, :].reshape(50, 50, 3)
      demo = (np.array(demosaic_patch.tolist())*255)
      demosaiced[i-50:i, j-50:j, :] = demo + bilin_patches[idx]
      bili[i-50:i, j-50:j, :] = bilin_patches[idx]
      origin[i-50:i, j-50:j, :] = origin_pathces[idx]
      last_layer[i-50:i, j-50:j, :] = demo

  result[:,:,0] = demosaiced[:,:,0]
  result[:,:,1] = demosaiced[:,:,1]
  result[:,:,2] = demosaiced[:,:,2]
  result = cv2.convertScaleAbs(result[0:raw_image.shape[0]-50,0:raw_image.shape[1]-50,:].astype(np.int16))#residual cnn output
  bili = cv2.convertScaleAbs(bilin[0:raw_image.shape[0]-50,0:raw_image.shape[1]-50,:].astype(np.int16))#bilinear output
  ground_truth = cv2.convertScaleAbs(raw_image[0:raw_image.shape[0]-50,0:raw_image.shape[1]-50,:].astype(np.int16))#groundthruth


  bili_PSNR_RGB.append(PSNR(ground_truth,bili))
  demo_PSNR_RGB.append(PSNR(ground_truth,result))
  bili_MSE_RGB.append(MSE(ground_truth,bili))
  demo_MSE_RGB.append(MSE(ground_truth,result))
  bili_SSIM_RGB.append(SSIM(ground_truth,bili,quant))
  demo_SSIM_RGB.append(SSIM(ground_truth,result,quant))
  bili_deltaE_RGB.append(delta_E(ground_truth,bili,quant))
  demo_deltaE_RGB.append(delta_E(ground_truth,result,quant))

bili_PSNR_RGB_mean = np.mean(bili_PSNR_RGB, axis = 0)
demo_PSNR_RGB_mean = np.mean(demo_PSNR_RGB, axis = 0)
bili_MSE_RGB_mean = np.mean(bili_MSE_RGB, axis = 0)
demo_MSE_RGB_mean = np.mean(demo_MSE_RGB, axis = 0)
bili_SSIM_RGB_mean = np.mean(bili_SSIM_RGB, axis = 0)
demo_SSIM_RGB_mean = np.mean(demo_SSIM_RGB, axis = 0)
bili_deltaE_RGB_mean = np.mean(bili_deltaE_RGB, axis = 0)
demo_deltaE_RGB_mean = np.mean(demo_deltaE_RGB, axis = 0)

bili_PSNR_RGB_std = np.std(bili_PSNR_RGB, axis = 0)
demo_PSNR_RGB_std = np.std(demo_PSNR_RGB, axis = 0)
bili_MSE_RGB_std = np.std(bili_MSE_RGB, axis = 0)
demo_MSE_RGB_std = np.std(demo_MSE_RGB, axis = 0)
bili_SSIM_RGB_std = np.std(bili_SSIM_RGB, axis = 0)
demo_SSIM_RGB_std = np.std(demo_SSIM_RGB, axis = 0)
bili_deltaE_RGB_std = np.std(bili_deltaE_RGB, axis = 0)
demo_deltaE_RGB_std = np.std(demo_deltaE_RGB, axis = 0)

print("Demosaicing")
print("bili_PSNR_RGB_mean:",bili_PSNR_RGB_mean)
print("demo_PSNR_RGB_mean:",demo_PSNR_RGB_mean)
print("bili_MSE_RGB_mean:",bili_MSE_RGB_mean)
print("demo_MSE_RGB_mean:",demo_MSE_RGB_mean)
print("bili_SSIM_RGB_mean:",bili_SSIM_RGB_mean)
print("demo_SSIM_RGB_mean:",demo_SSIM_RGB_mean)
print("bili_deltaE_RGB_mean:",bili_deltaE_RGB_mean)
print("demo_deltaE_RGB_mean:",demo_deltaE_RGB_mean)

print("bili_PSNR_RGB_std:",bili_PSNR_RGB_std)
print("demo_PSNR_RGB_std:",demo_PSNR_RGB_std)
print("bili_MSE_RGB_std:",bili_MSE_RGB_std)
print("demo_MSE_RGB_std:",demo_MSE_RGB_std)
print("bili_SSIM_RGB_std:",bili_SSIM_RGB_std)
print("demo_SSIM_RGB_std:",demo_SSIM_RGB_std)
print("bili_deltaE_RGB_std:",bili_deltaE_RGB_std)
print("demo_deltaE_RGB_std:",demo_deltaE_RGB_std)
