In [None]:
%load_ext autoreload
%autoreload 3

from general import *
from call import *

from site_settings_coniferous_1 import *

from dataclasses import dataclass
import numpy as np
import pandas as pd
import torch

In [None]:
## PRIOR ##
file_prior              = '../parameters/parameters_BC_CONIFEROUS-1.txt'

## DATA FILES ##
sitedata_ancillary_file = '../data/data_ancillary_CONIFEROUS-1.txt'
sitedata_daily_file     = '../data/data_daily_CONIFEROUS-1.txt'
sitedata_EC_file        = '../data/data_EC_CONIFEROUS-1.txt'

## COLUMN NUMBERS FOR SPECIFIC OUTPUTS ##
 # Output column numbers for DAILY data & EC data
iWCoutput  = outputNames.index("WC")
iNEEoutput = outputNames.index("NEE_gCm2d")
iGPPoutput = outputNames.index("GPP_gCm2d")
iEToutput  = outputNames.index("ET_mmd")

# Indices are zero-based
int_daily  = 30 ; nnamax_daily = 1 ; cv_daily = 0.3
iWCdata    = 3
# EC data:
int_EC     = 30 ; nnamax_EC    = 1 ; cv_EC    = 0.3
iNEEdata   = 3  ; iGPPdata     = 4 ; iETdata  = 5

In [None]:
# Load data, slice to intervals with enough values, calc means and sds

dataset_ancillary = pd.read_csv(sitedata_ancillary_file, header=None, sep=r"\s+")
dataset_daily = pd.read_csv(sitedata_daily_file, header=None, sep=r"\s+")
dataset_daily = dataset_daily.iloc[:, 1:] # drop text "WC"
dataset_EC = pd.read_csv(sitedata_EC_file, header=0, sep=r"\s+")

@dataclass
class Intervals:
    mean: torch.Tensor
    sd: torch.Tensor
    year: torch.Tensor
    doy: torch.Tensor
    time: torch.Tensor
    output_idx: int
    interval: int

    def match_output_means(self, output: torch.Tensor) -> torch.Tensor:
        # find corresponding output intervals, calculate means
        # Extract year, day-of-year, and data columns from the output tensor.
        output_year = output[:, 1]
        output_doy  = output[:, 2]
        output_data = output[:, self.output_idx]
        
        # Find the index for each (year, doy) pair.
        indices = []
        for y, d in zip(self.year, self.doy):
            mask = (output_year == y) & (output_doy == d)
            idx = torch.where(mask)[0]
            if idx.numel() == 0:
                raise ValueError(f"No match found for year={y}, doy={d}")
            # Convert to a Python integer
            indices.append(idx[0].item())
        
        # Compute the mean over the interval starting at each found index.
        means = torch.stack([torch.nanmean(output_data[i : i + self.interval]) for i in indices])
        return means        
    
def intervals(dataset, interval, cv, idx) -> Intervals:
    data = dataset.iloc[:, idx].to_numpy()

    block_indices = np.arange(0, len(dataset), interval)
    nna = np.array([np.isnan(data[i : i + interval]).sum() for i in block_indices])
    block_indices = block_indices[nna < nnamax_daily]

    data_mean = np.array([np.nanmean(data[i : i + interval]) for i in block_indices])
    data_year = dataset.iloc[block_indices, 0].to_numpy()
    data_doy = dataset.iloc[block_indices, 1].to_numpy()
    data_time = data_year + (data_doy - 0.5) / 366

    sdmin = cv * np.abs(data_mean)
    data_sd = np.maximum(sdmin, 1.0)

    dtype = torch.float64
    data_mean = torch.tensor(data_mean, dtype=dtype)
    data_sd = torch.tensor(data_sd, dtype=dtype)
    data_year = torch.tensor(data_year, dtype=torch.int32)
    data_doy = torch.tensor(data_doy, dtype=torch.int32)
    data_time = torch.tensor(data_time, dtype=dtype)

    return Intervals(data_mean, data_sd, data_year, data_doy, data_time, idx, interval)

data_WC = intervals(dataset_daily, int_daily, cv_daily, iWCdata)
data_NEE = intervals(dataset_EC, int_EC, cv_EC, iNEEdata)
data_GPP = intervals(dataset_EC, int_EC, cv_EC, iGPPdata)
data_ET = intervals(dataset_EC, int_EC, cv_EC, iETdata)

dataset_ancillary = pd.read_csv(sitedata_ancillary_file, header=None, index_col=0, sep=r"\s+")
ancillary = []

for name, row in dataset_ancillary.iterrows():
    year, doy, value, sd = row
    output_idx = outputNames.index(name)

    value = torch.tensor([value], dtype=torch.float64)
    sd = torch.tensor([sd], dtype=torch.float64)
    year = torch.tensor([year], dtype=torch.int32)
    doy = torch.tensor([doy], dtype=torch.int32)
    time = year + (doy - 0.5) / 366

    interval = Intervals(mean=value, sd=sd, year=year, doy=doy, time=time, output_idx=output_idx, interval=1)
    
    ancillary.append(interval)

all_intervals = [data_WC, data_NEE, data_GPP, data_ET] + ancillary

In [None]:
df_params_BC = pd.read_csv(file_prior, header=None, sep=r"\s+", index_col=0)
parname_BC = df_params_BC.index
parmin_BC    = df_params_BC.iloc[:, 0].to_numpy()
parmod_BC    = df_params_BC.iloc[:, 1].to_numpy()
parmax_BC    = df_params_BC.iloc[:, 2].to_numpy()
parsites_BC  = df_params_BC.iloc[:, 3].to_numpy().astype(str)

ip_BC = [df_params.index.get_loc(name) for name in parname_BC]

# Set parameters to mode
params[ip_BC] = parmod_BC

In [None]:
@torch.library.custom_op("mylib::adbasfor", mutates_args=[], device_types="cpu")
def adbasfor(params: torch.Tensor) -> torch.Tensor:
    assert params.device.type == "cpu"
    params_np = params.numpy()
    y = np.zeros((NDAYS, NOUT), dtype=np.float64)
    output = submit(call_BASFOR_C, params_np, matrix_weather, calendar_fert, calendar_Ndep, calendar_prunT, calendar_thinT, NDAYS, NOUT, y)
    return torch.from_numpy(output)

"""
@adbasfor.register_fake
def _(params: torch.Tensor) -> torch.Tensor:
    return params.new_empty((NDAYS, NOUT))
"""

def backward(ctx, grad_output):
    input, = ctx.saved_tensors
    
    grad_input = submit(call_dBASFOR_C, input.detach(), matrix_weather, calendar_fert, calendar_Ndep, calendar_prunT, calendar_thinT, NDAYS, NOUT, grad_output)

    return torch.from_numpy(grad_input[0][:input.size(0)])

def setup_context(ctx, inputs, output):
    # FIXME
    ctx.save_for_backward(inputs[0])

adbasfor.register_autograd(backward, setup_context=setup_context)

In [None]:
half_log_2pi = torch.tensor(0.5 * np.log(2 * np.pi))

def flogL(sims: torch.Tensor, data: torch.Tensor, data_s: torch.Tensor):
    Ri = (sims - data) / data_s
    i0 = torch.abs(Ri) < 1.e-8

    logLi = torch.log(1 - torch.exp(-0.5 * Ri**2)) - torch.log(Ri**2) - torch.log(data_s)

    logLi[i0] =  - torch.log(2 * data_s[i0])
    logLi -= half_log_2pi
    return torch.sum(logLi)

def loss(output):
    logLikelihoods = []
    for interval in all_intervals:
        sims = interval.match_output_means(output)
        data = interval.mean
        data_s = interval.sd
        logLikelihoods.append(flogL(sims, data, data_s))

    logL = torch.sum(torch.stack(logLikelihoods))
    return -logL

In [None]:
factors = torch.ones(len(params), dtype=torch.float64, requires_grad=True)
optim = torch.optim.Adam([factors], lr=1e-1)

for i in range(20):
    optim.zero_grad()
    params_tensor = factors*torch.tensor(params)
    output = adbasfor(params_tensor)
    loss_value = loss(output)
    loss_value.backward()
    optim.step()
    with torch.no_grad():
        factors.clamp_(0.2, 5)
    print(loss_value.item())