# Inverse Landscape Modeling with JAXScape

## Overview

Landscape genetics aims to understand how landscape structure influences gene flow and genetic differentiation among populations. Traditional approaches assume resistance values for different habitat types, but **inverse modeling** offers a data-driven alternative: we infer landscape resistance patterns directly from observed genetic data.

This notebook demonstrates how to fit a neural network-based resistance model to genetic differentiation (Fst) measurements using JAXScape's differentiable distance metrics and gradient-based optimization. Rather than making *a priori* assumptions about which land-cover types facilitate or impede movement, we let the data reveal these patterns through automatic differentiation and iterative refinement. The approach leverages JAX's computational efficiency to make inverse modeling tractable for realistic landscape-scale problems, opening new possibilities for evidence-based conservation planning and understanding species-landscape relationships.

### Prerequisites:
```bash
pip install jaxscape equinox optimistix rioxarray geopandas scikit-learn
```

In [None]:
import os
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import rioxarray
import xarray as xr
import geopandas as gpd
import equinox as eqx
from equinox import nn
import optimistix as optx

from jaxscape import GridGraph, ResistanceDistance, LCPDistance
from jaxscape.solvers import CholmodSolver
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import train_test_split

# Set random seed for reproducibility
np.random.seed(42)

### Configuration Parameters

The analysis requires several key hyperparameters. The **coarsening factor** controls spatial downsampling to reduce computational cost while preserving landscape patterns. We use a **Cholesky solver** for efficient computation of resistance distances on large graphs. The **distance metric** defaults to ResistanceDistance, though LCPDistance can be substituted for experimentation. Finally, **max steps** limits the number of optimization iterations to ensure convergence within reasonable time.

In [None]:
# File paths (adjust to your data location)
LANDCOVER_PATH = '../data/cesar/landcover_7855.tif'
SITE_METADATA_PATH = '../data/cesar/cesar_site_metadata.gpkg'
GENETIC_DISTANCES_PATH = '../data/cesar/cesar_genetic_distances.npy'

# Model configuration
COARSENING_FACTOR = 10  # Spatial downsampling for computational efficiency
SOLVER = CholmodSolver()  # Fast linear solver for large graphs
DISTANCE_FUN = ResistanceDistance(solver=SOLVER)  # Effective resistance distance
MAX_STEPS = 500  # Maximum optimization iterations

# Alternative distance metric for experimentation:
# DISTANCE_FUN = LCPDistance()  # Least-cost path distance

## Load and Visualize Input Data

The analysis requires three key datasets. The **land-cover raster** from ESA WorldCover provides habitat type classifications across the study area. **Site metadata** contains the geographic coordinates of sampling locations where genetic data were collected. The **genetic distance matrix** stores pairwise Fst values (or similar differentiation metrics) quantifying how genetically distinct populations are at each pair of sites. This genetic differentiation reflects the cumulative effect of landscape resistance on gene flow, which our model will learn to predict from the spatial pattern of land-cover types.

In [None]:
# Load land-cover raster
predictor_raster = rioxarray.open_rasterio(
    LANDCOVER_PATH, 
    mask_and_scale=True
)

# Load site metadata (geographic locations)
site_metadata = gpd.read_file(SITE_METADATA_PATH)
site_gdf = site_metadata.to_crs(epsg=7855)  # Reproject to Australian GDA2020

# Load genetic distance matrix
genetic_distances = np.load(GENETIC_DISTANCES_PATH)

print(f"Land-cover raster shape: {predictor_raster['band_1'].shape}")
print(f"Number of sites: {len(site_gdf)}")
print(f"Genetic distance matrix shape: {genetic_distances.shape}")

In [None]:
# Visualize land-cover with sampling sites
fig, ax = plt.subplots(figsize=(12, 10))
predictor_raster["band_1"].plot(ax=ax, cmap="tab20", add_colorbar=True)
ax.scatter(
    site_gdf.geometry.x, 
    site_gdf.geometry.y, 
    c="red", 
    s=100, 
    edgecolor='white',
    linewidth=2,
    label="Sampling Sites",
    zorder=10
)
ax.set_title("Land-Cover Map with Sampling Sites", fontsize=14, pad=20)
ax.legend(loc='upper right', fontsize=12)
plt.tight_layout()
plt.show()

print(f"Sites: {', '.join(site_gdf['site_name'].values)}")

## Prepare Features and Targets

The land-cover data requires several preprocessing steps to create suitable model inputs. First, we **compress class IDs** by converting sparse WorldCover classes to a contiguous range 0..K-1, making one-hot encoding more efficient. Next, we **one-hot encode** these categorical classes to create binary feature vectors that the neural network can process. We then apply **coarsening via mean pooling** to downsample the raster while preserving the proportional composition of land-cover classes within each coarsened cell—this captures habitat heterogeneity rather than forcing a single discrete class per pixel. Finally, we **map site coordinates to grid nodes** by associating each sampling location with its nearest coarsened grid cell. This preprocessing pipeline balances computational efficiency with preserving the spatial detail most relevant to understanding gene flow patterns.

In [None]:
def prepare_feature_targets(predictor_raster, site_gdf, coarsening_factor):
    """Process land-cover and create model inputs.
    
    Returns
    -------
    features_onehot_coarse : array
        One-hot encoded land-cover features after coarsening (H, W, K)
    unique_classes : array
        Original WorldCover class values
    target_nodes : array
        Node indices for sampling sites
    grid : GridGraph
        Reference grid for node indexing
    feature_da : xarray.DataArray
        Coarsened feature raster with coordinates
    """
    # Compress WorldCover classes to contiguous IDs
    raw_band = np.asarray(predictor_raster["band_1"])
    unique_vals, inverse = np.unique(raw_band.ravel(), return_inverse=True)
    class_ids = inverse.reshape(raw_band.shape).astype(np.int32)
    features_categorical = jnp.array(class_ids).squeeze()
    unique_classes = jnp.array(unique_vals)
    
    print(f"Found {len(unique_vals)} unique land-cover classes")
    
    # One-hot encode: (H, W) -> (H, W, K)
    features_onehot = jax.nn.one_hot(
        features_categorical, 
        num_classes=len(unique_vals)
    )
    
    # Reorder for coarsening: (H, W, K) -> (K, H, W)
    features_onehot = jnp.moveaxis(features_onehot, -1, 0)
    
    # Coarsen using mean pooling (preserves class composition)
    coords = {
        "band": np.arange(features_onehot.shape[0]),
        "y": predictor_raster.y.values,
        "x": predictor_raster.x.values,
    }
    feature_da = xr.DataArray(
        features_onehot,
        coords=coords,
        dims=("band", "y", "x"),
    )
    feature_da = feature_da.coarsen(
        x=coarsening_factor, 
        y=coarsening_factor, 
        boundary="trim"
    ).mean()
    
    # Back to (H, W, K)
    features_onehot_coarse = jnp.moveaxis(feature_da.data, 0, -1)
    print(f"Coarsened feature shape: {features_onehot_coarse.shape}")
    
    # Map site coordinates to coarsened grid indices
    x_idx = jnp.array([
        int(np.argmin(np.abs(feature_da.x.values - x))) 
        for x in site_gdf.geometry.x.values
    ])
    y_idx = jnp.array([
        int(np.argmin(np.abs(feature_da.y.values - y))) 
        for y in site_gdf.geometry.y.values
    ])
    
    # Create reference grid for node indexing
    grid = GridGraph(
        jnp.ones((feature_da.x.size, feature_da.y.size)), 
        fun=lambda x, y: (x + y) / 2
    )
    target_nodes = grid.coord_to_index(x_idx, y_idx)
    
    return features_onehot_coarse, unique_classes, target_nodes, grid, feature_da

# Process features
features_onehot, unique_classes, target_nodes, ref_grid, coarse_feature_da = prepare_feature_targets(
    predictor_raster, site_gdf, COARSENING_FACTOR
)

print(f"Target nodes (site indices): {target_nodes}")

In [None]:
# Visualize coarsened land-cover with sites
node_coords = ref_grid.index_to_coord(target_nodes)
x_indices, y_indices = node_coords[:, 0], node_coords[:, 1]

fig, ax = plt.subplots(figsize=(12, 10))
ax.imshow(features_onehot.argmax(axis=-1), cmap="tab20")
ax.scatter(
    x_indices, y_indices, 
    c="blue", 
    s=150, 
    edgecolor='white',
    linewidth=2,
    label="Sites",
    zorder=10
)

# Annotate sites
for xi, yi, name in zip(x_indices, y_indices, site_gdf["site_name"].values):
    ax.text(
        int(xi), int(yi) - 5,
        str(name),
        color="black",
        fontsize=10,
        fontweight='bold',
        ha="center",
        va="bottom",
        bbox=dict(facecolor="white", alpha=0.9, edgecolor="black", pad=2),
    )

ax.set_title("Coarsened Land-Cover with Site Locations", fontsize=14, pad=20)
ax.legend(loc='upper right', fontsize=12)
ax.axis('off')
plt.tight_layout()
plt.show()

## Define the Resistance Model

We build a **neural network** that maps one-hot land-cover features to positive resistance values according to the transformation $\text{resistance} = \exp(\text{NN}(\text{features})) + \epsilon$. The architecture consists of a K-dimensional one-hot vector as **input** (representing land-cover classes), followed by **two hidden layers** with ReLU activation and 16 hidden units each, culminating in a **single output** value that is exponentiated to ensure positive resistance. This model is applied pixel-wise via `vmap` to produce a complete resistance surface across the landscape.

The design reflects several important considerations. The **small network** (16 hidden units) helps prevent overfitting when training on limited genetic data, forcing the model to learn parsimonious resistance patterns. The **exponential activation** on the output ensures all predicted resistance values are physically meaningful (positive), which is required for well-defined distance metrics. The **one-hot input encoding** allows the model to learn distinct resistance values for each land-cover type while also capturing their interactions through the hidden layers, providing more flexibility than simple categorical lookup tables.

In [None]:
def build_model(num_classes: int, seed: int = 1) -> tuple:
    """Build neural resistance model.
    
    Parameters
    ----------
    num_classes : int
        Number of land-cover classes
    seed : int
        Random seed for initialization
        
    Returns
    -------
    model : eqx.Module
        Complete model
    params : pytree
        Trainable parameters
    static : pytree
        Static (non-trainable) components
    """
    key = jax.random.PRNGKey(seed)
    
    class ResistanceModel(eqx.Module):
        layers: list
        num_classes: int
        
        def __init__(self, num_classes: int, key):
            self.num_classes = num_classes
            k1, k2, k3 = jax.random.split(key, 3)
            hidden_dim = 16  # Small network to prevent overfitting
            
            self.layers = [
                nn.Linear(num_classes, hidden_dim, key=k1),
                jax.nn.relu,
                nn.Linear(hidden_dim, hidden_dim, key=k2),
                jax.nn.relu,
                nn.Linear(hidden_dim, 1, key=k3),
            ]
        
        def __call__(self, x):
            """Map one-hot feature to positive resistance."""
            for layer in self.layers:
                x = layer(x)
            return jnp.exp(x) + 1e-3  # Ensure positive resistance
    
    model = ResistanceModel(num_classes, key)
    params, static = eqx.partition(model, eqx.is_inexact_array)
    return model, params, static

# Initialize model
model, params, static = build_model(len(unique_classes))
print(f"Model initialized with {len(unique_classes)} land-cover classes")
print(f"Trainable parameters: {sum(p.size for p in jax.tree_util.tree_leaves(params))}")

### Visualize Initial Resistance Prediction

Before training, the model produces random resistance values based on the initialization. This serves as a baseline for comparison after optimization.

In [None]:
# Apply model to all pixels via vmap
model_vmapped = jax.vmap(jax.vmap(model, in_axes=0), in_axes=0)
initial_resistance = model_vmapped(features_onehot).squeeze()

fig, ax = plt.subplots(figsize=(12, 10))
im = ax.imshow(initial_resistance, cmap="RdYlGn_r")
ax.set_title("Initial Resistance Prediction (Random)", fontsize=14, pad=20)
ax.axis("off")
plt.colorbar(im, ax=ax, label="Resistance", shrink=0.6)
plt.tight_layout()
plt.show()

print(f"Resistance range: [{initial_resistance.min():.3f}, {initial_resistance.max():.3f}]")

## Define Loss Function and Training Setup

The **loss function** measures the discrepancy between predicted resistance distances and observed genetic distances:

$$L = \frac{1}{N} \sum_{i,j} (d_{\text{genetic}}^{ij} - d_{\text{resistance}}^{ij})^2$$

**Training procedure:**
1. Model predicts resistance for each pixel
2. Build GridGraph with predicted resistance
3. Compute pairwise resistance distances between sites
4. Compare to genetic distances via MSE
5. Backpropagate gradients and update model parameters

We use **train/test split** on pairwise distances to validate generalization.

In [None]:
@eqx.filter_jit
def loss_fn(params, args):
    """Compute squared error between predicted and target distances.
    
    Parameters
    ----------
    params : pytree
        Trainable model parameters
    args : tuple
        (static, features, target_flat_train, tri_i_train, tri_j_train)
        
    Returns
    -------
    loss : float
        Mean squared error on training pairs
    """
    static, features, target_flat_train, tri_i_train, tri_j_train = args
    
    # Reconstruct model and predict resistance surface
    model = eqx.combine(params, static)
    model_vmapped = jax.vmap(jax.vmap(model, in_axes=0), in_axes=0)
    resistance = model_vmapped(features).squeeze()
    
    # Build graph and compute resistance distances
    grid = GridGraph(resistance, fun=lambda x, y: (x + y) / 2)
    predicted_distances = DISTANCE_FUN(grid, nodes=target_nodes)
    
    # Extract training pairs and compute loss
    pred_flat_train = predicted_distances[tri_i_train, tri_j_train]
    return ((target_flat_train - pred_flat_train) ** 2).mean()

print("Loss function defined and JIT-compiled")

### Prepare Training and Test Sets

We split pairwise distances (upper triangle of distance matrix) into 80% training and 20% test sets. This allows us to evaluate whether the model learns generalizable resistance patterns rather than overfitting to the training data.

In [None]:
# Extract upper triangle indices (all unique pairs)
n_sites = genetic_distances.shape[0]
tri_i_all, tri_j_all = np.triu_indices(n_sites, k=1)
target_flat_all = np.asarray(genetic_distances)[tri_i_all, tri_j_all]

print(f"Total pairwise distances: {len(target_flat_all)}")

# Train/test split (80/20)
(
    target_flat_train,
    target_flat_test,
    tri_i_train,
    tri_i_test,
    tri_j_train,
    tri_j_test,
) = train_test_split(
    target_flat_all,
    tri_i_all,
    tri_j_all,
    test_size=0.2,
    random_state=42,
)

# Convert to JAX arrays for training
tri_i_train = jnp.array(tri_i_train)
tri_j_train = jnp.array(tri_j_train)
tri_i_test = np.array(tri_i_test)
tri_j_test = np.array(tri_j_test)

print(f"Training pairs: {len(target_flat_train)}")
print(f"Test pairs: {len(target_flat_test)}")

# Sanity check: compute initial loss
initial_loss = loss_fn(
    params, 
    (static, features_onehot, target_flat_train, tri_i_train, tri_j_train)
)
print(f"\nInitial training loss: {initial_loss:.6f}")

## Train the Model

We use **L-BFGS** optimization (limited-memory Broyden-Fletcher-Goldfarb-Shanno), a quasi-Newton method that's particularly efficient for small-to-medium parameter spaces. L-BFGS uses gradient information to approximate the Hessian matrix and identify optimal parameter updates without storing the full Hessian. This approach offers several advantages: it achieves faster convergence than stochastic gradient descent for smooth objectives, provides a memory-efficient approximation of second-order information, and is well-suited to the differentiable landscape models we're working with. The optimization typically takes several minutes depending on graph size and the number of iterations required for convergence.

In [None]:
# Configure L-BFGS optimizer
solver = optx.LBFGS(
    rtol=1e-5,  # Relative tolerance for convergence
    atol=1e-5,  # Absolute tolerance
    verbose=frozenset({"loss"})  # Print loss during optimization
)

print("Starting optimization...\n")
print("="*60)

start_train_time = time.time()

# Run optimization
opt_solution = optx.minimise(
    loss_fn,
    solver,
    params,
    args=(static, features_onehot, target_flat_train, tri_i_train, tri_j_train),
    max_steps=MAX_STEPS,
)

training_time = time.time() - start_train_time

print("="*60)
print(f"\n✓ Training completed in {training_time:.2f} seconds")
print(f"Final loss: {opt_solution.value:.6f}")
print(f"\nOptimization statistics:")
print(opt_solution.stats)

### Visualize Fitted Resistance Surface

After training, we can visualize the learned resistance patterns across the landscape. High-resistance areas (shown in red) represent barriers that impede gene flow, while low-resistance areas (shown in green) facilitate connectivity between populations. When interpreting these patterns, compare the resistance surface to the original land-cover map to understand which specific habitat types facilitate or impede movement. The learned resistance values should reflect ecological realism—for example, showing roads or urban areas as barriers—and the spatial patterns should align with what's known about the species' ecology and dispersal behavior.

In [None]:
# Apply fitted model to landscape
fitted_model = eqx.combine(opt_solution.value, static)
fitted_vmapped = jax.vmap(jax.vmap(fitted_model, in_axes=0), in_axes=0)
fitted_resistance = fitted_vmapped(features_onehot).squeeze()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))

# Initial (random) resistance
im1 = ax1.imshow(initial_resistance, cmap="RdYlGn_r")
ax1.set_title("Initial Resistance (Random)", fontsize=13, pad=15)
ax1.axis("off")
plt.colorbar(im1, ax=ax1, shrink=0.6, label="Resistance")

# Fitted resistance
im2 = ax2.imshow(fitted_resistance, cmap="RdYlGn_r")
ax2.set_title("Fitted Resistance (Trained)", fontsize=13, pad=15)
ax2.axis("off")
plt.colorbar(im2, ax=ax2, shrink=0.6, label="Resistance")

plt.tight_layout()
plt.show()

print(f"Fitted resistance range: [{fitted_resistance.min():.3f}, {fitted_resistance.max():.3f}]")
print(f"Mean resistance: {fitted_resistance.mean():.3f}")

## Evaluate Model Performance

We assess model quality using two complementary metrics. **R² (coefficient of determination)** measures the proportion of variance in genetic distances explained by the model, with 1.0 indicating a perfect fit. **RMSE (root mean squared error)** quantifies the average prediction error in the original measurement units, providing an interpretable scale of accuracy. By reporting these metrics separately for training and test sets, we can detect overfitting—where the model memorizes training data rather than learning generalizable patterns. Good generalization is indicated by similar performance on both sets, suggesting the learned resistance patterns capture true landscape-genetic relationships rather than noise.

In [None]:
# Compute predicted distances using fitted resistance
pred_grid = GridGraph(fitted_resistance, fun=lambda x, y: (x + y) / 2)
pred_distances = DISTANCE_FUN(pred_grid, nodes=target_nodes)

genetic_np = np.asarray(genetic_distances)
pred_np = np.asarray(pred_distances)

# Extract predictions for train and test pairs
train_pred = pred_np[tri_i_train, tri_j_train]
test_pred = pred_np[tri_i_test, tri_j_test]
train_target = target_flat_train
test_target = target_flat_test

# Compute metrics
r2_train = r2_score(train_target, train_pred)
r2_test = r2_score(test_target, test_pred)
rmse_train = np.sqrt(mean_squared_error(train_target, train_pred))
rmse_test = np.sqrt(mean_squared_error(test_target, test_pred))

print("="*60)
print("MODEL PERFORMANCE")
print("="*60)
print(f"\n{'Metric':<15} {'Training':<15} {'Test':<15}")
print("-"*45)
print(f"{'R²':<15} {r2_train:>14.3f} {r2_test:>14.3f}")
print(f"{'RMSE':<15} {rmse_train:>14.4f} {rmse_test:>14.4f}")
print("="*60)

# Check for overfitting
if r2_train - r2_test > 0.2:
    print("\n⚠️  Warning: Possible overfitting detected (train R² >> test R²)")
elif r2_test >= 0.5:
    print("\n✓ Good generalization: Model explains genetic patterns well")
else:
    print("\n⚠️  Moderate fit: Consider model refinement or more data")

### Predicted vs. Observed: Scatterplot

The scatterplot visualizes how well predicted resistance distances match observed genetic distances:

- **Points near diagonal**: Good predictions
- **Training vs. test**: Different colors show generalization
- **Systematic deviations**: May indicate model misspecification

An ideal model would have all points falling on the 1:1 line with similar scatter for train and test sets.

In [None]:
fig, ax = plt.subplots(figsize=(9, 8))

# Plot training and test predictions
ax.scatter(
    train_pred, train_target, 
    s=60, alpha=0.6, 
    edgecolor="none", 
    label="Training",
    c='#2E86AB'
)
ax.scatter(
    test_pred, test_target, 
    s=80, alpha=0.8, 
    edgecolor="black",
    linewidth=1,
    label="Test",
    c='#A23B72'
)

# 1:1 reference line
min_val = min(pred_np.min(), genetic_np.min())
max_val = max(pred_np.max(), genetic_np.max())
ax.plot(
    [min_val, max_val], [min_val, max_val], 
    "k--", linewidth=2, 
    alpha=0.5,
    label="Perfect fit (1:1)"
)

ax.set_xlabel("Predicted Resistance Distance", fontsize=12)
ax.set_ylabel("Observed Genetic Distance (Fst)", fontsize=12)
ax.set_title(
    "Inverse Landscape Model: Predicted vs. Observed", 
    fontsize=14, 
    pad=20
)

# Add metrics box
textstr = (
    f"Training\n"
    f"  R² = {r2_train:.3f}\n"
    f"  RMSE = {rmse_train:.4f}\n\n"
    f"Test\n"
    f"  R² = {r2_test:.3f}\n"
    f"  RMSE = {rmse_test:.4f}"
)
ax.text(
    0.05, 0.95,
    textstr,
    transform=ax.transAxes,
    fontsize=10,
    verticalalignment="top",
    bbox=dict(
        boxstyle="round", 
        facecolor="white", 
        alpha=0.9, 
        edgecolor="gray",
        linewidth=1.5
    ),
)

ax.legend(loc="lower right", fontsize=11, framealpha=0.9)
ax.grid(True, alpha=0.3, linestyle=':')
plt.tight_layout()
plt.show()

## Key Takeaways

### Methodological Insights

**Inverse modeling** enables data-driven inference of landscape resistance patterns directly from genetic data, reversing the traditional approach of assuming resistance values *a priori*. The power of **automatic differentiation** through JAX makes gradient-based optimization tractable even for complex spatial models with thousands of parameters. **Neural networks** provide flexible parameterization that can capture nonlinear relationships between land-cover and resistance while maintaining interpretability through the one-hot encoding of discrete habitat classes. Finally, **train/test splits** are essential for validating that learned resistance patterns represent generalizable landscape-genetic relationships rather than memorized training data.

### Practical Applications

This approach has direct utility for **conservation planning**, where identifying which landscape features most facilitate or impede gene flow guides restoration priorities. The fitted resistance surfaces support **corridor design** by revealing optimal pathways for habitat connectivity interventions. Repeating the analysis for different taxa yields **species-specific models** that quantify how different organisms respond to the same landscape, informing multi-species conservation strategies. The trained models also enable **scenario evaluation**, predicting how proposed land-use changes would affect genetic connectivity before interventions are implemented.

### Extensions

Several extensions could enhance the approach. Comparing **multiple distance metrics** (LCP, Resistance, RSP) would reveal whether genetic patterns better match single-path or multi-path connectivity models. Incorporating additional **environmental covariates** such as climate variables and topography beyond land-cover could improve predictions in heterogeneous landscapes. **Hierarchical models** that explicitly account for population structure and demographic history would separate landscape effects from other evolutionary processes. Finally, **uncertainty quantification** through bootstrap resampling or Bayesian approaches would provide confidence intervals on learned resistance values.

### Important Notes

Several practical considerations affect model performance. Loss values exceeding 0.002 often indicate poor convergence and may require adjusting learning rates or network architecture. The deliberately small hidden layer (16 units) helps prevent overfitting when training on limited genetic data typical of empirical studies. The coarsening factor represents a trade-off between spatial resolution and computational efficiency that should be tuned based on species dispersal scales. Most importantly, results should always be validated against independent ecological knowledge about species behavior and habitat use to ensure biological realism.