# 0) Setup

In [None]:
import util
import math
import xarray as xr
import numpy as np

from blurs import *
from scipy.optimize import minimize

from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go

# 1) Load the original image

In [None]:
image = util.load_img('../samples/pumpkins.tif')

fig = px.imshow(image, color_continuous_scale='gray', title='Original Image', width=512, height=512)
fig.update_layout(coloraxis_showscale=False)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.show()

# 2) Create the point spread function

In [None]:
psf_params = [1, 1]
psf_fn = gaussian
psf = psf_fn(image.shape, *psf_params)

from scipy.fft import ifft2, fftshift
psf_img = np.real(fftshift(ifft2(psf)))

psf_fig = px.imshow(
    psf_img,
    color_continuous_scale='gray', 
    title='Point Spread Function',
    width=512, 
    height=512)
psf_fig.show()

# 3) Create the blurred image

In [None]:
image_blurred = blur(image, psf)

# optional: add noise
noise_factor = 0
image_blurred = noise(image_blurred, noise_factor)

fig = px.imshow(image_blurred, color_continuous_scale='gray', title='Blurred Image', width=512, height=512)
fig.update_layout(coloraxis_showscale=False)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.show()

# 4) Deblur the image
This assumes that the point source function that was used to generate the blurred image is known. If noise is present when constructing the blurred image, regularization is necessary to compute an approximation of the true image that is not dominated by noise.

This code uses a blurred image and point source function to deblur. The blurred image should be defined in real space, while the point source function in Fourier space.

## Tikhonov Regularization
This deblurring approach filters out values in the point source function according to a regularization parameter $\phi$, where $S$ is the point source function in Fourier space and $\alpha$ is a constant value:

$$\phi = \frac{|S|^2}{|S|^2 + \alpha^2}$$

In the code implementation, a few additional steps are taken to prevent issues with the deblurring attempt. Zero values may exist in $S$ when computing $\phi$, resulting in a divide-by-zero error. This can be remediated by slightly adjusting the zero values to a non-zero number; in this case, the smallest possible number that can be represented in a computer was chosen. The adjustment takes place before $\phi$ is computed, and is only applied to $S$ in the denominator of the equation.

In [None]:
def deblur_tik(alpha, S, blurred_img):
    # S is fft2(circshift(psf))

    # phi = abs(S).^2 ./ (abs(s).^2 + alpha^2)
    # perform adjustment to values which could result in division by zero
    denominator = np.absolute(np.square(S)) + np.square(alpha)
    denominator = np.where(denominator == 0, np.finfo(float).eps, denominator)
    phi = np.divide(np.absolute(np.square(S)), denominator)

    # filter out S values according to phi 
    S_filt = np.zeros(phi.shape)
    idx = S != 0
    S_filt[idx] = np.divide(phi[idx], S[idx])

    # deblur using filtered S
    X_filt = ifft2(np.multiply(fft2(blurred_img), S_filt)).real

    return X_filt

def deblur_tik_min(S):
    # error function; norm of deblurred image with blurred image
    error_function = lambda input, S, blurred_img: np.linalg.norm(deblur_tik(input[0], S, blurred_img) - image, ord='fro')

    initial_alpha = 0.1
    # since minimizer takes in a vector, need to get first element (which will store alpha)
    result = minimize(error_function, [initial_alpha], args=(S, image_blurred), bounds=[(0, 32)])

    # returns min alpha
    return result

def deblur_psf(psf_params, deblurred_img):
    S_i = linear(image.shape, *psf_params)
    return blur(deblurred_img, S_i)

def deblur_psf_min(deblurred_img, *initial_guess):
    # min_{length, angle} = ||A(length, angle) x - b||
    error_function = lambda psf_params, deblurred_img, blurred_img: np.linalg.norm(deblur_psf(psf_params, deblurred_img) - blurred_img, ord='fro')

    # specifting the solver as 'Powell' was important; otherwise it was stuck at wrong values
    result = minimize(error_function, initial_guess, args=(deblurred_img, image_blurred), method='Powell', bounds=[(0, 16), (0, 2*math.pi)])

    return result

In [None]:
# deblur image using various alpha regularization params
# also normalize image values to range [0, 1]
alphas = np.linspace(0, 1, 11)
image_sequence = [util.normalize(deblur_tik(alpha, psf, image_blurred)) for alpha in alphas]

# plot
fig = px.imshow(np.array(image_sequence), animation_frame=0, color_continuous_scale='gray', title='Deblurred Image', width=512, height=512,
                labels=dict(animation_frame="alpha step"))
fig.update_layout(coloraxis_showscale=False)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.show()

# 5) Deblur using Alternating Minimization Technique

In [None]:
prev_li = math.nan
li      = 10
prev_ai = math.nan
ai      = (math.pi / 4) + 0.0005

images = []

step = 1
# while not math.isclose(prev_li, li) and not math.isclose(prev_ai, ai):
# while not math.isclose(prev_li, li) and not math.isclose(prev_ai, ai):
for i in range(20):
    # fix length and angle for guess
    psfi = linear(image.shape, li, ai)
    min_alpha = deblur_tik_min(psfi).x[0]
    deblurred_img = deblur_tik(min_alpha, psfi, image_blurred)

    # fix image for guess
    next_A = deblur_psf_min(deblurred_img, li, ai)
    prev_li = li
    prev_ai = ai
    li = next_A.x[0]
    ai = next_A.x[1]

    print('i:', step, 'li:', li, 'ai:', ai)

    images.append(np.real(deblurred_img))
    step += 1

    if step > 10:
        break

print('Final li:', li)
print('Final ai:', ai)

fig = px.imshow(np.array(images), animation_frame=0, color_continuous_scale='gray', title='Deblurred Image', width=512, height=512,
                labels=dict(animation_frame="iteration"))
fig.update_layout(coloraxis_showscale=False)
fig.update_xaxes(showticklabels=False)
fig.update_yaxes(showticklabels=False)
fig.show()

# Plot of Relative Errors when Deblurring

In [None]:
# # number of samples to take
# num_samples = 20

# lengths = np.linspace(1, 20, num_samples)
# angles  = np.linspace(0, 2*np.pi, num_samples)
# errors  = []

# # error_fn = lambda psf_params: np.linalg.norm(deblur_psf(psf_params, deblurred_img) - image_blurred, ord='fro')

# # iterate over lengths
# for l in lengths:
#     row = []
#     for a in angles:
#         psfi = linear(image.shape, l, a)
#         min_alpha = deblur_tik_min(psfi).x[0]
#         deblurred_img = deblur_tik(min_alpha, psfi, image_blurred)

#         error = np.linalg.norm(blur(images[0], psfi) - image_blurred)
#         row.append(error)
#     errors.append(row)

fig = go.Figure(data=[go.Surface(z=errors, y=lengths, x=angles)])
fig.update_layout(title='Plot of Normalized Difference for PSF Parameters', autosize=False,
                  width=700, height=700,
                  scene = dict(
                    xaxis_title='Angle (rad)',
                    yaxis_title='Length (pixels)',
                    zaxis_title='Relative Error'))
fig.show()