<a href="https://githubtocolab.com/geonextgis/cropengine/blob/main/docs/examples/Run data assimilation (wofost).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>

Uncomment the following line to install the latest version of [cropengine](https://geonextgis.github.io/cropengine) if needed.


In [None]:
# !pip install -U cropengine

## Import libraries


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from cropengine import WOFOSTCropSimulationRunner
from cropengine.assimilation import WOFOSTEnKF

## Instantiate crop simulation engine for WOFOST


In [None]:
# Define the model name
MODEL_NAME = "Wofost72_WLP_CWB"

# Initialize Engine
runner = WOFOSTCropSimulationRunner(
    model_name=MODEL_NAME, workspace_dir="test_output/assimilation_workspace"
)

In [None]:
# Location
LATITUDE = 53.3721
LONGITUDE = 13.82299

# Crop Configuration
# Note: Use runner.get_..._options() to see valid values if unsure
models = runner.get_model_options()
crops = runner.get_crop_options(MODEL_NAME)
CROP_NAME = "sugarbeet"
varieties = runner.get_variety_options(MODEL_NAME, CROP_NAME)
CROP_VARIETY = "Sugarbeet_601"

# Timing
crop_start_end = runner.get_crop_start_end_options()
CAMPAIGN_START = "2006-01-01"
CROP_START = "2006-04-05"
CROP_START_TYPE = "emergence"
CROP_END_TYPE = "harvest"
CROP_END = "2006-10-20"
CAMPAIGN_END = "2007-01-01"
MAX_DURATION = 300

## Prepare system (must be implemented before running the simulation)


In [None]:
runner.prepare_system(
    latitude=LATITUDE,
    longitude=LONGITUDE,
    campaign_start=CAMPAIGN_START,
    campaign_end=CAMPAIGN_END,
    crop_start=CROP_START,
    crop_start_type=CROP_START_TYPE,
    crop_end_type=CROP_END_TYPE,
    crop_end=CROP_END,
    max_duration=MAX_DURATION,
    crop_name=CROP_NAME,
    variety_name=CROP_VARIETY,
    force_update=False,
    force_param_update=True,
    site_overrides={"WAV": 10},  # Extra site params can be passed as overrides
)

## Read the observations for LAI and SM


In [None]:
# Prepare a demo observation data
obs_df = pd.DataFrame(
    {
        "date": pd.to_datetime(["2006-06-01", "2006-07-01", "2006-08-01"]),
        "LAI": [1.5, 4.2, 3.8],
        "SM": [0.2, 0.25, 0.22],
    }
)

# Define uncertainty in the observation
uncertainty = {"LAI": 0.2, "SM": 0.05}

# Deefine the state variables to track
states = ["LAI", "SM", "TWST", "TWSO"]

## Initialize and run data assimilation using Ensemble Kalman Filter


In [None]:
# Initialize WOFOSTEnKF
enkf = WOFOSTEnKF(runner, ensemble_size=50)

# Setup ensemble
# Here we perturb the leaf lifespan (SPAN) and temperature sum (TSUM1)
# to create variety in the ensemble growth curves.
enkf.setup_ensemble(param_std={"TDWI": 0.5, "WAV": 10, "SPAN": 5.0, "SMFCF": 0.03})

# Run assimilation
results = enkf.run_assimilation(
    observations_df=obs_df, observation_std=uncertainty, state_vars=states
)

print(results.shape)
results.head()

## Summarize the ensemble


In [None]:
summary = results.groupby("day").agg(
    {
        "LAI": ["mean", "std"],
        "SM": ["mean", "std"],
        "TWST": ["mean", "std"],
        "TWSO": ["mean", "std"],
    }
)

print(summary.shape)
summary.tail()

## Plot the data


In [None]:
# Simulation summary (your EnKF output)
df = summary.copy()
df.dropna(inplace=True)
df.index = pd.to_datetime(df.index)

# Observations
obs = obs_df.copy()
obs["date"] = pd.to_datetime(obs["date"])

state_vars = ["LAI", "SM", "TWST", "TWSO"]

fig, axes = plt.subplots(2, 2, figsize=(14, 8), sharex=True)
axes = axes.flatten()

for i, var in enumerate(state_vars):
    mean = df[(var, "mean")]
    std = df[(var, "std")]

    ax = axes[i]

    # Ensemble mean ± std
    ax.plot(df.index, mean, label=f"Model {var} (mean)")
    ax.fill_between(
        df.index, mean - std, mean + std, alpha=0.3, label=f"Model {var} ± std"
    )

    # Observations (if available)
    if var in obs.columns:
        ax.errorbar(obs["date"], obs[var], fmt="o", capsize=4, label=f"Observed {var}")

    ax.set_title(f"{var} assimilation (EnKF)")
    ax.set_ylabel(var)
    ax.grid(alpha=0.3)
    ax.legend(fontsize=8)

# Common x-label
for ax in axes[2:]:
    ax.set_xlabel("Date")

plt.tight_layout()
plt.show()