Code for **super-resolution** (figures $1$ and $5$ from main paper).. Change `factor` to $8$ to reproduce images from fig. $9$ from supmat.

You can play with parameters and see how they affect the result. 

In [None]:
"""
*Uncomment if running on colab* 
Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab 
"""
# !git clone https://github.com/DmitryUlyanov/deep-image-prior
# !mv deep-image-prior/* ./

# Import libs

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import argparse
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import numpy as np
from time import time 
from sklearn.feature_extraction.image import extract_patches_2d
from models import *
from torch.utils.data import DataLoader
import torch
import torch.optim

from my_utils import * 
from skimage.metrics import peak_signal_noise_ratio
from models.downsampler import Downsampler
from datetime import datetime

import shutil 
from utils.sr_utils import *
import PIL
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

imsize = -1 
enforse_div32 = 'CROP' # we usually need the dimensions to be divisible by a power of two (32 in this case)
PLOT = True
im_res = '20m'

# Load image and baselines

In [None]:
# Starts here - Read all images in path and save them in dictionary indexed by image resolution(10m,20m,30m)
Data_Path = '/home/savvas/Thesis/Data/Sentinel-2_Images_Testing/'
imgs_sentinel = get_data2(Data_Path,imgs=True,paths_to_imgs=False)

# Define closure and optimize

In [None]:
def closure():
    global i, net_input
    
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)

    out_HR = net(net_input)
    out_LR = downsampler(out_HR)
    
    total_loss = mse(out_LR, img_LR_var) 
    
    if tv_weight > 0:
        total_loss += tv_weight * tv_loss(out_HR)
        
    total_loss.backward()

    # Log
    psnr_LR = peak_signal_noise_ratio(torch_to_np(img_LR_var), torch_to_np(out_LR))
#     psnr_HR = peak_signal_noise_ratio(test_image, torch_to_np(out_HR))
#     psnr_HR =35.0
    print ('Iteration %05d    PSNR_LR %.3f   MSE %.8f' % (i, psnr_LR,total_loss), '\r', end='')
                      
    # History
#     psnr_history.append([psnr_LR, psnr_HR])
    tmp = beautify(better_test_image).astype(np.float32)/255.
    if PLOT and i % 500 == 0:
        out_HR_np = beautify(torch_to_np(out_HR)).astype(np.float32)/255.
        
        if im_res == '10m' :
            f,(ax1,ax2) = plt.subplots(ncols=2,figsize=(60,60))
            ax1.imshow(tmp.transpose(1,2,0)[:,:,:3])
            ax1.set_title('Original Image ',fontsize=40)
            ax2.imshow(out_HR_np.transpose(1,2,0)[:,:,:3])
            ax2.set_title('SR Image ',fontsize=40)
            plt.show()

# FOR 20m IMAGES
        elif im_res == '20m' :
            f,(ax1,ax2,ax3,ax4) = plt.subplots(ncols=4,figsize=(30,30))
            ax1.imshow(tmp.transpose(1,2,0)[:,:,3:])
            ax1.set_title('Original Image Last 3 Channels',fontsize=20)
            ax2.imshow(out_HR_np.transpose(1,2,0)[:,:,3:])
            ax2.set_title('SR Last 3 Channels',fontsize=20)
            ax3.imshow(tmp.transpose(1,2,0)[:,:,:3])
            ax3.set_title('Original Image First 3 Channels',fontsize=20)
            ax4.imshow(out_HR_np.transpose(1,2,0)[:,:,:3])
            ax4.set_title('SR First 3 Channels',fontsize=20)
            plt.show()

# FOR 60m IMAGEs
        else :
            f,(ax1,ax2,ax3,ax4) = plt.subplots(ncols=4,figsize=(30,30))
            ax1.imshow(tmp[0],cmap='gray')
            ax1.set_title('Original Image First Channel',fontsize=20)
            ax2.imshow(out_HR_np[0],cmap='gray')
            ax2.set_title('SR Firt Channel',fontsize=20)
            ax3.imshow(tmp[1],cmap='gray')
            ax3.set_title('Original Image Last Channel',fontsize=20)
            ax4.imshow(out_HR_np[1],cmap='gray')
            ax4.set_title('SR Last Channel',fontsize=20)
            plt.show()
#         plot_image_grid([test_image[:3], np.clip(out_HR_np[:3], 0, 1)], factor=13, nrow=3)

    i += 1
    
    return total_loss,out_HR

### This part is for super resolving on single image !

# Set up parameters and net

In [None]:
better_test_image = imgs_sentinel[im_res][2]
input_depth = 32
channels = better_test_image.shape[0]

INPUT =     'noise'
pad   =     'reflection'
OPT_OVER =  'net'
KERNEL_TYPE='lanczos2'

LR = 0.01
tv_weight = 0.0

OPTIMIZER = 'adam'

factor = 2 # 8
if factor == 4: 
    num_iter = 8000
    reg_noise_std = 0.03
elif factor == 8:
    num_iter = 8000
    reg_noise_std = 0.02
elif factor == 2:
    num_iter = 8000
    reg_noise_std = 0.01
elif factor == 6:
    num_iter = 8000
    reg_noise_std = 0.02
else:
    assert False, 'We did not experiment with other factors'

In [None]:
net_input = get_noise(input_depth, INPUT, (better_test_image.shape[1]*factor, better_test_image.shape[2]*factor)).type(dtype).detach()
# NET_TYPE = 'ResNet'
NET_TYPE = 'skip'
# NET_TYPE = 'UNet'

net = get_net(input_depth, NET_TYPE, pad,
              n_channels=channels,
              skip_n33d=128, 
              skip_n33u=128, 
              skip_n11=4, 
              num_scales=5,
              upsample_mode='bilinear').type(dtype)

# Losses
mse = torch.nn.MSELoss().type(dtype)

img_LR_var = np_to_torch(better_test_image).type(dtype)

downsampler = Downsampler(n_planes=channels, factor=factor, kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype)

In [None]:
psnr_history = [] 
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

i = 0
p = get_params(OPT_OVER, net, net_input)
stop_epoch,best_result=optimize(OPTIMIZER, p, closure, LR, num_iter)

### This part is for super resolving every image in every resolution 

In [None]:
better_test_image = imgs_sentinel[im_res][2]
input_depth = 32
channels = better_test_image.shape[0]

INPUT =     'noise'
pad   =     'reflection'
OPT_OVER =  'net'
KERNEL_TYPE='lanczos2'

LR = 0.01
tv_weight = 0.0

OPTIMIZER = 'adam'
def set_params(imres):
    if imres == '10m' :
        factor = 6 
    elif imres == '20m':
        factor = 6
    else :
        factor = 6
    
    if factor == 4: 
        num_iter = 8000
        reg_noise_std = 0.03
    elif factor == 8:
        num_iter = 8000
        reg_noise_std = 0.05
    elif factor == 2:
        num_iter = 8000
        reg_noise_std = 0.02
    elif factor == 6:
        num_iter = 8000
        reg_noise_std = 0.05
        
    return factor, num_iter,reg_noise_std

In [None]:
folder_name=create_folder(factor,'/home/savvas/Thesis/Results/DIP/').split('/')[-1]
for im_res in ['10m','20m','60m']:
    factor, num_iter,reg_noise_std = set_params(im_res)
    for num,image in enumerate(imgs_sentinel[im_res]):
        # Create directoty to save results
        if not os.path.isdir(f'/home/savvas/Thesis/Results/DIP/{folder_name}/{im_res}/'):
            os.mkdir(f'/home/savvas/Thesis/Results/DIP/{folder_name}/{im_res}/')
   
        print('#'*50,f' Image {num+1} ','#'*50)
        #Get image and its channels
        better_test_image = image
        channels = better_test_image.shape[0]
        #Get Patches
        if im_res == '10m':
            if factor == 2:
                patches = get_subimages(better_test_image,2)
            elif factor == 4 :
                patches = get_subimages(better_test_image,3)
            elif factor == 6 :
                patches = get_subimages(better_test_image,5)
            for index, img in enumerate(patches):
                print('#'*50,f' Image{num+1} Patch {index+1} ','#'*50)
                better_test_image = img
                #Initialize Net
                net_input = get_noise(input_depth, INPUT, (better_test_image.shape[1]*factor, better_test_image.shape[2]*factor)).type(dtype).detach()
                print('Initialize Network')
                
                NET_TYPE = 'skip'
              
                net = get_net(input_depth, NET_TYPE, pad,
                              n_channels=channels,
                              skip_n33d=128, 
                              skip_n33u=128, 
                              skip_n11=4, 
                              num_scales=5,
                              upsample_mode='bilinear').type(dtype)

                # Losses
                mse = torch.nn.MSELoss().type(dtype)

                img_LR_var = np_to_torch(better_test_image).type(dtype)

                downsampler = Downsampler(n_planes=channels, factor=factor, kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype)

                #Train
                start_time = time()
                print('Starting Trainning')
                net_input_saved = net_input.detach().clone()
                noise = net_input.detach().clone()

                i = 0
                p = get_params(OPT_OVER, net, net_input)
                stop_epoch,best_result=optimize(OPTIMIZER, p, closure, LR, num_iter)

                print('Finished Trainning...')
                print(f'Training took {(time()-start_time)/60} minutes')

                SR = torch_to_np(best_result)
                save_image(SR,SR.shape[0],'/home/savvas/Thesis/Results/DIP/{}/{}/image_{}_patch_{}_x{}_{}_epochs'.format(folder_name,im_res,num+1,index+1,factor,stop_epoch))
                print('Patch Saved')

            full_image = reconstruct_image(path_to_patches=f'/home/savvas/Thesis/Results/DIP/{folder_name}/{im_res}/',image_num=num+1)
            save_image(full_image,full_image.shape[0],f'/home/savvas/Thesis/Results/DIP/{folder_name}/{im_res}/image_{num+1}_x{factor}')
            break
        elif im_res == '20m':
            if factor == 2:
                patches = get_subimages(better_test_image,1)
            elif factor == 4 :
                patches = get_subimages(better_test_image,2)
            elif factor == 6 :
                patches = get_subimages(better_test_image,3)
            for index, img in enumerate(patches):
                print('#'*50,f' Image{num+1} Patch {index+1} ','#'*50)
                better_test_image = img
                #Initialize Net
                net_input = get_noise(input_depth, INPUT, (better_test_image.shape[1]*factor, better_test_image.shape[2]*factor)).type(dtype).detach()
                print('Initialize Network')
                
                NET_TYPE = 'skip'
              
                net = get_net(input_depth, NET_TYPE, pad,
                              n_channels=channels,
                              skip_n33d=128, 
                              skip_n33u=128, 
                              skip_n11=4, 
                              num_scales=5,
                              upsample_mode='bilinear').type(dtype)

                # Losses
                mse = torch.nn.MSELoss().type(dtype)

                img_LR_var = np_to_torch(better_test_image).type(dtype)

                downsampler = Downsampler(n_planes=channels, factor=factor, kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype)

                #Train
                start_time = time()
                print('Starting Trainning')
                net_input_saved = net_input.detach().clone()
                noise = net_input.detach().clone()

                i = 0
                p = get_params(OPT_OVER, net, net_input)
                stop_epoch,best_result=optimize(OPTIMIZER, p, closure, LR, num_iter)

                print('Finished Trainning...')
                print(f'Training took {(time()-start_time)/60} minutes')

                SR = torch_to_np(best_result)
                save_image(SR,SR.shape[0],'/home/savvas/Thesis/Results/DIP/{}/{}/image_{}_patch_{}_x{}_{}_epochs'.format(folder_name,im_res,num+1,index+1,factor,stop_epoch))
                print('Patch Saved')

            if factor>2:
                full_image = reconstruct_image(path_to_patches=f'/home/savvas/Thesis/Results/DIP/{folder_name}/{im_res}/',image_num=num+1)
                save_image(full_image,full_image.shape[0],f'/home/savvas/Thesis/Results/DIP/{folder_name}/{im_res}/image_{num+1}_x{factor}')
            break
        else :
            #Initialize Net
            net_input = get_noise(input_depth, INPUT, (better_test_image.shape[1]*factor, better_test_image.shape[2]*factor)).type(dtype).detach()
            print('Initialize Network')
            NET_TYPE = 'skip'

            net = get_net(input_depth, NET_TYPE, pad,
                          n_channels=channels,
                          skip_n33d=128, 
                          skip_n33u=128, 
                          skip_n11=4, 
                          num_scales=5,
                          upsample_mode='bilinear').type(dtype)

            # Losses
            mse = torch.nn.MSELoss().type(dtype)

            img_LR_var = np_to_torch(better_test_image).type(dtype)

            downsampler = Downsampler(n_planes=channels, factor=factor, kernel_type=KERNEL_TYPE, phase=0.5, preserve_size=True).type(dtype)

            #Train
            start_time = time()
            print('Starting Trainning')
            net_input_saved = net_input.detach().clone()
            noise = net_input.detach().clone()

            i = 0
            p = get_params(OPT_OVER, net, net_input)
            stop_epoch,best_result = optimize(OPTIMIZER, p, closure, LR, num_iter)

            print('Finished Trainning...')
            print(f'Training took {(time()-start_time)/60} minutes')
            SR = torch_to_np(best_result)
            save_image(SR,SR.shape[0],'/home/savvas/Thesis/Results/DIP/{}/{}/image_{}_x{}_{}_epochs'.format(folder_name,im_res,num+1,factor,stop_epoch))
            print('Image Saved')
            break