In [83]:
import argparse

import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image
import os

from models import SRCNN, SRCNN_video
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
from torchvision.transforms.functional import to_tensor



In [75]:
#### Code for runing test on a full video (all frames) ### NOT FOR SRCNN VIDEOOO!! Just for original SRCNN
def run_full_video_tests_SRCNN(model_,test_set_path,path_to_weights,path_to_outputs,target_scale=4):
    
    # Get all pictures in the folder
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps') # Switch to "cpu" if you don't have apple silicon

    # test_set_path = './videos/test_set/AMVTG_004/truth/'
    bicubic_avg_ssim = []
    bicubic_avg_psnr = []


    srcnn_avg_psnr = []
    srcnn_avg_ssim = []
    
    for idx,image_name in enumerate(os.listdir(test_set_path)):

        path_to_image = test_set_path + image_name
        cudnn.benchmark = True
        model = model_().to(device)


        state_dict = model.state_dict()
        for n, p in torch.load(path_to_weights, map_location=lambda storage, loc: storage).items():
            if n in state_dict.keys():
                state_dict[n].copy_(p)
            else:
                raise KeyError(n)

        model.eval()


        image_org = pil_image.open(path_to_image).convert('RGB')
        # print(type(image_org))
        #printing info about the current image being processed 
        # print(f'Processing image {idx+1} ): {image_name}')
        # print(f'Image Width: {image_org.width} Image Height: {image_org.height}')
        # Why are we resizing the image in this way?
        image_org.save(path_to_outputs + '_original.png')
        image_width = (image_org.width // target_scale) * target_scale
        image_height = (image_org.height // target_scale) * target_scale
        image_bic = image_org.resize((image_width, image_height), resample=pil_image.BICUBIC)
        image_bic = image_bic.resize((image_bic.width // target_scale, image_bic.height // target_scale), resample=pil_image.BICUBIC)
        image_bic = image_bic.resize((image_bic.width * target_scale, image_bic.height * target_scale), resample=pil_image.BICUBIC)
        # print(f"BICUBIC")
        # print(f"After resizing Image Width: {image_bic.width} Image Height: {image_bic.height}")
        image_bic.save(path_to_outputs + '_bicubic_x{}.png'.format(target_scale))

        # bicubic_psnr = calc_psnr

        image_org_y = np.array(image_org).astype(np.float32)
        image_bic_y = np.array(image_bic).astype(np.float32)
        ycbcr_org = convert_rgb_to_ycbcr(image_org_y)
        ycbcr_bic = convert_rgb_to_ycbcr(image_bic_y)


        # y = ycbcr[..., 0]
        # y /= 255.
        # y = torch.from_numpy(y).to(device)
        # y = y.unsqueeze(0).unsqueeze(0)

        # For original image
        y_org = ycbcr_org[..., 0]
        y_org /= 255.
        y_org = torch.from_numpy(y_org).to(device)
        y_org = y_org.unsqueeze(0).unsqueeze(0)

        # For bicubic image
        y_bic = ycbcr_bic[..., 0]
        y_bic /= 255.
        y_bic = torch.from_numpy(y_bic).to(device)
        y_bic = y_bic.unsqueeze(0).unsqueeze(0)


        with torch.no_grad():
            preds = model(y_bic).clamp(0.0, 1.0)


        psnr_srcnn = calc_psnr(y_org, preds)
        psnr_bicubic = calc_psnr(y_org, y_bic)




        srcnn_avg_psnr.append(psnr_srcnn)
        bicubic_avg_psnr.append(psnr_bicubic)

            
        # print('SRCNN PSNR: {:.2f}'.format(psnr_srcnn))
        print(f"Frame idx: {idx} ---  SRCNN PSNR: {psnr_srcnn}")

        preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

        output = np.array([preds, ycbcr_bic[..., 1],  ycbcr_bic[..., 2]]).transpose([1, 2, 0])
        output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)

        srcnn_avg_ssim.append(ssim(image_org_y, output))

        output = pil_image.fromarray(output)
        
        output.save(path_to_outputs + '_srcnn_x{}.png'.format(target_scale))
        # ssim_srcnn = ssim(image_org_y, output)
        bicubic_avg_ssim.append(ssim(image_org_y,image_bic_y))

    print(f"Average SRCNN PSNR: {sum(srcnn_avg_psnr)/len(srcnn_avg_psnr)}")
    print(f"Average Bicubic PSNR: {sum(bicubic_avg_psnr)/len(bicubic_avg_psnr)}")
    srcnn_just_item = [t.item() for t in srcnn_avg_psnr]
    bicubic_just_item = [t.item() for t in bicubic_avg_psnr]
    print(f"Full List SRCNN:  {repr(srcnn_just_item)}")
    print(f"Full List Bicubic:  {repr(bicubic_just_item)}")
    print(f"Average SRCNN SSIM: {sum(srcnn_avg_ssim)/len(srcnn_avg_ssim)}")
    print(f"Average Bicubic SSIM: {sum(bicubic_avg_ssim)/len(bicubic_avg_ssim)}")





path_to_weights= 'pretrained/srcnn_x4.pth'
test_set_path = './videos/test_set/AMVTG_004/truth/'
# test_set_path = './videos/test_set/veni3_011/truth/'
# test_set_path = './videos/test_set/land9_007/truth/'
path_to_outputs = 'outputs/test_video_output/AMVTG_004'
run_full_video_tests_SRCNN(SRCNN,test_set_path,path_to_weights,path_to_outputs,target_scale=4)

Frame idx: 0 ---  SRCNN PSNR: 24.526451110839844
Frame idx: 1 ---  SRCNN PSNR: 24.506916046142578
Frame idx: 2 ---  SRCNN PSNR: 24.509681701660156
Frame idx: 3 ---  SRCNN PSNR: 24.52451515197754
Frame idx: 4 ---  SRCNN PSNR: 24.44962501525879
Frame idx: 5 ---  SRCNN PSNR: 24.475175857543945
Frame idx: 6 ---  SRCNN PSNR: 24.528913497924805
Frame idx: 7 ---  SRCNN PSNR: 24.521081924438477
Frame idx: 8 ---  SRCNN PSNR: 24.525861740112305
Frame idx: 9 ---  SRCNN PSNR: 24.456911087036133
Frame idx: 10 ---  SRCNN PSNR: 24.495269775390625
Frame idx: 11 ---  SRCNN PSNR: 24.521902084350586
Frame idx: 12 ---  SRCNN PSNR: 24.46906280517578
Frame idx: 13 ---  SRCNN PSNR: 24.52183723449707
Frame idx: 14 ---  SRCNN PSNR: 24.475507736206055
Frame idx: 15 ---  SRCNN PSNR: 24.530380249023438
Frame idx: 16 ---  SRCNN PSNR: 24.455615997314453
Frame idx: 17 ---  SRCNN PSNR: 24.52719497680664
Frame idx: 18 ---  SRCNN PSNR: 24.468820571899414
Frame idx: 19 ---  SRCNN PSNR: 24.538400650024414
Frame idx: 20 -

In [66]:
# HELPERS CODE!

import cv2
import numpy as np
# from skimage.metrics import structural_similarity
from skimage.metrics import structural_similarity as compare_ssim
# import cv2



#Manual SSIM calculation
def ssim(img1, img2):
    """SSIM values range between -1 and 1, where 1 indicates perfect similarity. Close to 1 is good (little perceptual difference)
    SSIM = 0 suggest no correlation between the strucutral information i nthe two images (rare case). 
    SSIM < 0 suggest that the images are very different. They have structural changes (inverted colors or 
    other significant transformations)."""
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    def compute_ssim(img1, img2):
        mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # Valid region
        mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
        mu1_sq = mu1**2
        mu2_sq = mu2**2
        mu1_mu2 = mu1 * mu2
        
        sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
        sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
        sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
        
        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
        return ssim_map.mean()

    # Convert the images to float64 and separate the channels
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    ssim_scores = []
    for i in range(3):  # Assuming img1 and img2 are RGB images
        ssim_scores.append(compute_ssim(img1[:,:,i], img2[:,:,i]))

    return np.mean(ssim_scores)

# Example usage
# img1 = cv2.imread('path_to_hr_image.jpg', cv2.IMREAD_COLOR)
# img2 = cv2.imread('path_to_sr_image.jpg', cv2.IMREAD_COLOR)
# print("SSIM:", ssim(img1, img2))


# SSIM using skimage.metrics
def compute_ssim(img1, img2):
    # Assume img1 and img2 are already loaded and are in BGR format as loaded by OpenCV
    # Convert images to grayscale
    img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
    img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
    
    # Compute SSIM between two images
    ssim_value, _ = compare_ssim(img1_gray, img2_gray, full=True)
    return ssim_value




In [67]:
import cv2
import numpy as np
import matplotlib.pyplot as plt


def bicubic_super_resolution(lr_image_path, target_dimensions):
    # LR image
    lr_image = cv2.imread(lr_image_path, cv2.IMREAD_COLOR)
    # Basic bicubic interpolation
    sr_image = cv2.resize(lr_image, target_dimensions, interpolation=cv2.INTER_CUBIC)
    return sr_image


# From Paper, PSNR calculation
# def calc_psnr(img1, img2):
#     return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))


def calculate_psnr(hr_image, sr_image):
    # Ensure the images are of the same dimension
    if hr_image.shape != sr_image.shape:
        raise ValueError("HR and SR images must have the same dimensions for PSNR calculation.")

    # Mean Squared Error (MSE) between the HR and SR images
    mse = np.mean((hr_image - sr_image) ** 2)
    
    if mse == 0:
        return float('inf')  # Infinite PSNR means no error
    else:
        # PSNR
        max_pixel = 255.0
        psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
        return psnr

# # Example usage
# # hr_image_path = 'path_to_high_res_image.jpg'
# # lr_image_path = 'path_to_low_res_image.jpg'
# lr_image_path = 'Datasets/Set14/Set14_LR_x4/comic.png'
# hr_image_path = 'Datasets/Set14/Set14_HR/comic.png'
# hr_image = cv2.imread(hr_image_path, cv2.IMREAD_COLOR)  # Read HR image to get dimensions

# # Generate the super-resolved image using the dimensions of the HR image
# lr_image = cv2.imread(lr_image_path, cv2.IMREAD_COLOR)
# sr_image = bicubic_super_resolution(lr_image_path, (hr_image.shape[1], hr_image.shape[0]))

# # Calculate PSNR
# psnr_value = calculate_psnr(hr_image, sr_image)

# im_to_plot = [lr_image, hr_image, sr_image]

# #Printing LR, HR, and SR image shapes.
# print(f'LR Image shape: {lr_image.shape}')
# print(f'HR Image shape: {hr_image.shape}')
# print(f'SR Image shape: {sr_image.shape}')



# plt.figure(figsize=(15, 5))

# for count,im in enumerate(im_to_plot):
#     # convert bgr to rgb 
    
#     plt.subplot(1,3,count+1)
#     if count ==0:
#         plt.title('LR Image')   
#     elif count ==1:
#         plt.title('HR Image')
#     else:
#         plt.title('Super Resolution (Bicubic)')
#     plt.axis('off')
    
#     rgb = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
#     plt.imshow(rgb, cmap = plt.cm.Spectral)



# print("PSNR value:", psnr_value)

# print("SSIM value:", ssim(hr_image, sr_image))
# print("SSIM value using skimage.metrics:", compute_ssim(hr_image, sr_image))


In [82]:
# Just for SSIM, Delete later
# Code for SRCNN_video model 
def run_full_video_tests_SRCNN_video(model_,test_set_path,path_to_weights,path_to_outputs,target_scale=4):
    
    # Get all pictures in the folder
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps') # Switch to "cpu" if you don't have apple silicon

    # test_set_path = './videos/test_set/AMVTG_004/truth/'


    bicubic_avg_ssim = []
    bicubic_avg_psnr = []


    srcnn_avg_psnr = []
    srcnn_avg_ssim = []

    





    images = sorted(os.listdir(test_set_path))

    cudnn.benchmark = True
    model = model_().to(device)


    state_dict = model.state_dict()
    for n, p in torch.load(path_to_weights, map_location=lambda storage, loc: storage).items():
        if n in state_dict.keys():
            state_dict[n].copy_(p)
        else:
            raise KeyError(n)

    model.eval()

    for idx,image_name in enumerate(images):

        path_to_image = test_set_path + image_name

        # Load current and previous image (use the first image twice if it's the first frame)
        current_image = pil_image.open(path_to_image).convert('RGB')
        if idx > 0:
            prev_image = pil_image.open(os.path.join(test_set_path, images[idx-1])).convert('RGB')
        else:
            prev_image = current_image
        
        current_image.save(path_to_outputs + '_original.png')
        #########
        # Resize images for SRCNN_video processing
        width, height = current_image.size
        width = (width // target_scale) * target_scale
        height = (height // target_scale) * target_scale
        current_image_bic = current_image.resize((width, height), resample=pil_image.BICUBIC)
        current_image_bic = current_image_bic.resize((current_image_bic.width // target_scale, current_image_bic.height // target_scale), resample=pil_image.BICUBIC)
        current_image_bic = current_image_bic.resize((current_image_bic.width * target_scale, current_image_bic.height * target_scale), resample=pil_image.BICUBIC)
        current_image_bic.save(path_to_outputs + '_bicubic_x{}.png'.format(target_scale))

        
        prev_image_bic = prev_image.resize((width, height), resample=pil_image.BICUBIC)
        prev_image_bic = prev_image_bic.resize((prev_image_bic.width // target_scale, prev_image_bic.height // target_scale), resample=pil_image.BICUBIC)
        prev_image_bic = prev_image_bic.resize((prev_image_bic.width * target_scale, prev_image_bic.height * target_scale), resample=pil_image.BICUBIC)

        
        current_image_y = np.array(current_image).astype(np.float32)
        current_image_bic_y = np.array(current_image_bic).astype(np.float32)

        prev_image_y = np.array(prev_image).astype(np.float32)
        prev_image_bic_y = np.array(prev_image_bic).astype(np.float32)

        ycbcr_current = convert_rgb_to_ycbcr(current_image_y)
        ycbcr_current_bic = convert_rgb_to_ycbcr(current_image_bic_y) # USE THIS
        ycbcr_prev_bic = convert_rgb_to_ycbcr(prev_image_bic_y) ## USE THIS

        y_org = ycbcr_current[..., 0]
        y_org /= 255.
        y_org = torch.from_numpy(y_org).to(device)
        y_org = y_org.unsqueeze(0).unsqueeze(0)

        y_prev_bic = ycbcr_prev_bic[..., 0]
        y_prev_bic /= 255.
        y_prev_bic  = torch.from_numpy(y_prev_bic).to(device)
        y_prev_bic = y_prev_bic.unsqueeze(0).unsqueeze(0)
        # print(f"y_prev_bic shape: {y_prev_bic.shape}")

        y_bic = ycbcr_current_bic[..., 0]
        y_bic /= 255.
        y_bic = torch.from_numpy(y_bic).to(device)
        y_bic = y_bic.unsqueeze(0).unsqueeze(0)
        # print(f"y_bic shape: {y_bic.shape}")
        
        # current_y = to_tensor(current_image.convert('YCbCr'))[0].unsqueeze(0).to(device)
        # prev_y = to_tensor(prev_image.convert('YCbCr'))[0].unsqueeze(0).to(device)
  

        input_tensor = torch.cat([y_bic,y_prev_bic], dim=1).to(device)
        # print(f"input_tensor shape: {input_tensor.shape}")
        # input_tensor = torch.cat([current_y, prev_y], dim=0).unsqueeze(0).to(device)
        
        ################

        with torch.no_grad():
            preds = model(input_tensor).clamp(0.0, 1.0)


        psnr_srcnn = calc_psnr(y_org, preds)
        psnr_bicubic = calc_psnr(y_org, y_bic)


        srcnn_avg_psnr.append(psnr_srcnn)
        bicubic_avg_psnr.append(psnr_bicubic)

            
        # print('SRCNN PSNR: {:.2f}'.format(psnr_srcnn))
        #Print frame as well
        
        # print(f"Frame idx: {idx} ---  SRCNN PSNR: {psnr_srcnn}")

        preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)

        output = np.array([preds, ycbcr_current[..., 1],  ycbcr_current[..., 2]]).transpose([1, 2, 0])
        output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
        srcnn_avg_ssim.append(ssim(current_image_y, output))
        output = pil_image.fromarray(output)
        output.save(path_to_outputs + '_srcnn_x{}.png'.format(target_scale))
        bicubic_avg_ssim.append(ssim(current_image_y,current_image_bic_y))




    print(f"Average SRCNN PSNR: {sum(srcnn_avg_psnr)/len(srcnn_avg_psnr)}")
    print(f"Average Bicubic PSNR: {sum(bicubic_avg_psnr)/len(bicubic_avg_psnr)}")
    srcnn_just_item = [t.item() for t in srcnn_avg_psnr]
    bicubic_just_item = [t.item() for t in bicubic_avg_psnr]
    print(f"Full List SRCNN:  {repr(srcnn_just_item)}")
    print(f"Full List Bicubic:  {repr(bicubic_just_item)}")
    print(f"Average SRCNN SSIM: {sum(srcnn_avg_ssim)/len(srcnn_avg_ssim)}")
    print(f"Average Bicubic SSIM: {sum(bicubic_avg_ssim)/len(bicubic_avg_ssim)}")





# path_to_weights= 'outputs/x4/best.pth'
# path_to_weights= 'outputs/x4/previous_only_10_epochs_3databases/best.pth'
path_to_weights= 'outputs/x4/SRCNN_video_v2_trained/best.pth'
# test_set_path = './videos/test_set/AMVTG_004/truth/'
# test_set_path = './videos/test_set/veni3_011/truth/'
test_set_path = './videos/test_set/land9_007/truth/'
path_to_outputs = 'outputs/test_video_srcnn_custom/AMVTG_004'
run_full_video_tests_SRCNN_video(SRCNN_video,test_set_path,path_to_weights,path_to_outputs,target_scale=4)

Average SRCNN PSNR: 32.32741928100586
Average Bicubic PSNR: 31.428936004638672
Full List SRCNN:  [32.276798248291016, 32.33149719238281, 32.325050354003906, 32.23936462402344, 32.22751998901367, 32.26617431640625, 32.28605651855469, 32.286651611328125, 32.30243682861328, 32.295501708984375, 32.29678726196289, 32.25708770751953, 32.212337493896484, 32.20344161987305, 32.236122131347656, 32.278255462646484, 32.2684440612793, 32.288063049316406, 32.33387756347656, 32.327430725097656, 32.38397216796875, 32.393775939941406, 32.37760925292969, 32.38277816772461, 32.38229751586914, 32.38859939575195, 32.390682220458984, 32.42462921142578, 32.4670524597168, 32.511756896972656, 32.50786590576172]
Full List Bicubic:  [31.371917724609375, 31.398056030273438, 31.38946533203125, 31.325408935546875, 31.324148178100586, 31.34177017211914, 31.34881019592285, 31.34765625, 31.364585876464844, 31.36539077758789, 31.376361846923828, 31.35464096069336, 31.324462890625, 31.321537017822266, 31.35280418395996