#### Run EnKF on all the parcels

In [2]:
import fpcup

In [3]:
from pcse.fileinput import YAMLCropDataProvider
from pcse.base import ParameterProvider
from pcse.util import WOFOST71SiteDataProvider
from pcse.models import Wofost72_WLP_FD

In [4]:
import copy
from pathlib import Path
import datetime as dt
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import os
import re
import rasterio
from rasterio.transform import rowcol
from pyproj import Transformer
from datetime import datetime
from tqdm import tqdm
import matplotlib.dates as mdates

In [5]:
def set_wofost_up(crop="maize", variety="Grain_maize_201",
                meteo="./data/meteo/csv_meteo_mazie_new/weather_52.6574_6.8533_2022.csv", 
                soil="ec1",wav=60, rdmsol=100):
    cropdata = YAMLCropDataProvider()
    cropdata.set_active_crop(crop, variety)
    soildata = fpcup.soil.soil_types[soil]
    soildata["RDMSOL"] = rdmsol
    sitedata = WOFOST71SiteDataProvider(WAV=wav)
    parameters = ParameterProvider(cropdata=cropdata,
                                   soildata=soildata, sitedata=sitedata)
    crop = fpcup.crop.crops["maize (green)"]
    agromanagement = crop.agromanagement_first_sowingdate(2022)
    wdp = fpcup.weather.load_example_csv(meteo)
    return parameters, agromanagement, wdp

In [6]:
#Sentinel-2 observations 

tif_folder = "G:/S2_bio" 

def get_transformer(crs):
    return Transformer.from_crs("EPSG:4326", crs, always_xy=True)

def extract_lai_with_uncertainty(coordinates, date_start, date_end, tif_folder,window_size=5):
    if isinstance(coordinates, tuple):
        coordinates = [coordinates]
        
    data_per_point = {coord: [] for coord in coordinates}
    
    for fname in os.listdir(tif_folder):
        if not fname.lower().endswith(".tif"):
            continue
        match = re.search(r"_([0-9]{8})T[0-9]{6}_", fname)
        if not match:
            continue
        date_str = match.group(1)
        try:
            date = datetime.strptime(date_str, "%Y%m%d")
        except:
            continue
        if not (date_start <= date <= date_end):
            continue

        fpath = os.path.join(tif_folder, fname)
        with rasterio.open(fpath) as src:
            transformer = get_transformer(src.crs)
            band1 = src.read(1)

            for lon, lat in coordinates:
                x, y = transformer.transform(lon, lat)
                row, col = rowcol(src.transform, x, y)

                if not (0 <= row < band1.shape[0] and 0 <= col < band1.shape[1]):
                    continue

                half_window = window_size // 2
                row_min = max(0, row - half_window)
                row_max = min(band1.shape[0], row + half_window + 1)
                col_min = max(0, col - half_window)
                col_max = min(band1.shape[1], col + half_window + 1)

                window_values = band1[row_min:row_max, col_min:col_max]
                valid_values = window_values[window_values != src.nodata]

                if len(valid_values) > 0:
                    lai_mean = np.mean(valid_values)
                    lai_std_spatial =np.std(valid_values)
                    lai_std_algorithm = 0.9  # retrivel error from SNAP
                    lai_std_total = np.sqrt(0.7*lai_std_spatial**2 + 0.3*lai_std_algorithm**2)
    
                    data_per_point[(lon, lat)].append((date, lai_mean, lai_std_total))

    return data_per_point

In [7]:

weather_folder = "./data/meteo/ensemble/maize/maize_25"
weather_files = [f for f in os.listdir(weather_folder) if f.endswith(".csv")]

coordinates = []
site_name_map = {} 
for fname in weather_files:
    match = re.search(r"weather_([0-9.]+)_([0-9.]+)_\d{4}\.csv", fname)
    if match:
        lat = float(match.group(1))
        lon = float(match.group(2))
        coordinates.append((lon, lat))  
        site_name_map[(lon, lat)] = fname.replace(".csv", "")

date_start = datetime(2022, 5, 1)
date_end = datetime(2022, 10, 30)
obs_period = (date_start, date_end)

lai_data = extract_lai_with_uncertainty(
    coordinates=coordinates,
    date_start=date_start,
    date_end=date_end,
    tif_folder=tif_folder,
    window_size=5
)


In [8]:
len(lai_data)

25

In [9]:
class WOFOSTEnKF(object):
    def __init__(self, assimilation_variables, override_parameters,
                n_ensemble, observations,
                lai_unc=0.1, sm_unc=0.25):
        self.n_ensemble = n_ensemble
        self.assimilation_variables = assimilation_variables
        self.override_parameters = override_parameters
        self.lai_unc = lai_unc
        self.sm_unc = sm_unc
        self.observations = observations

    def setup_wofost(self, crop="maize", variety="Grain_maize_201",
                     meteo="./data/meteo/csv_meteo_mazie_new/weather_52.6501_6.4941_2022.csv",
                     soil="ec1", wav=60, rdmsol=100):
        self.parameters, self.agromanagement, self.weather_db = set_wofost_up(
            crop=crop, variety=variety,
            meteo=meteo, soil=soil, wav=wav, rdmsol=rdmsol)
        self._setup_ensemble()

    def _setup_ensemble(self):
        self.ensemble = []
        for i in range(self.n_ensemble):
            p = copy.deepcopy(self.parameters)
            for par, distr in self.override_parameters.items():
                p.set_override(par, distr[i])
            member = Wofost72_WLP_FD(p, self.weather_db, self.agromanagement)
            self.ensemble.append(member)

    def run_filter(self):
        if len(self.assimilation_variables) > 0:
            for obs_date, obs in tqdm(self.observations):
                
                ensemble_state = self._run_wofost_gather_sates(obs_date)
                if ensemble_state.isnull().values.any():
                    continue  
                P = np.array(ensemble_state.cov().values)
                ensemble_obs = self._observations_ensemble(obs)
                R = np.array(ensemble_obs.cov().values)
                xx = [obs[x] for x in self.assimilation_variables]
                obs = xx
                K = self.kalman_gain(obs, P, R)
                x = np.array(ensemble_state.values).T
                y = np.array(ensemble_obs.values).T
                x_opt = x + K @ (y - x)
                df_analysis = pd.DataFrame(x_opt.T,
                    columns=self.assimilation_variables)
                for member, new_states in zip(self.ensemble,
                                            df_analysis.itertuples()):
                    if "LAI" in self.assimilation_variables:
                        member.set_variable("LAI", new_states.LAI)
                    if "SM" in self.assimilation_variables:
                        member.set_variable("SM", new_states.SM)

        [member.run_till_terminate() for member in self.ensemble]

        results = [pd.DataFrame(member.get_output()).set_index("day")
                    for member in self.ensemble]
        return results


    def kalman_gain(self, obs, P, R):
        H = np.identity(len(obs))
        K = H.T @ P @ np.linalg.inv(H.T @ P @ H + R)
        return K

    def _run_wofost_gather_sates(self, date):
        [member.run_till(date) for member in self.ensemble]
        ensemble_states = []
        for member in self.ensemble:
            t = {}
            for state in self.assimilation_variables:
                t[state] = member.get_variable(state)
            ensemble_states.append(t)
        return pd.DataFrame(ensemble_states)

    def _observations_ensemble(self, observations):
        fake_obs = []
        for state_var in self.assimilation_variables:
            (value, std) = observations[state_var]
            d = np.random.normal(value, std, (self.n_ensemble))
            fake_obs.append(d)
        df_obs = pd.DataFrame(fake_obs).T
        df_obs.columns = self.assimilation_variables
        return df_obs

In [10]:
def run_ensemble(n_ensemble, obs_period,
                 assim_lai=True, assim_sm=False,
                 ens_param_inflation=1.,
                 observations=None,
                 site_name=None):
    if observations is None:
        raise ValueError("You must provide Sentinel-2 observation data via the 'observations' argument.")

    start_date, end_date = obs_period
    sd = start_date.strftime("%d%b")
    ed = end_date.strftime("%d%b")
    base_dir = Path("G:/da_plots/all_maize")
    if site_name:
        fname_out_str = base_dir / f"{site_name}"

    np.random.seed(42)
    override_parameters = {}
    # Initial conditions
    override_parameters["TDWI"] = np.clip(
        np.random.normal(60, ens_param_inflation * 30., n_ensemble),30, 150)

    override_parameters["WAV"] = np.random.normal(10, ens_param_inflation*5, (n_ensemble))
    # Parameters
    override_parameters["SPAN"] = np.random.normal(33, ens_param_inflation*5 ,(n_ensemble))
    override_parameters["TSUM1"] = np.random.normal(695, ens_param_inflation*50 ,(n_ensemble))
    override_parameters["TSUM2"] = np.random.normal(800, ens_param_inflation*50 ,(n_ensemble))
    override_parameters["CVL"] = np.random.normal(0.68, ens_param_inflation*0.2 ,(n_ensemble))
    override_parameters["CVO"] = np.random.normal(0.67, ens_param_inflation*0.1 ,(n_ensemble))
    override_parameters["CVR"] = np.random.normal(0.69, ens_param_inflation*0.1, (n_ensemble))
    override_parameters["SMW"] = np.random.normal(0.3, ens_param_inflation*0.03, (n_ensemble))
    override_parameters["SMFCF"] = np.random.normal(0.46, ens_param_inflation*0.04, (n_ensemble))
    override_parameters["SM0"] = np.random.normal(0.57, ens_param_inflation*0.057, (n_ensemble))

    assim_vars = []
    if assim_lai:
        assim_vars.append("LAI")
    if assim_sm:
        assim_vars.append("SM")
    
    enkf = WOFOSTEnKF(assim_vars, override_parameters, n_ensemble, observations)
    enkf.setup_wofost()
    results = enkf.run_filter()

    fig, axs = plt.subplots(nrows=5, ncols=2, sharex=True, squeeze=True,
                            figsize=(18, 19))
    axs = axs.flatten()

    for df_results in results:
        for j, p in enumerate(WOFOST_PARAMETERS):
            df_results['date'] = pd.to_datetime(df_results.index)
            axs[j].plot(df_results.date, df_results[p], '-', c="0.8")
            axs[j].set_ylabel(WOFOST_LABELS[p], fontsize=12)
    
    df_mean = pd.concat(results).groupby(level=0).mean()
    df_mean['date'] = pd.to_datetime(df_mean.index)

    
    for j, p in enumerate(WOFOST_PARAMETERS):
        axs[j].plot(df_mean['date'], df_mean[p], '-', color='red', linewidth=2, label='EnKF mean')
        axs[j].legend(fontsize=12)
    
    for obs_date, obs in observations:
        if "LAI" in obs:
            axs[1].errorbar(obs_date, obs['LAI'][0], yerr=obs['LAI'][1], c="#8DA0CB")
            axs[1].plot(obs_date, obs['LAI'][0], 'o', c="#8DA0CB")
        if "SM" in obs:
            axs[9].errorbar(obs_date, obs['SM'][0], yerr=obs['SM'][1], c="#A6D854")
            axs[9].plot(obs_date, obs['SM'][0], 'o', c="#A6D854")

    
    plt.gcf().autofmt_xdate()
    plt.gca().fmt_xdata = matplotlib.dates.DateFormatter('%Y-%m-%d')
    axs[9].set_xlim(start_date, end_date)
    axs[8].set_xlabel("Time [d]")
    axs[9].set_xlabel("Time [d]")
    plt.close(fig)
    fig.savefig(f"{fname_out_str}.png", dpi=300, bbox_inches="tight")
    return results, df_mean


In [11]:
WOFOST_PARAMETERS = ['DVS', 'LAI', 'TAGP', 'TWSO', 'TWLV', 'TWST',
                'TWRT', 'TRA', 'RD', 'SM']
LABELS = ["Development stage [-]", "LAI [m2/m2]",
                 "Total Biomass [kg/ha]",
                 "Total Storage Organ Weight [kg/ha]",
                 "Total Leaves Weight [kg/ha]",
                 "Total Stems Weight [kg/ha]",
                 "Total Root Weight [kg/ha]",
                 "Transpiration rate [cm/d]",
                 "Rooting depth [cm]",
                 "Soil moisture [cm3/cm3]"]
WOFOST_LABELS = dict(zip(WOFOST_PARAMETERS, LABELS))

In [12]:
all_results = {}

for coord, lai_obs in lai_data.items():
    observations = sorted([
        (date, {"LAI": (lai_mean, lai_std)})
        for date, lai_mean, lai_std in lai_obs
    ])
    site_name = f"site_{coord[1]:.4f}_{coord[0]:.4f}"
    
    results, df_mean = run_ensemble(
        n_ensemble=30,
        obs_period=obs_period,
        assim_lai=True,
        assim_sm=False,
        observations=observations,
        site_name = site_name
    )
    
    all_results[site_name] = {
        "results": results,
        "mean": df_mean,
        "observations": observations
    }

100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.63it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.57it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.58it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.49it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.45it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.35it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.30it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.31it/s]
100%|███████████████████████████████████

In [13]:
import pickle

with open("./outputs/ensemble/maize_25.pkl", "wb") as f:
    pickle.dump(all_results, f)