In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pandas import Timestamp
from daesim.utils import ODEModelSolver

from daesim.climate import *
from daesim.plantgrowthphases import PlantGrowthPhases
from daesim.management import ManagementModule
from daesim.plant_1000_thermaltime import PlantModuleCalculator

from daesim2_analysis.parameters import Parameters
from daesim2_analysis.forcing_data import ForcingData
from daesim2_analysis.utils import load_df_forcing

In [4]:
from pandas import ExcelFile
import geopandas as gpd
from pandas import read_csv

In [5]:
xlsx = ExcelFile('/g/data/xe2/ya6227/NVTAnalysis/data/NVT.xlsx')
df = xlsx.parse('Cereals')
df_wheat = df[df['Crop.Name'] == 'Wheat']
df_rockstar = df_wheat[df_wheat['VarietyDisplayName'] == 'RockStar']

In [7]:
df_rockstar['Single Site Yield']

11344    5.726809
11376    2.461329
11407    3.848269
11445    4.259623
11492    2.631956
           ...   
65067    3.075590
65096    4.875894
65120    4.217914
65151    3.201857
65175    3.651493
Name: Single Site Yield, Length: 1053, dtype: float64

In [8]:
wheat_rockstar_11344 = df.loc[11344]

In [9]:
parameters = Parameters.__from_file__("/g/data/xe2/ya6227/daesim2-analysis/parameters/Fast1.json")

In [10]:
path_df_forcing = '/g/data/xe2/ya6227/NVTAnalysis/data/DAESim/Wheat-Rockstar-11344/environmental/Wheat-Rockstar-11344_DAESim_forcing.csv'
CLatDeg = wheat_rockstar_11344['Trial GPS Lat']
CLonDeg = wheat_rockstar_11344['Trial GPS Long']
df_forcing = load_df_forcing(path_df_forcing)

In [11]:
SiteX = ClimateModule(CLatDeg=CLatDeg, CLonDeg=CLonDeg, timezone=10)
ForcingDataX = ForcingData(
    SiteX=SiteX,
    sowing_dates=[Timestamp(wheat_rockstar_11344['SowingDate'])],
    harvest_dates=[Timestamp(wheat_rockstar_11344['HarvestDate'])],
    df=df_forcing,
    df_type='0'
)

In [12]:
ManagementX = ManagementModule(cropType="Wheat", sowingDays=ForcingDataX.sowing_days, harvestDays=ForcingDataX.harvest_days, sowingYears=ForcingDataX.sowing_years, harvestYears=ForcingDataX.harvest_years)
PlantDevX = PlantGrowthPhases(
    phases=["germination", "vegetative", "anthesis", "grainfill", "maturity"],
    gdd_requirements=[120, 500, 200, 350, 200],
    vd_requirements=[0, 25, 0, 0, 0],
    allocation_coeffs=[
        [0.2, 0.1, 0.7, 0.0, 0.0],   # Phase 1
        [0.5, 0.1, 0.4, 0.0, 0.0],   # Phase 2
        [0.25, 0.5, 0.25, 0.0, 0.0], # Phase 3
        [0.1, 0.1, 0.1, 0.7, 0.0],   # Phase 4
        [0.1, 0.1, 0.1, 0.7, 0.0]    # Phase 5
    ],
    turnover_rates = [
        [0.001, 0.001, 0.001, 0.0, 0.0],  # Phase 1
        [0.01,  0.002, 0.01,  0.0, 0.0],  # Phase 2
        [0.02,  0.002, 0.04,  0.0, 0.0],  # Phase 3
        [0.10,  0.008, 0.10,  0.0, 0.0],  # Phase 4
        [0.50,  0.017, 0.50,  0.0, 0.0]   # Phase 5
    ]    ## Turnover rates per pool and developmental phase (days-1))
)

PlantX = PlantModuleCalculator(
    Site=SiteX,
    Management=ManagementX,
    PlantDev=PlantDevX,
    GDD_method="linear1",
    GDD_Tbase=0.0,
    GDD_Tupp=25.0,
)

In [13]:
## Define the callable calculator that defines the right-hand-side ODE function
PlantXCalc = PlantX.calculate

Model = ODEModelSolver(calculator=PlantXCalc, states_init=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], time_start=ForcingDataX.time_axis[0], log_diagnostics=True)

In [14]:
## Observations

## Example 1: Milgadara, Rocky East, Canola TT Y44, 2021
observables_names = ["start of vegetative", "start of flowering", "start of grainfill", "harvest"]
observables_units = ["ordinal day of year", "ordinal day of year", "ordinal day of year", "ordinal day of year"]
observables_values = [143, 234, 262, 339]
observables_uncertainty = [5, 5, 5, 5]

## Example 2: Milgadara, Horse Paddock Big, Canola, 2019
observables_names = ["start of vegetative", "start of flowering", "start of grainfill", "harvest"]
observables_units = ["ordinal day of year", "ordinal day of year", "ordinal day of year", "ordinal day of year"]
observables_values = [167, 258, 265, 307]    # 16-Jun, 3-Nov, 15-Sep, 22-Sep (2021)
observables_uncertainty = [5, 5, 5, 5]

## Put these into a combined pandas dataframe
target_df = pd.DataFrame({
    "Name": observables_names,
    "Units": observables_units,
    "Values": observables_values,
    "Uncertainty": observables_uncertainty,
})

# Put the observations vectors into arrays
y = target_df["Values"].values
U_y = target_df["Uncertainty"].values

In [15]:
from scipy.optimize import differential_evolution

from daesim2_analysis.run import update_attribute, update_attribute_in_phase

In [16]:
def run_model_and_get_outputs(Plant, ODEModelSolver, time_axis, forcing_inputs, reset_days, zero_crossing_indices):
    ## Define the callable calculator that defines the right-hand-side ODE function
    PlantCalc = Plant.calculate
    
    Model = ODEModelSolver(calculator=PlantCalc, states_init=[0.0, 0.0], time_start=time_axis[0], log_diagnostics=True)
    
    ## Run the model solver
    res = Model.run(
        time_axis=time_axis,
        forcing_inputs=forcing_inputs,
        solver="euler",
        zero_crossing_indices=zero_crossing_indices,
        reset_days=reset_days,
    )

    # Convert the defaultdict to a regular dictionary
    _diagnostics = dict(Model.diagnostics)
    # Convert each list in the dictionary to a NumPy array
    diagnostics = {key: np.array(value) for key, value in _diagnostics.items()}

    # Convert the array to a numeric type, handling mixed int and float types
    diagnostics['idevphase_numeric'] = np.array(diagnostics['idevphase'],dtype=np.float64)
    
    # In the model idevphase can equal None but that is not useable in post-processing, so we set None values to np.nan
    diagnostics["idevphase_numeric"][diagnostics["idevphase"] == None] = np.nan

    # Add np.nan to the end of each array in the dictionary to represent the last time point in the time_axis (corresponds to the last time point of the state vector)
    for key in diagnostics:
        if key == "t":
            diagnostics[key] = np.append(diagnostics[key], res["t"][-1])
        else:
            diagnostics[key] = np.append(diagnostics[key], np.nan)

    # Add state variables to the diagnostics dictionary
    diagnostics["GDD"] = res["y"][0,:]
    diagnostics["VD"] = res["y"][1,:]

    # Add forcing inputs to diagnostics dictionary
    for i,f in enumerate(forcing_inputs):
        ni = i+1
        if f(time_axis[0]).size == 1:
            fstr = f"forcing {ni:02}"
            diagnostics[fstr] = f(time_axis)
        elif f(time_axis[0]).size > 1:
            # this forcing input has levels/layers (e.g. multilayer soil moisture)
            nz = f(time_axis[0]).size
            for iz in range(nz):
                fstr = f"forcing {ni:02} z{iz}"
                diagnostics[fstr] = f(time_axis)[:,iz]
    
    # Observation Operator
    # Calculate model-equivalent observations from model run output

    # # Diagnose time indexes when developmental phase transitions occur
    ngrowing_seasons = (len(Plant.Management.sowingDays) if (isinstance(Plant.Management.sowingDays, int) == False) else 1)
    if ngrowing_seasons > 1:
        # print("Multiple sowing and harvest events occur. Only returning results for first growing season.")
        ## ignore any time steps before first sowing event and after last harvest event
        it_sowing = np.where(time_axis == reset_days[0])[0][0]  #sowing_steps_itax[0]
        
        if Plant.Management.harvestDays is not None:
            it_harvest = np.where(time_axis == reset_days[1])[0][0]  #harvest_steps_itax[0]   # np.where(np.floor(Climate_doy_f(time_axis)) == Plant.Management.harvestDay)[0][0]
        else:
            it_harvest = -1   # if there is no harvest day specified, we just take the last day of the simulation. 
    else:
        # print("Just one sowing event and one harvest event occurs. Returning results for first (and only) growing season.")
        ## ignore any time steps before first sowing event and after last harvest event
        it_sowing = np.where(time_axis == reset_days[0])[0][0]  #sowing_steps_itax[0]
        
        if Plant.Management.harvestDays is not None:
            it_harvest = np.where(time_axis == reset_days[1])[0][0]  #harvest_steps_itax[0]   # np.where(np.floor(Climate_doy_f(time_axis)) == Plant.Management.harvestDay)[0][0]
        else:
            it_harvest = -1   # if there is no harvest day specified, we just take the last day of the simulation. 

    # Diagnose time indexes when developmental phase transitions occur

    # Convert the array to a numeric type, handling mixed int and float types
    idevphase = diagnostics["idevphase_numeric"]   #[it_sowing:it_harvest+1]
    valid_mask = ~np.isnan(idevphase)
    
    # Identify all transitions (number-to-NaN, NaN-to-number, or number-to-different-number)
    it_phase_transitions = np.where(
        ~valid_mask[:-1] & valid_mask[1:] |  # NaN-to-number
        valid_mask[:-1] & ~valid_mask[1:] |  # Number-to-NaN
        (valid_mask[:-1] & valid_mask[1:] & (np.diff(idevphase) != 0))  # Number-to-different-number
    )[0] + 1
    
    # Time index for the end of the maturity phase
    if PlantX.PlantDev.phases.index('maturity') in idevphase:
        it_mature = np.where(idevphase == PlantX.PlantDev.phases.index('maturity'))[0][-1]    # Index for end of maturity phase
    elif PlantX.Management.harvestDays is not None: 
        it_mature = it_harvest    # Maturity developmental phase not completed, so take harvest as the end of growing season
    else:
        it_mature = -1    # if there is no harvest day specified, we just take the last day of the simulation. 

    # it_sowing = np.where(time_axis == Plant.Management.sowingDay)[0][0]
    # if Plant.Management.harvestDay is not None:
    #     it_harvest = np.where(time_axis == Plant.Management.harvestDay)[0][0]
    # else:
    #     it_harvest = -1   # if there is no harvest day specified, we just take the last day of the simulation. 

    # # Convert the array to a numeric type, handling mixed int and float types
    # idevphase = diagnostics["idevphase_numeric"]
    # valid_mask = ~np.isnan(idevphase)
    
    # # Identify all transitions (number-to-NaN, NaN-to-number, or number-to-different-number)
    # it_phase_transitions = np.where(
    #     ~valid_mask[:-1] & valid_mask[1:] |  # NaN-to-number
    #     valid_mask[:-1] & ~valid_mask[1:] |  # Number-to-NaN
    #     (valid_mask[:-1] & valid_mask[1:] & (np.diff(idevphase) != 0))  # Number-to-different-number
    # )[0] + 1
    
    # # Time index for the end of the maturity phase
    # if Plant.PlantDev.phases.index('maturity') in idevphase:
    #     it_mature = np.where(idevphase == Plant.PlantDev.phases.index('maturity'))[0][-1]    # Index for end of maturity phase
    # elif Plant.Management.harvestDay is not None: 
    #     it_mature = it_harvest    # Maturity developmental phase not completed, so take harvest as the end of growing season
    # else:
    #     it_mature = -1    # if there is no harvest day specified, we just take the last day of the simulation. 

    # import pdb; pdb.set_trace()
    # Filter out transitions that occur on or before the sowing day
    # it_phase_transitions = [t for t in it_phase_transitions if time_axis[t] > time_axis[it_sowing+1]]
    it_phase_transitions = [t for t in it_phase_transitions if t > int(it_sowing+1)]
    # Filter out transitions that occur after the maturity or harvest day
    # it_phase_transitions = [t for t in it_phase_transitions if time_axis[t] <= time_axis[it_mature]]
    it_phase_transitions = [t for t in it_phase_transitions if t <= it_mature]

    # Developmental phase indexes
    igermination = Plant.PlantDev.phases.index("germination")
    ivegetative = Plant.PlantDev.phases.index("vegetative")
    if Plant.Management.cropType == "Wheat":
        ispike = Plant.PlantDev.phases.index("spike")
    ianthesis = Plant.PlantDev.phases.index("anthesis")
    igrainfill = Plant.PlantDev.phases.index("grainfill")
    imaturity = Plant.PlantDev.phases.index("maturity")

    ip = np.where(diagnostics['idevphase'][it_phase_transitions] == Plant.PlantDev.phases.index('vegetative'))[0][0]
    tdoy_vegetative = time_axis[it_phase_transitions[ip]]   # ordinal day-of-year at transition point into vegetative phase
    if Plant.PlantDev.phases.index('anthesis') in idevphase[it_sowing+1:it_harvest+1]:
        ip = np.where(diagnostics['idevphase'][it_phase_transitions] == Plant.PlantDev.phases.index('anthesis'))[0][0]
        tdoy_anth0 = time_axis[it_phase_transitions[ip]]   # ordinal day-of-year at transition point into anthesis phase
    else:
        tdoy_anth0 = time_axis[it_harvest]
    if Plant.PlantDev.phases.index('grainfill') in idevphase[it_sowing+1:it_harvest+1]:
        ip = np.where(diagnostics['idevphase'][it_phase_transitions] == Plant.PlantDev.phases.index('grainfill'))[0][0]
        tdoy_anth1 = time_axis[it_phase_transitions[ip]]   # ordinal day-of-year at transition point into grainfill stage (out of anthesis phase)
    else:
        tdoy_anth1 = time_axis[it_harvest]
    tdoy_harvest = time_axis[it_harvest]   # ordinal day-of-year at harvest
    
    # import pdb; pdb.set_trace()
    # Model output (of observables) given the parameter vector p
    # - this is the model output that we compare to observations and use to calibrate the parameters
    M_p = np.array([
        tdoy_vegetative, 
        tdoy_anth0, 
        tdoy_anth1,
        tdoy_harvest,
    ])

    return M_p

def model_function(params, model_instance, input_data, param_info):
    # Update the model class with the new parameters
    # Plant.PlantDev.gdd_requirements = [params[ip_GDD_germ],params[ip_GDD_veg],params[ip_GDD_ant],params[ip_GDD_gf],params[ip_GDD_mat]]
    # Plant.Management.sowingDay = params[ip_sowingDay]
    # Plant.Management.harvestDay = params[ip_harvestDay]

    # Collate input data to pass to model run function
    ODEModelSolver, time_axis, forcing_inputs, reset_days, zero_crossing_indices, time_nday_f, time_doy_f, time_year_f = input_data
    
    for idx, value in enumerate(params):
        param_name = param_info["Name"].values[idx]
        param_path = param_info["Module Path"].values[idx]
        full_path = f"{param_path}.{param_name}"
        phase_specific = param_info["Phase Specific"].values[idx]
        
        if phase_specific:
            # Handle phase-specific parameters
            phase = param_info["Phase"].values[idx]
            update_attribute_in_phase(model_instance, full_path, value, phase)
        else:
            if (param_name == "sowingDays") or (param_name == "harvestDays"):
                # Update parameters that must be defined as a list type
                update_attribute(model_instance, full_path, [value])
            else:
                # Update regular parameters
                update_attribute(model_instance, full_path, value)

        # Make sure the solver knows about the sowing and harvest dates as well (to reset the state variables like GDD and VD)
        if (param_name == "sowingDays") or (param_name == "harvestDays"):
            # Find value of time_nday_f where time_doy_f == sowingDay and time_year_f == sowingYear.
            sowingDay, sowingYear = model_instance.Management.sowingDays, model_instance.Management.sowingYears
            sowing_nday = time_nday_f[(np.floor(time_doy_f) == sowingDay) & (np.array(time_year_f) == sowingYear)]
            
            # Find value of time_nday_f where time_doy_f == sowingDay and time_year_f == sowingYear.
            harvestDay, harvestYear = model_instance.Management.harvestDays, model_instance.Management.harvestYears
            harvest_nday = time_nday_f[(np.floor(time_doy_f) == harvestDay) & (np.array(time_year_f) == harvestYear)]
            
            # Set reset_days to be the updated sowing and harvest nday
            reset_days = [sowing_nday[0], harvest_nday[0]]
     
    model_output = run_model_and_get_outputs(model_instance, ODEModelSolver, time_axis, forcing_inputs, reset_days, zero_crossing_indices)

    return model_output

# Define an objective function to minimize
def objective_function_mse(params, observations, Plant, input_data, param_info):
    """
    Objective function using mean squared error (MSE). 

    Notes
    -----
    """
    # Round the parameters off to integers as these parameters must be integers representing day-of-year
    int_params = np.round(params).astype(int)
    # Calculate model outputs
    model_outputs = model_function(int_params, Plant, input_data, param_info)
    # Calculate the error (e.g., mean squared error) TODO: Include model-obs uncertainties here too
    error = np.mean((model_outputs - observations) ** 2)
    return error

# Define an objective function to minimize
def objective_function_wls(params, observations, observation_unc_sigma, Plant, input_data, param_info):
    """
    Objective function using weighted least squares (WLS). 

    Notes
    -----
    The cost function, $J$, is defined using a weighted least-squares as follows: 
    
    $J = \sum_{i=1}^n \frac{(M_i(p) - y_i)^2}{\sigma^2}$
    
    Where $y$ is the vector of observations, $M$ is the vector model predicted observables 
    given parameter set $p$, and $\sigma$ is the observation errors (assumed to include 
    structural model errors). Note that this formulation ignores the priors. 
    """
    # Round the parameters off to integers as these parameters must be integers representing day-of-year
    int_params = np.round(params).astype(int)
    # Calculate model outputs
    model_outputs = model_function(int_params, Plant, input_data, param_info)
    # Calculate the error as the weighted Least Squares: 
    # 
    # Error is the model - observed difference squared, normalised by the uncertainty (as a variance), and summed over all obs
    error = np.mean( ((model_outputs - observations) ** 2) / (observation_unc_sigma**2))
    print(f"Current error: {error}")
    return error



In [17]:
input_data = [ODEModelSolver, ForcingDataX.time_axis, ForcingDataX.inputs, ForcingDataX.reset_days, ForcingDataX.zero_crossing_indices, ForcingDataX.time_nday_f, ForcingDataX.time_doy_f, ForcingDataX.time_year_f]

param_info = parameters.df
params = parameters.df["Initial Value"].values

model_function(params, PlantX, input_data, param_info)

AttributeError: 'PlantModuleCalculator' object has no attribute 'PlantCH2O'

In [18]:
PlantX

PlantModuleCalculator(Site=ClimateModule(CLatDeg=-34.30348, CLonDeg=135.72141, timezone=10, Elevation=70.74206543, degSlope=4.62, slopeLength=97.2, iniSoilDepth=0.09, met_z_meas=10.0, rainConv=0.001, T_K0=273.15, g=9.80665, L=2450, R_w_mol=8.31446, R_w_mass=0.4615, MW_ratio_H2O=0.622, rho_air=1.293, cp_air=1.013, StefanBoltzmannConstant=5.6704e-08, S0_Wm2=1370, S0_MJm2min=0.0822, airPressSeaLevel=101325, TempSeaLevel=288.15, tropos_lapse_rate=0.0065, M=0.0289644, D_H2O_T20=2.42e-05, D_CO2_T20=1.51e-05), Management=ManagementModule(cropType='Wheat', sowingDays=array([135.]), harvestDays=array([338.]), sowingYears=array([2018.]), harvestYears=array([2018.]), sowingRate=80, sowingDepth=0.03, propHarvestSeed=1.0, propHarvestLeaf=0.9, propHarvestStem=0.7, propPhHarvesting=0.3, propNPhHarvest=0.4, PhHarvestTurnoverTime=1, NPhHarvestTurnoverTime=1, propTillage=0.5, propHarvPhLeft=0.1, propHarvNPhLeft=0.8), PlantDev=PlantGrowthPhases(nCpools=5, ileaf=0, istem=1, iroot=2, iseed=3, iexud=4, phas

In [19]:
param_info

Unnamed: 0,Module Path,Module,Name,Unit,Initial Value,Min,Max,Phase Specific,Phase
0,PlantCH2O.CanopyGasExchange.Leaf,Leaf,Vcmax_opt,mol CO2 m-2 s-1,6e-05,3e-05,0.00012,False,
1,PlantCH2O.CanopyGasExchange.Leaf,Leaf,g1,kPa^0.5,3.0,1.0,6.0,False,
2,PlantCH2O,PlantCH2O,SLA,m2 g d.wt-1,0.03,0.015,0.035,False,
3,PlantCH2O,PlantCH2O,maxLAI,m2 m-2,6.0,5.0,7.0,False,
4,PlantCH2O,PlantCH2O,ksr_coeff,g d.wt-1 m-1,1000.0,300.0,5000.0,False,
5,PlantCH2O,PlantCH2O,Psi_f,MPa,-3.5,-8.0,-1.0,False,
6,PlantCH2O,PlantCH2O,sf,MPa-1,3.5,1.5,7.0,False,
7,PlantDev,PlantDev,gdd_requirements,deg C d,900.0,600.0,1800.0,True,vegetative
8,PlantDev,PlantDev,gdd_requirements,deg C d,650.0,350.0,700.0,True,grainfill
9,,,GY_FE,thsnd grains g d.wt spike-1,0.1,0.08,0.21,False,


In [20]:
CLatDeg = wheat_rockstar_11344['Trial GPS Lat']
CLonDeg = wheat_rockstar_11344['Trial GPS Long']

In [21]:
CLatDeg

-34.30348

In [22]:
CLonDeg

135.72141

In [23]:
Timestamp(wheat_rockstar_11344['SowingDate'])

Timestamp('2018-05-15 00:00:00')

In [25]:
wheat_rockstar_11344['HarvestDate']

Timestamp('2018-12-04 00:00:00')