# CPU 2D Pixel Binning - Data Reduction for Images

This demonstrates pixel binning by summing adjacent pixels in 2D blocks. In 2D, we typically use square bins (2x2, 3x3, etc.) where each output pixel is the sum of a block of input pixels.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def bin_pixels_2d(image, bin_size):
    """Bin 2D image data by summing adjacent pixels in square blocks."""
    # validate input
    if bin_size <= 0:
        raise ValueError("bin_size must be positive.")
    
    rows, cols = image.shape
    # calculate number of output bins (round up for partial bins)
    n_bins_y = (rows + bin_size - 1) // bin_size
    n_bins_x = (cols + bin_size - 1) // bin_size
    
    # pre-allocate output array with same dtype as input
    result = np.zeros((n_bins_y, n_bins_x), dtype=image.dtype)
    
    # process each output pixel (bin)
    for bin_row in range(n_bins_y):
        for bin_col in range(n_bins_x):
            # sum pixels belonging to this bin
            bin_sum = np.float32(0.0)
            
            for dy in range(bin_size):
                for dx in range(bin_size):
                    # calculate source pixel position
                    src_row = bin_row * bin_size + dy
                    src_col = bin_col * bin_size + dx
                    
                    # boundary check - critical to avoid processing out-of-bound data
                    if src_row < rows and src_col < cols:
                        bin_sum += image[src_row, src_col]
            
            # store result for this bin
            result[bin_row, bin_col] = bin_sum
    
    return result

In [None]:
def demonstrate_binning_2d(image, bin_size):
    """Walk through the 2D pixel binning algorithm."""
    if bin_size <= 0:
        raise ValueError("bin_size must be positive.")
    
    print(f"\nDemonstrating 2D binning algorithm:")
    print(f"Input image shape: {image.shape}")
    print(f"Bin size: {bin_size}x{bin_size}")
    
    rows, cols = image.shape
    n_bins_y = (rows + bin_size - 1) // bin_size
    n_bins_x = (cols + bin_size - 1) // bin_size
    print(f"Output shape: {n_bins_y}x{n_bins_x}")
    
    result = bin_pixels_2d(image, bin_size)
    
    # show summing for first few "threads" (output pixels)
    demo_bins = min(3, n_bins_y * n_bins_x)
    thread_id = 0
    
    for bin_row in range(n_bins_y):
        for bin_col in range(n_bins_x):
            if thread_id >= demo_bins:
                break
                
            print(f"\nThread {thread_id} (computing bin [{bin_row},{bin_col}]):")
            
            # show which pixels this thread sums
            elements = []
            for dy in range(bin_size):
                for dx in range(bin_size):
                    src_row = bin_row * bin_size + dy
                    src_col = bin_col * bin_size + dx
                    if src_row < rows and src_col < cols:
                        elements.append(f"{image[src_row, src_col]:.0f}")
            
            src_row_start = bin_row * bin_size
            src_row_end = min((bin_row + 1) * bin_size - 1, rows - 1)
            src_col_start = bin_col * bin_size
            src_col_end = min((bin_col + 1) * bin_size - 1, cols - 1)
            
            print(f"  Sums pixels from rows {src_row_start}-{src_row_end}, "
                  f"cols {src_col_start}-{src_col_end}")
            print(f"  Values: {' + '.join(elements)} = {result[bin_row, bin_col]:.0f}")
            
            thread_id += 1
    
    print(f"\nTotal intensity preserved: "
          f"{np.sum(image):.0f} -> {np.sum(result):.0f}")
    
    return result

In [None]:
# create test image
pixel_data = np.array([
    [10, 12, 8,  15, 20, 18],
    [5,  7,  9,  11, 13, 14],
    [22, 24, 26, 28, 30, 32],
    [15, 17, 19, 21, 23, 25]
], dtype=np.float32)

print("Test image:")
print(pixel_data)

In [None]:
binned = demonstrate_binning_2d(pixel_data, bin_size=2)
print("\nBinned result:")
print(binned)

## Simulate Noisy Detector Data
With Poisson noise (photon counting statistics).

We will compute a signal-to-noise ratio (SNR), which is a common metric to estimate how noisy the data is. The higher the value of SNR, the 'cleaner' the data are.

### Generate the 2D Detector Data

In [None]:
# create 2D detector image with diffraction rings
detector_size = 100
y, x = np.ogrid[:detector_size, :detector_size]
center = detector_size // 2

# distance from center
r = np.sqrt((x - center)**2 + (y - center)**2)

# create background
rng = np.random.default_rng()
background = rng.poisson(lam=5, size=(detector_size, detector_size)).astype(np.float32)

# add rings at different radii
signal = background.copy()
ring_radii = (15, 25, 35)
ring_width = 3
for radius in ring_radii:
    # create ring mask
    ring_mask = (r >= radius - ring_width/2) & (r <= radius + ring_width/2)
    # add photons to ring
    ring_photons = rng.poisson(lam=50, size=signal.shape)
    signal[ring_mask] += ring_photons[ring_mask].astype(np.float32)

# add spots
peak_positions = [(30, 30), (30, 70), (70, 30), (70, 70)]
for py, px in peak_positions:
    signal[py-1:py+2, px-1:px+2] += rng.poisson(lam=200, size=(3, 3))

### Compare Different Binning Sizes

In [None]:
bin_factors = (1, 2, 4, 8)
fig, ax = plt.subplots(2, 4, figsize=(16, 8))

# flatten ax array for easier indexing
ax_flat = ax.flatten()

# store SNR results for summary
snr_results = []

for idx, bin_size in enumerate(bin_factors):
    if bin_size == 1:
        binned = signal
        title = "Original Detector Data\n(no binning)"
    else:
        binned = bin_pixels_2d(signal, bin_size)
        title = f"Binned {bin_size}x{bin_size}\n({binned.shape[0]}x{binned.shape[1]} pixels)"
    
    # display image
    im = ax_flat[idx].matshow(binned, cmap='gray')
    ax_flat[idx].set_title(title)
    ax_flat[idx].set_xlabel("Column")
    ax_flat[idx].set_ylabel("Row")
    plt.colorbar(im, ax=ax_flat[idx], fraction=0.046)
    
    # statistics
    total = np.sum(binned)
    peak = np.max(binned)
    
    # estimate background from corners (typically low signal)
    corner_size = 10 // bin_size if bin_size < 10 else 1
    corners = []
    if corner_size > 0:
        corners.extend(binned[:corner_size, :corner_size].flatten())
        corners.extend(binned[:corner_size, -corner_size:].flatten())
        corners.extend(binned[-corner_size:, :corner_size].flatten())
        corners.extend(binned[-corner_size:, -corner_size:].flatten())
    else:
        corners = [binned[0,0], binned[0,-1], binned[-1,0], binned[-1,-1]]
    
    background_per_pixel = np.mean(corners)
    
    # calculate SNR for bright regions
    # find pixels significantly above background
    signal_mask = binned > background_per_pixel * 2
    if np.any(signal_mask):
        signal_pixels = binned[signal_mask]
        # SNR = (signal - background) / sqrt(total counts) for Poisson
        snr_values = (signal_pixels - background_per_pixel) / np.sqrt(signal_pixels)
        mean_snr = np.mean(snr_values)
        max_snr = np.max(snr_values)
    else:
        mean_snr = max_snr = 0
    
    snr_results.append((bin_size, mean_snr, max_snr))
    
    # display statistics on plot
    stats_text = (f"Total: {total:.0f}\n"
                  f"Peak: {peak:.0f}\n"
                  f"Avg SNR: {mean_snr:.1f}\n"
                  f"Max SNR: {max_snr:.1f}")
    ax_flat[idx].text(0.02, 0.98, stats_text,
                      transform=ax_flat[idx].transAxes,
                      verticalalignment="top",
                      bbox=dict(boxstyle="round,pad=0.3",
                               facecolor="yellow", alpha=0.7),
                      fontsize=9)


# show line profiles in bottom row
for idx, bin_size in enumerate(bin_factors):
    ax_idx = idx + len(bin_factors)
    
    if bin_size == 1:
        binned = signal
    else:
        binned = bin_pixels_2d(signal, bin_size)
    
    # extract center line profile
    center_row = binned.shape[0] // 2
    profile = binned[center_row, :]
    
    ax_flat[ax_idx].plot(profile, 'r-', linewidth=2)
    ax_flat[ax_idx].set_title(f"Center Line Profile\n(bin size {bin_size}x{bin_size})")
    ax_flat[ax_idx].set_xlabel("Column")
    ax_flat[ax_idx].set_ylabel("Counts")
    ax_flat[ax_idx].grid(True, alpha=0.3)

    # show line profile 
    ax_flat[idx].axhline(y=center_row, color='red', linewidth=2, linestyle='-', alpha=0.7)

plt.tight_layout()
plt.show()

# print SNR summary
print("\nSNR Summary:")
print("=" * 50)
print(f"{'Bin Size':>10} {'Avg SNR':>10} {'Max SNR':>10} {'SNR Gain':>10}")
print("-" * 50)
base_snr = snr_results[0][1]
for bin_size, avg_snr, max_snr in snr_results:
    gain = avg_snr / base_snr if base_snr > 0 else 0
    print(f"{bin_size} x {bin_size}     {avg_snr:>10.1f} {max_snr:>10.1f} {gain:>10.2f}x")

## Visualising the Binning Process

In [None]:
# create small example to show binning clearly
demo_image = np.arange(36, dtype=np.float32).reshape(6, 6)
print("Demo image (6x6):")
print(demo_image)

# apply 2x2 binning
binned_demo = bin_pixels_2d(demo_image, 2)
print("\nAfter 2x2 binning (3×3):")
print(binned_demo)

# visualise the binning
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

# original with grid
im1 = ax1.matshow(demo_image, cmap='viridis')
ax1.set_title("Original 6x6")
ax1.set_xlabel("Column")
ax1.set_ylabel("Row")

# add grid to show binning blocks
for i in range(0, 6, 2):
    ax1.axhline(i - 0.5, color='red', linewidth=2)
    ax1.axvline(i - 0.5, color='red', linewidth=2)

# add text values
for i in range(6):
    for j in range(6):
        ax1.text(j, i, f'{int(demo_image[i, j])}', 
                 ha='center', va='center', color='white', fontsize=12)

plt.colorbar(im1, ax=ax1, fraction=0.046)

# binned result
im2 = ax2.matshow(binned_demo, cmap='viridis')
ax2.set_title("Binned 3×3")
ax2.set_xlabel("Column")
ax2.set_ylabel("Row")

# add text values
for i in range(3):
    for j in range(3):
        ax2.text(j, i, f'{int(binned_demo[i, j])}', ha='center', va='center', color='white', fontsize=14)

plt.colorbar(im2, ax=ax2, fraction=0.046)

plt.tight_layout()
plt.show()

# show calculation for one bin
print("\nExample calculation for output pixel [0,0]:")
print(f"Sums input pixels [0:2, 0:2]:")
print(demo_image[0:2, 0:2])
print(f"Sum: {demo_image[0,0]} + {demo_image[0,1]} + {demo_image[1,0]} + {demo_image[1,1]} = {binned_demo[0,0]}")