In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.ndimage as ndi

def load_image():
    # For demonstration purposes, we will create a synthetic image with noise.
    x = np.linspace(-3, 3, 400)
    y = np.linspace(-3, 3, 400)
    X, Y = np.meshgrid(x, y)
    image = np.exp(-X**2 - Y**2)
    noisy_image = image + 0.2 * np.random.randn(*image.shape)
    
    return noisy_image, image

def total_variation_denoising(noisy_image, lambda_weight, n_iterations=100, dt=0.1):
    """
    Total Variation Denoising based on the paper's method.
    
    Parameters:
    - noisy_image: Input image with noise
    - lambda_weight: Regularization weight
    - n_iterations: Number of iterations to run the algorithm
    - dt: Time step for the update
    
    Returns:
    - denoised_image: Image after denoising
    """
    
    denoised_image = np.copy(noisy_image)
    
    for i in range(n_iterations):
        # Calculate the gradient of the image
        gradient_x, gradient_y = np.gradient(denoised_image)
        
        # Normalize the gradient
        norm = np.maximum(1e-10, np.sqrt(gradient_x**2 + gradient_y**2))
        normalized_gradient_x = gradient_x / norm
        normalized_gradient_y = gradient_y / norm
        
        # Calculate the divergence of the normalized gradient
        div_x, _ = np.gradient(normalized_gradient_x)
        _, div_y = np.gradient(normalized_gradient_y)
        divergence = div_x + div_y
        
        # Update the image
        denoised_image = denoised_image + dt * (noisy_image - denoised_image + lambda_weight * divergence)
        
    return denoised_image

# Load the image
noisy_image, original_image = load_image()

# Apply the denoising algorithm
lambda_weight = 10
denoised_image = total_variation_denoising(noisy_image, lambda_weight)

# Plot the results
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(original_image, cmap='gray')
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(noisy_image, cmap='gray')
plt.title('Noisy Image')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(denoised_image, cmap='gray')
plt.title('Denoised Image')
plt.axis('off')

plt.tight_layout()
plt.show()
