# CPU 2D Median Filter - Understanding Neighbourhood Operations
This introduces kernels that look at multiple pixels.

Median filter is excellent for removing "salt and pepper" noise while preserving edges.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

## Functions for Median Filter

In [None]:
def calculate_median_explicit(window):
    """
    Identical to the 1D case.
    
    For 2D, window is a flattened array of neighbourhood values.
    """
    # make a copy to avoid modifying original
    sorted_window = window.copy()
    n = len(sorted_window)
    
    # bubble sort - simple and sufficient for small windows
    # this directly translates to CUDA C
    for i in range(n):
        for j in range(0, n - i - 1):
            if sorted_window[j] > sorted_window[j + 1]:
                # swap elements
                temp = sorted_window[j]
                sorted_window[j] = sorted_window[j + 1]
                sorted_window[j + 1] = temp
    
    # find middle element (we ensure odd window sizes)
    middle_index = n // 2
    return sorted_window[middle_index]

In [None]:
def median_filter_2d(image, window_size=3, quiet=False):
    """
    Apply median filter to 2D image.
    Window size should be odd (3, 5, 7, etc.).
    
    Uses "valid" mode - output is smaller than input.
    """
    # validate inputs
    if image.size == 0:
        return np.array([[]])
    
    # ensure window size is odd
    if window_size % 2 == 0:
        window_size += 1
        print(f"Window size must be odd, using {window_size}")
    
    rows, cols = image.shape
    half_window = window_size // 2
    
    # output is smaller due to "valid" mode (no padding)
    out_rows = rows - 2 * half_window
    out_cols = cols - 2 * half_window
    
    if out_rows <= 0 or out_cols <= 0:
        raise ValueError(f"Image too small for {window_size}x{window_size} filter.")
    
    result = np.zeros((out_rows, out_cols), dtype=image.dtype)
    
    if not quiet:
        print(f"Applying {window_size}x{window_size} median filter")
        print(f"Input size: {rows}×{cols}")
        print(f"Output size: {out_rows}x{out_cols} (shrinks by {2*half_window} pixels on each side)\n")
    
    # process each valid pixel (those with complete neighbourhoods)
    for row in range(half_window, rows - half_window):
        for col in range(half_window, cols - half_window):
            # extract neighbourhood into flat array
            window = np.zeros(window_size * window_size, dtype=image.dtype)
            idx = 0
            
            for wr in range(-half_window, half_window + 1):
                for wc in range(-half_window, half_window + 1):
                    window[idx] = image[row + wr, col + wc]
                    idx += 1
            
            # calculate median using explicit operations
            median_val = calculate_median_explicit(window)
            
            # store in output (adjust indices for smaller size)
            result[row - half_window, col - half_window] = median_val
            
            # show details for first few pixels
            if not quiet and row == half_window and col < half_window + 3:
                print(f"Pixel ({row},{col}):")
                print(f"  Neighbourhood: {window}")
                print(f"  Sorted: {np.sort(window)}")
                print(f"  Median: {median_val}\n")
    
    return result

## Function to Visualise Filtered Image

In [None]:
def visualise_median_filter_2d(original, filtered, noisy=None):
    """Visualise the effect of median filtering on 2D images."""
    num_plots = 2 if noisy is None else 3
    fig, ax = plt.subplots(1, num_plots, figsize=(5*num_plots, 5))
    
    # ensure ax is always a list
    if num_plots == 2:
        ax = [ax[0], ax[1]]
    
    # original image
    im0 = ax[0].matshow(original, cmap='gray', vmin=0, vmax=255)
    ax[0].set_title("Original Image")
    ax[0].set_xlabel("Column")
    ax[0].set_ylabel("Row")
    plt.colorbar(im0, ax=ax[0], fraction=0.046)
    
    if noisy is not None:
        ax_filtered = ax[2]
        # noisy image
        im1 = ax[1].matshow(noisy, cmap='gray', vmin=0, vmax=255)
        ax[1].set_title("Image with Noise")
        ax[1].set_xlabel("Column")
        ax[1].set_ylabel("Row")
        plt.colorbar(im1, ax=ax[1], fraction=0.046)
    else:
        ax_filtered = ax[1]

    # filtered image
    im2 = ax_filtered.matshow(filtered, cmap='gray', vmin=0, vmax=255)
    ax_filtered.set_title("After Median Filter")
    ax_filtered.set_xlabel("Column")
    ax_filtered.set_ylabel("Row")
    plt.colorbar(im2, ax=ax_filtered, fraction=0.046)
    
    plt.tight_layout()
    plt.show()

## Removing Salt and Pepper Noise

In [None]:
# create test image with uniform regions
clean_image = np.ones((10, 12), dtype=np.float32) * 128

# add some structure
clean_image[2:5, 2:6] = 200  # bright rectangle
clean_image[6:9, 7:11] = 50  # dark rectangle

# add salt and pepper noise
noisy_image = clean_image.copy()
rng = np.random.default_rng()

# add salt (white spots)
salt_mask = rng.random(clean_image.shape) < 0.05
noisy_image[salt_mask] = 255

# add pepper (black spots)
pepper_mask = rng.random(clean_image.shape) < 0.05
noisy_image[pepper_mask] = 0

print("Clean image:")
print(clean_image)
print("\nNoisy image (with salt & pepper):")
print(noisy_image)

In [None]:
# apply median filter
filtered = median_filter_2d(noisy_image, window_size=3)
print("\nFiltered image:")
print(filtered.astype(int))

In [None]:
visualise_median_filter_2d(clean_image, filtered, noisy=noisy_image)

## Smoothing Noisy Gradient Image

In [None]:
# create gradient image
rows, cols = 20, 30
x = np.linspace(0, 1, cols)
y = np.linspace(0, 1, rows)
xx, yy = np.meshgrid(x, y)

# create smooth pattern
clean_gradient = (np.sin(5 * xx) * np.cos(3 * yy) + 1) * 127.5
clean_gradient = clean_gradient.astype(np.float32)

# add noise
noise = rng.normal(0, 20, clean_gradient.shape)
noisy_gradient = np.clip(clean_gradient + noise, 0, 255).astype(np.float32)

### Filter with Different Window Sizes

In [None]:
fig, ax = plt.subplots(2, 3, figsize=(15, 10))

# original and noisy
im00 = ax[0, 0].matshow(clean_gradient, cmap='gray', vmin=0, vmax=255)
ax[0, 0].set_title("Original")
plt.colorbar(im00, ax=ax[0, 0], fraction=0.046)

im01 = ax[0, 1].matshow(noisy_gradient, cmap='gray', vmin=0, vmax=255)
ax[0, 1].set_title("With Noise")
plt.colorbar(im01, ax=ax[0, 1], fraction=0.046)

# apply different filter sizes
window_sizes = (3, 5, 7)
ax_pos = [(0, 2), (1, 0), (1, 1)]

for ws, (r, c) in zip(window_sizes, ax_pos):
    filtered = median_filter_2d(noisy_gradient, window_size=ws, quiet=True)
    im = ax[r, c].matshow(filtered, cmap='gray', vmin=0, vmax=255)
    ax[r, c].set_title(f"{ws}x{ws} Median Filter")
    plt.colorbar(im, ax=ax[r, c], fraction=0.046)

# hide unused subplot
ax[1, 2].set_visible(False)

# labels for visible plots
for a in ax.flat:
    if a.get_visible():
        a.set_xlabel("Column")
        a.set_ylabel("Row")

plt.tight_layout()
plt.show()

## Handling of Edge
Different padding strategies for complete image filtering.

In [None]:
def median_filter_with_padding_2d(image, window_size=3, padding="edge"):
    """Median filter with different padding strategies."""
    half_window = window_size // 2
    
    if padding == "edge":
        # repeat edge values
        padded = np.pad(image, half_window, mode="edge")
    elif padding == "reflect":
        # mirror values at edges
        padded = np.pad(image, half_window, mode="reflect")
    elif padding == "zero":
        # pad with zeros
        padded = np.pad(image, half_window, mode="constant", 
                       constant_values=0)
    else:
        raise ValueError(f"Unsupported padding strategy '{padding}'.")
    
    rows, cols = image.shape
    result = np.zeros((rows, cols), dtype=image.dtype)
    
    # now we can process ALL pixels in the original image
    for row in range(rows):
        for col in range(cols):
            # extract window from padded image
            window = np.zeros(window_size * window_size, dtype=image.dtype)
            idx = 0
            
            for wr in range(window_size):
                for wc in range(window_size):
                    # adjust for padding offset
                    window[idx] = padded[row + wr, col + wc]
                    idx += 1
            
            result[row, col] = calculate_median_explicit(window)
    
    return result

In [None]:
# create edge test image
edge_test = np.array([
    [200, 200, 0,   0,   100],
    [200, 200, 0,   0,   100],
    [50,  50,  255, 150, 150],
    [50,  50,  255, 150, 150]
], dtype=np.float32)

print("Test image with strong edges:")
print(edge_test)

# compare different padding strategies
fig, ax = plt.subplots(2, 2, figsize=(10, 10))

# original
im0 = ax[0, 0].matshow(edge_test, cmap='gray', vmin=0, vmax=255)
ax[0, 0].set_title("Original")
plt.colorbar(im0, ax=ax[0, 0], fraction=0.046)

# different padding strategies
paddings = ("edge", "reflect", "zero")
ax_pos = [(0, 1), (1, 0), (1, 1)]

for padding, (r, c) in zip(paddings, ax_pos):
    filtered = median_filter_with_padding_2d(edge_test, window_size=3, padding=padding)
    im = ax[r, c].matshow(filtered, cmap='gray', vmin=0, vmax=255)
    ax[r, c].set_title(f"Padding: {padding}")
    plt.colorbar(im, ax=ax[r, c], fraction=0.046)
    print(f"\n{padding} padding result:")
    print(filtered)

for a in ax.flat:
    a.set_xlabel("Column")
    a.set_ylabel("Row")

plt.tight_layout()
plt.show()

## Visualising Padding explicitly

In [None]:
def visualise_padding(image, padding_mode, half_window=1):
    """Show how padding is done step-by-step."""
    rows, cols = image.shape
    padded_rows = rows + 2 * half_window
    padded_cols = cols + 2 * half_window
    padded = np.zeros((padded_rows, padded_cols), dtype=image.dtype)
    
    # copy original to center
    for r in range(rows):
        for c in range(cols):
            padded[r + half_window, c + half_window] = image[r, c]
    
    # apply padding based on mode
    if padding_mode == "edge":
        # top and bottom
        for pr in range(half_window):
            for pc in range(padded_cols):
                src_col = max(0, min(cols - 1, pc - half_window))
                padded[pr, pc] = image[0, src_col]
                padded[padded_rows - 1 - pr, pc] = image[rows - 1, src_col]
                
        # left and right
        for pr in range(half_window, padded_rows - half_window):
            src_row = pr - half_window
            for pc in range(half_window):
                padded[pr, pc] = image[src_row, 0]
                padded[pr, padded_cols - 1 - pc] = image[src_row, cols - 1]
    
    elif padding_mode == "reflect":
        for pr in range(padded_rows):
            for pc in range(padded_cols):
                if (pr >= half_window and pr < padded_rows - half_window and
                    pc >= half_window and pc < padded_cols - half_window):
                    continue
                src_row = pr - half_window
                src_col = pc - half_window
                if src_row < 0:
                    src_row = -src_row
                elif src_row >= rows:
                    src_row = 2 * rows - src_row - 2
                if src_col < 0:
                    src_col = -src_col
                elif src_col >= cols:
                    src_col = 2 * cols - src_col - 2
                src_row = max(0, min(rows - 1, src_row))
                src_col = max(0, min(cols - 1, src_col))
                padded[pr, pc] = image[src_row, src_col]
    
    return padded

In [None]:
# create small example
small_test = np.array([[1, 2, 3], 
                       [4, 5, 6], 
                       [7, 8, 9]], 
                      dtype=np.float32)

print("Original 3x3 image:")
print(small_test)

# show padding for each mode
fig, ax = plt.subplots(1, 3, figsize=(15, 5))

for idx, mode in enumerate(("zero", "edge", "reflect")):
    if mode == "zero":
        # for zero padding, just create zero array and copy center
        padded = np.zeros((5, 5))
        padded[1:4, 1:4] = small_test
    else:
        padded = visualise_padding(small_test, mode, half_window=1)
    
    im = ax[idx].matshow(padded, cmap='viridis')
    ax[idx].set_title(f"{mode} padding")
    
    # annotate values
    for i in range(5):
        for j in range(5):
            val = int(padded[i, j])
            color = 'white' if padded[i, j] > 4 else 'black'
            
            # mark padding pixels
            if i == 0 or i == 4 or j == 0 or j == 4:
                ax[idx].text(j, i, f'{val}', ha='center', va='center',
                           color=color, fontsize=12, weight='bold',
                           bbox=dict(boxstyle='round,pad=0.3',
                                     facecolor='red', alpha=0.3))
            else:
                ax[idx].text(j, i, f'{val}', ha='center', va='center',
                             color=color, fontsize=12)
    
    ax[idx].set_xlabel("Column")
    ax[idx].set_ylabel("Row")
    ax[idx].set_xticks(range(5))
    ax[idx].set_yticks(range(5))
    plt.colorbar(im, ax=ax[idx], fraction=0.046)

plt.tight_layout()
plt.show()