In [None]:
import polars as pl
import altair as alt
import numpy as np
import jax.numpy as jnp

alt.data_transformers.enable("vegafusion")

# Temperature data


In [None]:
temp_raw = pl.read_csv("../data/noaa.csv", try_parse_dates=True)
temp_raw

In [None]:
# TAVG is only in some parts of the dataset
alt.Chart(temp_raw).mark_line().encode(alt.X("date"), alt.Y("temperature")).facet(
    "variable"
)

In [None]:
# data start in June 1941 (1942 season)
temp_raw.group_by("variable").agg(start=pl.col("date").min(), end=pl.col("date").max())

In [None]:
# group temp data into seasons
# data to be used for analysis
temp = (
    temp_raw.with_columns(
        season=pl.col("date").dt.year()
        + pl.when(pl.col("date").dt.month() <= 5).then(0).otherwise(1)
    )
    .with_columns(
        day_in_season=(pl.col("date") - pl.date(pl.col("season"), 1, 1))
        / pl.duration(days=1)
    )
    .select(["date", "season", "day_in_season", "variable", "temperature"])
    .sort("date")
)

temp.sample(5)

# Schumacher method

- Get cumulative daily minimum degrees since Feb 1, then compare that to the actual date
- Compare to a result where you must have minimum temperature above zero (i.e., ignore negatives)


In [None]:
bloom = pl.read_csv("../data/bloom_dates.csv", try_parse_dates=True).rename(
    {"year": "season"}
)

bloom

In [None]:
cdd = (
    temp.filter(
        pl.col("variable") == pl.lit("TMIN"),
        pl.col("day_in_season").is_between(30, 120),
    )
    .sort("date")
    .with_columns(
        cdd=pl.col("temperature").cum_sum().over("season"),
        cdd_nz=pl.when(pl.col("temperature") > 0)
        .then(pl.col("temperature"))
        .otherwise(0.0)
        .cum_sum()
        .over("season"),
    )
    .select(["season", "date", "day_in_season", "cdd", "cdd_nz"])
    .unpivot(index=["season", "date", "day_in_season"])
)

(
    alt.Chart(cdd)
    .mark_line()
    .encode(
        alt.X("day_in_season"),
        alt.Y("value", title="Cum. daily degrees"),
        alt.Color("season"),
    )
    .facet("variable")
    .properties(title="Cumulative daily degrees")
)

In [None]:
bloom_cdd = bloom.join(cdd, on=["season", "date"], how="inner")

alt.Chart(bloom_cdd).mark_bar().encode(
    alt.X("value", bin=True, title="CDDs"), alt.Y("count()", title="Number of seasons")
).facet("variable").properties(title="CDD as of bloom date")

In [None]:
def score_cmp(
    cdd_nz_star=170.0, start_day=30.0, temp=temp, bloom=bloom
) -> pl.DataFrame:
    cdd = (
        temp.filter(
            pl.col("variable") == pl.lit("TMIN"), pl.col("day_in_season") >= start_day
        )
        .sort("date")
        .with_columns(
            cdd_nz=pl.when(pl.col("temperature") > 0)
            .then(pl.col("temperature"))
            .otherwise(0.0)
            .cum_sum()
            .over("season"),
        )
        .select(["season", "date", "day_in_season", "cdd_nz"])
    )

    cdd_pred = (
        cdd.filter(pl.col("cdd_nz") >= cdd_nz_star)
        .filter((pl.col("date") == pl.col("date").min()).over("season"))
        .select(["season", "date"])
        .with_columns(pl.col("season").cast(pl.Int64))
    )

    return (
        pl.concat(
            [
                bloom.with_columns(source=pl.lit("obs")),
                cdd_pred.with_columns(source=pl.lit("pred")),
            ]
        )
        .with_columns(
            day_in_season=(pl.col("date") - pl.date(pl.col("season"), 1, 1))
            / pl.duration(days=1)
        )
        .select(["season", "source", "date", "day_in_season"])
    )


def score(cdd_nz_star, start_day) -> float:
    s = (
        score_cmp(cdd_nz_star, start_day)
        .select(["season", "source", "date"])
        .pivot(index="season", on="source")
        .with_columns(
            abs_err=(pl.col("obs") - pl.col("pred")).abs() / pl.duration(days=1)
        )
        .filter(pl.col("abs_err").is_not_null())
        .get_column("abs_err")
    )
    return (s**2).sum()


result = pl.from_dicts(
    [
        {"cdd_nz_star": star, "start_day": start_day, "score": score(star, start_day)}
        for star in np.linspace(0, 280, num=28 * 2 + 1)
        for start_day in np.linspace(0, 100, num=41)
    ]
)

result

In [None]:
alt.Chart(result).mark_rect().encode(
    alt.X("cdd_nz_star", type="ordinal"),
    alt.Y("start_day", type="ordinal", scale=alt.Scale(reverse=True)),
    alt.Color(
        "score",
        scale=alt.Scale(
            type="threshold",
            # domain=[1000, 50000],
            domain=[1500, 2000, 3000, 10000],
            rangeMin=2000,
            scheme="inferno",
            reverse=True,
        ),
    ),
    alt.Tooltip(["cdd_nz_star", "start_day", "score"]),
)

In [None]:
point_data = score_cmp(160, 40).filter(pl.col("season") >= 1942)

line_data = (
    point_data.select(["season", "source", "day_in_season"])
    .group_by("season")
    .agg(ymin=pl.col("day_in_season").min(), ymax=pl.col("day_in_season").max())
)

yargs = {
    "axis": alt.Axis(values=range(63, 119 + 1, 7)),
    "scale": alt.Scale(domain=(63, 119)),
}

point_chart = (
    alt.Chart(point_data)
    .mark_point(filled=True, size=40)
    .encode(
        alt.X("season", type="nominal"),
        alt.Y("day_in_season", **yargs),
        alt.Color("source"),
        alt.Tooltip(["source", "date"]),
    )
    .properties(width=1000)
)

line_chart = (
    alt.Chart(line_data)
    .mark_rule()
    .encode(
        alt.X("season", type="nominal"),
        alt.X2("season"),
        alt.Y("ymin", **yargs),
        alt.Y2("ymax"),
    )
)

line_chart + point_chart

In [None]:
chart_data = (
    score_cmp(160, 40)
    .select(["season", "source", "day_in_season"])
    .pivot(index="season", on="source")
    .with_columns(abs_error=(pl.col("obs") - pl.col("pred")).abs())
    .filter(pl.col("abs_error").is_not_null())
)

print(chart_data.filter(pl.col("abs_error") > 7.0))

alt.Chart(chart_data).mark_bar().encode(alt.X("abs_error", bin=True), alt.Y("count()"))

# Chill units


In [None]:
temp.filter(pl.col("variable") == pl.lit("TMIN"))

In [None]:
def first_after_below(
    dates: np.ndarray,
    xs: np.ndarray,
    x0: float,
    n: int,
    output_type=int,
    missing_value=None,
):
    """First date, after at least n days when x is below x0

    Args:
      dates: vector of dates
      xs: vector of values
      x0: threshold value
      n: number of values of x under x0
    """
    n = np.floor(n)
    css = np.cumsum(xs < x0)

    if max(css) < n:
        return missing_value
    else:
        return dates[css == n][0].astype(output_type)


def f(season, x0, n):
    """In a given season, first date after n days below x0

    Args:
      season: string
      x0: threshold value
      n: minimum number of days
    """
    season_data = data.filter(pl.col("season") == season)

    return first_after_below(
        season_data["date"].to_numpy(), season_data["temp"].to_numpy(), x0, n
    )


out = {x: [] for x in ["max_temp", "season", "n_days", "first_day"]}

for season in data["season"].unique().to_list():
    for max_temp in [-2.5, 0, 2.5, 5]:
        for n_days in [5, 10, 20, 30]:
            first_day = f(season, max_temp, n_days)

            out["season"].append(season)
            out["max_temp"].append(max_temp)
            out["n_days"].append(n_days)
            out["first_day"].append(first_day)


results = pl.from_dict(out).with_columns(pl.col("first_day").cast(pl.Date))
results.sample(5)

In [None]:
(
    alt.Chart(
        results.with_columns(
            # should refactor this as days relative to Jan 1 of year after season
            y=(pl.col("first_day") - pl.date(pl.col("season"), 1, 1)).dt.total_days()
        ),
        title="Number of days below",
    )
    .encode(
        alt.X("season:N"),
        alt.Y("y", title="Days relative to Jan 1"),
        alt.Color("n_days:O", title="No. days below"),
        alt.Column("max_temp:N", title="Threshold temperature"),
    )
    .mark_line()
)

# Growth model


In [None]:
def bloom(dates, temps, N, T_star, T0: float, X_star):
    """Estimate bloom date

    D is the Nth day with temperature under T*. Compute X_i, the cumulative number of daily
    degrees after D. B is the lowest value such that X_B >= X*

    Args:
      dates: vector of dates
      temps: vector of temperatures
      N: number of days under threshold temperature
      T_star: threshold temperature
      T0: reference temperature
      X_star: threshold cumulative daily degrees

    Returns:
      the date B when X_B >= X*
    """
    # compute starting date
    D = first_after_below(dates, temps, T_star, N)

    if D is None:
        return None

    # compute cumulative daily degrees
    B = (
        pl.DataFrame({"date": dates, "temp": temps})
        .filter(pl.col("date") >= D)
        .sort("date")
        .with_columns(cdd=(pl.col("temp") - T0).cum_sum())
        .filter(pl.col("cdd") >= X_star)
        .item(row=0, column="date")
    )
    return B


d = data.filter(pl.col("season") == 2016).sort("date")

bloom(d["day_in_season"].to_numpy(), d["temp"].to_numpy(), 15, 0, 0, 100)

# Numpyro

Eg, what is mean and standard deviation of minimum temperature across seasons?


In [None]:
seasons = (
    data.select("season")
    .with_columns(pl.col("season").cast(pl.Int64))
    .unique()
    .sort("season")
)

bloom_dates = (
    pl.read_csv("../data/nps.csv", try_parse_dates=True)
    .filter(pl.col("stage_name") == pl.lit("Peak Bloom"))
    .with_columns(
        y=(pl.col("date") - pl.date(pl.col("date").dt.year(), 1, 1)).dt.total_days()
    )
    .rename({"year": "season"})
    .select(["season", "y"])
    .join(seasons, on="season", how="inner")
    .sort("season")
)

seasons = bloom_dates["season"].to_numpy()
y = bloom_dates["y"].to_numpy()

In [None]:
import numpyro
import numpyro.distributions as dist
import jax.random
from numpyro.infer import MCMC, NUTS


J = len(seasons)
assert len(y) == J


def estimate_blooms(N, Tstar, T0, Xstar, data=data):
    yhat = np.empty(J)
    for i, season in enumerate(seasons):
        d = data.filter(pl.col("season") == season).sort("date")
        b = bloom(
            dates=d["day_in_season"].to_numpy(),
            temps=d["temp"].to_numpy(),
            N=N,
            T_star=Tstar,
            T0=T0,
            X_star=Xstar,
        )
        yhat[i] = b

    return yhat


def model(J, y=None):
    # prior distributions
    N = numpyro.sample("N", dist.Uniform(5, 30))
    Tstar = numpyro.sample("Tstar", dist.Uniform(-2.5, 5.0))
    T0 = numpyro.sample("T0", dist.Uniform(-10.0, 10.0))
    Xstar = numpyro.sample("Xstar", dist.Uniform(100.0, 1500.0))

    # compute the predicted bloom dates, given parameters
    yhat = estimate_blooms(N.item(), Tstar.item(), T0.item(), Xstar.item())
    obs = yhat - y
    assert len(obs) == J

    with numpyro.plate("J", J):
        # observation error on predicted vs. obs. bloom dates
        numpyro.sample("obs", dist.Uniform(-7.0, 7.0), obs=obs)


mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key, J, y=y)

mcmc.print_summary()