In [None]:
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imread
import os

from math import log10, sqrt

DATA_PATH = './CroppedYalePNG'

class DataLoader():
	def __init__(self):
		self.data_path = DATA_PATH
		self.objs = [i for i in range(1, 40)]
		self.objs.remove(14) # data for object 11-18 is not complete, and no 14
		self.lights = set()

		for root, dirs, files in os.walk(DATA_PATH, topdown=False):
			for name in files:
				filename = os.path.join(DATA_PATH, name)
				if 'Ambient' in filename:
					pass
				else:

					parse = filename.split('_')[1][3:].split('.')[0]
					self.lights = self.lights.union([parse])
		self.lights = list(self.lights)

	def load_img(self, obj, light):
		obj = str(obj)
		obj = (2 - len(obj)) * '0' + obj
		filename = 'yaleB' + obj + '_P00' + light + '.png'
		filename = os.path.join(self.data_path, filename)
		try:
			img = imread(filename)
			return img / 255
		except FileNotFoundError:
			return None

	def load_all(self, skip=None):
		imgs = []
		for o in self.objs:
			for l in self.lights:
				if skip is not None and (o,l) in skip:
					continue
				img = self.load_img(o, l)
				if img is not None: imgs.append(img)
		return np.stack(imgs)

	def load_lights(self, obj):
		imgs = []
		for l in self.lights:
			img = self.load_img(obj, l)
			if img is not None: imgs.append(img)
		return np.stack(imgs)

	def load_obj(self, light):
		imgs = []
		for obj in self.objs:
			img = self.load_img(obj, light)
			if img is not None: imgs.append(img)
		return np.stack(imgs)

def add_noise(img, std):
	std = std / 255.
	noise = np.random.normal(scale=std, size=img.shape)
	img = np.clip(img + noise, 0, 1).astype(np.float32)
	return img

def PSNR(original, compressed):
	mse = np.mean((original - compressed) ** 2)
	if(mse == 0):
		return 100
	max_pixel = 1
	psnr = 20 * log10(max_pixel / sqrt(mse))
	return psnr

# Markov Random Field

here implemented Markov Random Field and Gibbs sampling, where we sample $$Y_{ij} \sim P(Y_{ij}|others)$$. In markov random field, given the markov blanket of $Y_{ij}$, which is $Y_{i-1,j}, Y_{i+1,j}, Y_{i,j-1}, Y_{i,j+1}, X_{ij}$, we can easily compute the conditional probability.

In MRF, we use the energy function
$$E(X,Y) = \alpha \sum_i \sum_j ||X_{ij} - Y_{ij}||_1 + \beta \sum_{i}\sum_{j} \sum_{n \in N(ij)} ||Y_{ij} - Y_n||_1$$
where $N_{ij}$ is the neighbors of $Y_{ij}$, then the probability becomes
$$P(X,Y) = \frac{1}{Z} exp(-E(X,Y))$$


In [None]:
import matplotlib.pyplot as plt
import numpy as np
from data_util import DataLoader, add_noise

class MRFGibbs():
	def __init__(self, alpha, beta):
		self.alpha = alpha
		self.beta = beta

	def compute_conditional(self, X, Y, i, j):
		m,n = X.shape
		ys = np.arange(0, 256) / 255
		es = self.beta * np.abs(ys - X[i,j]) ** 1

		if i-1 >= 0: es += self.alpha * np.abs(ys - Y[i-1,j]) ** 1
		if i+1 <  m: es += self.alpha * np.abs(ys - Y[i+1,j]) ** 1
		if j-1 >= 0: es += self.alpha * np.abs(ys - Y[i,j-1]) ** 1
		if j+1 <  n: es += self.alpha * np.abs(ys - Y[i,j+1]) ** 1

		ps = np.exp(-es)
		ps = ps / np.sum(ps)
		return ps

	def gibbs_sampling_one_iter(self, X, Y):
		m,n = X.shape
		Is = np.arange(0,m); np.random.shuffle(Is)
		Js = np.arange(0,n); np.random.shuffle(Js)
		Y_copy = Y.copy()

		for i in Is:
			for j in Js:
				ps = self.compute_conditional(X, Y_copy, i, j)
				idx = np.random.choice(np.arange(len(ps)), p=ps)
				# idx = np.argmax(ps)
				Y_copy[i,j] = idx / 255
		return Y_copy

	def fit(self, X, num_burn=100, num_sampling=1000):
		Y_copy = X.copy()
		for _ in range(num_burn):
			Y_copy = self.gibbs_sampling_one_iter(X, Y_copy)
			# fig, (ax1, ax2) = plt.subplots(1,2)
			# ax1.imshow(Y, cmap='gray')
			# ax2.imshow(Y_copy, cmap='gray')
			# plt.show()

		for _ in range(num_sampling):
			Y_copy = self.gibbs_sampling_one_iter(X, Y_copy)
			# fig, (ax1, ax2) = plt.subplots(1,2)
			# ax1.imshow(Y, cmap='gray')
			# ax2.imshow(Y_copy, cmap='gray')
			# plt.show()

		return Y_copy

In [None]:
d = DataLoader()
img = d.load_img(1, 'A-025E+00')
img_noise = add_noise(img)

fig, (ax1, ax2) = plt.subplots(1,2)
ax1.imshow(img, cmap='gray')
ax2.imshow(img_noise, cmap='gray')
plt.show()

mrf = MRFGibbs(1, 1)
mrf.fit(img_noise, img_noise)

The result of MRF under different scale of noise is shown here

![image info](./mrf.png)

Here we use $\alpha=20$ and $\beta=20$. It's not equals to $\alpha=2$ and $\beta=2$. Since for $\alpha=2$ and $\beta=2$, the probability would be in the "flat" region of exponential function, and the sampling will be very close to uniform sampling and the result is very bad.

# Total Variation Minimization

In Total Variation Minimzation, we solve a very simple optimization problem

$$
	min_{\hat{X},E} \quad TotalVariation(\hat{X}) + \lambda\|E\|_1\\
	s.t.\quad X = \hat{X} + E
$$
where the total variation is
$$
	TotalVariation(X) = \|X[:,1:] - X[:,:-1]\|_1 + \|X[1:,:] - X[:-1,:]\|_1
$$

In [None]:
import cvxpy
import matplotlib.pyplot as plt

from data_util import DataLoader, add_noise

class TVDenoise():
	def __init__(self, lamb):
		self.lamb = lamb

	def fit(self, img):
		m,n = img.shape
		X = cvxpy.Variable([m, n])
		E = cvxpy.Variable(shape=[m, n])
		p = cvxpy.Problem(
			cvxpy.Minimize(
				cvxpy.sum(cvxpy.abs(X[1:, :] - X[:-1, :])) + cvxpy.sum(
					cvxpy.abs(X[:, 1:] - X[:, :-1])) + 2. * cvxpy.sum(cvxpy.abs(E))
			),
			[img == X + E]
		)
		p.solve()
		return X.value

The result of MRF under different scale of noise is shown here

[<img src="./TV.png" width="500"/>](image.png)

# Low Rank Representation
Here, we flatten all images and stack them along the column axis. Since there are many images and only a few people, we assume the matrix X can be decomposed as a low rank matrix plus some error term


[<img src="./lr.png" width="800"/>](image.png)
Therefore, we can formulate our optimization problem as following:

$$
	min_{Z,E}\quad Rank(Z) + \lambda \|E\|_{2,1}\\
	s.t. \quad X=XZ+E
$$
where $X$ represents data, $XZ$ is the low rank representation and $E$ is the error term.
Since minimizing matrix rank is NP hard, we can approximate it by nuclear norm minimization. Then the optimization problem becomes
$$
	min_{Z,E}\quad \|Z\|_{*}+\lambda \|E\|_{2,1}\\
	s.t.\quad \quad X=XZ+E
$$
and we can solve it with Augmented Lagrangian Multiplier.

In [None]:
import numpy as np
import scipy.sparse as SP

def solve_l2(w:np.array, alpha):
	"""
	Solve the optimization problem:
		min. alpha * ||x||_2 + 0.5 * ||x-w||_2^2
	"""
	nw = np.linalg.norm(w)
	if nw > alpha:
		return (nw - alpha) / nw * w
	else:
		return np.zeros_like(w)

def solve_l21(W:np.array, alpha):
	"""
	Solve the optimization problem
		min. alpha * ||E||_{2,1} + ||E - W||_2^2
	"""
	E = np.zeros_like(W)
	m, n = W.shape
	for i in range(n):
		E[:,i] = solve_l2(W[:,i], alpha)
	return E

def singular_value_shrink(A, alpha):
	U, lamb, VT = np.linalg.svd(A)
	idx = np.where(lamb > alpha)[0]
	# print(idx)
	if len(idx) == 0:
		return np.zeros_like(A)
	elif len(idx) == 1:
		return np.outer(U[:,0], (lamb[0] - alpha) * VT[0,:])
	else:
		diags = np.maximum(lamb[idx] - alpha,0)
		return U[:,idx] @ SP.diags(diags).dot(VT[idx,:])

def __solve_low_rank_representation(X:np.array, A:np.array, lamb):
	"""
	Solve the nuclear-norm optimization problem
		min. ||Z||_* + lamb * ||E||_{2,1}
		s.t. X = AZ + E
	by solving the equivalent problem
		min. ||J||_* + lamb * ||E||_{2,1}
		s.t. X = AZ + E
			 Z = J
	using Augmented Lagrangian Multiplier

	:param X: Of shape [d, n], d is data dimensions, n is number of data
	:param A: Of shape [d, m], a dictionary
	:param lamb:
	:return:
	"""
	print(X.shape, A.shape)

	tol = 1e-8
	maxIter = 1e6
	d, n = X.shape
	d, m = A.shape
	c = 1e-3
	maxc = 1e10
	rho = 1.2

	# Initialize primal variables
	J = np.zeros(shape=[m,n])
	Z = np.zeros(shape=[m,n])
	# E = np.zeros(shape=[d,n])
	E = np.random.randn(d,n)

	# Initialize dual variables
	Y1 = np.zeros(shape=[d,n])
	Y2 = np.zeros(shape=[m,n])

	# Some data that will be used many times
	inv_atapi = np.linalg.inv(A.T @ A + np.eye(m))
	atx = A.T @ X

	iter = 0
	convergence = False
	ranks = []
	while not convergence:
		iter += 1

		# Update primal variabble J
		J = singular_value_shrink(Z + Y2 / c, 1/c)

		# Update primal variable Z
		Z = inv_atapi @ (atx - A.T@E + J + (A.T @ Y1 - Y2)/c)


		# Update primal variable E
		xmaz = X - A @ Z
		E = solve_l21(xmaz + Y1 / c, alpha=lamb/c)
		# E = solve_l1_plus_fro(xmaz + Y1 / c, c / (2 * lamb))
		# Update dual variable Y1 and Y2
		leq1 = xmaz - E # linear equality 1
		leq2 = Z - J    # linear equality 2
		Y1 += c * leq1
		Y2 += c * leq2

		ranks.append(np.linalg.matrix_rank(Z,tol=1e-3*np.linalg.norm(Z, 2)))
		# if iter % 10 == 0:
		print("Iteration %d, c:%.6f, Rank:%d, Equality 1 violation: %.5f, Equality 2 violation: %.5f"
			  %
			  (iter, c, np.linalg.matrix_rank(Z,tol=1e-3*np.linalg.norm(Z, 2)), np.max(np.abs(leq1)), np.max(np.abs(leq2)))
		)
		c = min(c * rho, maxc)


		if max(np.max(np.abs(leq1)), np.max(np.abs(leq2))) < tol:
			convergence = True

	return Z, E, ranks

def solve_low_rank_representation(X, lamb):
	"""

	:param X: Of shape [d, n], d is data dimensions, n is number of data
	:param lamb:
	:return:
	"""
	# Q = orth(X.T)
	# A = X @ Q
	Z, E, ranks = __solve_low_rank_representation(X, X, lamb)


	return X @ Z, E, ranks

class LowRankDenoise():
	def __init__(self, lamb, data):
		self.lamb = lamb
		self.data = np.reshape(data, [data.shape[0], -1])

	def fit(self, imgs):
		"""
		:param img: of shape [m,h,w]
		"""
		m, h, w = imgs.shape
		imgs = np.reshape(imgs, [m, h * w])
		data = np.concatenate([imgs, self.data]).T
		recover, E, ranks = solve_low_rank_representation(data, self.lamb)
		recover = recover.T[:m]
		recover = np.reshape(recover, [m,h,w])
		return recover, ranks

objs = [
		(1, 'A-020E+10'),
		(2, 'A-010E+00'),
		(3, 'A-020E-10'),
		(6, 'A+020E+10'),
		(7, 'A+000E+00')
	]
std = 25
lamb = 0.005

d = DataLoader()
data = d.load_all(skip=objs)[:1000]

model = LowRankDenoise(lamb, data)

imgs_origin = []
imgs_noisy = []
for idx, light in objs:
    img = d.load_img(idx, light)
    img_noise = add_noise(img, std)

    imgs_origin.append(img)
    imgs_noisy.append(img_noise)

imgs_noisy = np.array(imgs_noisy)

imgs_denoise, ranks, others = model.fit(imgs_noisy)

for i, (idx, light) in enumerate(objs):
    result = PSNR(imgs_origin[i], imgs_noisy[i])
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 10))
    ax1.imshow(imgs_origin[i], cmap='gray')
    ax1.set_title('Original')
    ax2.imshow(imgs_noisy[i], cmap='gray')
    ax2.set_title(r'Noised with Gaussian $\sigma^2=%d$' % (std))
    ax3.imshow(imgs_denoise[i], cmap='gray')
    ax3.set_title('After Denoising by Low Rank Representation \n PSNR %.2f dB' % (result))
    plt.legend()
    # plt.show()
    plt.savefig('./%d_%s_%d.png' % (idx, light, std))

The rank of $XZ$ decreases over iterations as shown in the following figure.
[<img src="./ranks.jpg" width="400"/>](image.png)

The result of low rank representaion is shown in the figure below.

[<img src="./lowrank.png" width="400"/>](image.png)

The PSNR score is not very high.  However, low rank representation has some interesting side effect. he left part is the original image, and the right part is the low-rank representation. We can see that this method recover some corruption in the image. In the first image, the shade near the nose is eliminated, and in the second image, it kind of recovered the left eye. Maybe we can use Low-rank representation for image restoration.

[<img src="./lowrank_se.png" width="400"/>](image.png)

# UNet

Lehtinen, Jaakko, et al. "Noise2Noise: Learning image restoration without clean data." arXiv preprint arXiv:1803.04189 (2018).
We also tried deep learning method. We used a U-net, where we add noise to corrupted image as input and the corrupted images without additional noise as label. We use 2k images as training set, and 200 images for validation set. For the input, we randomly add noise to the (corrupted) original image. Here is the result of the U-net model. The PSNR values are higher than all previous models in for all kinds of noises, and the visual effect is also very good.

In [None]:
import torch
import torch.nn as nn


class UNet(nn.Module):
    """Custom U-Net architecture for Noise2Noise (see Appendix, Table 2)."""

    def __init__(self, in_channels=3, out_channels=3):
        """Initializes U-Net."""

        super(UNet, self).__init__()

        # Layers: enc_conv0, enc_conv1, pool1
        self._block1 = nn.Sequential(
            nn.Conv2d(in_channels, 48, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(48, 48, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2))

        # Layers: enc_conv(i), pool(i); i=2..5
        self._block2 = nn.Sequential(
            nn.Conv2d(48, 48, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2))

        # Layers: enc_conv6, upsample5
        self._block3 = nn.Sequential(
            nn.Conv2d(48, 48, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(48, 48, 3, stride=2, padding=1, output_padding=1))
            #nn.Upsample(scale_factor=2, mode='nearest'))

        # Layers: dec_conv5a, dec_conv5b, upsample4
        self._block4 = nn.Sequential(
            nn.Conv2d(96, 96, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 96, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1))
            #nn.Upsample(scale_factor=2, mode='nearest'))

        # Layers: dec_deconv(i)a, dec_deconv(i)b, upsample(i-1); i=4..2
        self._block5 = nn.Sequential(
            nn.Conv2d(144, 96, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(96, 96, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(96, 96, 3, stride=2, padding=1, output_padding=1))
            #nn.Upsample(scale_factor=2, mode='nearest'))

        # Layers: dec_conv1a, dec_conv1b, dec_conv1c,
        self._block6 = nn.Sequential(
            nn.Conv2d(96 + in_channels, 64, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, out_channels, 3, stride=1, padding=1),
            nn.LeakyReLU(0.1))

        # Initialize weights
        self._init_weights()


    def _init_weights(self):
        """Initializes weights using He et al. (2015)."""

        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight.data)
                m.bias.data.zero_()


    def forward(self, x):
        """Through encoder, then decoder by adding U-skip connections. """
        # Encoder
        pool1 = self._block1(x)
        pool2 = self._block2(pool1)
        pool3 = self._block2(pool2)
        pool4 = self._block2(pool3)
        pool5 = self._block2(pool4)

        # Decoder
        upsample5 = self._block3(pool5)
        concat5 = torch.cat((upsample5, pool4), dim=1)
        upsample4 = self._block4(concat5)
        concat4 = torch.cat((upsample4, pool3), dim=1)
        upsample3 = self._block5(concat4)
        concat3 = torch.cat((upsample3, pool2), dim=1)
        upsample2 = self._block5(concat3)
        concat2 = torch.cat((upsample2, pool1), dim=1)
        upsample1 = self._block5(concat2)
        concat1 = torch.cat((upsample1, x), dim=1)

        # Final activation
        return self._block6(concat1)

In [None]:
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as tvF

import os
import numpy as np
from math import log10
from datetime import datetime
import OpenEXR
from PIL import Image
import Imath

from matplotlib import rcParams
rcParams['font.family'] = 'serif'
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator


def clear_line():
    """Clears line from any characters."""

    print('\r{}'.format(' ' * 80), end='\r')


def progress_bar(batch_idx, num_batches, report_interval, train_loss):
    """Neat progress bar to track training."""

    dec = int(np.ceil(np.log10(num_batches)))
    bar_size = 21 + dec
    progress = (batch_idx % report_interval) / report_interval
    fill = int(progress * bar_size) + 1
    print('\rBatch {:>{dec}d} [{}{}] Train loss: {:>1.5f}'.format(batch_idx + 1, '=' * fill + '>', ' ' * (bar_size - fill), train_loss, dec=str(dec)), end='')


def time_elapsed_since(start):
    """Computes elapsed time since start."""

    timedelta = datetime.now() - start
    string = str(timedelta)[:-7]
    ms = int(timedelta.total_seconds() * 1000)

    return string, ms


def show_on_epoch_end(epoch_time, valid_time, valid_loss, valid_psnr):
    """Formats validation error stats."""

    clear_line()
    print('Train time: {} | Valid time: {} | Valid loss: {:>1.5f} | Avg PSNR: {:.2f} dB'.format(epoch_time, valid_time, valid_loss, valid_psnr))


def show_on_report(batch_idx, num_batches, loss, elapsed):
    """Formats training stats."""

    clear_line()
    dec = int(np.ceil(np.log10(num_batches)))
    print('Batch {:>{dec}d} / {:d} | Avg loss: {:>1.5f} | Avg train time / batch: {:d} ms'.format(batch_idx + 1, num_batches, loss, int(elapsed), dec=dec))


def plot_per_epoch(ckpt_dir, title, measurements, y_label):
    """Plots stats (train/valid loss, avg PSNR, etc.)."""

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(range(1, len(measurements) + 1), measurements)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.set_xlabel('Epoch')
    ax.set_ylabel(y_label)
    ax.set_title(title)
    plt.tight_layout()

    fname = '{}.png'.format(title.replace(' ', '-').lower())
    plot_fname = os.path.join(ckpt_dir, fname)
    plt.savefig(plot_fname, dpi=200)
    plt.close()


def load_hdr_as_tensor(img_path):
    """Converts OpenEXR image to torch float tensor."""

    # Read OpenEXR file
    if not OpenEXR.isOpenExrFile(img_path):
        raise ValueError(f'Image {img_path} is not a valid OpenEXR file')
    src = OpenEXR.InputFile(img_path)
    pixel_type = Imath.PixelType(Imath.PixelType.FLOAT)
    dw = src.header()['dataWindow']
    size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)

    # Read into tensor
    tensor = torch.zeros((3, size[1], size[0]))
    for i, c in enumerate('RGB'):
        rgb32f = np.fromstring(src.channel(c, pixel_type), dtype=np.float32)
        tensor[i, :, :] = torch.from_numpy(rgb32f.reshape(size[1], size[0]))

    return tensor


def reinhard_tonemap(tensor):
    """Reinhard et al. (2002) tone mapping."""

    tensor[tensor < 0] = 0
    return torch.pow(tensor / (1 + tensor), 1 / 2.2)


def psnr(input, target):
    """Computes peak signal-to-noise ratio."""

    return 10 * torch.log10(1 / F.mse_loss(input, target))


def create_montage(img_name, noise_type, save_path, source_t, denoised_t, clean_t, show):
    """Creates montage for easy comparison."""

    fig, ax = plt.subplots(1, 3, figsize=(30, 10))
    fig.canvas.set_window_title(img_name.capitalize()[:-4])

    # Bring tensors to CPU
    source_t = source_t.cpu().narrow(0, 0, 3)
    denoised_t = denoised_t.cpu()
    clean_t = clean_t.cpu()

    source = tvF.to_pil_image(source_t)
    denoised = tvF.to_pil_image(torch.clamp(denoised_t, 0, 1))
    clean = tvF.to_pil_image(clean_t)

    # Build image montage
    psnr_vals = [psnr(source_t, clean_t), psnr(denoised_t, clean_t)]
    titles = ['Original',
              'Gaussian noise with sigma^2=9',
              'Denoised by Noise2Noise\n PNSR {:.2f} dB'.format(psnr_vals[1])]
    zipped = zip(titles, [clean, source, denoised])
    for j, (title, img) in enumerate(zipped):
        ax[j].imshow(img)
        ax[j].set_title(title, fontsize=30)
        # ax[j].axis('off')

    # Open pop up window, if requested
    if show > 0:
        plt.show()

    # Save to files
    fname = os.path.splitext(img_name)[0]
    source.save(os.path.join(save_path, f'{fname}-{noise_type}-noisy.png'))
    denoised.save(os.path.join(save_path, f'{fname}-{noise_type}-denoised.png'))
    fig.savefig(os.path.join(save_path, f'{fname}-{noise_type}-montage.png'), bbox_inches='tight')


class AvgMeter(object):
    """Computes and stores the average and current value.
    Useful for tracking averages such as elapsed times, minibatch losses, etc.
    """

    def __init__(self):
        self.reset()


    def reset(self):
        self.val = 0
        self.avg = 0.
        self.sum = 0
        self.count = 0


    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam, lr_scheduler
import os
import json


class Noise2Noise(object):
    """Implementation of Noise2Noise from Lehtinen et al. (2018)."""

    def __init__(self, params, trainable):
        """Initializes model."""

        self.p = params
        self.trainable = trainable
        self._compile()


    def _compile(self):
        """Compiles model (architecture, loss function, optimizers, etc.)."""

        print('Noise2Noise: Learning Image Restoration without Clean Data (Lethinen et al., 2018)')

        # Model (3x3=9 channels for Monte Carlo since it uses 3 HDR buffers)
        if self.p.noise_type == 'mc':
            self.is_mc = True
            self.model = UNet(in_channels=9)
        else:
            self.is_mc = False
            self.model = UNet(in_channels=3)

        # Set optimizer and loss, if in training mode
        if self.trainable:
            self.optim = Adam(self.model.parameters(),
                              lr=self.p.learning_rate,
                              betas=self.p.adam[:2],
                              eps=self.p.adam[2])

            # Learning rate adjustment
            self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim,
                patience=self.p.nb_epochs/4, factor=0.5, verbose=True)

            # Loss function
            if self.p.loss == 'hdr':
                assert self.is_mc, 'Using HDR loss on non Monte Carlo images'
                self.loss = HDRLoss()
            elif self.p.loss == 'l2':
                self.loss = nn.MSELoss()
            else:
                self.loss = nn.L1Loss()

        # CUDA support
        self.use_cuda = torch.cuda.is_available() and self.p.cuda
        if self.use_cuda:
            self.model = self.model.cuda()
            if self.trainable:
                self.loss = self.loss.cuda()


    def _print_params(self):
        """Formats parameters to print when training."""

        print('Training parameters: ')
        self.p.cuda = self.use_cuda
        param_dict = vars(self.p)
        pretty = lambda x: x.replace('_', ' ').capitalize()
        print('\n'.join('  {} = {}'.format(pretty(k), str(v)) for k, v in param_dict.items()))
        print()


    def save_model(self, epoch, stats, first=False):
        """Saves model to files; can be overwritten at every epoch to save disk space."""

        # Create directory for model checkpoints, if nonexistent
        if first:
            if self.p.clean_targets:
                ckpt_dir_name = f'{datetime.now():{self.p.noise_type}-clean-%H%M}'
            else:
                ckpt_dir_name = f'{datetime.now():{self.p.noise_type}-%H%M}'
            if self.p.ckpt_overwrite:
                if self.p.clean_targets:
                    ckpt_dir_name = f'{self.p.noise_type}-clean'
                else:
                    ckpt_dir_name = self.p.noise_type

            self.ckpt_dir = os.path.join(self.p.ckpt_save_path, ckpt_dir_name)
            if not os.path.isdir(self.p.ckpt_save_path):
                os.mkdir(self.p.ckpt_save_path)
            if not os.path.isdir(self.ckpt_dir):
                os.mkdir(self.ckpt_dir)

        # Save checkpoint dictionary
        if self.p.ckpt_overwrite:
            fname_unet = '{}/n2n-{}.pt'.format(self.ckpt_dir, self.p.noise_type)
        else:
            valid_loss = stats['valid_loss'][epoch]
            fname_unet = '{}/n2n-epoch{}-{:>1.5f}.pt'.format(self.ckpt_dir, epoch + 1, valid_loss)
        print('Saving checkpoint to: {}\n'.format(fname_unet))
        torch.save(self.model.state_dict(), fname_unet)

        # Save stats to JSON
        fname_dict = '{}/n2n-stats.json'.format(self.ckpt_dir)
        with open(fname_dict, 'w') as fp:
            json.dump(stats, fp, indent=2)


    def load_model(self, ckpt_fname):
        """Loads model from checkpoint file."""

        print('Loading checkpoint from: {}'.format(ckpt_fname))
        if self.use_cuda:
            self.model.load_state_dict(torch.load(ckpt_fname))
        else:
            self.model.load_state_dict(torch.load(ckpt_fname, map_location='cpu'))


    def _on_epoch_end(self, stats, train_loss, epoch, epoch_start, valid_loader):
        """Tracks and saves starts after each epoch."""

        # Evaluate model on validation set
        print('\rTesting model on validation set... ', end='')
        epoch_time = time_elapsed_since(epoch_start)[0]
        valid_loss, valid_time, valid_psnr = self.eval(valid_loader)
        show_on_epoch_end(epoch_time, valid_time, valid_loss, valid_psnr)

        # Decrease learning rate if plateau
        self.scheduler.step(valid_loss)

        # Save checkpoint
        stats['train_loss'].append(train_loss)
        stats['valid_loss'].append(valid_loss)
        stats['valid_psnr'].append(valid_psnr)
        self.save_model(epoch, stats, epoch == 0)

        # Plot stats
        if self.p.plot_stats:
            loss_str = f'{self.p.loss.upper()} loss'
            plot_per_epoch(self.ckpt_dir, 'Valid loss', stats['valid_loss'], loss_str)
            plot_per_epoch(self.ckpt_dir, 'Valid PSNR', stats['valid_psnr'], 'PSNR (dB)')


    def test(self, test_loader, show, denoised_dir):
        """Evaluates denoiser on test set."""

        self.model.train(False)

        source_imgs = []
        denoised_imgs = []
        clean_imgs = []

        # Create directory for denoised images
        # denoised_dir = os.path.dirname(denoised_dir)
        # save_path = os.path.join(denoised_dir, 'denoised')
        save_path = denoised_dir
        if not os.path.isdir(save_path):
            os.mkdir(save_path)

        for batch_idx, (source, target) in enumerate(test_loader):
            # Only do first <show> images
            if show == 0 or batch_idx >= show:
                break

            source_imgs.append(source)
            clean_imgs.append(target)

            if self.use_cuda:
                source = source.cuda()

            # Denoise
            denoised_img = self.model(source).detach()
            denoised_imgs.append(denoised_img)

        # Squeeze tensors
        source_imgs = [t.squeeze(0) for t in source_imgs]
        denoised_imgs = [t.squeeze(0) for t in denoised_imgs]
        clean_imgs = [t.squeeze(0) for t in clean_imgs]

        # Create montage and save images
        print('Saving images and montages to: {}'.format(save_path))
        for i in range(len(source_imgs)):
            img_name = test_loader.dataset.imgs[i]
            create_montage(img_name, self.p.noise_type, save_path, source_imgs[i], denoised_imgs[i], clean_imgs[i], show)


    def eval(self, valid_loader):
        """Evaluates denoiser on validation set."""

        self.model.train(False)

        valid_start = datetime.now()
        loss_meter = AvgMeter()
        psnr_meter = AvgMeter()

        for batch_idx, (source, target) in enumerate(valid_loader):
            if self.use_cuda:
                source = source.cuda()
                target = target.cuda()

            # Denoise
            source_denoised = self.model(source)

            # Update loss
            loss = self.loss(source_denoised, target)
            loss_meter.update(loss.item())

            # Compute PSRN
            if self.is_mc:
                source_denoised = reinhard_tonemap(source_denoised)
            # TODO: Find a way to offload to GPU, and deal with uneven batch sizes
            for i in range(self.p.batch_size):
                source_denoised = source_denoised.cpu()
                target = target.cpu()
                psnr_meter.update(psnr(source_denoised[i], target[i]).item())

        valid_loss = loss_meter.avg
        valid_time = time_elapsed_since(valid_start)[0]
        psnr_avg = psnr_meter.avg

        return valid_loss, valid_time, psnr_avg


    def train(self, train_loader, valid_loader):
        """Trains denoiser on training set."""

        self.model.train(True)

        self._print_params()
        num_batches = len(train_loader)
        assert num_batches % self.p.report_interval == 0, 'Report interval must divide total number of batches'

        # Dictionaries of tracked stats
        stats = {'noise_type': self.p.noise_type,
                 'noise_param': self.p.noise_param,
                 'train_loss': [],
                 'valid_loss': [],
                 'valid_psnr': []}

        # Main training loop
        train_start = datetime.now()
        for epoch in range(self.p.nb_epochs):
            print('EPOCH {:d} / {:d}'.format(epoch + 1, self.p.nb_epochs))

            # Some stats trackers
            epoch_start = datetime.now()
            train_loss_meter = AvgMeter()
            loss_meter = AvgMeter()
            time_meter = AvgMeter()

            # Minibatch SGD
            for batch_idx, (source, target) in enumerate(train_loader):
                batch_start = datetime.now()
                progress_bar(batch_idx, num_batches, self.p.report_interval, loss_meter.val)

                if self.use_cuda:
                    source = source.cuda()
                    target = target.cuda()

                # Denoise image
                source_denoised = self.model(source)

                loss = self.loss(source_denoised, target)
                loss_meter.update(loss.item())

                # Zero gradients, perform a backward pass, and update the weights
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                # Report/update statistics
                time_meter.update(time_elapsed_since(batch_start)[1])
                if (batch_idx + 1) % self.p.report_interval == 0 and batch_idx:
                    show_on_report(batch_idx, num_batches, loss_meter.avg, time_meter.avg)
                    train_loss_meter.update(loss_meter.avg)
                    loss_meter.reset()
                    time_meter.reset()

            # Epoch end, save and reset tracker
            self._on_epoch_end(stats, train_loss_meter.avg, epoch, epoch_start, valid_loader)
            train_loss_meter.reset()

        train_elapsed = time_elapsed_since(train_start)[0]
        print('Training done! Total elapsed time: {}\n'.format(train_elapsed))


class HDRLoss(nn.Module):
    """High dynamic range loss."""

    def __init__(self, eps=0.01):
        """Initializes loss with numerical stability epsilon."""

        super(HDRLoss, self).__init__()
        self._eps = eps


    def forward(self, denoised, target):
        """Computes loss by unpacking render buffer."""

        loss = ((denoised - target) ** 2) / (denoised + self._eps) ** 2
        return torch.mean(loss.view(-1))


Here is the result of UNet model.
The PSNR values are higher than all previous models in for all kinds of noises, and the visual effect is also very good.

[<img src="./n2n.png" width="300"/>](image.png)

# Deep Image Prior

, one drawback of deep learning Is that it need a lot of data to have good performance. Can we get rid of this data-dependency? Our answer is yes. If we look and our previous methods, all of them can be formulated into the following  form.
$$
    min_{X} E(X,X_0) + R(X)
$$

We want to make the prediction image close to the noisy image under some regularization. In MRF and Total Variation Minimization, the regularization is that the prediction image has low variation, and in low rank representation, we assume the original images come from a low-rank subspace.  In deep image prior, we want the prediction to be very close to the noisy image, subject to the regularization that the prediction is produced by a deep CNN.  Here we use the architecture of CNN as prior. The motivation is, in Deep CNN, we stack some convolution layers together and wish it to extract some semantic meaning of the image, and they are pretty good at this than other architectures. So, the architecture of the Deep CNN may encodes prior knowledge about how an image should look like.  In Deep Image prior, we input a random noise and output an image, then optimize the weights of the CNN to make the prediction close to the noisy image.

In [None]:
import torch
import torch.nn as nn
import torchvision
import sys

import numpy as np
from PIL import Image
import PIL
import numpy as np

import matplotlib.pyplot as plt

def crop_image(img, d=32):
    '''Make dimensions divisible by `d`'''

    new_size = (img.size[0] - img.size[0] % d,
                img.size[1] - img.size[1] % d)

    bbox = [
            int((img.size[0] - new_size[0])/2),
            int((img.size[1] - new_size[1])/2),
            int((img.size[0] + new_size[0])/2),
            int((img.size[1] + new_size[1])/2),
    ]

    img_cropped = img.crop(bbox)
    return img_cropped

def get_params(opt_over, net, net_input, downsampler=None):
    '''Returns parameters that we want to optimize over.

    Args:
        opt_over: comma separated list, e.g. "net,input" or "net"
        net: network
        net_input: torch.Tensor that stores input `z`
    '''
    opt_over_list = opt_over.split(',')
    params = []

    for opt in opt_over_list:

        if opt == 'net':
            params += [x for x in net.parameters() ]
        elif  opt=='down':
            assert downsampler is not None
            params = [x for x in downsampler.parameters()]
        elif opt == 'input':
            net_input.requires_grad = True
            params += [net_input]
        else:
            assert False, 'what is it?'

    return params

def get_image_grid(images_np, nrow=8):
    '''Creates a grid from a list of images by concatenating them.'''
    images_torch = [torch.from_numpy(x) for x in images_np]
    torch_grid = torchvision.utils.make_grid(images_torch, nrow)

    return torch_grid.numpy()

def plot_image_grid(images_np, nrow =8, factor=1, interpolation='lanczos'):
    """Draws images in a grid

    Args:
        images_np: list of images, each image is np.array of size 3xHxW of 1xHxW
        nrow: how many images will be in one row
        factor: size if the plt.figure
        interpolation: interpolation used in plt.imshow
    """
    n_channels = max(x.shape[0] for x in images_np)
    assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels"

    images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]

    grid = get_image_grid(images_np, nrow)

    plt.figure(figsize=(len(images_np) + factor, 12 + factor))

    if images_np[0].shape[0] == 1:
        plt.imshow(grid[0], cmap='gray', interpolation=interpolation)
    else:
        plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation)

    plt.show()

    return grid

def load(path):
    """Load PIL image."""
    img = Image.open(path)
    return img

def get_image(path, imsize=-1):
    """Load an image and resize to a cpecific size.

    Args:
        path: path to image
        imsize: tuple or scalar with dimensions; -1 for `no resize`
    """
    img = load(path)

    if isinstance(imsize, int):
        imsize = (imsize, imsize)

    if imsize[0]!= -1 and img.size != imsize:
        if imsize[0] > img.size[0]:
            img = img.resize(imsize, Image.BICUBIC)
        else:
            img = img.resize(imsize, Image.ANTIALIAS)

    img_np = pil_to_np(img)

    return img, img_np



def fill_noise(x, noise_type):
    """Fills tensor `x` with noise of type `noise_type`."""
    if noise_type == 'u':
        x.uniform_()
    elif noise_type == 'n':
        x.normal_()
    else:
        assert False

def get_noise(input_depth, method, spatial_size, noise_type='u', var=1./10):
    """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`)
    initialized in a specific way.
    Args:
        input_depth: number of channels in the tensor
        method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid
        spatial_size: spatial size of the tensor to initialize
        noise_type: 'u' for uniform; 'n' for normal
        var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler.
    """
    if isinstance(spatial_size, int):
        spatial_size = (spatial_size, spatial_size)
    if method == 'noise':
        shape = [1, input_depth, spatial_size[0], spatial_size[1]]
        net_input = torch.zeros(shape)

        fill_noise(net_input, noise_type)
        net_input *= var
    elif method == 'meshgrid':
        assert input_depth == 2
        X, Y = np.meshgrid(np.arange(0, spatial_size[1])/float(spatial_size[1]-1), np.arange(0, spatial_size[0])/float(spatial_size[0]-1))
        meshgrid = np.concatenate([X[None,:], Y[None,:]])
        net_input=  np_to_torch(meshgrid)
    else:
        assert False

    return net_input

def pil_to_np(img_PIL):
    '''Converts image in PIL format to np.array.

    From W x H x C [0...255] to C x W x H [0..1]
    '''
    ar = np.array(img_PIL)

    if len(ar.shape) == 3:
        ar = ar.transpose(2,0,1)
    else:
        ar = ar[None, ...]

    return ar.astype(np.float32) / 255.

def np_to_pil(img_np):
    '''Converts image in np.array format to PIL image.

    From C x W x H [0..1] to  W x H x C [0...255]
    '''
    ar = np.clip(img_np*255,0,255).astype(np.uint8)

    if img_np.shape[0] == 1:
        ar = ar[0]
    else:
        ar = ar.transpose(1, 2, 0)

    return Image.fromarray(ar)

def np_to_torch(img_np):
    '''Converts image in numpy.array to torch.Tensor.

    From C x W x H [0..1] to  C x W x H [0..1]
    '''
    return torch.from_numpy(img_np)[None, :]

def torch_to_np(img_var):
    '''Converts an image in torch.Tensor format to np.array.

    From 1 x C x W x H [0..1] to  C x W x H [0..1]
    '''
    return img_var.detach().cpu().numpy()[0]


def optimize(optimizer_type, parameters, closure, LR, num_iter):
    """Runs optimization loop.

    Args:
        optimizer_type: 'LBFGS' of 'adam'
        parameters: list of Tensors to optimize over
        closure: function, that returns loss variable
        LR: learning rate
        num_iter: number of iterations
    """
    if optimizer_type == 'LBFGS':
        # Do several steps with adam first
        optimizer = torch.optim.Adam(parameters, lr=0.001)
        for j in range(100):
            optimizer.zero_grad()
            closure()
            optimizer.step()

        print('Starting optimization with LBFGS')
        def closure2():
            optimizer.zero_grad()
            return closure()
        optimizer = torch.optim.LBFGS(parameters, max_iter=num_iter, lr=LR, tolerance_grad=-1, tolerance_change=-1)
        optimizer.step(closure2)

    elif optimizer_type == 'adam':
        print('Starting optimization with ADAM')
        optimizer = torch.optim.Adam(parameters, lr=LR)

        for j in range(num_iter):
            optimizer.zero_grad()
            closure()
            optimizer.step()
    else:
        assert False


def get_noisy_image(img_np, sigma):
    """Adds Gaussian noise to an image.

    Args:
        img_np: image, np.array with values from 0 to 1
        sigma: std of the noise
    """
    img_noisy_np = np.clip(img_np + np.random.normal(scale=sigma, size=img_np.shape), 0, 1).astype(np.float32)
    img_noisy_pil = np_to_pil(img_noisy_np)

    return img_noisy_pil, img_noisy_np

In [None]:
import numpy as np
import torch
import torch.nn as nn

class Downsampler(nn.Module):
    '''
        http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf
    '''
    def __init__(self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False):
        super(Downsampler, self).__init__()

        assert phase in [0, 0.5], 'phase should be 0 or 0.5'

        if kernel_type == 'lanczos2':
            support = 2
            kernel_width = 4 * factor + 1
            kernel_type_ = 'lanczos'

        elif kernel_type == 'lanczos3':
            support = 3
            kernel_width = 6 * factor + 1
            kernel_type_ = 'lanczos'

        elif kernel_type == 'gauss12':
            kernel_width = 7
            sigma = 1/2
            kernel_type_ = 'gauss'

        elif kernel_type == 'gauss1sq2':
            kernel_width = 9
            sigma = 1./np.sqrt(2)
            kernel_type_ = 'gauss'

        elif kernel_type in ['lanczos', 'gauss', 'box']:
            kernel_type_ = kernel_type

        else:
            assert False, 'wrong name kernel'


        # note that `kernel width` will be different to actual size for phase = 1/2
        self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma)

        downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0)
        downsampler.weight.data[:] = 0
        downsampler.bias.data[:] = 0

        kernel_torch = torch.from_numpy(self.kernel)
        for i in range(n_planes):
            downsampler.weight.data[i, i] = kernel_torch

        self.downsampler_ = downsampler

        if preserve_size:

            if  self.kernel.shape[0] % 2 == 1:
                pad = int((self.kernel.shape[0] - 1) / 2.)
            else:
                pad = int((self.kernel.shape[0] - factor) / 2.)

            self.padding = nn.ReplicationPad2d(pad)

        self.preserve_size = preserve_size

    def forward(self, input):
        if self.preserve_size:
            x = self.padding(input)
        else:
            x= input
        self.x = x
        return self.downsampler_(x)

def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None):
    assert kernel_type in ['lanczos', 'gauss', 'box']

    # factor  = float(factor)
    if phase == 0.5 and kernel_type != 'box':
        kernel = np.zeros([kernel_width - 1, kernel_width - 1])
    else:
        kernel = np.zeros([kernel_width, kernel_width])


    if kernel_type == 'box':
        assert phase == 0.5, 'Box filter is always half-phased'
        kernel[:] = 1./(kernel_width * kernel_width)

    elif kernel_type == 'gauss':
        assert sigma, 'sigma is not specified'
        assert phase != 0.5, 'phase 1/2 for gauss not implemented'

        center = (kernel_width + 1.)/2.
        print(center, kernel_width)
        sigma_sq =  sigma * sigma

        for i in range(1, kernel.shape[0] + 1):
            for j in range(1, kernel.shape[1] + 1):
                di = (i - center)/2.
                dj = (j - center)/2.
                kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq))
                kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq)
    elif kernel_type == 'lanczos':
        assert support, 'support is not specified'
        center = (kernel_width + 1) / 2.

        for i in range(1, kernel.shape[0] + 1):
            for j in range(1, kernel.shape[1] + 1):

                if phase == 0.5:
                    di = abs(i + 0.5 - center) / factor
                    dj = abs(j + 0.5 - center) / factor
                else:
                    di = abs(i - center) / factor
                    dj = abs(j - center) / factor


                pi_sq = np.pi * np.pi

                val = 1
                if di != 0:
                    val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support)
                    val = val / (np.pi * np.pi * di * di)

                if dj != 0:
                    val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support)
                    val = val / (np.pi * np.pi * dj * dj)

                kernel[i - 1][j - 1] = val


    else:
        assert False, 'wrong method name'

    kernel /= kernel.sum()

    return kernel


In [None]:
import torch
import torch.nn as nn
import numpy as np

def add_module(self, module):
    self.add_module(str(len(self) + 1), module)

torch.nn.Module.add = add_module

class Concat(nn.Module):
    def __init__(self, dim, *args):
        super(Concat, self).__init__()
        self.dim = dim

        for idx, module in enumerate(args):
            self.add_module(str(idx), module)

    def forward(self, input):
        inputs = []
        for module in self._modules.values():
            inputs.append(module(input))

        inputs_shapes2 = [x.shape[2] for x in inputs]
        inputs_shapes3 = [x.shape[3] for x in inputs]

        if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
            inputs_ = inputs
        else:
            target_shape2 = min(inputs_shapes2)
            target_shape3 = min(inputs_shapes3)

            inputs_ = []
            for inp in inputs:
                diff2 = (inp.size(2) - target_shape2) // 2
                diff3 = (inp.size(3) - target_shape3) // 2
                inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])

        return torch.cat(inputs_, dim=self.dim)

    def __len__(self):
        return len(self._modules)


class GenNoise(nn.Module):
    def __init__(self, dim2):
        super(GenNoise, self).__init__()
        self.dim2 = dim2

    def forward(self, input):
        a = list(input.size())
        a[1] = self.dim2
        # print (input.data.type())

        b = torch.zeros(a).type_as(input.data)
        b.normal_()

        x = torch.autograd.Variable(b)

        return x


class Swish(nn.Module):
    """
        https://arxiv.org/abs/1710.05941
        The hype was so huge that I could not help but try it
    """
    def __init__(self):
        super(Swish, self).__init__()
        self.s = nn.Sigmoid()

    def forward(self, x):
        return x * self.s(x)


def act(act_fun = 'LeakyReLU'):
    '''
        Either string defining an activation function or module (e.g. nn.ReLU)
    '''
    if isinstance(act_fun, str):
        if act_fun == 'LeakyReLU':
            return nn.LeakyReLU(0.2, inplace=True)
        elif act_fun == 'Swish':
            return Swish()
        elif act_fun == 'ELU':
            return nn.ELU()
        elif act_fun == 'none':
            return nn.Sequential()
        else:
            assert False
    else:
        return act_fun()


def bn(num_features):
    return nn.BatchNorm2d(num_features)


def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'):
    downsampler = None
    if stride != 1 and downsample_mode != 'stride':

        if downsample_mode == 'avg':
            downsampler = nn.AvgPool2d(stride, stride)
        elif downsample_mode == 'max':
            downsampler = nn.MaxPool2d(stride, stride)
        elif downsample_mode  in ['lanczos2', 'lanczos3']:
            downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True)
        else:
            assert False

        stride = 1

    padder = None
    to_pad = int((kernel_size - 1) / 2)
    if pad == 'reflection':
        padder = nn.ReflectionPad2d(to_pad)
        to_pad = 0

    convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)


    layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
    return nn.Sequential(*layers)

In [None]:
import torch
import torch.nn as nn
from numpy.random import normal
from numpy.linalg import svd
from math import sqrt
import torch.nn.init

class ResidualSequential(nn.Sequential):
    def __init__(self, *args):
        super(ResidualSequential, self).__init__(*args)

    def forward(self, x):
        out = super(ResidualSequential, self).forward(x)
        # print(x.size(), out.size())
        x_ = None
        if out.size(2) != x.size(2) or out.size(3) != x.size(3):
            diff2 = x.size(2) - out.size(2)
            diff3 = x.size(3) - out.size(3)
            # print(1)
            x_ = x[:, :, diff2 /2:out.size(2) + diff2 / 2, diff3 / 2:out.size(3) + diff3 / 2]
        else:
            x_ = x
        return out + x_

    def eval(self):
        print(2)
        for m in self.modules():
            m.eval()
        exit()


def get_block(num_channels, norm_layer, act_fun):
    layers = [
        nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False),
        norm_layer(num_channels, affine=True),
        act(act_fun),
        nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False),
        norm_layer(num_channels, affine=True),
    ]
    return layers


class ResNet(nn.Module):
    def __init__(self, num_input_channels, num_output_channels, num_blocks, num_channels, need_residual=True, act_fun='LeakyReLU', need_sigmoid=True, norm_layer=nn.BatchNorm2d, pad='reflection'):
        '''
            pad = 'start|zero|replication'
        '''
        super(ResNet, self).__init__()

        if need_residual:
            s = ResidualSequential
        else:
            s = nn.Sequential

        stride = 1
        # First layers
        layers = [
            # nn.ReplicationPad2d(num_blocks * 2 * stride + 3),
            conv(num_input_channels, num_channels, 3, stride=1, bias=True, pad=pad),
            act(act_fun)
        ]
        # Residual blocks
        # layers_residual = []
        for i in range(num_blocks):
            layers += [s(*get_block(num_channels, norm_layer, act_fun))]

        layers += [
            nn.Conv2d(num_channels, num_channels, 3, 1, 1),
            norm_layer(num_channels, affine=True)
        ]

        # if need_residual:
        #     layers += [ResidualSequential(*layers_residual)]
        # else:
        #     layers += [Sequential(*layers_residual)]

        # if factor >= 2:
        #     # Do upsampling if needed
        #     layers += [
        #         nn.Conv2d(num_channels, num_channels *
        #                   factor ** 2, 3, 1),
        #         nn.PixelShuffle(factor),
        #         act(act_fun)
        #     ]
        layers += [
            conv(num_channels, num_output_channels, 3, 1, bias=True, pad=pad),
            nn.Sigmoid()
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, input):
        return self.model(input)

    def eval(self):
        self.model.eval()

In [None]:
import torch
import torch.nn as nn

def dcgan(inp=2,
          ndf=32,
          num_ups=4, need_sigmoid=True, need_bias=True, pad='zero', upsample_mode='nearest', need_convT = True):

    layers= [nn.ConvTranspose2d(inp, ndf, kernel_size=3, stride=1, padding=0, bias=False),
             nn.BatchNorm2d(ndf),
             nn.LeakyReLU(True)]

    for i in range(num_ups-3):
        if need_convT:
            layers += [ nn.ConvTranspose2d(ndf, ndf, kernel_size=4, stride=2, padding=1, bias=False),
                        nn.BatchNorm2d(ndf),
                        nn.LeakyReLU(True)]
        else:
            layers += [ nn.Upsample(scale_factor=2, mode=upsample_mode),
                        nn.Conv2d(ndf, ndf, kernel_size=3, stride=1, padding=1, bias=False),
                        nn.BatchNorm2d(ndf),
                        nn.LeakyReLU(True)]

    if need_convT:
        layers += [nn.ConvTranspose2d(ndf, 3, 4, 2, 1, bias=False),]
    else:
        layers += [nn.Upsample(scale_factor=2, mode='bilinear'),
                   nn.Conv2d(ndf, 3, kernel_size=3, stride=1, padding=1, bias=False)]


    if need_sigmoid:
        layers += [nn.Sigmoid()]

    model =nn.Sequential(*layers)
    return model

In [None]:
import numpy as np


import torch
import torch.optim

# from skimage.measure import compare_psnr
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import structural_similarity
import cv2

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

imsize =-1
PLOT = True
sigma = 25
sigma_ = sigma/255.
names = ['yaleB01_P00A-020E+10.png', 'yaleB02_P00A-010E+00.png', 'yaleB05_P00A-020E-10.png', \
         'yaleB06_P00A+020E+10.png', 'yaleB07_P00A+000E+00.png']

# deJPEG
# fname = 'data/denoising/0.jpg'

i = 4
## denoising
fname = 'data/denoising/'+names[i]
filename = 'data/denoised/sigma^2='+str(sigma)+'/B0'+str(i+1)+'_scale'+str(sigma)+'.png'

In [None]:
# de-JPEG
if fname == 'data/denoising/0.jpg':
    img_noisy_pil = crop_image(get_image(fname, imsize)[0], d=32)
    img_noisy_np = pil_to_np(img_noisy_pil)

    # As we don't have ground truth
    img_pil = img_noisy_pil
    img_np = img_noisy_np

    if PLOT:
        plot_image_grid([img_np], 4, 5);

else: #denoising
    # Add synthetic noise
    img_pil = crop_image(get_image(fname, imsize)[0], d=32)
    img_np = pil_to_np(img_pil)

    if img_np.shape[0] == 1:
        img_np = np.stack((img_np[0],)*3, axis=0)

    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)

    if PLOT:
        plot_image_grid([img_np, img_noisy_np], 4, 6);
# else:
#     assert False

In [None]:
INPUT = 'noise' # 'meshgrid'
pad = 'reflection'
OPT_OVER = 'net' # 'net,input'

reg_noise_std = 1./30. # set to 1./20. for sigma=50
LR = 0.01

OPTIMIZER='adam' # 'LBFGS'
show_every = 100
exp_weight=0.99

if fname == 'data/denoising/0.jpg':
    num_iter = 2400
    input_depth = 3
    figsize = 5

    net = skip(
                input_depth, 3,
                num_channels_down = [8, 16, 32, 64, 128],
                num_channels_up   = [8, 16, 32, 64, 128],
                num_channels_skip = [0, 0, 0, 4, 4],
                upsample_mode='bilinear',
                need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU')

    net = net.type(dtype)

else:
    num_iter = 1000
    input_depth = 32
    figsize = 4


    net = get_net(input_depth, 'skip', pad,
                  skip_n33d=128,
                  skip_n33u=128,
                  skip_n11=4,
                  num_scales=5,
                  upsample_mode='bilinear').type(dtype)

net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach()

# Compute number of parameters
s  = sum([np.prod(list(p.size())) for p in net.parameters()]);
print ('Number of params: %d' % s)

# Loss
mse = torch.nn.MSELoss().type(dtype)

img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

In [None]:
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()
out_avg = None
last_net = None
psrn_noisy_last = 0
psrn_gt_last = 0

i = 0
def closure():

    global i, out_avg, psrn_noisy_last, last_net, net_input, psrn_gt_last

    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)

    out = net(net_input)

    # Smoothing
    if out_avg is None:
        out_avg = out.detach()
    else:
        out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

    total_loss = mse(out, img_noisy_torch)
    total_loss.backward()

    psrn_noisy = peak_signal_noise_ratio(img_noisy_np, out.detach().cpu().numpy()[0])
    psrn_gt    = peak_signal_noise_ratio(img_np, out.detach().cpu().numpy()[0])
    psrn_gt_sm = peak_signal_noise_ratio(img_np, out_avg.detach().cpu().numpy()[0])
    # ssim_score = structural_similarity(img_noisy_np.transpose(1,2,0), out.detach().cpu().numpy()[0].transpose(1,2,0),multichannel=True)

    # Note that we do not have GT for the "snail" example
    # So 'PSRN_gt', 'PSNR_gt_sm' make no sense
    if  PLOT and i % show_every == 0:
        print(f'Iteration {i} Loss {total_loss.item()} PSNR_noisy {psrn_noisy} PSRN_gt: {psrn_gt}')
        out_np = torch_to_np(out)
        plot_image_grid([np.clip(out_np, 0, 1),
                         np.clip(torch_to_np(out_avg), 0, 1)], factor=figsize, nrow=1)



    # Backtracking
    if i % show_every:
        if psrn_noisy - psrn_noisy_last < -5:
            print('Falling back to previous checkpoint.')

            for new_param, net_param in zip(last_net, net.parameters()):
                net_param.data.copy_(new_param.cuda())

            return total_loss*0
        else:
            last_net = [x.detach().cpu() for x in net.parameters()]
            psrn_noisy_last = psrn_noisy

    i += 1

    return total_loss

p = get_params(OPT_OVER, net, net_input)
optimize(OPTIMIZER, p, closure, LR, num_iter)

In [None]:
out_np = torch_to_np(net(net_input))

psnr = peak_signal_noise_ratio(img_np, out_np)
ssim = structural_similarity(img_np.transpose(1,2,0), out_np.transpose(1,2,0), multichannel=True)

fig = plt.figure(figsize=(35, 20))
fig.add_subplot(1, 3, 1)
plt.imshow(img_np.transpose(1,2,0))
plt.title("Original",fontsize=30)

fig.add_subplot(1, 3, 2)
plt.imshow(img_noisy_np.transpose(1,2,0))
plt.title("Noised with Gaussian \u03C3^2="+str(sigma),fontsize=30)

fig.add_subplot(1, 3, 3)
plt.imshow(out_np.transpose(1,2,0))
plt.title(f"After Denoising by Deep Image Prior\n PSNR {psnr:.2f} dB",fontsize=30)

plt.savefig(filename)

And here is the result of the Deep Image Prior.

[<img src="./deepprior.png" width="400"/>](image.png)

Below is the result of all of our methods under different scale of noise

[<img src="./result.png" width="600"/>](image.png)

The U-net model achieves the highest PSNR value. For the deep image prior model, though it's not good as the U-net model, it gives a good result without using a lot of data (actually it only uses one image). And the total variation minimization, it can not beat the UNet based on PSNR score, but it can also be a good choice because of the cheap computation. As for low rank representation, even it has very low PSNR score, the visual effect is not bad and it is promising in other perspectives.