In [3]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.fft import dctn, idctn
from skimage import data, color, img_as_float
import ipywidgets as widgets
from IPython.display import display


def zigzag_indices(n):
    """
    Generate a list of (row, col) indices for an n x n matrix in zigzag order.
    """
    indices = np.empty((n * n, 2), dtype=int)
    index = 0
    for s in range(2 * n - 1):
        if s < n:
            start_row = s
            start_col = 0
        else:
            start_row = n - 1
            start_col = s - n + 1
        diag_length = min(s, n - 1) - max(0, s - n + 1) + 1
        for i in range(diag_length):
            if s % 2 == 0:
                row = start_row - i
                col = start_col + i
            else:
                row = start_row - (diag_length - 1 - i)
                col = start_col + (diag_length - 1 - i)
            indices[index] = [row, col]
            index += 1
    return indices


def zigzag_flatten(matrix):
    """
    Flatten a 2D matrix in zigzag order.
    """
    n, m = matrix.shape
    if n != m:
        raise ValueError("The zigzag function currently supports only square matrices.")
    indices = zigzag_indices(n)
    # Extract the elements in the order of the zigzag indices
    return np.array([matrix[row, col] for row, col in indices])


def retain_lowest_k_coefficients(dct_patch, k):
    """
    Retain only the lowest k coefficients (according to zigzag order) in the DCT patch.
    The rest are set to zero.
    """
    n, m = dct_patch.shape
    if n != m:
        raise ValueError("Patch must be square.")
    indices = zigzag_indices(n)
    filtered = np.zeros_like(dct_patch)
    for idx in range(min(k, n * n)):
        row, col = indices[idx]
        filtered[row, col] = dct_patch[row, col]
    return filtered


def process_image(image, patch_size=32, k=50, random_mean=4096, random_std=500):
    """
    Divide the image into patches and for each patch compute the DCT.
    For each patch, create two versions:
      1. The filtered DCT (keeping only the lowest k coefficients).
      2. A random version: same as (1) but with the DC coefficient replaced
         by a random value drawn from a Gaussian distribution.
    Returns:
      - cropped_shape: (height, width) of the cropped image.
      - orig_dct_patches: list of filtered DCT patches.
      - random_dct_patches: list of patches with random DC coefficients.
    """
    # Crop image so dimensions are multiples of patch_size.
    height, width = image.shape
    height_crop = (height // patch_size) * patch_size
    width_crop = (width // patch_size) * patch_size
    image = image[:height_crop, :width_crop]

    orig_dct_patches = []
    random_dct_patches = []

    # Precompute zigzag indices for a patch.
    indices = zigzag_indices(patch_size)

    for i in range(0, height_crop, patch_size):
        for j in range(0, width_crop, patch_size):
            patch = image[i : i + patch_size, j : j + patch_size]
            # Compute the 2D DCT of the patch.
            patch_dct = dctn(patch, type=2, norm="ortho")
            # Retain only the lowest k coefficients.
            patch_dct_filtered = retain_lowest_k_coefficients(patch_dct, k)

            # Create the random version: copy the filtered coefficients,
            # then replace the DC coefficient (first in zigzag order) with a random value.
            patch_dct_random = np.zeros_like(patch_dct_filtered.copy())
            dc_index = indices[0]  # typically [0,0]
            patch_dct_random[dc_index[0], dc_index[1]] = np.random.normal(
                loc=random_mean, scale=random_std
            )

            orig_dct_patches.append(patch_dct_filtered)
            random_dct_patches.append(patch_dct_random)

    return (height_crop, width_crop), orig_dct_patches, random_dct_patches


def reconstruct_image_from_patches(patches, img_shape, patch_size=32):
    """
    Given a list of patches (in the DCT domain), perform the inverse DCT on each patch
    and reassemble the image.
    """
    height, width = img_shape
    reconstructed = np.zeros((height, width))
    patch_idx = 0
    for i in range(0, height, patch_size):
        for j in range(0, width, patch_size):
            patch = idctn(patches[patch_idx], type=2, norm="ortho")
            reconstructed[i : i + patch_size, j : j + patch_size] = patch
            patch_idx += 1
    return reconstructed


# Precompute a frequency weight map for a patch.
def compute_frequency_map(patch_size):
    """
    Compute a frequency magnitude map for a patch of size patch_size x patch_size.
    The frequency for coefficient (i, j) is computed as:
       f(i,j) = sqrt( (i*pi/(2*patch_size))^2 + (j*pi/(2*patch_size))^2 ).
    Returns the frequency map and its maximum value.
    """
    freq_map = np.zeros((patch_size, patch_size))
    for i in range(patch_size):
        for j in range(patch_size):
            f_i = (i * np.pi) / (2 * patch_size)
            f_j = (j * np.pi) / (2 * patch_size)
            freq_map[i, j] = np.sqrt(f_i**2 + f_j**2)
    f_max = freq_map.max()
    return freq_map, f_max


# --- Load and Preprocess Image ---
img_rgb = data.astronaut()
img_gray = color.rgb2gray(img_rgb)
# Scale to 0-255.
img_gray = img_as_float(img_gray) * 255.0

# Set parameters.
patch_size = 32
k = 512

# Process the image into patches.
cropped_shape, orig_dct_patches, random_dct_patches = process_image(
    img_gray, patch_size=patch_size, k=k
)

# Precompute the reconstructions from the two sets.
orig_recon = reconstruct_image_from_patches(
    orig_dct_patches, cropped_shape, patch_size=patch_size
)
random_recon = reconstruct_image_from_patches(
    random_dct_patches, cropped_shape, patch_size=patch_size
)

# Precompute frequency map for a patch.
freq_map, f_max = compute_frequency_map(patch_size)


def average_zigzag(patches, patch_size):
    """
    Given a list of DCT patches (each patch_size x patch_size), flatten them in zigzag order
    and return the average (mean) vector over all patches.
    """
    all_flat = np.array([zigzag_flatten(patch) for patch in patches])
    return all_flat.mean(axis=0)


def histogram_coefficients(patches):
    """
    Given a list of patches (each a 2D DCT array), return a flattened array of all coefficients.
    """
    all_coeff = np.hstack([patch.flatten() for patch in patches])
    return all_coeff


def count_nonzeros(patches):

    return np.sum([np.count_nonzero(p) for p in patches]) / np.sum([p.size for p in patches])


def update(lambda_val):
    """
    For a given global interpolation parameter lambda (0 to 1), interpolate between
    the original and random DCT patches elementwise. The interpolation is weighted
    by the (normalized) frequency at each coefficient so that higher frequencies interpolate
    faster than lower frequencies.

    Three images are reconstructed and displayed:
      - Left: Original reconstruction.
      - Middle: Interpolated reconstruction.
      - Right: Random reconstruction.
    """
    interp_patches = []
    # For each patch, compute an elementwise interpolation weight:
    # weight(i,j) = lambda * (freq_map(i,j)/f_max), clipped at 1.
    weight = np.clip(lambda_val * np.exp((freq_map / f_max)*10), 0, 1)
    # weight = lambda_val

    for orig_patch, random_patch in zip(orig_dct_patches, random_dct_patches):
        # Interpolate each coefficient using the frequency-dependent weight.
        # Note that weight is broadcasted over the patch shape.
        interp_patch = (1 - weight) * orig_patch + weight * random_patch
        interp_patches.append(interp_patch)

    interp_recon = reconstruct_image_from_patches(
        interp_patches, cropped_shape, patch_size=patch_size
    )

    # Compute average zigzag vectors for each set of patches.
    avg_orig = np.log(np.abs(average_zigzag(orig_dct_patches, patch_size))+1)
    avg_interp = np.log(np.abs(average_zigzag(interp_patches, patch_size))+1)
    avg_random = np.log(np.abs(average_zigzag(random_dct_patches, patch_size)) + 1)

    # Compute histograms (flattened DCT coefficients) for each case.
    hist_orig = histogram_coefficients(avg_orig)
    hist_interp = histogram_coefficients(avg_interp)
    hist_random = histogram_coefficients(avg_random)


    orig_nonzeros = count_nonzeros(orig_dct_patches)
    interp_nonzeros = count_nonzeros(interp_patches)
    random_nonzeros = count_nonzeros(random_dct_patches)



    # Create the figure with 3 rows and 3 columns.
    fig, axes = plt.subplots(3, 3, figsize=(18, 15))

    # Row 1: Spatial Reconstructions.
    axes[0, 0].imshow(orig_recon, cmap="gray", vmin=0, vmax=255)
    axes[0, 0].set_title("Original Reconstruction")
    axes[0, 0].axis("off")

    axes[0, 1].imshow(interp_recon, cmap="gray", vmin=0, vmax=255)
    axes[0, 1].set_title(f"Interpolated Reconstruction\nλ = {lambda_val:.2f}")
    axes[0, 1].axis("off")

    axes[0, 2].imshow(random_recon, cmap="gray", vmin=0, vmax=255)
    axes[0, 2].set_title("Random Reconstruction")
    axes[0, 2].axis("off")

    # Row 2: Average DCT Coefficients (Zigzag Order).
    x = np.arange(len(avg_orig))
    axes[1, 0].plot(x, avg_orig, marker="o", label="Original")
    axes[1, 0].set_title("Original DCT Coefficients (Zigzag)")
    axes[1, 0].set_xlabel("Zigzag Index")
    axes[1, 0].set_ylabel("Coefficient Value")

    axes[1, 1].plot(x, avg_interp, marker="o", color="green", label="Interpolated")
    axes[1, 1].set_title("Interpolated DCT Coefficients (Zigzag)")
    axes[1, 1].set_xlabel("Zigzag Index")
    axes[1, 1].set_ylabel("Coefficient Value")

    axes[1, 2].plot(x, avg_random, marker="o", color="red", label="Random")
    axes[1, 2].set_title("Random DCT Coefficients (Zigzag)")
    axes[1, 2].set_xlabel("Zigzag Index")
    axes[1, 2].set_ylabel("Coefficient Value")

    # # Row 3: Number of non-zero coefficients.
    labels = ['Original', 'Interpolated', 'Random']
    nonzero_counts = [orig_nonzeros, interp_nonzeros, random_nonzeros]

    axes[2, 0].bar(labels, nonzero_counts, color=['blue', 'green', 'red'])
    axes[2, 0].set_title("Number of Non-Zero DCT Coefficients")
    axes[2, 0].set_ylabel("Count")
    
    axes[2, 1].bar(labels, nonzero_counts, color=['blue', 'green', 'red'])
    axes[2, 1].set_title("Number of Non-Zero DCT Coefficients")
    axes[2, 1].set_ylabel("Count")

    axes[2, 2].bar(labels, nonzero_counts, color=['blue', 'green', 'red'])
    axes[2, 2].set_title("Number of Non-Zero DCT Coefficients")
    axes[2, 2].set_ylabel("Count")



    plt.tight_layout()
    plt.show()


# Create an interactive slider for lambda between 0 and 1.
slider = widgets.FloatSlider(value=0.0, min=0.0, max=1.0, step=0.01, description="λ:")
widgets.interact(update, lambda_val=slider)

interactive(children=(FloatSlider(value=0.0, description='λ:', max=1.0, step=0.01), Output()), _dom_classes=('…

<function __main__.update(lambda_val)>

In [3]:
!pip install ipywidgets

Collecting ipywidgets
  Downloading ipywidgets-8.1.6-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Downloading widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.14 (from ipywidgets)
  Downloading jupyterlab_widgets-3.0.14-py3-none-any.whl.metadata (4.1 kB)
Downloading ipywidgets-8.1.6-py3-none-any.whl (139 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.8/139.8 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jupyterlab_widgets-3.0.14-py3-none-any.whl (213 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.0/214.0 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading widgetsnbextension-4.0.14-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[?25hInstalling collected packages: widgetsnbextension, jupyterlab_widgets, ipywidgets
