In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
import PIL
from PIL import Image, ImageFilter

import torch
import torch.nn as nn
import torch.optim

import time

from skimage.measure import compare_psnr

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

sigma = 25

In [None]:
class Concat(nn.Module):
    def __init__(self, dim, skip, deeper):
        super(Concat, self).__init__()
        self.dim = dim
        self.layer1 = skip
        self.layer2 = deeper
    def forward(self, input):
        inputs = []
        inputs.append(self.layer1(input))
        inputs.append(self.layer2(input))
        return torch.cat(inputs, dim=self.dim)

def get_name(name):
    name[0] +=1
    return str(name[0])

def skip(c_in, c_out, c_down, c_up, c_skip, k_down, k_up, k_skip, upsample_mode):
    model = nn.Sequential()
    model_tmp = model
    input_depth = c_in
    name = [0]
    for i in range(len(c_down)):
        layer = nn.Sequential()
        layer.add_module(get_name(name),nn.Conv2d(input_depth, c_down[i], k_down, 2, padding=int((k_down - 1) / 2)))
        layer.add_module(get_name(name),nn.BatchNorm2d(c_down[i]))
        layer.add_module(get_name(name),nn.LeakyReLU(0.2, inplace=True))

        layer.add_module(get_name(name),nn.Conv2d(c_down[i], c_down[i], k_down, 1, padding=int((k_down - 1) / 2)))
        layer.add_module(get_name(name),nn.BatchNorm2d(c_down[i]))
        layer.add_module(get_name(name),nn.LeakyReLU(0.2, inplace=True))

        deeper_main = nn.Sequential()
        
        if i < len(c_down)-1:
            layer.add_module(get_name(name),deeper_main)
            layer.add_module(get_name(name),nn.Upsample(scale_factor=2, mode=upsample_mode))
            if c_skip[i] != 0:
                concat_layers = []
                concat_layers.append(nn.Conv2d(input_depth, c_skip[i], k_skip, 1, padding=int((k_skip - 1) / 2)))
                concat_layers.append(nn.BatchNorm2d(c_skip[i]))
                concat_layers.append(nn.LeakyReLU(0.2, inplace=True))
                model_tmp.add_module(get_name(name),Concat(1, nn.Sequential(*concat_layers), layer))
            else:
                model_tmp.add_module(get_name(name),layer)
            model_tmp.add_module(get_name(name),nn.BatchNorm2d(c_skip[i] + c_up[i + 1] ))
            model_tmp.add_module(get_name(name),nn.Conv2d(c_skip[i] + c_up[i + 1], c_up[i], k_up, 1, padding=int((k_up - 1) / 2)))
            
        else:#last layer
            layer.add_module(get_name(name),nn.Upsample(scale_factor=2, mode=upsample_mode))
            if c_skip[i] != 0:
                concat_layers = []
                concat_layers.append(nn.Conv2d(input_depth, c_skip[i], k_skip, 1, padding=int((k_skip - 1) / 2)))
                concat_layers.append(nn.BatchNorm2d(c_skip[i]))
                concat_layers.append(nn.LeakyReLU(0.2, inplace=True))
                model_tmp.add_module(get_name(name),Concat(1, nn.Sequential(*concat_layers), layer))
            else:
                model_tmp.add_module(get_name(name),layer)
            model_tmp.add_module(get_name(name),nn.BatchNorm2d(c_skip[i] +c_down[i]))
            model_tmp.add_module(get_name(name),nn.Conv2d(c_skip[i] + c_down[i], c_up[i], k_up, 1, padding=int((k_up - 1) / 2)))

        model_tmp.add_module(get_name(name),nn.BatchNorm2d(c_up[i]))
        model_tmp.add_module(get_name(name),nn.LeakyReLU(0.2, inplace=True))
        model_tmp.add_module(get_name(name),nn.Conv2d(c_up[i], c_up[i], 1, 1))
        model_tmp.add_module(get_name(name),nn.BatchNorm2d(c_up[i]))
        model_tmp.add_module(get_name(name),nn.LeakyReLU(0.2, inplace=True))
        input_depth = c_down[i]
        model_tmp = deeper_main

    model.add_module(get_name(name),nn.Conv2d(c_up[0], c_out, 1, 1))
    model.add_module(get_name(name),nn.Sigmoid())

    return model

In [None]:
def process_img(img_path, f=False):
    img = Image.open(fname)
    
    size = (img.size[0] - img.size[0] % 32, img.size[1] - img.size[1] % 32)

    bbox = [
            (img.size[0] - size[0])/2, 
            (img.size[1] - size[1])/2,
            (img.size[0] + size[0])/2,
            (img.size[1] + size[1])/2,
    ]

    img = img.crop(bbox)
    
    img_np = np.array(img)/255
    
    if len(img_np.shape) == 3:
        img_np = img_np.transpose(2,0,1)
    else:
        img_np = img_np[None, ...]
    
    if f:
        img_noise_np = np.clip((img_np + np.random.normal(scale=sigma, size=img_np.shape)*255.0).astype(np.uint8), 0, 255)
#         if img_np.shape[0] == 1:
#             img_np = img_np[0]
#         else:
        img_noise_np = img_noise_np.transpose(1, 2, 0)
        plt.figure()
        plt.title("img_noise_np")
        plt.imshow(img_noise_np)
        img_noise = Image.fromarray(img_noise_np)

        img_noise_np = img_noise_np.transpose(2,0,1)/255

    else:
        img_noise = img
        img_noise_np = img_np
        plt.figure()
        plt.title("img_noise_np")
        plt.imshow(img_np.transpose(1,2,0))
    
    plt.figure()
    plt.title("img_np")
    plt.imshow(img_np.transpose(1,2,0))
    
    
    return img, img_np, img_noise, img_noise_np

In [None]:
img_path = '/home/yujiaq3/deep-image-prior/data/denoising/snail.jpg'
img, img_np, img_noise, img_noise_np = process_img(img_path)

img_path = '/home/yujiaq3/deep-image-prior/data/denoising/F16_GT.png'
img, img_np, img_noise, img_noise_np = process_img(img_path, True)

In [None]:
reg_noise_std = 1/30 # set to 1./20. for sigma=50
LR = 0.01

exp_weight=0.99
psrn_noisy_pre = 0
pre_model = None

if 'snail.jpg' in img_path:
    num_iter = 2400
    input_depth = 3

    model = skip(input_depth, 3, 
               c_down = [8, 16, 32, 64, 128],
               c_up =   [8, 16, 32, 64, 128],
               c_skip =    [0, 0, 0, 4, 4],  
               k_up = 3, k_down = 3, 
               upsample_mode='bilinear', k_skip=1).type(dtype)

elif 'F16_GT.png' in img_path:
    num_iter = 3000
    input_depth = 32 
    
    model = skip(input_depth, 3, 
               c_down = [128] * 5,
               c_up =   [128] * 5,
               c_skip =    [4] * 5,  
               k_up = 3, k_down = 3, 
               upsample_mode='bilinear', k_skip=1).type(dtype)

mse = torch.nn.MSELoss().type(dtype)

net_input = torch.zeros([1, input_depth,img.size[1], img.size[0]])

net_input = net_input.type(dtype)
net_input.uniform_() #net_input.normal_()
net_input *= 0.1           
noise = net_input.detach().clone()
net_input_saved = net_input.detach().clone()

img_torch = (torch.from_numpy(img_np)[None, :]).type(dtype)

params = [x for x in model.parameters()]
optimizer = torch.optim.Adam(params, lr=LR)
PSNR = []
start = time.time()
for i in range(num_iter):
    optimizer.zero_grad()
    
    net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    out = model(net_input)
    
    out_avg = out.detach()

    loss = mse(out, img_torch)
    
    loss.backward()
    
    optimizer.step()
    
    out = out.detach().cpu().numpy()[0]
    out_avg = out_avg.detach().cpu().numpy()[0]
    
    psrn_noisy = compare_psnr(img_noise_np, out) 
    psrn_gt    = compare_psnr(img_np, out) 
    psrn_gt_sm = compare_psnr(img_np, out_avg)


    if i % 100 == 0:
        print("Op time: %f" % (time.time()-start))
        print ('Iteration %05d    Loss %f   PSNR_noisy: %f   PSRN_gt: %f PSNR_gt_sm: %f' % (i, loss.item(), psrn_noisy, psrn_gt, psrn_gt_sm), '\r', end='')
        
        if len(img_np.shape)<3:
            plt.figure()
            plt.title("out")
            plt.imshow(np.clip(out, 0, 1)[0])
            plt.figure()
            plt.title("out_avg")
            plt.imshow(np.clip(out_avg, 0, 1)[0])
        else:
            plt.figure()
            plt.title("out")
            plt.imshow(np.clip(out, 0, 1).transpose(1,2,0))
            plt.figure()
            plt.title("out_avg")
            plt.imshow(np.clip(out_avg, 0, 1).transpose(1,2,0))
        plt.show()
        
        if psrn_noisy_pre - psrn_noisy > 5:
            for prev_para, cur_para in zip(pre_model, model.parameters()):
                cur_para.data.copy_(prev_para.cuda())
            prev_noise = prev_noise
        else:
            pre_model = []
            for para in model.parameters():
                pre_model.append(para.detach().cpu())
            psrn_noisy_pre = psrn_noisy

            
print("Op time: %f" % (time.time()-start))
