# EDA: Temperature timing

- Read in data
- Group into seasons
- Count: on what day were there are least $D$ days with average temperature $T$ or lower?
- Compare that across seasons
- Reasonable values look like temps of -5 to 0 C and 10-30 days

Compare this to reading that cherries require something like 860 hours of <=7C. This gets a little confusing, because I started by doing _daily_ average temperature, not hourly, which might matter


In [1]:
import polars as pl
import altair as alt
import pyarrow.dataset as ds
import numpy as np

In [3]:
data_raw = pl.scan_pyarrow_dataset(
    ds.dataset("../data/cdo", format="parquet", partitioning="hive")
).collect()

data_raw.sample(5)

date,value,year
date,f64,i32
2012-02-17,3.9,2012
2018-05-03,17.8,2018
2018-10-16,13.3,2018
2015-05-12,21.7,2015
2016-10-22,10.6,2016


In [4]:
# data to be used for analysis
data = (
    data_raw.rename({"value": "temp"})
    .with_columns(
        season=pl.when(pl.col("date") < pl.date(pl.col("year"), 6, 1))
        .then(pl.col("year"))
        .otherwise(pl.col("year") + 1)
    )
    .with_columns(
        day_in_season=(pl.col("date") - pl.date(pl.col("season"), 1, 1)).dt.total_days()
    )
    .filter(pl.col("season").is_between(2011, 2022))
    .sort("date")
)

data.sample(5)

date,temp,year,season,day_in_season
date,f64,i32,i32,i64
2018-03-18,1.1,2018,2018,76
2015-03-12,6.1,2015,2015,70
2012-04-14,7.8,2012,2012,104
2020-11-21,7.2,2020,2021,-41
2021-08-19,24.4,2021,2022,-135


In [6]:
(
    alt.Chart(data)
    .mark_line()
    .encode(
        alt.X("day_in_season", title="day relative to Jan 1"),
        alt.Y("temp", title="daily average temperature"),
    )
    .facet("season:N")
    .properties(title="Temperature")
)

In [8]:
def csnn(x: pl.Expr) -> pl.Expr:
    """Non-negative cum-sum"""
    return pl.when(x > 0).then(x).otherwise(0).cum_sum()


(
    alt.Chart(
        data.filter(pl.col("day_in_season") > -100)
        .sort("season", "day_in_season")
        .with_columns(cdd=(pl.col("temp")).pipe(csnn).over("season"))
    )
    .mark_line()
    .encode(
        alt.X("day_in_season"), alt.Y("cdd", title="Cum. daily degrees, non-negative")
    )
    .facet("season:N")
    .properties(title="Cumulative daily degrees")
)

In [17]:
def first_after_below(
    dates: np.ndarray,
    xs: np.ndarray,
    x0: float,
    n: int,
    output_type=int,
    missing_value=None,
):
    """First date, after at least n days when x is below x0

    Args:
      dates: vector of dates
      xs: vector of values
      x0: threshold value
      n: number of values of x under x0
    """
    n = np.floor(n)
    css = np.cumsum(xs < x0)

    if max(css) < n:
        return missing_value
    else:
        return dates[css == n][0].astype(output_type)


def f(season, x0, n):
    """In a given season, first date after n days below x0

    Args:
      season: string
      x0: threshold value
      n: minimum number of days
    """
    season_data = data.filter(pl.col("season") == season)

    return first_after_below(
        season_data["date"].to_numpy(), season_data["temp"].to_numpy(), x0, n
    )


out = {x: [] for x in ["max_temp", "season", "n_days", "first_day"]}

for season in data["season"].unique().to_list():
    for max_temp in [-2.5, 0, 2.5, 5]:
        for n_days in [5, 10, 20, 30]:
            first_day = f(season, max_temp, n_days)

            out["season"].append(season)
            out["max_temp"].append(max_temp)
            out["n_days"].append(n_days)
            out["first_day"].append(first_day)


results = pl.from_dict(out).with_columns(pl.col("first_day").cast(pl.Date))
results.sample(5)

max_temp,season,n_days,first_day
f64,i64,i64,date
-2.5,2013,10,2013-02-03
0.0,2012,5,2011-12-29
5.0,2012,10,2011-11-18
5.0,2018,5,2017-11-15
-2.5,2013,20,


In [21]:
(
    alt.Chart(
        results.with_columns(
            # should refactor this as days relative to Jan 1 of year after season
            y=(pl.col("first_day") - pl.date(pl.col("season"), 1, 1)).dt.total_days()
        ),
        title="Number of days below",
    )
    .encode(
        alt.X("season:N"),
        alt.Y("y", title="Days relative to Jan 1"),
        alt.Color("n_days:O", title="No. days below"),
        alt.Column("max_temp:N", title="Threshold temperature"),
    )
    .mark_line()
)

# Growth model


In [22]:
def bloom(dates, temps, N, T_star, T0: float, X_star):
    """Estimate bloom date

    D is the Nth day with temperature under T*. Compute X_i, the cumulative number of daily
    degrees after D. B is the lowest value such that X_B >= X*

    Args:
      dates: vector of dates
      temps: vector of temperatures
      N: number of days under threshold temperature
      T_star: threshold temperature
      T0: reference temperature
      X_star: threshold cumulative daily degrees

    Returns:
      the date B when X_B >= X*
    """
    # compute starting date
    D = first_after_below(dates, temps, T_star, N)

    if D is None:
        return None

    # compute cumulative daily degrees
    B = (
        pl.DataFrame({"date": dates, "temp": temps})
        .filter(pl.col("date") >= D)
        .sort("date")
        .with_columns(cdd=(pl.col("temp") - T0).cum_sum())
        .filter(pl.col("cdd") >= X_star)
        .item(row=0, column="date")
    )
    return B


d = data.filter(pl.col("season") == 2016).sort("date")

bloom(d["day_in_season"].to_numpy(), d["temp"].to_numpy(), 15, 0, 0, 100)

79

# Numpyro

Eg, what is mean and standard deviation of minimum temperature across seasons?


In [25]:
seasons = (
    data.select("season")
    .with_columns(pl.col("season").cast(pl.Int64))
    .unique()
    .sort("season")
)

bloom_dates = (
    pl.read_csv("../data/nps.csv", try_parse_dates=True)
    .filter(pl.col("stage_name") == pl.lit("Peak Bloom"))
    .with_columns(
        y=(pl.col("date") - pl.date(pl.col("date").dt.year(), 1, 1)).dt.total_days()
    )
    .rename({"year": "season"})
    .select(["season", "y"])
    .join(seasons, on="season", how="inner")
    .sort("season")
)

seasons = bloom_dates["season"].to_numpy()
y = bloom_dates["y"].to_numpy()

In [26]:
import numpyro
import numpyro.distributions as dist
import jax.random
from numpyro.infer import MCMC, NUTS


J = len(seasons)
assert len(y) == J


def estimate_blooms(N, Tstar, T0, Xstar, data=data):
    yhat = np.empty(J)
    for i, season in enumerate(seasons):
        d = data.filter(pl.col("season") == season).sort("date")
        b = bloom(
            dates=d["day_in_season"].to_numpy(),
            temps=d["temp"].to_numpy(),
            N=N,
            T_star=Tstar,
            T0=T0,
            X_star=Xstar,
        )
        yhat[i] = b

    return yhat


def model(J, y=None):
    # prior distributions
    N = numpyro.sample("N", dist.Uniform(5, 30))
    Tstar = numpyro.sample("Tstar", dist.Uniform(-2.5, 5.0))
    T0 = numpyro.sample("T0", dist.Uniform(-10.0, 10.0))
    Xstar = numpyro.sample("Xstar", dist.Uniform(100.0, 1500.0))

    # compute the predicted bloom dates, given parameters
    yhat = estimate_blooms(N.item(), Tstar.item(), T0.item(), Xstar.item())
    obs = yhat - y
    assert len(obs) == J

    with numpyro.plate("J", J):
        # observation error on predicted vs. obs. bloom dates
        numpyro.sample("obs", dist.Uniform(-7.0, 7.0), obs=obs)


mcmc = MCMC(NUTS(model), num_warmup=500, num_samples=1000)
rng_key = jax.random.PRNGKey(0)
mcmc.run(rng_key, J, y=y)

mcmc.print_summary()

  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/1500 [00:00<?, ?it/s]


ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[]
This occurred in the item() method of jax.Array

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError