# Wavelet inpainting 
### Author: Pawel Budzynski

In [None]:
import functools
from typing import Callable, Tuple

import matplotlib.pyplot as plt
import numpy as np
import skimage.io
import scipy.signal


import nt_toolbox.general as nt_general
from nt_toolbox.perform_wavelet_transf import perform_wavelet_transf
import nt_toolbox.signal as nt_signal

# Collection of filters taken from matlab implementation.
from wavelets.wavelets import get_wavelet_filter

In [None]:
# Define shortcut for plotting images with the same setting. 
imshow = functools.partial(plt.imshow, cmap='gray', interpolation='none')
# Load the image and transform it into one channel grayscale. 
image = nt_signal.load_image('barbara.jpg', n=256)
imshow(image)
plt.show()

The goal of this exercise is to generate a direct problem 
\begin{equation}
y = Ax + b,
\end{equation}
where $x$ is the original image, $A$ is a binary matrix and $b$ is a white noise.

In the following steps Fast Iterative Shrinkage/Thresholding algorithm (FISTA) is going to be implemented to reconstruct image $x$ given the noisy signal $y$.

## Noisy image generation

The problem is generated by replacing missing value in the image with value $0$ and adding white noise to the clipped image. For simplicity, masking array $A$ is returned as well to easily identify which pixels are missing values.

In [None]:
def get_mask(
    size: Tuple[int],
    p: float = 0.5,
) -> np.ndarray:
    """Create a binary mask for given shape and probability p 
    for Bernoulli distribution.
    """
    return 1 - np.random.binomial(1, p, size=size)

In [None]:
p = 0.3
A = get_mask(image.shape, p)
b = np.random.normal(0, 0.1)


plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.title("Original image.")
imshow(image)
plt.subplot(1, 3, 2)
plt.title("With binary mask.")
imshow(A * image)
plt.subplot(1, 3, 3)
plt.title("With binary mask and noise.")
imshow(A * image + b)
plt.show()

### Noisy image for required SNR

In [None]:
def snr(
    x: np.ndarray, 
    y: np.ndarray,
) -> float:
    """Compute Signal to noise ratio given original signal 
    and signal with a noise added.
    
    Parameters:
        x: original signal. 
        y: noised signal.
    """
    return 10 * np.log10(np.sum(x**2) / np.sum((x - y)**2))

def snr_to_n_std(
    snr: float, 
    signal: np.ndarray,
) -> float:
    """Compute noise variance for requested SNR."""
    N = np.size(signal)
    return np.sqrt(np.sum(signal**2) / (10**(snr / 10) * N))

In [None]:
def generate_problem(
    image: np.ndarray,
    p: float = 0.5,
    sigma: float = 0.1,
) -> Tuple[np.ndarray]:
    A = get_mask(image.shape, p=p)
    b = np.random.normal(0, sigma, size=image.shape)
    return A * image + b, A

In [None]:
p = 0.2

plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
noisy_image, _ = generate_problem(image, p=p, sigma=0.1)
imshow(noisy_image)
plt.title(f"SNR: {round(snr(image, noisy_image), 2)}")
plt.subplot(1, 3, 2)
noisy_image, _ = generate_problem(image, p=p, sigma=0.01)
imshow(noisy_image)
plt.title(f"SNR: {round(snr(image, noisy_image), 2)}")
plt.subplot(1, 3, 3)
noisy_image, _ = generate_problem(image, p=p, sigma=0)
imshow(noisy_image)
plt.title(f"SNR: {round(snr(image, noisy_image), 2)}")
plt.show()

## Thresholding functions
Thresholding functions that will be used later in FISTA algorith implementation. 

In [None]:
def hard_threshold(
    arr: np.ndarray,
    t: float,
) -> np.ndarray:
    """Perform hard thresholding for the array of numbers."""
    return arr * (np.abs(arr) > t)


def soft_threshold(
    arr: np.ndarray,
    t: float,
) -> np.ndarray:
    """Perform soft thresholding on an array of values."""
    return arr * np.maximum(0, 1 - (t / np.abs(arr)))


def empirical_weiner(
    arr: np.ndarray,
    t: float,
) -> np.ndarray:
    """Perform thresholding using Empirical Weiner rule."""
    return arr * np.maximum(0, 1 - (t**2 / np.abs(arr)**2))

## FISTA

The iterative algorith is formulated as follows:

Initialize: $\alpha^{(0)}\in\mathbb{R}^N, z^{(0)}\in\mathbb{R}^N, L \leq ||A\Phi ||^2, t=0$

Do until convergence :
- $ α^{(t+1)} = \mathcal{S}_{\lambda/L} \left( z^{(t)} + \frac{1}{L}\Phi^* A^* \left( y - A\Phi z^{(t)} \right) \right)$,
- $z^{(t+1)} = \alpha^{(t+1)} + \frac{t}{t+5} \left( \alpha^{(t+1)} - \alpha^{(t)} \right)$,

where $\mathcal{S}_{\lambda}$ is thresholding function, $\Phi$ is a dictionary (frame), $A$ is a binary masking matrix and $\alpha$ are the coefficients.

Furthermore we can say that:
- $\Phi^{*} x$ - is a wavelet transform that creates coefficients for a signal $x$,
- $\Phi \alpha$ - is an inverse wavelet transform that reconstruct image from coefficients $\alpha$.

In [None]:
# Generate filter and define transform parameters.
filter_ = get_wavelet_filter('Daubechies', 4)
n_levels = 2

# Define wavelet transform and inverse transform functions.
Phi = lambda a: nt_signal.perform_wavortho_transf(a, n_levels, -1, filter_)
PhiS = lambda f: nt_signal.perform_wavortho_transf(f, n_levels, +1, filter_)

In [None]:
# Generate the problem.
y, A = generate_problem(image, p=0.5, sigma=0.01)

a = PhiS(y)
z = a.copy()
L = 1
lambd = 0.03

loss = []
for t in range(50):
    a_prev = a.copy()
    a = soft_threshold(z + (1/L) * PhiS(A*(y - A * Phi(z))), lambd/L)
    z = a + (t / (t + 5)) * (a - a_prev)
    
    loss.append(0.5 * np.sqrt(np.sum((y-a)**2)) + lambd * np.sum(np.abs(a)))

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
imshow(Phi(a))
plt.title("Reconstructed image.")
plt.subplot(1, 2, 2)
plt.plot(loss)
plt.title("Loss.")
plt.show()

##  Wavelet orhtogonal transform
Perform image inpainting using wavelet orthogonal transform.

In [None]:
def image_inpainting(
    noisy_image: np.ndarray,
    A: np.ndarray,
    L: float, 
    lambd: float,
    filter_: np.ndarray,
    threshold_fn: Callable,
    n_levels: int = 2,
    max_iter: int = 100,
    eps: float = 1e-3, 
):
    def Phi(a):
        # Inverse wavelet transform.
        return nt_signal.perform_wavortho_transf(a, n_levels, -1, filter_)
    def PhiS(f):
        # Wavelet transform.
        return nt_signal.perform_wavortho_transf(f, n_levels, +1, filter_)
    def loss_fn(y, a):
        # Objective function that is being minimized.
        return 0.5 * np.sqrt(np.sum((y-a)**2)) + lambd * np.sum(np.abs(a))
    
    y = noisy_image
    # Initialize algorith parameters.
    a = PhiS(y)
    z = a.copy()
    loss = [loss_fn(y, a)]
    
    for t in range(max_iter):
        # Save a(t-1) to use it later.
        a_prev = a.copy()
        
        a = threshold_fn(z + (1/L) * PhiS(A*(y - A * Phi(z))), lambd/L)
        z = a + (t / (t + 5)) * (a - a_prev)
    
        loss.append(loss_fn(y, a))
        
        # Break if converged before reaching max_iter.
        if abs(loss[-2] - loss[-1]) < eps:
            break
    
    return Phi(a), loss

Test the defined function on a generated problem.

In [None]:
noisy_image, A = generate_problem(image, p=0.5, sigma=0.01)

inpainted_image, loss = image_inpainting(
    noisy_image,
    A,
    L=1, 
    lambd=0.01,
    filter_=get_wavelet_filter('Daubechies', 4),
    threshold_fn=soft_threshold,
)

plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.title("Noised image.")
imshow(noisy_image)

plt.subplot(1, 2, 2)
plt.title("Inpainted image.")
imshow(inpainted_image)

plt.show()

plt.title("Loss.")
plt.plot(np.arange(0, len(loss)), loss)
plt.show()

## Translation invariant wavelets

In [None]:
def image_inpainting_transl_invar(
    noisy_image: np.ndarray,
    A: np.ndarray,
    lambd: float,
    threshold_fn: Callable,
    max_iter: int = 1000,
    loss_threshold: float = 1e-3,
    eps: float = 1e-3,
):
    # Get image size.
    n = image.shape[0]
    
    # Set algorithm parameters following numerical tours tutorial.
    Jmax = np.log2(n)-1
    Jmin = Jmax-3
    J = Jmax-Jmin + 1
    
    u = np.hstack(([4**(-J)], 4**(-np.floor(np.arange(J + 2./3,1,-1./3)))))
    U = np.transpose(np.tile(u, (n,n,1)),(2,0,1))
    
    Xi = lambda a: perform_wavelet_transf(a, Jmin, -1, ti=1)
    # image -> coeff
    PhiS = lambda f: perform_wavelet_transf(f, Jmin, +1, ti=1)
    # coeff -> image
    Phi = lambda a: Xi(a/U)
    def loss_fn(y, a):
        # Objective function that is being minimized.
        return 0.5 * np.sqrt(np.sum((y-a)**2)) + lambd * np.sum(np.abs(a))
    
    L = 1 / (1.9*np.min(u))
    
    y = noisy_image
    a = U*PhiS(y)
    z = a.copy()
    
    loss = [loss_fn(y, a)]
    for t in range(max_iter):
        a_prev = a.copy()

        a = threshold_fn(z + (1 / L) * PhiS(A * (y - A * Phi(z))), lambd / L)
        z = a + (t / (t + 5))*(a - a_prev)
        
        loss.append(loss_fn(y, a))
        
        if abs(loss[-2] - loss[-1]) < eps:
            break
        

    return nt_general.clamp(Phi(a)), loss

In [None]:
noisy_image, A = generate_problem(image, p=0.5, sigma=0.01)

inpainted_image, loss = image_inpainting_transl_invar(
    noisy_image,
    A,
    lambd=0.03,
    threshold_fn=soft_threshold,
    max_iter=1000,
    eps=1e-3,
)

plt.figure(figsize=(10, 6))
plt.subplot(1, 2, 1)
plt.title("Noised image.")
imshow(noisy_image)

plt.subplot(1, 2, 2)
plt.title("Inpainted image.")
imshow(inpainted_image)

plt.show()

plt.title("Loss.")
plt.plot(np.arange(0, len(loss)), loss)
plt.show()

## Discussion

The following section will discuss results obrained by varying algorithm parameters:

- the sparse respresentation (various wavelet orhtogonal transform and translation invariant wavelets)
- the thresholding rules (soft, hard, empirical Wiener)
- the choice of the $\lambda$ parameter
- the value of $p$ and the level of noise


### Orhtogonal transform and translation invariant wavelets

In [None]:
for p in (0.3, 0.5, 0.7):
    print(f"Generating problems for p={p}")
    noisy_image, A = generate_problem(image, p=p, sigma=0.01)

    eps = 1e-3
    max_iter = 1000

    # Possible filters: Daubechies, Coiflet, Symmlet
    i = 1
    fig = plt.figure(figsize=(10, 5))

    for filter_name, param in (("Daubechies", 4), ("Coiflet", 2)):
        filter_ = get_wavelet_filter(filter_name, param)
        plt.subplot(1, 3, i)
        image_rec, _ = image_inpainting(
            noisy_image,
            A,
            L=1, 
            lambd=0.02,
            filter_=filter_,
            threshold_fn=soft_threshold,
            max_iter=max_iter,
            eps=eps,
        )
        imshow(image_rec)
        plt.title(filter_name)
        i += 1

    plt.subplot(1, 3, 3)
    image_rec, _ = image_inpainting_transl_invar(
        noisy_image,
        A,
        lambd=0.02,
        threshold_fn=soft_threshold,
        max_iter=max_iter,
        eps=eps,
    )
    imshow(image_rec)
    plt.title("Translation invariant") 
    plt.show()

### Thresholding rules

In [None]:
noisy_image, A = generate_problem(image, p=0.5, sigma=0.01)
filter_ = get_wavelet_filter("Daubechies", 4)
max_iter = 1000
eps = 1e-2
lambd = 1e-2

fig, axs = plt.subplots(3, 2, figsize=(10, 10))
for i, threshold_fn in enumerate((soft_threshold, hard_threshold, empirical_weiner)):
    image_rec, _ = image_inpainting(
        noisy_image,
        A,
        L=1, 
        lambd=lambd,
        filter_=filter_,
        threshold_fn=threshold_fn,
        max_iter=max_iter,
        eps=eps,
    )
    axs[i][0].set_title(f"Orthogonal + {threshold_fn.__name__}")
    axs[i][0].imshow(image_rec, cmap='gray', interpolation='none')
    
    image_rec, _ = image_inpainting_transl_invar(
        noisy_image,
        A,
        lambd=lambd,
        threshold_fn=threshold_fn,
        max_iter=max_iter,
        eps=eps,
    )
    
    axs[i][1].set_title(f"Trans. inv. + {threshold_fn.__name__}")
    axs[i][1].imshow(image_rec, cmap='gray', interpolation='none')
    
plt.show()

### Choice of $\lambda$ parameter

In [None]:
noisy_image, A = generate_problem(image, p=0.5, sigma=0.01)
filter_ = get_wavelet_filter("Daubechies", 4)
max_iter = 1000
eps = 1e-2

fig, axs = plt.subplots(3, 2, figsize=(10, 10))
for i, lambd in enumerate((1e-1, 1e-2, 1e-3)):
    image_rec, _ = image_inpainting(
        noisy_image,
        A,
        L=1, 
        lambd=lambd,
        filter_=filter_,
        threshold_fn=soft_threshold,
        max_iter=max_iter,
        eps=eps,
    )
    axs[i][0].set_title(f"Orthogonal $\lambda$ = {lambd}")
    axs[i][0].imshow(image_rec, cmap='gray', interpolation='none')
    
    image_rec, _ = image_inpainting_transl_invar(
        noisy_image,
        A,
        lambd=lambd,
        threshold_fn=soft_threshold,
        max_iter=max_iter,
        eps=eps,
    )
    
    axs[i][1].set_title(f"Trans. inv. $\lambda$ = {lambd}")
    axs[i][1].imshow(image_rec, cmap='gray', interpolation='none')
    
plt.show()

### The value of $p$ and level of noise

In [None]:
# Test results for various values of p.
filter_ = get_wavelet_filter("Daubechies", 4)
max_iter = 1000
lambd = 1e-2
eps = 1e-2

fig, axs = plt.subplots(3, 2, figsize=(10, 10))
for i, p in enumerate((0.3, 0.5, 0.7)):
    noisy_image, A = generate_problem(image, p=p, sigma=0.01)
    
    image_rec, _ = image_inpainting(
        noisy_image,
        A,
        L=1, 
        lambd=lambd,
        filter_=filter_,
        threshold_fn=soft_threshold,
        max_iter=max_iter,
        eps=eps,
    )
    axs[i][0].set_title(f"Orthogonal $p$ = {p}")
    axs[i][0].imshow(image_rec, cmap='gray', interpolation='none')
    
    image_rec, _ = image_inpainting_transl_invar(
        noisy_image,
        A,
        lambd=lambd,
        threshold_fn=soft_threshold,
        max_iter=max_iter,
        eps=eps,
    )
    
    axs[i][1].set_title(f"Trans. inv. $p$ = {p}")
    axs[i][1].imshow(image_rec, cmap='gray', interpolation='none')
    
plt.show()

In [None]:
# Test results for various values of sigma.
filter_ = get_wavelet_filter("Daubechies", 4)
max_iter = 1000
lambd = 1e-2
eps = 1e-2

fig, axs = plt.subplots(3, 2, figsize=(10, 10))
for i, sigma in enumerate((0.01, 0.1, 0.5)):
    noisy_image, A = generate_problem(image, p=0.5, sigma=sigma)
    
    image_rec, _ = image_inpainting(
        noisy_image,
        A,
        L=1, 
        lambd=lambd,
        filter_=filter_,
        threshold_fn=soft_threshold,
        max_iter=max_iter,
        eps=eps,
    )
    axs[i][0].set_title(f"Orthogonal $\sigma$ = {sigma}")
    axs[i][0].imshow(image_rec, cmap='gray', interpolation='none')
    
    image_rec, _ = image_inpainting_transl_invar(
        noisy_image,
        A,
        lambd=lambd,
        threshold_fn=soft_threshold,
        max_iter=max_iter,
        eps=eps,
    )
    
    axs[i][1].set_title(f"Trans. inv. $\sigma$ = {sigma}")
    axs[i][1].imshow(image_rec, cmap='gray', interpolation='none')
    
plt.show()

## Conclusions
- Wavelet transformations allow to perform image inpainting whithout prior knowledge of the noise. 
- With well tuned parameters the transformation allows to generate very nice results when noise is not strong. The results are still nice when amount of lost samples is significant (>50%). 
- Translation invariant wavelets seem to generate better results however for some set of parameters it fails to reconstruct image. 
- Orthogonal wavelet with Coiflet filter seem to produce higher contrast images than Daubechies filter. 
- A try to use Hard Threshold for inpainting failed as image was not reconstructed. For warious values of $\lambda$ (0.9, 0.5) results were slightly better but still not comparable to other thresholding functions. 
- Selection of $\lambda$ parameter is crucial for the result quality, in the experiments above value of $1e-2$ worked the best however it is no universal and fails in some settings.
- For low walues of $p$ translation invariant wavelet failed to produce results although it perform very well for $p >= 0.5$.
- The level of noise has a great impact on the algorithm and when SNR ratio is low algorithm may fail to produce results.