# Inverse landscape genetics

Traditional landscape genetics assumes *a priori* which habitat features facilitate or impede movement. Inverse landscape genetics reverses this approach, inferring permeability patterns directly from genetic data. This notebook demonstrates fitting a neural network to predict landscape resistance from land-cover features, optimizing the model to match observed genetic differentiation (Fst) between sampling sites.

We use genetic data from the [Mountain Pygmy-possum](https://en.wikipedia.org/wiki/Mountain_pygmy_possum) (*Burramys parvus*), an endangered marsupial endemic to alpine regions of southeastern Australia, kindly provided by [Cesar Australia](https://cesaraustralia.com). We'll use land-cover data to model landscape permeability and fit the model to observed genetic distances, obtained from ESA WorldCover.
<!-- TODO: update link -->

<div align="center">
  <img src="https://www.australiangeographic.com.au/wp-content/uploads/2018/06/Pygmy-Possum_Amanda-McLean-Copy-1.jpg" width="450">
</div>

### Prerequisites
```bash
pip install jaxscape equinox optimistix rioxarray geopandas scikit-learn
```

 We've prepared the data for you to focus on the modeling aspects; you can download the dataset [here](https://vboussange.github.io/jaxscape/data/inverse_landscape_genetics/). 

In [1]:
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import rioxarray
import xarray as xr
import geopandas as gpd
import equinox as eqx
from equinox import nn
import optimistix as optx

from jaxscape import GridGraph, ResistanceDistance, LCPDistance
from jaxscape.solvers import CholmodSolver
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import train_test_split

# Set random seed for reproducibility
np.random.seed(42)

ModuleNotFoundError: No module named 'rioxarray'

### Configuration parameters

In [None]:
# File paths (adjust to your data location)
LANDCOVER_PATH = '../data/cesar/landcover_7855.tif'
SITE_METADATA_PATH = '../data/cesar/cesar_site_metadata.gpkg'
GENETIC_DISTANCES_PATH = '../data/cesar/cesar_genetic_distances.npy'

# Model configuration
COARSENING_FACTOR = 10  # Spatial downsampling of feature raster to accelerate computation
SOLVER = CholmodSolver()  # Fast linear solver for large graphs
DISTANCE_FUN = ResistanceDistance(solver=SOLVER)  # Effective resistance distance
MAX_STEPS = 500  # Maximum optimization iterations

# Alternative distance metric for experimentation:
# DISTANCE_FUN = LCPDistance()  # Least-cost path distance

## Load the data

In [None]:
# Load land-cover raster
predictor_raster = rioxarray.open_rasterio(
    LANDCOVER_PATH, 
    mask_and_scale=True
)

# Load site metadata (geographic locations)
site_metadata = gpd.read_file(SITE_METADATA_PATH)
site_gdf = site_metadata.to_crs(epsg=7855)  # Reproject to Australian GDA2020

# Load genetic distance matrix
genetic_distances = np.load(GENETIC_DISTANCES_PATH)

print(f"Land-cover raster shape: {predictor_raster['band_1'].shape}")
print(f"Number of sites: {len(site_gdf)}")
print(f"Genetic distance matrix shape: {genetic_distances.shape}")

In [None]:
# Visualize land-cover with sampling sites
fig, ax = plt.subplots(figsize=(12, 10))
predictor_raster["band_1"].plot(ax=ax, cmap="tab20", add_colorbar=True)
ax.scatter(
    site_gdf.geometry.x, 
    site_gdf.geometry.y, 
    c="red", 
    s=100, 
    edgecolor='white',
    linewidth=2,
    label="Sampling Sites",
    zorder=10
)
ax.set_title("Land-Cover Map with Sampling Sites", fontsize=14, pad=20)
ax.legend(loc='upper right', fontsize=12)
plt.tight_layout()
plt.show()

print(f"Sites: {', '.join(site_gdf['site_name'].values)}")

## Prepare features and targets

The land-cover raster requires preprocessing before feeding it to the neural network. We first compress the sparse WorldCover class IDs into a contiguous range (0..K-1), then one-hot encode them into binary feature vectors. Spatial coarsening via mean pooling downsamples the raster while preserving land-cover composition within each aggregated cell, capturing habitat heterogeneity without forcing discrete classifications. Finally, we map each sampling site to its nearest grid cell in the coarsened raster. 

In [None]:
def prepare_feature_targets(predictor_raster, site_gdf, coarsening_factor):
    """Process land-cover and create model inputs.
    
    Returns
    -------
    features_onehot_coarse : array
        One-hot encoded land-cover features after coarsening (H, W, K)
    unique_classes : array
        Original WorldCover class values
    target_nodes : array
        Node indices for sampling sites
    grid : GridGraph
        Reference grid for node indexing
    feature_da : xarray.DataArray
        Coarsened feature raster with coordinates
    """
    # Compress WorldCover classes to contiguous IDs
    raw_band = np.asarray(predictor_raster["band_1"])
    unique_vals, inverse = np.unique(raw_band.ravel(), return_inverse=True)
    class_ids = inverse.reshape(raw_band.shape).astype(np.int32)
    features_categorical = jnp.array(class_ids).squeeze()
    unique_classes = jnp.array(unique_vals)
    
    print(f"Found {len(unique_vals)} unique land-cover classes")
    
    # One-hot encode: (H, W) -> (H, W, K)
    features_onehot = jax.nn.one_hot(
        features_categorical, 
        num_classes=len(unique_vals)
    )
    
    # Reorder for coarsening: (H, W, K) -> (K, H, W)
    features_onehot = jnp.moveaxis(features_onehot, -1, 0)
    
    # Coarsen using mean pooling (preserves class composition)
    coords = {
        "band": np.arange(features_onehot.shape[0]),
        "y": predictor_raster.y.values,
        "x": predictor_raster.x.values,
    }
    feature_da = xr.DataArray(
        features_onehot,
        coords=coords,
        dims=("band", "y", "x"),
    )
    feature_da = feature_da.coarsen(
        x=coarsening_factor, 
        y=coarsening_factor, 
        boundary="trim"
    ).mean()
    
    # Back to (H, W, K)
    features_onehot_coarse = jnp.moveaxis(feature_da.data, 0, -1)
    print(f"Coarsened feature shape: {features_onehot_coarse.shape}")
    
    # Map site coordinates to coarsened grid indices
    x_idx = jnp.array([
        int(np.argmin(np.abs(feature_da.x.values - x))) 
        for x in site_gdf.geometry.x.values
    ])
    y_idx = jnp.array([
        int(np.argmin(np.abs(feature_da.y.values - y))) 
        for y in site_gdf.geometry.y.values
    ])
    
    # Create reference grid for node indexing
    grid = GridGraph(
        jnp.ones((feature_da.x.size, feature_da.y.size)), 
        fun=lambda x, y: (x + y) / 2
    )
    target_nodes = grid.coord_to_index(x_idx, y_idx)
    
    return features_onehot_coarse, unique_classes, target_nodes, grid, feature_da

# Process features
features_onehot, unique_classes, target_nodes, ref_grid, coarse_feature_da = prepare_feature_targets(
    predictor_raster, site_gdf, COARSENING_FACTOR
)

print(f"Target nodes (site indices): {target_nodes}")

In [None]:
# Visualize coarsened land-cover with sites
node_coords = ref_grid.index_to_coord(target_nodes)
x_indices, y_indices = node_coords[:, 0], node_coords[:, 1]

fig, ax = plt.subplots(figsize=(12, 10))
ax.imshow(features_onehot.argmax(axis=-1), cmap="tab20")
ax.scatter(
    x_indices, y_indices, 
    c="blue", 
    s=150, 
    edgecolor='white',
    linewidth=2,
    label="Sites",
    zorder=10
)

# Annotate sites
for xi, yi, name in zip(x_indices, y_indices, site_gdf["site_name"].values):
    ax.text(
        int(xi), int(yi) - 5,
        str(name),
        color="black",
        fontsize=10,
        fontweight='bold',
        ha="center",
        va="bottom",
        bbox=dict(facecolor="white", alpha=0.9, edgecolor="black", pad=2),
    )

ax.set_title("Coarsened Land-Cover with Site Locations", fontsize=14, pad=20)
ax.legend(loc='upper right', fontsize=12)
ax.axis('off')
plt.tight_layout()
plt.show()

## Define the permeability model

We build a neural network mapping one-hot land-cover features to positive permeability values: $\text{permeability} = \exp(\text{NN}(\text{features})) + \epsilon$. The architecture takes a K-dimensional one-hot vector (land-cover classes), passes it through two hidden layers with ReLU activation (16 units each), and outputs a single value that is exponentiated to ensure positivity. We apply this model pixel-wise via `vmap` to generate the full permeability surface.

In [None]:
def build_model(num_classes: int, seed: int = 1) -> tuple:
    """Build neural permeability model.
    
    Parameters
    ----------
    num_classes : int
        Number of land-cover classes
    seed : int
        Random seed for initialization
        
    Returns
    -------
    model : eqx.Module
        Complete model
    params : pytree
        Trainable parameters
    static : pytree
        Static (non-trainable) components
    """
    key = jax.random.PRNGKey(seed)
    
    class PermeabilityModel(eqx.Module):
        layers: list
        num_classes: int
        
        def __init__(self, num_classes: int, key):
            self.num_classes = num_classes
            k1, k2, k3 = jax.random.split(key, 3)
            hidden_dim = 16  # Small network to prevent overfitting
            
            self.layers = [
                nn.Linear(num_classes, hidden_dim, key=k1),
                jax.nn.relu,
                nn.Linear(hidden_dim, hidden_dim, key=k2),
                jax.nn.relu,
                nn.Linear(hidden_dim, 1, key=k3),
            ]
        
        def __call__(self, x):
            """Map one-hot feature to positive permeability."""
            for layer in self.layers:
                x = layer(x)
            return jnp.exp(x) + 1e-3  # Ensure positive permeability
    
    model = PermeabilityModel(num_classes, key)
    params, static = eqx.partition(model, eqx.is_inexact_array)
    return model, params, static

# Initialize model
model, params, static = build_model(len(unique_classes))
print(f"Model initialized with {len(unique_classes)} land-cover classes")
print(f"Trainable parameters: {sum(p.size for p in jax.tree_util.tree_leaves(params))}")

Before training, the model produces random permeability values based on the initialization.

In [None]:
# Apply model to all pixels via vmap
model_vmapped = jax.vmap(jax.vmap(model, in_axes=0), in_axes=0)
initial_permeability = model_vmapped(features_onehot).squeeze()

fig, ax = plt.subplots(figsize=(12, 10))
im = ax.imshow(initial_permeability, cmap="RdYlGn_r")
ax.set_title("Initial permeability Prediction (Random)", fontsize=14, pad=20)
ax.axis("off")
plt.colorbar(im, ax=ax, label="Permeability", shrink=0.6)
plt.tight_layout()
plt.show()

## Define loss function and training setup

We minimize mean squared error (MSE) between predicted resistance distances and observed genetic distances. The training loop predicts pixel-wise permeability, constructs a GridGraph, computes pairwise resistance distances between sampling sites, and backpropagates gradients through the entire pipeline using L-BFGS optimization.

We split pairwise distances (upper triangle of the matrix) into 80% training and 20% test sets to evaluate generalization. This train/test split reveals whether the model learns meaningful permeability patterns or simply memorizes training data.

In [None]:
# Extract upper triangle indices (all unique pairs)
n_sites = genetic_distances.shape[0]
tri_i_all, tri_j_all = np.triu_indices(n_sites, k=1)
target_flat_all = np.asarray(genetic_distances)[tri_i_all, tri_j_all]

print(f"Total pairwise distances: {len(target_flat_all)}")

# Train/test split (80/20)
(
    target_flat_train,
    target_flat_test,
    tri_i_train,
    tri_i_test,
    tri_j_train,
    tri_j_test,
) = train_test_split(
    target_flat_all,
    tri_i_all,
    tri_j_all,
    test_size=0.2,
    random_state=42,
)

# Convert to JAX arrays for training
tri_i_train = jnp.array(tri_i_train)
tri_j_train = jnp.array(tri_j_train)
tri_i_test = np.array(tri_i_test)
tri_j_test = np.array(tri_j_test)

print(f"Training pairs: {len(target_flat_train)}")
print(f"Test pairs: {len(target_flat_test)}")

# Sanity check: compute initial loss
initial_loss = loss_fn(
    params, 
    (static, features_onehot, target_flat_train, tri_i_train, tri_j_train)
)
print(f"\nInitial training loss: {initial_loss:.6f}")

In [None]:
@eqx.filter_jit
def loss_fn(params, args):
    """Compute squared error between predicted and target distances.
    
    Parameters
    ----------
    params : pytree
        Trainable model parameters
    args : tuple
        (static, features, target_flat_train, tri_i_train, tri_j_train)
        
    Returns
    -------
    loss : float
        Mean squared error on training pairs
    """
    static, features, target_flat_train, tri_i_train, tri_j_train = args
    
    # Reconstruct model and predict permeability surface
    model = eqx.combine(params, static)
    model_vmapped = jax.vmap(jax.vmap(model, in_axes=0), in_axes=0)
    permeability = model_vmapped(features).squeeze()
    
    # Build graph and compute permeability distances
    grid = GridGraph(permeability, fun=lambda x, y: (x + y) / 2)
    predicted_distances = DISTANCE_FUN(grid, nodes=target_nodes)
    
    # Extract training pairs and compute loss
    pred_flat_train = predicted_distances[tri_i_train, tri_j_train]
    return ((target_flat_train - pred_flat_train) ** 2).mean()

print("Loss function defined and JIT-compiled")

## Training
The optimization typically takes several minutes depending on graph size and the number of iterations required for convergence.

In [None]:
# Configure L-BFGS optimizer
solver = optx.LBFGS(
    rtol=1e-5,  # Relative tolerance for convergence
    atol=1e-5,  # Absolute tolerance
    verbose=frozenset({"loss"})  # Print loss during optimization
)

print("Starting optimization...\n")
print("="*60)

start_train_time = time.time()

# Run optimization
opt_solution = optx.minimise(
    loss_fn,
    solver,
    params,
    args=(static, features_onehot, target_flat_train, tri_i_train, tri_j_train),
    max_steps=MAX_STEPS,
)

training_time = time.time() - start_train_time

print("="*60)
print(f"\n✓ Training completed in {training_time:.2f} seconds")
print(f"Final loss: {opt_solution.value:.6f}")
print(f"\nOptimization statistics:")
print(opt_solution.stats)

### Visualize the learned permeability surface


In [None]:
# Compute predicted distances using fitted permeability
pred_grid = GridGraph(fitted_permeability, fun=lambda x, y: (x + y) / 2)
pred_distances = DISTANCE_FUN(pred_grid, nodes=target_nodes)

genetic_np = np.asarray(genetic_distances)
pred_np = np.asarray(pred_distances)

# Extract predictions for train and test pairs
train_pred = pred_np[tri_i_train, tri_j_train]
test_pred = pred_np[tri_i_test, tri_j_test]
train_target = target_flat_train
test_target = target_flat_test

# Compute metrics
r2_train = r2_score(train_target, train_pred)
r2_test = r2_score(test_target, test_pred)
rmse_train = np.sqrt(mean_squared_error(train_target, train_pred))
rmse_test = np.sqrt(mean_squared_error(test_target, test_pred))

fig, ax = plt.subplots(figsize=(9, 8))

# Plot training and test predictions
ax.scatter(
    train_pred, train_target, 
    s=60, alpha=0.6, 
    edgecolor="none", 
    label="Training",
    c='#2E86AB'
)
ax.scatter(
    test_pred, test_target, 
    s=80, alpha=0.8, 
    edgecolor="black",
    linewidth=1,
    label="Test",
    c='#A23B72'
)

# 1:1 reference line
min_val = min(pred_np.min(), genetic_np.min())
max_val = max(pred_np.max(), genetic_np.max())
ax.plot(
    [min_val, max_val], [min_val, max_val], 
    "k--", linewidth=2, 
    alpha=0.5,
    label="Perfect fit (1:1)"
)

ax.set_xlabel("Predicted Resistance Distance", fontsize=12)
ax.set_ylabel("Observed Genetic Distance (Fst)", fontsize=12)
ax.set_title(
    "Inverse Landscape Model: Predicted vs. Observed", 
    fontsize=14, 
    pad=20
)

# Add metrics box
textstr = (
    f"Training\n"
    f"  R² = {r2_train:.3f}\n"
    f"  RMSE = {rmse_train:.4f}\n\n"
    f"Test\n"
    f"  R² = {r2_test:.3f}\n"
    f"  RMSE = {rmse_test:.4f}"
)
ax.text(
    0.05, 0.95,
    textstr,
    transform=ax.transAxes,
    fontsize=10,
    verticalalignment="top",
    bbox=dict(
        boxstyle="round", 
        facecolor="white", 
        alpha=0.9, 
        edgecolor="gray",
        linewidth=1.5
    ),
)

ax.legend(loc="lower right", fontsize=11, framealpha=0.9)
ax.grid(True, alpha=0.3, linestyle=':')
plt.tight_layout()
plt.show()

Our model seems to perform reasonably well, capturing a significant portion of the variance in genetic distances. 
We can now visualize the learned permeability patterns across the landscape. Remember that this is to be interpreted with great caution, as the inferred permeability surface may reflect complex interactions and correlations in the data rather than direct causal relationships.

In [None]:
# Apply fitted model to landscape
fitted_model = eqx.combine(opt_solution.value, static)
fitted_vmapped = jax.vmap(jax.vmap(fitted_model, in_axes=0), in_axes=0)
fitted_permeability = fitted_vmapped(features_onehot).squeeze()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))

# Initial (random) permeability
im1 = ax1.imshow(initial_permeability, cmap="RdYlGn_r")
ax1.set_title("Initial permeability surface", fontsize=13, pad=15)
ax1.axis("off")
plt.colorbar(im1, ax=ax1, shrink=0.6, label="Permeability")

# Fitted permeability
im2 = ax2.imshow(fitted_permeability, cmap="RdYlGn_r")
ax2.set_title("Fitted permeability surface (Trained)", fontsize=13, pad=15)
ax2.axis("off")
plt.colorbar(im2, ax=ax2, shrink=0.6, label="Permeability")

plt.tight_layout()
plt.show()

print(f"Fitted permeability range: [{fitted_permeability.min():.3f}, {fitted_permeability.max():.3f}]")
print(f"Mean permeability: {fitted_permeability.mean():.3f}")

## Key takeaways

This notebook demonstrated inverse landscape genetics: learning resistance patterns from genetic data rather than assuming them *a priori*. JAX's automatic differentiation enables gradient-based optimization, allowing to train neural networks to map landscape features to a permeability surface. Neural networks provide flexible parameterization capturing nonlinear landscape feature–permeability relationships. It would be interesting to compare this learned resistance surface with expert knowledge about the species' ecology and known barriers to movement in the landscape. We could also assess the predictive performance of e.g. the `LCPDistance`. But this goes beyond the scope of this notebook.