Code for **"Blind restoration of a JPEG-compressed image"** and **"Blind image denoising"** figures. Select `fname` below to switch between the two.

- To see overfitting set `num_iter` to a large value.

In [None]:
"""
*Uncomment if running on colab*
Set Runtime -> Change runtime type -> Under Hardware Accelerator select GPU in Google Colab
"""
!git clone https://github.com/DmitryUlyanov/deep-image-prior
!mv deep-image-prior/* ./

# Import libs

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import numpy as np
from models import *

import torch
import torch.optim

from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from utils.denoising_utils import *

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True

dtype = torch.cuda.FloatTensor

imsize =-1
PLOT = True
sigma = 25
sigma_ = sigma/255.

In [None]:
# deJPEG
# fname = 'data/denoising/snail.jpg'

## denoising
fname = 'data/denoising/F16_GT.png'

# Load image

In [None]:
if fname == 'data/denoising/snail.jpg':
    img_noisy_pil = crop_image(get_image(fname, imsize)[0], d=32)
    img_noisy_np = pil_to_np(img_noisy_pil)

    # As we don't have ground truth
    img_pil = img_noisy_pil
    img_np = img_noisy_np

    if PLOT:
        plot_image_grid([img_np], 4, 5);

elif fname == 'data/denoising/F16_GT.png':
    # Add synthetic noise
    img_pil = crop_image(get_image(fname, imsize)[0], d=32)
    img_np = pil_to_np(img_pil)

    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)

    if PLOT:
        plot_image_grid([img_np, img_noisy_np], 4, 6);
else:
    assert False

# Setup

In [None]:
INPUT = 'noise' # 'meshgrid'
pad = 'reflection'
OPT_OVER = 'net' # 'net,input'

reg_noise_std = 1./30. # set to 1./20. for sigma=50
LR = 0.01

OPTIMIZER='adam' # 'LBFGS'
show_every = 100
exp_weight=0.99

if fname == 'data/denoising/snail.jpg':
    num_iter = 2400
    input_depth = 3
    figsize = 5

    net = skip(
                input_depth, 3,
                num_channels_down = [8, 16, 32, 64, 128],
                num_channels_up   = [8, 16, 32, 64, 128],
                num_channels_skip = [0, 0, 0, 4, 4],
                upsample_mode='bilinear',
                need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU')

    net = net.type(dtype)

elif fname == 'data/denoising/F16_GT.png':
    num_iter = 3000
    input_depth = 32
    figsize = 4


    net = get_net(input_depth, 'skip', pad,
                  skip_n33d=128,
                  skip_n33u=128,
                  skip_n11=4,
                  num_scales=5,
                  upsample_mode='bilinear').type(dtype)

else:
    assert False

net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach()

# Compute number of parameters
s  = sum([np.prod(list(p.size())) for p in net.parameters()]);
print ('Number of params: %d' % s)

# Loss
mse = torch.nn.MSELoss().type(dtype)

img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

# Optimize

In [None]:
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()
out_avg = None
last_net = None
psrn_noisy_last = 0

i = 0
psrn = []
net = get_net(input_depth, 'skip', pad,
              skip_n33d=128,
              skip_n33u=128,
              skip_n11=4,
              num_scales=5,
              upsample_mode='bilinear').type(dtype)
def plot_psnr(psrn):
    plt.figure(figsize=(10, 6))
    plt.plot(psrn, label='PSNR over iterations', color='blue', linestyle='-', marker='o')
    plt.title('PSNR over Iterations')
    plt.xlabel('Iteration')
    plt.ylabel('PSNR')
    plt.legend(loc='upper left')
    plt.grid(True)
    plt.ylim([min(psrn) - 1, max(psrn) + 1])
    plt.tight_layout()
    plt.show()

# closure 函數內部
def closure():
    global i, out_avg, psrn_noisy_last, last_net, net_input

    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)

    out = net(net_input)

    # 平滑處理
    if out_avg is None:
        out_avg = out.detach()
    else:
        out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

    total_loss = mse(out, img_noisy_torch)
    total_loss.backward()

    psrn_noisy = compare_psnr(img_noisy_np, out.detach().cpu().numpy()[0])
    psrn_gt = compare_psnr(img_np, out.detach().cpu().numpy()[0])

    if PLOT and i % show_every == 0:
        out_np = torch_to_np(out)
        plot_image_grid([img_np, np.clip(out_np, 0, 1)], factor=figsize)
        psrn.append(psrn_gt)
        plot_psnr(psrn)

    if i % show_every:
        if psrn_noisy - psrn_noisy_last < -5:
            print('Falling back to previous checkpoint.')
            for new_param, net_param in zip(last_net, net.parameters()):
                net_param.data.copy_(new_param.cuda())
            return total_loss * 0
        else:
            last_net = [x.detach().cpu() for x in net.parameters()]
            psrn_noisy_last = psrn_noisy

    i += 1

    return total_loss

## 原始DIP


In [None]:
sigma = 25
sigma_ = sigma / 255
img_pil = crop_image(get_image(fname, imsize)[0], d=32)
img_np = pil_to_np(img_pil)
img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)
p =plot_image_grid( [img_noisy_np])

In [None]:

sigma = 25
num_iter = 500
# 優化過程
normal_dip_plot = []
for i in range(5):
    sigma_ = sigma / 255.
    img_pil = crop_image(get_image(fname, imsize)[0], d=32)
    img_np = pil_to_np(img_pil)
    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)
    if PLOT:
        print("Original image and Current noise image")
        plot_image_grid([img_np, img_noisy_np], 4, 6)
    p = get_params(OPT_OVER, net, net_input)
    optimize(OPTIMIZER, p, closure, LR, num_iter)
    out_np = torch_to_np(net(net_input))
    normal_dip_plot.append(out_np)

In [None]:
def plot_generate_reults(dip_plot):
  fig, axes = plt.subplots(1, 5, figsize=(20, 10))

  # Plot each tensor in a different subplot
  for i, tensor in enumerate(dip_plot):
      # Combine the 3 channels into one image
      combined_image = np.transpose(tensor, (1, 2, 0))
      axes[i].imshow(combined_image)
      axes[i].axis('off')  # Hide the axes

  plt.show()
plot_generate_reults(normal_dip_plot)

## 噪音從多到少

In [None]:
sigma = 75
image = []
for i in range(5):
  sigma -= 10
  sigma_ = sigma/255.
  img_pil = crop_image(get_image(fname, imsize)[0], d=32)
  img_np = pil_to_np(img_pil)
  img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)
  image.append(img_noisy_np)
plot_generate_reults(image)

In [None]:
import matplotlib.pyplot as plt
psrn = []
net = get_net(input_depth, 'skip', pad,
              skip_n33d=128,
              skip_n33u=128,
              skip_n11=4,
              num_scales=5,
              upsample_mode='bilinear').type(dtype)
# Compute number of parameters
s  = sum([np.prod(list(p.size())) for p in net.parameters()]);
print ('Number of params: %d' % s)

# Loss
mse = torch.nn.MSELoss().type(dtype)

img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

sigma = 75
num_iter = 500
mosttoleast_dip_plot = []
for i in range(5):
  sigma -= 10
  sigma_ = sigma/255.
  img_pil = crop_image(get_image(fname, imsize)[0], d=32)
  img_np = pil_to_np(img_pil)
  img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)
  if PLOT:
    print(f"Original image and Current noise image (noise level{i+1}):")
    plot_image_grid([img_np, img_noisy_np], 4, 6);
  p = get_params(OPT_OVER, net, net_input)
  optimize(OPTIMIZER, p, closure, LR, num_iter)
  out_np = torch_to_np(net(net_input))
  mosttoleast_dip_plot.append(out_np)

In [None]:
plot_generate_reults(mosttoleast_dip_plot)

## 噪音從少到多

In [None]:
import matplotlib.pyplot as plt
psrn = []
net = get_net(input_depth, 'skip', pad,
              skip_n33d=128,
              skip_n33u=128,
              skip_n11=4,
              num_scales=5,
              upsample_mode='bilinear').type(dtype)
# Compute number of parameters
s  = sum([np.prod(list(p.size())) for p in net.parameters()]);
print ('Number of params: %d' % s)

# Loss
mse = torch.nn.MSELoss().type(dtype)

sigma = 25
num_iter = 500
leasttomost_dip_plot = []
for i in range(5):
  sigma += 10
  sigma_ = sigma/255.
  img_pil = crop_image(get_image(fname, imsize)[0], d=32)
  img_np = pil_to_np(img_pil)
  img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)
  img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)
  if PLOT:
    print(f"Original image and Current noise image (noise level{i+1}):")
    plot_image_grid([img_np, img_noisy_np], 4, 6);
  p = get_params(OPT_OVER, net, net_input)
  optimize(OPTIMIZER, p, closure, LR, num_iter)
  out_np = torch_to_np(net(net_input))
  leasttomost_dip_plot.append(out_np)

In [None]:
plot_generate_reults(leasttomost_dip_plot)

## 從大到小，噪聲級別大

In [None]:
import matplotlib.pyplot as plt
psrn = []
net = get_net(input_depth, 'skip', pad,
              skip_n33d=128,
              skip_n33u=128,
              skip_n11=4,
              num_scales=5,
              upsample_mode='bilinear').type(dtype)
# Compute number of parameters
s  = sum([np.prod(list(p.size())) for p in net.parameters()]);
print ('Number of params: %d' % s)

# Loss
mse = torch.nn.MSELoss().type(dtype)

img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

sigma = 150
num_iter = 500
big_mosttoleast_dip_plot = []
for i in range(5):
  sigma -= 25
  sigma_ = sigma/255.
  img_pil = crop_image(get_image(fname, imsize)[0], d=32)
  img_np = pil_to_np(img_pil)
  img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)
  if PLOT:
    print(f"Original image and Current noise image (noise level{i+1}):")
    plot_image_grid([img_np, img_noisy_np], 4, 6);
  p = get_params(OPT_OVER, net, net_input)
  optimize(OPTIMIZER, p, closure, LR, num_iter)
  out_np = torch_to_np(net(net_input))
  big_mosttoleast_dip_plot.append(out_np)

In [None]:
plot_generate_reults(big_mosttoleast_dip_plot)

## 從小到大，噪聲級別大

In [None]:
import matplotlib.pyplot as plt
psrn = []
net = get_net(input_depth, 'skip', pad,
              skip_n33d=128,
              skip_n33u=128,
              skip_n11=4,
              num_scales=5,
              upsample_mode='bilinear').type(dtype)
# Compute number of parameters
s  = sum([np.prod(list(p.size())) for p in net.parameters()]);
print ('Number of params: %d' % s)

# Loss
mse = torch.nn.MSELoss().type(dtype)

img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

sigma = 25
num_iter = 500
big_leasttomost_dip_plot = []
for i in range(5):
  sigma += 25
  sigma_ = sigma/255.
  img_pil = crop_image(get_image(fname, imsize)[0], d=32)
  img_np = pil_to_np(img_pil)
  img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)
  if PLOT:
    print(f"Original image and Current noise image (noise level{i+1}):")
    plot_image_grid([img_np, img_noisy_np], 4, 6);
  p = get_params(OPT_OVER, net, net_input)
  optimize(OPTIMIZER, p, closure, LR, num_iter)
  out_np = torch_to_np(net(net_input))
  big_leasttomost_dip_plot.append(out_np)

In [None]:
plot_generate_reults(big_leasttomost_dip_plot)