# Moving window operations

JAXScape's `WindowOperation` class enables processing large rasters by dividing them into overlapping windows with configurable buffer zones. This approach makes memory-intensive operations tractable by working on manageable tiles rather than entire landscapes. The notebook demonstrates lazy iteration (memory-efficient, one window at a time) and eager iteration (GPU-accelerated batch processing), plus techniques for reassembling processed windows into seamless outputs.

### Prerequisites
```bash
pip install jaxscape rasterio matplotlib
```

In [None]:
import rasterio
import jax.numpy as jnp
import matplotlib.pyplot as plt

from jaxscape import WindowOperation, padding

## Load input raster

We load a habitat quality raster representing continuous suitability values across a landscape.

In [None]:
# Load habitat quality raster
with rasterio.open("suitability.tif") as src:
    raster = src.read(1, masked=True)  # Read first band with masking
    quality = jnp.array(
        raster.filled(0), 
        dtype="float32"
    )  # Replace no-data values with 0

print(f"Raster shape: {quality.shape}")
print(f"Value range: [{quality.min():.2f}, {quality.max():.2f}]")

In [None]:
# Visualize the full raster
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(quality, cmap='viridis')
ax.set_title('Habitat Quality Raster', fontsize=14, pad=20)
ax.axis('off')
plt.colorbar(im, ax=ax, label='Quality', shrink=0.7)
plt.tight_layout()
plt.savefig('quality_raster.png', dpi=300, bbox_inches='tight')
plt.show()

## Configure window parameters

Moving windows require three key parameters: **window size** (tile dimensions for each focal region), **buffer size** (overlap zone around each window's core to handle edge effects), and **padding** (extends the raster so boundary windows have complete neighborhoods).

The relationship is: `padded_size = original_size + 2 × buffer_size`, with each window having dimensions `window_size + 2 × buffer_size`. Buffer zones provide context for edge pixels but are excluded when assembling final results.

In [None]:
# Define window parameters
buffer_size = 10  # Overlap zone to handle edge effects
window_size = 50  # Core area of each window (in pixels)

print(f"Window size: {window_size} × {window_size} pixels")
print(f"Buffer size: {buffer_size} pixels on each side")
print(f"Total window dimensions: {window_size + 2*buffer_size} × {window_size + 2*buffer_size}")

# Pad the raster to handle boundaries
quality_padded = padding(quality, buffer_size, window_size)

print(f"\nOriginal raster shape: {quality.shape}")
print(f"Padded raster shape: {quality_padded.shape}")
print(f"Padding added: {quality_padded.shape[0] - quality.shape[0]} pixels per dimension")

## Initialize window operation

The `WindowOperation` class manages raster tiling, tracking window positions and providing iterators for sequential or batch processing.

In [None]:
# Initialize window operation manager
window_op = WindowOperation(
    shape=quality_padded.shape,
    window_size=window_size,
    buffer_size=buffer_size
)

print(f"Total number of windows: {window_op.nb_steps}")
print(f"Each window: {window_size}² core + {2*buffer_size} buffer per side")

## Lazy iterator: memory-efficient processing

The lazy iterator yields windows one at a time (via Python generator), ideal for memory-constrained environments or sequential processing. Each iteration provides window coordinates `(x, y)` and the data array. This enables processing arbitrarily large rasters as long as individual windows fit in memory.

In [None]:
# Visualize all windows using lazy iteration
n_windows = window_op.nb_steps
n_cols = 4
n_rows = int(jnp.ceil(n_windows / n_cols))

fig, axs = plt.subplots(n_rows, n_cols, figsize=(12, 3*n_rows))
axs = axs.flatten() if n_windows > 1 else [axs]

for i, (xy, window) in enumerate(window_op.lazy_iterator(quality_padded)):
    ax = axs[i]
    im = ax.imshow(window, cmap='viridis')
    ax.set_title(f'Window {i+1}\nPosition: ({xy[0]}, {xy[1]})', fontsize=10)
    ax.axis('off')
    
    # Draw rectangle showing core area (excluding buffer)
    from matplotlib.patches import Rectangle
    rect = Rectangle(
        (buffer_size, buffer_size),
        window_size, window_size,
        linewidth=2, edgecolor='red', facecolor='none'
    )
    ax.add_patch(rect)

# Hide unused subplots
for j in range(i+1, len(axs)):
    axs[j].axis('off')

plt.tight_layout()
plt.savefig('windows.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Processed {i+1} windows (red rectangles = core areas)")

## Eager iterator: vectorized batch processing

The eager iterator loads all windows into memory at once, returning coordinate and data arrays with shapes `(n_windows, 2)` and `(n_windows, height, width)`. This enables vectorized operations across all windows simultaneously via `jax.vmap`, dramatically accelerating GPU/TPU processing. The trade-off is higher memory consumption.

In [None]:
# Load all windows at once for batch processing
xy_coords, windows = window_op.eager_iterator(quality_padded)

print(f"Coordinates shape: {xy_coords.shape}")
print(f"Windows shape: {windows.shape}")
print(f"Memory usage: ~{windows.nbytes / 1e6:.2f} MB")

# Example: compute statistics across all windows in parallel
window_means = windows.mean(axis=(1, 2))
window_stds = windows.std(axis=(1, 2))

## Update raster with processed windows

The `update_raster_with_focal_window` method replaces a window's core area (excluding buffers) with new data. It automatically handles buffer zones by writing only core regions, maintaining spatial continuity across window boundaries.

In [None]:
# Example: Replace window 3's core area with ones (demonstrating update)
window_index = 2  # Zero-indexed: window 3
new_window = jnp.ones(windows[window_index].shape, dtype="float32")

# Update the padded raster
updated_raster = window_op.update_raster_with_focal_window(
    xy_coords[window_index], 
    quality_padded, 
    new_window
)

print(f"Updated window {window_index + 1} at {xy_coords[window_index]}")

In [None]:
# Visualize the update
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

# Original padded raster
im1 = ax1.imshow(quality_padded, cmap='viridis')
ax1.set_title('Original (Padded)', fontsize=12)
ax1.axis('off')
plt.colorbar(im1, ax=ax1, shrink=0.7)

# Updated window (before integration)
im2 = ax2.imshow(new_window, cmap='viridis')
ax2.set_title(f'Modified Window {window_index + 1}', fontsize=12)
ax2.axis('off')
# Draw core area
from matplotlib.patches import Rectangle
rect = Rectangle(
    (buffer_size, buffer_size),
    window_size, window_size,
    linewidth=2, edgecolor='red', facecolor='none'
)
ax2.add_patch(rect)
plt.colorbar(im2, ax=ax2, shrink=0.7)

# Updated raster
im3 = ax3.imshow(updated_raster, cmap='viridis')
ax3.set_title('Updated Raster', fontsize=12)
ax3.axis('off')
plt.colorbar(im3, ax=ax3, shrink=0.7)

plt.tight_layout()
plt.savefig('new_raster.png', dpi=300, bbox_inches='tight')
plt.show()



## Parallel processing pipeline

We demonstrate processing all windows in parallel using JAX's vectorization. This example applies contrast enhancement, but the pattern extends to any per-window operation (distance calculations, filtering, connectivity metrics).

In [None]:
import jax

def process_window(window):
    """Example processing: enhance contrast via histogram stretching."""
    # Extract core region (exclude buffers)
    core = window[buffer_size:-buffer_size, buffer_size:-buffer_size]
    
    # Contrast enhancement
    min_val, max_val = core.min(), core.max()
    core_enhanced = (core - min_val) / (max_val - min_val + 1e-8)
    
    # Reconstruct full window (keep buffers unchanged)
    enhanced_window = window.at[
        buffer_size:-buffer_size, 
        buffer_size:-buffer_size
    ].set(core_enhanced)
    
    return enhanced_window

# Vectorize processing across all windows
process_all = jax.vmap(process_window)

# Process all windows in parallel
processed_windows = process_all(windows)

print(f"Processed {len(windows)} windows in parallel")
print(f"Output shape: {processed_windows.shape}")

In [None]:
# Reassemble processed windows into final raster
result_raster = quality_padded.copy()

for i, xy in enumerate(xy_coords):
    result_raster = window_op.update_raster_with_focal_window(
        xy, result_raster, processed_windows[i]
    )

# Remove padding to get final result
final_result = result_raster[
    buffer_size:-buffer_size,
    buffer_size:-buffer_size
]

print(f"Final result shape: {final_result.shape} (matches original: {final_result.shape == quality.shape})")

In [None]:
# Compare original and processed rasters
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 5))

im1 = ax1.imshow(quality, cmap='viridis')
ax1.set_title('Original Raster', fontsize=13, pad=15)
ax1.axis('off')
plt.colorbar(im1, ax=ax1, shrink=0.7, label='Quality')

im2 = ax2.imshow(final_result, cmap='viridis')
ax2.set_title('Processed (Contrast Enhanced)', fontsize=13, pad=15)
ax2.axis('off')
plt.colorbar(im2, ax=ax2, shrink=0.7, label='Enhanced')

# Difference map
diff = jnp.abs(final_result - quality)
im3 = ax3.imshow(diff, cmap='hot')
ax3.set_title('Absolute Difference', fontsize=13, pad=15)
ax3.axis('off')
plt.colorbar(im3, ax=ax3, shrink=0.7, label='|Δ|')

plt.tight_layout()
plt.show()

print(f"\nProcessing statistics:")
print(f"  Mean absolute change: {diff.mean():.4f}")
print(f"  Max absolute change: {diff.max():.4f}")