<a href="https://githubtocolab.com/geonextgis/cropengine/blob/main/docs/examples/Run parameter optimization (wofost).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>

Uncomment the following line to install the latest version of [cropengine](https://geonextgis.github.io/cropengine) if needed.

In [None]:
# !pip install -U cropengine

## Import libraries

In [None]:
import os
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from cropengine import WOFOSTCropSimulationBatchRunner
from cropengine.agromanagement import WOFOSTAgroEventBuilder
from cropengine.optimizer import WOFOSTOptimizer

from sklearn.metrics import mean_squared_error

## Instantiate batch crop simulation engine for WOFOST

In [None]:
# Define the model name
MODEL_NAME = "Wofost72_WLP_CWB"

# Define the csv path with 'id', 'latitude', and 'longitude'
locations_csv_path = "test_data/optimizer/location.csv"

# Initialize Engine
batch_runner = WOFOSTCropSimulationBatchRunner(
    model_name=MODEL_NAME,
    locations_csv_path=locations_csv_path,
    workspace_dir="test_output/optimizer_workspace",
)

## User inputs

In [None]:
# Crop Configuration
models = batch_runner.get_model_options()

crops = batch_runner.get_crop_options(MODEL_NAME)
CROP_NAME = "wheat"
varieties = batch_runner.get_variety_options(MODEL_NAME, CROP_NAME)
CROP_VARIETY = "Winter_wheat_103"

# Timing
crop_start_end = batch_runner.get_crop_start_end_options()
CAMPAIGN_START = "2019-09-01"
CROP_START = "2019-09-25"
CROP_START_TYPE = "sowing"
CROP_END_TYPE = "maturity"
CROP_END = None
CAMPAIGN_END = "2020-09-30"
MAX_DURATION = 365

## Create agromanagements with user inputs

In [None]:
agro_event_builder = WOFOSTAgroEventBuilder()

# Note: Use agro_event_builder.get_..._events_info() to see valid values if unsure
timed_events_info = agro_event_builder.get_timed_events_info()
state_events_info = agro_event_builder.get_state_events_info()

# Build timed events (irrigation)
irrigation_schedule = [
    {"event_date": "2020-03-20", "amount": 3.0, "efficiency": 0.7},  # stem elongation
    {"event_date": "2020-04-25", "amount": 2.5, "efficiency": 0.7},  # booting/heading
    {"event_date": "2020-05-20", "amount": 2.0, "efficiency": 0.7},  # flowering
]

irrigation_events = agro_event_builder.create_timed_events(
    signal_type="irrigate", events_list=irrigation_schedule
)

# Build state Events (fertilization based on DVS)
nitrogen_schedule = [
    {"threshold": 0.3, "N_amount": 40, "N_recovery": 0.7},  # early vegetative
    {"threshold": 0.6, "N_amount": 60, "N_recovery": 0.7},  # stem elongation
    {"threshold": 1.0, "N_amount": 40, "N_recovery": 0.7},  # heading
]

nitrogen_events = agro_event_builder.create_state_events(
    signal_type="apply_n",
    state_var="DVS",
    zero_condition="rising",
    events_list=nitrogen_schedule,
)

## Prepare batch system

In [None]:
batch_runner.prepare_batch_system(
    campaign_start=CAMPAIGN_START,
    campaign_end=CAMPAIGN_END,
    crop_start=CROP_START,
    crop_end=CROP_END,
    crop_name=CROP_NAME,
    variety_name=CROP_VARIETY,
    max_workers=5,
    crop_start_type=CROP_START_TYPE,
    crop_end_type=CROP_END_TYPE,
    max_duration=MAX_DURATION,
    timed_events=[irrigation_events],
    state_events=[nitrogen_events],
    force_update=False,
    force_param_update=True,
    crop_overrides=None,
    soil_overrides=None,
    site_overrides={"WAV": 10},  # Extra site params can be passed as overrides
)

## Run the simulation first

In [None]:
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()

## Plot the results (before optimization)

In [None]:
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])

# Variables to plot (exclude metadata columns)
vars_to_plot = [
    col
    for col in batch_results.columns
    if col not in ["point_id", "latitude", "longitude", "day"]
]

# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)

# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}

fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()

for i, var in enumerate(vars_to_plot):
    ax = axes[i]

    for pid in unique_points:
        df_sub = batch_results[batch_results["point_id"] == pid]

        sns.lineplot(
            x=df_sub["day"],
            y=df_sub[var],
            ax=ax,
            label=f"Point {pid}",
            color=color_map[pid],
        )

    ax.set_title(var)
    ax.legend()

# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
    axes[j].axis("off")

plt.tight_layout()
plt.show()

## Optimize phenology

In [None]:
# Load the observed phenology
obs_phenology_df = pd.read_csv("test_data/optimizer/phenology_observed.csv")
phenology_optimizer = WOFOSTOptimizer(
    runner=batch_runner, observed_data=obs_phenology_df
)


# Create the loss function for phenology
def loss_fn_phenology(sim_df, obs_df):
    # Process observed data
    phenology_obs = obs_df[["id", "flowering_doy", "maturity_doy"]]

    # Process simulated data
    sim_df["day"] = pd.to_datetime(sim_df["day"])
    flowering_dates = sim_df[sim_df["DVS"] == 1][["point_id", "day"]]
    flowering_dates["flowering_doy_sim"] = flowering_dates["day"].dt.day_of_year
    maturity_dates = sim_df[sim_df["DVS"] == 2][["point_id", "day"]]
    maturity_dates["maturity_doy_sim"] = maturity_dates["day"].dt.day_of_year
    phenology_sim = pd.merge(
        left=flowering_dates[["point_id", "flowering_doy_sim"]],
        right=maturity_dates[["point_id", "maturity_doy_sim"]],
        on="point_id",
        how="inner",
    )

    merged_df = pd.merge(
        left=phenology_obs, right=phenology_sim, left_on="id", right_on="point_id"
    )

    flowering_loss = np.sqrt(
        mean_squared_error(merged_df["flowering_doy"], merged_df["flowering_doy_sim"])
    )
    maturity_loss = np.sqrt(
        mean_squared_error(merged_df["maturity_doy"], merged_df["maturity_doy_sim"])
    )

    total_loss = np.round((flowering_loss + maturity_loss) / 2, 2)

    return total_loss


# Define the search space
def search_space(trial):
    return {
        "crop_params": {
            "TSUM1": trial.suggest_int("TSUM1", 100, 1200),
            "TSUM2": trial.suggest_int("TSUM2", 100, 1200),
        }
    }


study = phenology_optimizer.optimize(
    search_space,
    loss_fn_phenology,
    n_trials=100,
    n_workers=5,
    directions=["minimize"],
    sampler="TPE",
)

### Run the simulation with optimized parameters

In [None]:
# Run the simulation with optimized parameters
best_params = phenology_optimizer.get_best_params(study, search_space)

# Update the parameters in the workspace
batch_runner.update_parameters(crop_overrides=best_params["crop_params"])

# Run the simulations with updated parameters
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()

### Plot the simulation

In [None]:
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])

# Variables to plot (exclude metadata columns)
vars_to_plot = [
    col
    for col in batch_results.columns
    if col not in ["point_id", "latitude", "longitude", "day"]
]

# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)

# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}

fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()

for i, var in enumerate(vars_to_plot):
    ax = axes[i]

    for pid in unique_points:
        df_sub = batch_results[batch_results["point_id"] == pid]

        sns.lineplot(
            x=df_sub["day"],
            y=df_sub[var],
            ax=ax,
            label=f"Point {pid}",
            color=color_map[pid],
        )

    ax.set_title(var)
    ax.legend()

# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
    axes[j].axis("off")

plt.tight_layout()
plt.show()

## Optimize yield (TWSO)

In [None]:
# Load the observed yield
obs_yield_df = pd.read_csv("test_data/optimizer/yield_observed.csv")
yield_optimizer = WOFOSTOptimizer(runner=batch_runner, observed_data=obs_yield_df)


# Create the loss function for yield
def loss_fn_yield(sim_df, obs_df):
    # Process observed data
    obs_df = obs_df[["id", "yield"]].copy()
    obs_df["yield"] = obs_df["yield"] * 1000  # t/ha -> kg/ha

    # Process simulated data
    sim_df["day"] = pd.to_datetime(sim_df["day"])
    sim_df = sim_df[sim_df["DVS"] >= 2]
    sim_df = sim_df.groupby(by="point_id").first()
    sim_df = sim_df[["TWSO"]].reset_index()

    merged_df = pd.merge(left=obs_df, right=sim_df, left_on="id", right_on="point_id")

    yield_loss = np.sqrt(mean_squared_error(merged_df["yield"], merged_df["TWSO"]))

    return np.round(yield_loss, 2)


# Define the search space
def search_space(trial):
    # 1. Define a scaling factor for Photosynthesis (AMAX)
    # This allows Optuna to shift the entire curve up or down by +/- 20%
    amax_factor = trial.suggest_float("amax_factor", 0.8, 1.2)

    # 2. Define a scaling factor for Leaf Thickness (SLA)
    sla_factor = trial.suggest_float("sla_factor", 0.8, 1.2)

    return {
        "crop_params": {
            # LEAF DYNAMICS (Source capacity)
            # SPAN: Leaf lifespan. Higher = longer green canopy duration.
            "SPAN": trial.suggest_float("SPAN", 25.0, 40.0),
            # SLATB (Specific Leaf Area): Controls how much leaf area is built per kg biomass.
            # We scale the entire table by a factor (0.8x to 1.2x).
            "SLATB": [
                0.0,
                0.00212 * sla_factor,
                0.5,
                0.00212 * sla_factor,
                2.0,
                0.00212 * sla_factor,
            ],
            # ASSIMILATION & CONVERSION (Biomass production)
            # AMAXTB: Max CO2 assimilation rate. Highly sensitive.
            "AMAXTB": [
                0.0,
                35.83 * amax_factor,  # Vegetative
                1.0,
                35.83 * amax_factor,  # Flowering
                1.3,
                35.83 * amax_factor,  # Early Grain filling
                2.0,
                4.48 * amax_factor,  # Maturity (Senescence)
            ],
            # CVT: Efficiency of conversion to storage organs (Harvest Index driver).
            "CVO": trial.suggest_float("CVO", 0.65, 0.75),
            # ROOTING (Water access)
            # RDMCR: Max rooting depth. Critical for drought resistance.
            "RDMCR": trial.suggest_int("RDMCR", 80, 150),
        }
    }


study = yield_optimizer.optimize(
    search_space,
    loss_fn_yield,
    n_trials=1000,
    n_workers=5,
    directions=["minimize"],
    sampler="TPE",
)

### Run the simulation with optimized parameters

In [None]:
# Run the simulation with optimized parameters
best_params = yield_optimizer.get_best_params(study, search_space)

# Update the parameters in the workspace
batch_runner.update_parameters(crop_overrides=best_params["crop_params"])

# Run the simulations with updated parameters
results = batch_runner.run_batch_simulation(max_workers=5)
print(results.shape)
results.head()

### Plot the simulation

In [None]:
# Ensure 'day' is datetime
batch_results = results.copy()
batch_results["day"] = pd.to_datetime(batch_results["day"])

# Variables to plot (exclude metadata columns)
vars_to_plot = [
    col
    for col in batch_results.columns
    if col not in ["point_id", "latitude", "longitude", "day"]
]

# Layout
cols = 2
rows = math.ceil(len(vars_to_plot) / cols)

# Colors for point_id groups
unique_points = batch_results["point_id"].unique()
palette = sns.color_palette("tab10", len(unique_points))
color_map = {pid: palette[i] for i, pid in enumerate(unique_points)}

fig, axes = plt.subplots(rows, cols, figsize=(14, 3 * rows), sharex=True)
axes = axes.flatten()

for i, var in enumerate(vars_to_plot):
    ax = axes[i]

    for pid in unique_points:
        df_sub = batch_results[batch_results["point_id"] == pid]

        sns.lineplot(
            x=df_sub["day"],
            y=df_sub[var],
            ax=ax,
            label=f"Point {pid}",
            color=color_map[pid],
        )

    ax.set_title(var)
    ax.legend()

# Hide remaining empty subplots
for j in range(len(vars_to_plot), len(axes)):
    axes[j].axis("off")

plt.tight_layout()
plt.show()