# Coding Task: Image Deblurring with Spectral Filters

-  implement TSVD regularization: $x_k = \sum_{i=1}^k \frac{u_i^T b}{\sigma_i} v_i$ and plot the reconstructions for different $k$.
-  implement Tikhonov regularization: $x_\delta = \sum_{i=1}^r \frac{\sigma_i}{\sigma_i^2 + \delta} (u_i^T b) v_i$ and plot the reconstructions for different  $\delta$.

In [5]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import convolve2d
from skimage import color, data
from scipy.fftpack import fft2, ifft2, fftshift

In [None]:
# ----------- This code is already written for you -----------
def gaussian_kernel(size, sigma):
    """
    Creates a 2D Gaussian kernel.
    This is used to simulate the blurring process.
    """
    size = int(size) // 2
    x, y = np.mgrid[-size:size+1, -size:size+1]
    g = np.exp(-(x**2 + y**2) / (2.0 * sigma**2))
    return g / g.sum()

def get_padded_kernel_and_fft(image_shape, kernel):
    """
    Pads the kernel to match the image size and computes its 2D FFT.
    Padding and shifting are crucial for using FFT for convolution.
    The FFT of the kernel, F_kernel, represents the singular values (sigma_i)
    of the convolution operator in the frequency domain.
    """
    img_h, img_w = image_shape
    ker_h, ker_w = kernel.shape
    
    # Create a padded kernel with the same dimensions as the image
    padded_kernel = np.zeros(image_shape)
    
    # Calculate the starting indices to center the kernel
    h_start = (img_h - ker_h) // 2
    w_start = (img_w - ker_w) // 2
    
    # Place the kernel in the center
    padded_kernel[h_start:h_start+ker_h, w_start:w_start+ker_w] = kernel
    
    # Shift the kernel so its center is at (0,0) for FFT properties
    # and compute the 2D FFT.
    return fft2(fftshift(padded_kernel))

In [None]:
# --- YOU need to implement the two deblurring (spectral filtering) methods: (1) Tikhonov (2) TSVD deblurring ---

def tikhonov_deblur(F_blurred_noisy, F_kernel, delta):
    """
    Performs Tikhonov deblurring in the frequency domain.

    Args:
        F_blurred_noisy (np.ndarray): FFT of the noisy, blurred image (F(b)).
        F_kernel (np.ndarray): FFT of the padded kernel (F(H) or sigma_i).
        delta (float): The Tikhonov regularization parameter.

    Returns:
        np.ndarray: The deblurred image (real part).

    Task:
        Implement the Tikhonov filter formula in the frequency domain.
        The formula x_delta = sum(sigma_i / (sigma_i^2 + delta) * (u_i^T b) * v_i)
        is equivalent to the Wiener filter in the Fourier domain:
        F(x_delta) = [F(H)* / (|F(H)|^2 + delta)] * F(b)
        where F(H)* is the complex conjugate of F(H).
    """
    F_kernel_conj = np.conj(F_kernel)
    F_kernel_abs_sq = np.abs(F_kernel)**2 # This is |F(H)|^2

    # --- TODO ---
    # 1. Calculate the Tikhonov filter 'F_tikhonov'.
    #    Use F_kernel_conj, F_kernel_abs_sq, and delta.
    # F_tikhonov = ...

    # 2. Apply the filter to the F_blurred_noisy.
    # F_deblurred = ...
    
    # Placeholder - Replace this with your implementation
    F_deblurred = np.zeros_like(F_blurred_noisy) 
    # --- END TODO ---

    # Inverse FFT to get the image back to spatial domain
    return np.real(ifft2(F_deblurred))

def tsvd_deblur(F_blurred_noisy, F_kernel, k):
    """
    Performs Truncated SVD (TSVD) deblurring in the frequency domain.

    Args:
        F_blurred_noisy (np.ndarray): FFT of the noisy, blurred image (F(b)).
        F_kernel (np.ndarray): FFT of the padded kernel (F(H) or sigma_i).
        k (int): The number of largest singular values (frequencies) to keep.

    Returns:
        np.ndarray: The deblurred image (real part).

    Task:
        Implement the TSVD filter in the frequency domain.
        The formula x_k = sum_i=1^k (1 / sigma_i * (u_i^T b) * v_i)
        is equivalent to:
        F(x_k) = [1 / F(H) if |F(H)| >= threshold else 0] * F(b)
        
        Steps:
        1. Calculate the magnitudes |F_kernel|.
        2. Find the threshold value: the k-th largest magnitude.
        3. Create an inverse filter F_inv_kernel = 1 / F_kernel.
           (Hint: Use np.conj(F_kernel) / (np.abs(F_kernel)**2 + epsilon) for stability).
        4. Apply truncation: Set elements in F_inv_kernel to 0 
           where |F_kernel| is smaller than your threshold.
        5. Apply the truncated filter to F_blurred_noisy.
    """
    F_kernel_abs = np.abs(F_kernel)

    # --- TODO ---
    # 1. Determine the threshold for the k-th largest singular value.
    #    Handle edge cases for k (e.g., k=0 or k >= total number of pixels).
    #    Hint: Use np.sort(F_kernel_abs.flatten())[-k] for valid k.
    # threshold = ...
    
    # 2. Calculate the naive inverse filter (1 / F_kernel).
    #    Use F_kernel_conj / (|F_kernel|^2 + epsilon) for stability.
    # epsilon = 1e-16 # Small number
    # F_inv_kernel = ... 
    
    # 3. Apply the truncation based on the threshold.
    # F_inv_kernel[F_kernel_abs < threshold] = 0
    
    # 4. Apply the filter.
    # F_deblurred = ...
    
    # Placeholder - Replace this with your implementation
    F_deblurred = np.zeros_like(F_blurred_noisy)
    # --- END TODO ---

    # Inverse FFT to get the image back
    return np.real(ifft2(F_deblurred))


In [None]:
# ----------- This code is already written for you -----------
# 1. Load an Image
original_image = color.rgb2gray(data.astronaut())
original_image = original_image[50:250, 150:350] # Crop

# 2. Create a Blur Kernel
kernel_size = 11
kernel_sigma = 4.0
kernel = gaussian_kernel(kernel_size, kernel_sigma)

# 3. Simulate Blurring and Noise
blurred_image = convolve2d(original_image, kernel, 'same', boundary='wrap')
noise_level = 0.02
np.random.seed(42)
noisy_blurred_image = np.clip(blurred_image + noise_level * np.random.randn(*blurred_image.shape), 0, 1)

# Precompute FFTs
F_noisy_blurred = fft2(noisy_blurred_image)
F_kernel = get_padded_kernel_and_fft(original_image.shape, kernel)

# --- Tikhonov ---
print("--- Calculating Tikhonov Deblurring ---")
delta_values = np.logspace(-10, 10, 38)
tikhonov_images = []

for dlt in delta_values:
    deblurred = tikhonov_deblur(F_noisy_blurred, F_kernel, dlt)
    tikhonov_images.append(deblurred)

# --- TSVD ---
print("\n--- Calculating TSVD Deblurring ---")
total_pixels = original_image.size
k_values = np.logspace(np.log10(100), np.log10(total_pixels), 38, dtype=int)
k_values = np.unique(k_values)
tsvd_images = []

for k in k_values:
    deblurred = tsvd_deblur(F_noisy_blurred, F_kernel, k)
    tsvd_images.append(deblurred)

# --- Plotting ---

# 1. Display Tikhonov Image Grid (5x8 Layout)
print("\nPlotting Tikhonov image grid...")
fig, axes = plt.subplots(nrows=5, ncols=8, figsize=(20, 12.5))
ax = axes.ravel()
ax[0].imshow(original_image, cmap='gray')
ax[0].set_title("Original")
ax[0].axis('off')
ax[1].imshow(noisy_blurred_image, cmap='gray')
ax[1].set_title(f"Blurred")
ax[1].axis('off')

for i, (dlt, img) in enumerate(zip(delta_values, tikhonov_images)):
    # Ensure we don't plot more than available axes
    if i + 2 < len(ax):
        ax[i+2].imshow(img, cmap='gray')
        title_str = f"$\delta$={dlt:.1e}"
        ax[i+2].set_title(title_str, fontsize=8)
        ax[i+2].axis('off')

# Hide any remaining unused axes
for j in range(i + 3, len(ax)):
    fig.delaxes(ax[j])


plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.suptitle("Tikhonov Deblurring with Varying $\delta$", fontsize=20)
plt.show()

# 2. Display TSVD Image Grid (5x8 Layout)
print("\nPlotting TSVD image grid...")
fig, axes = plt.subplots(nrows=5, ncols=8, figsize=(20, 12.5))
ax = axes.ravel()
ax[0].imshow(original_image, cmap='gray')
ax[0].set_title("Original")
ax[0].axis('off')
ax[1].imshow(noisy_blurred_image, cmap='gray')
ax[1].set_title(f"Blurred")
ax[1].axis('off')

num_to_plot_tsvd = min(len(k_values), len(ax) - 2)
for i in range(num_to_plot_tsvd):
    k_val = k_values[i]
    img = tsvd_images[i]
    ax[i+2].imshow(img, cmap='gray')
    title_str = f"k={k_val}"
    ax[i+2].set_title(title_str, fontsize=8)
    ax[i+2].axis('off')

# Hide unused subplots
for j in range(num_to_plot_tsvd + 2, len(ax)):
    fig.delaxes(ax[j])

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.suptitle("TSVD Deblurring with Varying k", fontsize=20)
plt.show()

print("\nDone.")