Super Resolution

Importing Packages

In [1]:
import sys
import keras
import cv2
import numpy
import matplotlib
import skimage

print('Python: {}'.format(sys.version))
print('Keras: {}'.format(keras.__version__))
print('OpenCV: {}'.format(cv2.__version__))
print('NumPy: {}'.format(numpy.__version__))
print('Matplotlib: {}'.format(matplotlib.__version__))
print('Scikit-Image: {}'.format(skimage.__version__))

Python: 3.7.15 (default, Oct 12 2022, 19:14:55) 
[GCC 7.5.0]
Keras: 2.9.0
OpenCV: 4.6.0
NumPy: 1.21.6
Matplotlib: 3.2.2
Scikit-Image: 0.18.3


In [3]:
from keras.models import Sequential
from keras.layers import Conv2D
from keras.optimizers import Adam
from skimage import measure # s = measure.compare_ssim(imageA, imageB)
#from skimage.measure import compare_ssim as ssim
from matplotlib import pyplot as plt
import cv2
import numpy as np
import math
import os

In [4]:
%matplotlib inline


In [5]:
#unzip file
import zipfile
path_to_zip_file = '/content/video_frames.zip'
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
    zip_ref.extractall()

Image quality metrics

In [6]:
# function for peak signal-to-noise ratio (PSNR)
def psnr(target, ref):
         
    # assume RGB image
    target_data = target.astype(float)
    ref_data = ref.astype(float)
    print(target_data.shape)
    print(ref_data.shape)
    diff = ref_data - target_data
    
    diff = diff.flatten('C')
    
    rmse = math.sqrt(np.mean(diff ** 2.))

    return 20 * math.log10(255. / rmse)

In [7]:
# function for mean squared error (MSE)
def mse(target, ref):
    # the MSE between the two images is the sum of the squared difference between the two images
    err = np.sum((target.astype('float') - ref.astype('float')) ** 2)
    err /= float(target.shape[0] * target.shape[1])
    
    return err
  

In [8]:
# function that combines all three image quality metrics
def compare_images(target, ref):
    scores = []
    scores.append(psnr(target, ref))
    scores.append(mse(target, ref))
    scores.append(ssim(target, ref, multichannel =True))
    
    return scores

Building the SRCNN Model

In [9]:
  def model():
    
    # define model type
    SRCNN = Sequential()
    
    # add model layers
    SRCNN.add(Conv2D(filters=128, kernel_size = (9, 9), kernel_initializer='glorot_uniform',
                     activation='relu', padding='valid', use_bias=True, input_shape=(None, None, 1)))
    SRCNN.add(Conv2D(filters=64, kernel_size = (3, 3), kernel_initializer='glorot_uniform',
                     activation='relu', padding='same', use_bias=True))
    SRCNN.add(Conv2D(filters=1, kernel_size = (5, 5), kernel_initializer='glorot_uniform',
                     activation='linear', padding='valid', use_bias=True))
    
    # define optimizer
    adam = Adam(lr=0.0003)
    
    # compile model
    SRCNN.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
    
    return SRCNN

Deploying the SRCNN

In [10]:
# Image processing functions
def modcrop(img, scale):
    tmpsz = img.shape
    sz = tmpsz[0:2]
    sz = sz - np.mod(sz, scale)
    img = img[0:sz[0], 1:sz[1]]
    return img


def shave(image, border):
    img = image[border: -border, border: -border]
    return img

Define main prediction function

In [11]:
def predict(image_path, frame):
    
    # load the srcnn model with weights
    srcnn = model()
    srcnn.load_weights('/content/3051crop_weight_200.h5')
    
    # load the degraded and reference images
    degraded = cv2.imread(image_path)
    
    # preprocess the image with modcrop
    degraded = modcrop(degraded, 3)
    
    # convert the image to YCrCb - (srcnn trained on Y channel)
    temp = cv2.cvtColor(degraded, cv2.COLOR_BGR2YCrCb)
    
    # create image slice and normalize  
    Y = numpy.zeros((1, temp.shape[0], temp.shape[1], 1), dtype=float)
    Y[0, :, :, 0] = temp[:, :, 0].astype(float) / 255
    
    # perform super-resolution with srcnn
    pre = srcnn.predict(Y, batch_size=1)
    
    # post-process output
    pre *= 255
    pre[pre[:] > 255] = 255
    pre[pre[:] < 0] = 0
    pre = pre.astype(np.uint8)
    
    # copy Y channel back to image and convert to BGR
    temp = shave(temp, 6)
    temp[:, :, 0] = pre[0, :, :, 0]
    output = cv2.cvtColor(temp, cv2.COLOR_YCrCb2BGR)

    #save image
    cv2.imwrite('/content/Super_resolution_videos/'+ frame, output)

    
    # remove border from reference and degraged image
    degraded = shave(degraded.astype(np.uint8), 6)
    
    
    # return images and scores
    return output

#output = predict('/content/video_frames/Frame0.jpg')

In [12]:
import os
import glob

import glob, os
os.chdir("/content/video_frames")

frames = []
for frame in glob.glob("*.jpg"):
    frames.append(frame)


In [15]:
outputs = []
for frame in frames:
  frame_path = "/content/video_frames/" + frame
  print(frame_path)
  output = predict(frame_path, frame)


/content/video_frames/Frame41.jpg




/content/video_frames/Frame58.jpg
/content/video_frames/Frame40.jpg
/content/video_frames/Frame10.jpg
/content/video_frames/Frame107.jpg
/content/video_frames/Frame84.jpg
/content/video_frames/Frame22.jpg
/content/video_frames/Frame3.jpg
/content/video_frames/Frame78.jpg
/content/video_frames/Frame87.jpg
/content/video_frames/Frame29.jpg
/content/video_frames/Frame56.jpg
/content/video_frames/Frame67.jpg
/content/video_frames/Frame16.jpg
/content/video_frames/Frame42.jpg
/content/video_frames/Frame79.jpg
/content/video_frames/Frame21.jpg
/content/video_frames/Frame83.jpg
/content/video_frames/Frame0.jpg
/content/video_frames/Frame106.jpg
/content/video_frames/Frame26.jpg
/content/video_frames/Frame27.jpg
/content/video_frames/Frame64.jpg
/content/video_frames/Frame68.jpg
/content/video_frames/Frame31.jpg
/content/video_frames/Frame9.jpg
/content/video_frames/Frame18.jpg
/content/video_frames/Frame98.jpg
/content/video_frames/Frame73.jpg
/content/video_frames/Frame59.jpg
/content/video_

In [16]:
# Get filename to have the frames ordered before converting to video
name = '/content/Super_resolution_videos/'

filenames = []
for i in range(114):
  filenames.append('/content/Super_resolution_videos/Frame' + str(i) + '.jpg')



In [None]:
#Convert frames to video
import cv2

img_array = []
for filename in filenames:
  img = cv2.imread(filename)
  height, width, layers = img.shape
  size = (width, height)
  img_array.append(img)

out = cv2.VideoWriter('/content/superresolution_video.mp4', cv2.VideoWriter_fourcc(*'DIVX'), 15, size)

for i in range(len(img_array)):
  out.write(img_array[i])
out.release()
