# Moving Windows with JAXScape

## Overview

Spatial ecological data often arrives as large continuous rasters that exceed computational memory or would benefit from localized processing. Moving window operations—systematically dividing a raster into overlapping tiles—provide an elegant solution for both computational tractability and ecological interpretability. Rather than processing an entire landscape at once, moving windows allow us to compute metrics locally, capturing spatial heterogeneity while managing memory constraints through sequential or parallel processing of manageable chunks.

JAXScape's `WindowOperation` class offers flexible tools for iterating over raster windows with configurable buffer zones that handle edge effects gracefully. Whether computing local connectivity metrics, applying spatially-varying transformations, or parallelizing expensive distance calculations across a large landscape, moving windows transform intractable problems into sequences of tractable sub-problems. This notebook demonstrates both lazy iteration (processing one window at a time for memory efficiency) and eager iteration (loading all windows into memory for GPU-accelerated batch processing), along with techniques for seamlessly reassembling processed tiles back into coherent landscape-scale outputs.

### Prerequisites:
```bash
pip install jaxscape rasterio matplotlib
```

In [None]:
import rasterio
import jax.numpy as jnp
import matplotlib.pyplot as plt

from jaxscape import WindowOperation, padding

## Load and Visualize Input Raster

We begin by loading a habitat quality raster that will serve as our test dataset. This raster represents continuous habitat suitability values across a landscape, with higher values indicating better quality habitat. Understanding the spatial structure of your input data is important for choosing appropriate window sizes and buffer zones in subsequent analyses.

In [None]:
# Load habitat quality raster
with rasterio.open("suitability.tif") as src:
    raster = src.read(1, masked=True)  # Read first band with masking
    quality = jnp.array(
        raster.filled(0), 
        dtype="float32"
    )  # Replace no-data values with 0

print(f"Raster shape: {quality.shape}")
print(f"Value range: [{quality.min():.2f}, {quality.max():.2f}]")

In [None]:
# Visualize the full raster
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(quality, cmap='viridis')
ax.set_title('Habitat Quality Raster', fontsize=14, pad=20)
ax.axis('off')
plt.colorbar(im, ax=ax, label='Quality', shrink=0.7)
plt.tight_layout()
plt.savefig('quality_raster.png', dpi=300, bbox_inches='tight')
plt.show()

## Configure Window Parameters

Moving window operations require three key parameters. The **window size** determines the dimensions of each tile or focal region we'll process—larger windows capture more spatial context but increase computational cost. The **buffer size** defines an overlap zone around each window's core area, which is essential for handling edge effects in operations like distance calculations or convolutions where pixels near boundaries need information from neighboring regions. Finally, **padding** extends the original raster with boundary values to ensure windows near edges have complete neighborhoods.

The relationship between these parameters follows: `padded_size = original_size + 2 × buffer_size`, and each window has dimensions `window_size + 2 × buffer_size`. The buffer zone is processed but typically discarded when assembling final results, serving only to provide context for edge pixels.

In [None]:
# Define window parameters
buffer_size = 10  # Overlap zone to handle edge effects
window_size = 50  # Core area of each window (in pixels)

print(f"Window size: {window_size} × {window_size} pixels")
print(f"Buffer size: {buffer_size} pixels on each side")
print(f"Total window dimensions: {window_size + 2*buffer_size} × {window_size + 2*buffer_size}")

# Pad the raster to handle boundaries
quality_padded = padding(quality, buffer_size, window_size)

print(f"\nOriginal raster shape: {quality.shape}")
print(f"Padded raster shape: {quality_padded.shape}")
print(f"Padding added: {quality_padded.shape[0] - quality.shape[0]} pixels per dimension")

## Initialize WindowOperation

The `WindowOperation` class manages the logistics of dividing a raster into windows, tracking their positions, and providing iterators for sequential or batch processing. By specifying the padded raster shape along with window and buffer sizes, we create a systematic grid that tiles the entire landscape with the specified overlap.

In [None]:
# Initialize window operation manager
window_op = WindowOperation(
    shape=quality_padded.shape,
    window_size=window_size,
    buffer_size=buffer_size
)

print(f"Total number of windows: {window_op.nb_steps}")
print(f"Window grid dimensions: {window_op.nb_steps} windows")
print(f"\nEach window covers {window_size}² = {window_size**2} core pixels")
print(f"Plus {2*buffer_size*(window_size + 2*buffer_size) - 4*buffer_size**2} buffer pixels")

## Lazy Iterator: Memory-Efficient Processing

The **lazy iterator** yields windows one at a time, making it ideal for memory-constrained environments or when processing can be done sequentially. Each iteration provides the window's `(x, y)` coordinates (top-left corner position) and the corresponding data array. This approach is particularly useful when each window requires expensive computations that benefit from not holding all windows in memory simultaneously.

The lazy iterator is implemented as a Python generator, meaning windows are extracted on-demand rather than pre-computed. This allows processing arbitrarily large rasters as long as individual windows fit in memory.

In [None]:
# Visualize all windows using lazy iteration
n_windows = window_op.nb_steps
n_cols = 4
n_rows = int(jnp.ceil(n_windows / n_cols))

fig, axs = plt.subplots(n_rows, n_cols, figsize=(12, 3*n_rows))
axs = axs.flatten() if n_windows > 1 else [axs]

for i, (xy, window) in enumerate(window_op.lazy_iterator(quality_padded)):
    ax = axs[i]
    im = ax.imshow(window, cmap='viridis')
    ax.set_title(f'Window {i+1}\nPosition: ({xy[0]}, {xy[1]})', fontsize=10)
    ax.axis('off')
    
    # Draw rectangle showing core area (excluding buffer)
    from matplotlib.patches import Rectangle
    rect = Rectangle(
        (buffer_size, buffer_size),
        window_size, window_size,
        linewidth=2, edgecolor='red', facecolor='none'
    )
    ax.add_patch(rect)

# Hide unused subplots
for j in range(i+1, len(axs)):
    axs[j].axis('off')

plt.tight_layout()
plt.savefig('windows.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\n✓ Processed {i+1} windows using lazy iteration")
print("Red rectangles show core areas (buffers excluded)")

## Eager Iterator: Vectorized Batch Processing

The **eager iterator** loads all windows into memory at once, returning two arrays: window coordinates and window data. This approach enables vectorized operations across all windows simultaneously, which can dramatically accelerate processing when GPU or TPU acceleration is available through JAX. The trade-off is higher memory consumption, but the performance gains are substantial for operations that can be parallelized.

The returned arrays have shapes `(n_windows, 2)` for coordinates and `(n_windows, height, width)` for window data, making them ready for `jax.vmap` or other batch processing operations.

In [None]:
# Load all windows at once for batch processing
xy_coords, windows = window_op.eager_iterator(quality_padded)

print(f"Coordinates array shape: {xy_coords.shape}")
print(f"Windows array shape: {windows.shape}")
print(f"\nMemory usage: ~{windows.nbytes / 1e6:.2f} MB for window data")

# Example: compute statistics across all windows in parallel
window_means = windows.mean(axis=(1, 2))
window_stds = windows.std(axis=(1, 2))

print(f"\nWindow statistics:")
for i, (mean, std) in enumerate(zip(window_means, window_stds)):
    print(f"  Window {i+1}: mean={mean:.3f}, std={std:.3f}")

## Updating the Raster with Processed Windows

After processing individual windows, we often need to update the original raster with modified values. The `update_raster_with_focal_window` method handles this by replacing a specific window's core area (excluding buffers) with new data. This is essential for workflows where each window undergoes transformation—such as connectivity calculations or filtering operations—and results must be assembled back into a seamless raster.

The method automatically handles buffer zones, ensuring only the core region is written and that overlapping edges between adjacent windows are properly managed. This maintains spatial continuity across window boundaries.

In [None]:
# Example: Replace window 3's core area with ones (demonstrating update)
window_index = 2  # Zero-indexed: window 3
new_window = jnp.ones(windows[window_index].shape, dtype="float32")

# Update the padded raster
updated_raster = window_op.update_raster_with_focal_window(
    xy_coords[window_index], 
    quality_padded, 
    new_window
)

print(f"Updated window {window_index + 1} at position {xy_coords[window_index]}")
print(f"Modified pixels: {window_size**2} (core area only)")

In [None]:
# Visualize the update
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

# Original padded raster
im1 = ax1.imshow(quality_padded, cmap='viridis')
ax1.set_title('Original (Padded)', fontsize=12)
ax1.axis('off')
plt.colorbar(im1, ax=ax1, shrink=0.7)

# Updated window (before integration)
im2 = ax2.imshow(new_window, cmap='viridis')
ax2.set_title(f'Modified Window {window_index + 1}', fontsize=12)
ax2.axis('off')
# Draw core area
from matplotlib.patches import Rectangle
rect = Rectangle(
    (buffer_size, buffer_size),
    window_size, window_size,
    linewidth=2, edgecolor='red', facecolor='none'
)
ax2.add_patch(rect)
plt.colorbar(im2, ax=ax2, shrink=0.7)

# Updated raster
im3 = ax3.imshow(updated_raster, cmap='viridis')
ax3.set_title('Updated Raster', fontsize=12)
ax3.axis('off')
plt.colorbar(im3, ax=ax3, shrink=0.7)

plt.tight_layout()
plt.savefig('new_raster.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Window 3's core area replaced with ones (value = 1.0)")
print("Buffer zones remain unchanged to maintain spatial continuity")

## Example: Parallel Processing Pipeline

Here we demonstrate a complete pipeline that processes all windows in parallel using JAX's vectorization. This example applies a simple transformation (contrast enhancement) to each window, but the pattern extends to any per-window operation including distance calculations, filtering, or local connectivity metrics.

In [None]:
import jax

def process_window(window):
    """Example processing: enhance contrast via histogram stretching."""
    # Extract core region (exclude buffers)
    core = window[buffer_size:-buffer_size, buffer_size:-buffer_size]
    
    # Contrast enhancement
    min_val, max_val = core.min(), core.max()
    core_enhanced = (core - min_val) / (max_val - min_val + 1e-8)
    
    # Reconstruct full window (keep buffers unchanged)
    enhanced_window = window.at[
        buffer_size:-buffer_size, 
        buffer_size:-buffer_size
    ].set(core_enhanced)
    
    return enhanced_window

# Vectorize processing across all windows
process_all = jax.vmap(process_window)

# Process all windows in parallel
processed_windows = process_all(windows)

print(f"Processed {len(windows)} windows in parallel")
print(f"Output shape: {processed_windows.shape}")

In [None]:
# Reassemble processed windows into final raster
result_raster = quality_padded.copy()

for i, xy in enumerate(xy_coords):
    result_raster = window_op.update_raster_with_focal_window(
        xy, result_raster, processed_windows[i]
    )

# Remove padding to get final result
final_result = result_raster[
    buffer_size:-buffer_size,
    buffer_size:-buffer_size
]

print(f"Final result shape: {final_result.shape}")
print(f"Matches original: {final_result.shape == quality.shape}")

In [None]:
# Compare original and processed rasters
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 5))

im1 = ax1.imshow(quality, cmap='viridis')
ax1.set_title('Original Raster', fontsize=13, pad=15)
ax1.axis('off')
plt.colorbar(im1, ax=ax1, shrink=0.7, label='Quality')

im2 = ax2.imshow(final_result, cmap='viridis')
ax2.set_title('Processed (Contrast Enhanced)', fontsize=13, pad=15)
ax2.axis('off')
plt.colorbar(im2, ax=ax2, shrink=0.7, label='Enhanced')

# Difference map
diff = jnp.abs(final_result - quality)
im3 = ax3.imshow(diff, cmap='hot')
ax3.set_title('Absolute Difference', fontsize=13, pad=15)
ax3.axis('off')
plt.colorbar(im3, ax=ax3, shrink=0.7, label='|Δ|')

plt.tight_layout()
plt.show()

print(f"\nProcessing statistics:")
print(f"  Mean absolute change: {diff.mean():.4f}")
print(f"  Max absolute change: {diff.max():.4f}")

## Key Takeaways

### Moving Window Fundamentals

Moving window operations provide a systematic framework for processing large rasters by dividing them into manageable, overlapping tiles. The **buffer zones** are critical for maintaining spatial continuity—they provide necessary context for edge pixels but are excluded when assembling final results. **Padding** the original raster ensures windows near boundaries have complete neighborhoods, preventing edge artifacts in downstream analyses. The choice between lazy and eager iteration depends on the balance between memory constraints and computational parallelism: lazy iteration conserves memory by processing one window at a time, while eager iteration enables GPU-accelerated batch processing through vectorization.

### Practical Applications

This approach is essential for **large-scale connectivity analysis**, where computing all-pairs distances for millions of pixels would exhaust memory, but dividing the landscape into windows makes the problem tractable. It supports **parallel computation** by distributing independent windows across multiple processors or devices, with JAX's `vmap` providing seamless vectorization. The technique enables **focal operations** that depend on local neighborhoods, such as moving averages, edge detection, or local connectivity metrics. It also facilitates **progressive processing** where results from one window inform calculations in adjacent windows, enabling iterative refinement of landscape-scale metrics.

### Implementation Considerations

Several practical factors affect performance and accuracy. **Window size** should match the spatial scale of ecological processes—too small and you lose context, too large and computation becomes prohibitive. **Buffer size** must accommodate the maximum distance or neighborhood size in your operations, typically set to the dispersal distance or kernel radius. **Overlap handling** requires careful attention when reassembling results: the `update_raster_with_focal_window` method manages this automatically by only writing core regions, but custom pipelines need explicit edge coordination. **Memory-computation trade-offs** are central: eager iteration's memory cost grows linearly with window count, but enables orders-of-magnitude speedups through parallelism on modern accelerators.

### Extensions and Advanced Usage

The basic moving window framework can be extended in several directions. **Adaptive window sizing** could vary window dimensions based on local landscape heterogeneity or data density, allocating computational resources where they're most needed. **Multi-scale analysis** might process the same landscape at different window sizes to capture phenomena operating at multiple spatial scales simultaneously. **Streaming workflows** could combine lazy iteration with progressive disk writes to process rasters larger than available RAM. **Hierarchical tiling** might use coarse windows to identify regions of interest, then apply fine-grained windows only where needed, optimizing the computation-accuracy trade-off for large-scale applications.