# Connectivity and Sensitivity Analysis with JAXScape

## Overview

This notebook demonstrates how to use **JAXScape** to quantify landscape connectivity and perform **sensitivity analysis** to identify critical areas for conservation prioritization.

### Prerequisites:
```bash
pip install jaxscape rasterio matplotlib
```

In [None]:
import jax.numpy as jnp
import jax.random as jr
import jax
import matplotlib.pyplot as plt
import matplotlib
import rasterio

from jaxscape import LCPDistance, ConnectivityAnalysis, SensitivityAnalysis

## Load and Prepare Input Data

We load a habitat suitability raster and use it as both the quality (habitat value) and permeability (movement ease) rasters. In practice, you may want to use different rasters for these purposes:

- **Quality raster**: Represents habitat value, species density, or patch importance
- **Permeability raster**: Represents movement cost surface (higher values = easier movement)

The raster values are normalized to [0, 1] range, with 0 representing unsuitable habitat and 1 representing optimal conditions.

In [None]:
# Load habitat suitability raster
with rasterio.open("../data/suitability.tif") as src:
    raster = src.read(1, masked=True)  # Read first band with masking
    quality_raster = jnp.array(raster.filled(0), dtype="float32") / 100  # Normalize to [0, 1]

# Visualize the landscape
plt.figure(figsize=(10, 8))
plt.imshow(quality_raster, cmap='viridis')
plt.colorbar(label='Habitat Quality/Permeability', shrink=0.7)
plt.title('Input Landscape')
plt.axis('off')
plt.show()

## Define Dispersal Parameters

We need to specify:

1. **Dispersal range (D)**: Maximum distance an individual can traverse (in pixels) through optimal habitat (permeability = 1)
2. **Distance metric**: Method to calculate effective distance (here we use Least-Cost Path)
3. **Proximity function**: How distance translates to connectivity

Common proximity functions:
- **Negative exponential**: `exp(-d/D)` - smooth decay (used here)
- **Threshold**: `(d < D).astype(float)` - binary connectivity
- **Power law**: `d**(-α)` - fat-tailed dispersal

In [None]:
# Define dispersal range (in pixels)
D = 20  # Maximum dispersal distance through optimal habitat

# Initialize distance metric
distance = LCPDistance()  # Least-Cost Path distance

# Define proximity function: converts distance to connectivity
def proximity(dist):
    """Negative exponential decay of connectivity with distance."""
    return jnp.exp(-dist / D)

## Calculate Landscape Connectivity

`ConnectivityAnalysis` computes the overall landscape connectivity by summing quality-weighted proximities between all pairs of habitat cells. This produces a scalar measure of total connectivity.

**Key parameters:**
- `quality_raster`: Habitat quality or patch importance
- `permeability_raster`: Movement cost surface
- `distance`: Distance metric (LCP, Resistance, or RSP)
- `proximity`: Function converting distance to connectivity
- `dependency_range`: Maximum relevant distance (optimization parameter)
- `batch_size`: Parallel processing batch size (tune based on available memory)
- `coarsening_factor`: Spatial aggregation factor (0.0 = no aggregation)
- `q_weighted`: Whether to weight by quality (True) or not (False)

In [None]:
# Initialize connectivity analysis
connectivity_prob = ConnectivityAnalysis(
    quality_raster=quality_raster,
    permeability_raster=quality_raster,
    distance=distance,
    proximity=proximity,
    coarsening_factor=0.,  # No spatial aggregation
    dependency_range=D,
    batch_size=50  # Process 50 cells per batch
)

# Compute baseline connectivity (unweighted by quality)
connectivity = connectivity_prob.run(q_weighted=False)
print(f"Baseline landscape connectivity: {connectivity:.0f}")

## Sensitivity Analysis

`SensitivityAnalysis` computes the derivative of landscape connectivity with respect to either permeability or quality. This identifies which landscape cells, when improved, would most increase overall connectivity.

The output is a raster where each cell value represents the marginal effect of improving that cell on total connectivity.

In [None]:
# Initialize sensitivity analysis
sensitivity_prob = SensitivityAnalysis(
    quality_raster=quality_raster,
    permeability_raster=quality_raster,
    distance=distance,
    proximity=proximity,
    coarsening_factor=0.,
    dependency_range=D,
    batch_size=20  # Smaller batch size for memory efficiency
)

# Compute sensitivity with respect to permeability
sensitivity_permeability = sensitivity_prob.run("permeability", q_weighted=True)
print(f"Sensitivity raster computed with shape: {sensitivity_permeability.shape}")

## Visualize Elasticity

We convert sensitivity to **elasticity** - the proportional change in connectivity from a proportional change in permeability. This is computed as:

$$\text{Elasticity} = \frac{\partial C}{\partial p} \times p$$

where $C$ is connectivity and $p$ is permeability.

Elasticity is more interpretable than raw sensitivity because it's scale-independent and represents percent change in connectivity per percent change in permeability.

**High elasticity areas** are the most impactful for conservation - small improvements there yield large connectivity gains.

In [None]:
# Compute elasticity: sensitivity × current permeability
elasticity = sensitivity_permeability * quality_raster
elasticity = jnp.nan_to_num(elasticity, nan=0.0)  # Replace NaN with 0

# Visualize elasticity on log scale (highlights variation)
plt.figure(figsize=(12, 10))
plt.imshow(
    elasticity + 1e-2,  # Add small constant for log scale
    cmap='plasma',
    norm=matplotlib.colors.LogNorm(vmin=1e0)
)
plt.axis('off')
cbar = plt.colorbar(shrink=0.6)
cbar.set_label('Elasticity w.r.t Permeability', fontsize=12)
plt.title('Conservation Priority: Elasticity Map', fontsize=14, pad=20)
plt.tight_layout()
plt.savefig('elasticity_permeability.png', dpi=300, bbox_inches='tight')
plt.show()

print("High elasticity areas (bright colors) have the greatest conservation impact!")

## Conservation Prioritization: Comparing Strategies

We now compare two prioritization strategies for habitat restoration:

1. **Elasticity-based**: Target the top 5% of cells by elasticity
2. **Random**: Target random cells (control)

For each strategy, we simulate improving permeability by 0.4 units (40% of the scale) and measure the resulting gain in landscape connectivity.

This demonstrates the practical value of sensitivity analysis for conservation decision-making.

In [None]:
# Define improvement magnitude
improved_permeability = 0.4  # Permeability increase per restored cell

# Strategy 1: Target high elasticity cells (top 5%)
threshold = jnp.percentile(elasticity, 95)  # 95th percentile
high_sensitivity_coords = jnp.where(elasticity >= threshold)
improved_quality_raster = quality_raster.at[high_sensitivity_coords].add(improved_permeability)

print(f"Strategy 1: Restoring {high_sensitivity_coords[0].size} cells based on elasticity")

# Strategy 2: Target random cells (same number as Strategy 1)
key = jr.PRNGKey(0)
random_indices = jr.choice(
    key,
    jnp.arange(elasticity.size),
    shape=(high_sensitivity_coords[0].size,),
    replace=False
)
random_coords = jnp.unravel_index(random_indices, quality_raster.shape)
modified_quality_raster = quality_raster.at[random_coords].add(improved_permeability)

print(f"Strategy 2: Restoring {random_coords[0].size} random cells (control)")

## Evaluate Connectivity Gains

We now compute landscape connectivity for:
1. Baseline (no restoration)
2. Elasticity-based restoration
3. Random restoration

The connectivity gain is expressed as a percentage increase over baseline.

In [None]:
def run_connectivity_analysis(raster):
    """Helper function to compute connectivity for a given permeability raster."""
    connectivity_prob = ConnectivityAnalysis(
        quality_raster=quality_raster,
        permeability_raster=raster,
        distance=distance,
        proximity=proximity,
        coarsening_factor=0.,
        dependency_range=D,
        batch_size=50
    )
    return connectivity_prob.run(q_weighted=True)

# Compute connectivity for all scenarios
print("Computing connectivity for baseline...")
base_connectivity = run_connectivity_analysis(quality_raster)

print("Computing connectivity for elasticity-based restoration...")
connectivity_improved = run_connectivity_analysis(improved_quality_raster)

print("Computing connectivity for random restoration...")
connectivity_improved_randomly = run_connectivity_analysis(modified_quality_raster)

## Results: Prioritization Comparison

The results demonstrate the advantage of using elasticity-based prioritization over random selection. By targeting high-elasticity areas, we achieve greater connectivity gains with the same restoration budget.

In [None]:
# Calculate percent gains
elasticity_gain = (connectivity_improved - base_connectivity) / base_connectivity * 100
random_gain = (connectivity_improved_randomly - base_connectivity) / base_connectivity * 100

print("="*60)
print("LANDSCAPE CONNECTIVITY GAIN")
print("="*60)
print(f"Baseline connectivity:           {base_connectivity:.2f}")
print(f"\nElasticity-based restoration:    {connectivity_improved:.2f} (+{elasticity_gain:.2f}%)")
print(f"Random restoration (control):    {connectivity_improved_randomly:.2f} (+{random_gain:.2f}%)")
print(f"\nElasticity advantage:            {elasticity_gain / random_gain:.2f}x more effective")
print("="*60)

# Visualize comparison
fig, ax = plt.subplots(figsize=(8, 6))
strategies = ['Baseline', 'Elasticity-based', 'Random']
connectivities = [base_connectivity, connectivity_improved, connectivity_improved_randomly]
colors = ['gray', 'green', 'orange']

bars = ax.bar(strategies, connectivities, color=colors, alpha=0.7, edgecolor='black')
ax.set_ylabel('Landscape Connectivity', fontsize=12)
ax.set_title('Restoration Strategy Comparison', fontsize=14, pad=20)
ax.grid(axis='y', alpha=0.3)

# Add percentage labels
for i, (bar, val) in enumerate(zip(bars, connectivities)):
    if i > 0:
        gain = (val - base_connectivity) / base_connectivity * 100
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height(),
                f'+{gain:.1f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

print("\n✓ Elasticity-based prioritization significantly outperforms random selection!")

## Key Takeaways

1. **Sensitivity analysis** identifies which landscape modifications have the greatest impact on connectivity
2. **Elasticity** provides a scale-independent measure for prioritization
3. **Targeted restoration** based on elasticity substantially outperforms random selection
4. **JAXScape** enables efficient computation of these metrics through JAX's automatic differentiation

This approach can be adapted to different:
- Distance metrics (LCP, Resistance, RSP)
- Proximity functions (exponential, threshold, power law)
- Dispersal ranges and species characteristics
- Conservation objectives (quality vs. permeability improvements)