<a href="https://colab.research.google.com/github/sydney-machine-learning/cyclonegenesis_seasurfacetemperature/blob/albert/Climate%20Project/albert/notebooks/Hierarchical_Regressionipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro

  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
import os
# cpu cores available for sampling (we want this to equal num_chains)
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

from IPython.display import set_matplotlib_formats
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import arviz as az
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS, Predictive

def get_saffir_simpson_category(wind_kn):
    # NOTE: observations are rounded to nearest 5 so is this being a bit cheeky??
    if wind_kn <= 82:
        return 1
    if wind_kn <= 95:
        return 2
    if wind_kn <= 112:
        return 3
    if wind_kn <= 136:
        return 4
    return 5

DATASET_URL = "https://raw.githubusercontent.com/sydney-machine-learning/cyclonegenesis_seasurfacetemperature/albert/Climate%20Project/albert/cyclone_data/jtwc/cleaned/full_instantaneous.csv"
dset = pd.read_csv(DATASET_URL)
# rename columns
dset.columns = dset.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()
dset['category'] = dset['peak_vmax_kt'].apply(get_saffir_simpson_category)
dset

  dset.columns = dset.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()
  dset.columns = dset.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()


Unnamed: 0,timestamp,storm_id,basin,season,season_tc_number,stormname,latitude_degrees,longitude_degrees,vmax_kt,peak_vmax_kt,ace,maximum_24h_intensification,tropical_sst,global_sst,local_sst,local_month_mean,local_anomaly,saffir-simpson_category,category
0,1982-03-14 06:00:00,1982-N-1,WP,1982.0,1.0,,7.1,153.0,15.0,60.0,0.73250,15.0,27.961514,13.499058,28.082220,28.562027,-0.479807,1,1
1,1981-10-21 06:00:00,1982-S-2,SI,1982.0,2.0,,-8.0,84.6,40.0,85.0,1.18725,25.0,27.240122,13.386498,28.012896,27.828356,0.184540,2,2
2,1982-03-18 06:00:00,1982-N-2,WP,1982.0,2.0,,3.8,160.7,25.0,105.0,2.02800,25.0,27.961514,13.499058,29.003502,28.968107,0.035395,3,3
3,1982-03-28 06:00:00,1982-N-3,WP,1982.0,3.0,,3.5,156.6,20.0,75.0,0.70650,15.0,27.961514,13.499058,28.872086,29.140408,-0.268322,1,1
4,1981-12-05 00:00:00,1982-S-4,AUS,1982.0,4.0,,-11.9,125.0,45.0,45.0,0.04850,-10.0,27.350914,13.314075,29.079239,30.004630,-0.925390,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1806,2022-10-14 06:00:00,2022-N-24,WP,2022.0,24.0,INVEST,26.4,154.4,25.0,40.0,0.19675,15.0,27.271444,13.853819,28.137896,27.757618,0.380278,1,1
1807,2022-10-26 00:00:00,2022-N-26,WP,2022.0,26.0,INVEST,12.2,133.8,20.0,75.0,0.67825,25.0,27.271444,13.853819,29.645842,29.207080,0.438761,1,1
1808,2022-10-28 12:00:00,2022-N-27,WP,2022.0,27.0,INVEST,8.0,140.4,25.0,40.0,0.10875,15.0,27.271444,13.853819,29.821840,29.312254,0.509586,1,1
1809,2022-11-11 18:00:00,2022-N-28,WP,2022.0,28.0,INVEST,20.2,166.3,20.0,40.0,0.10600,10.0,27.478075,13.700613,28.148760,27.703903,0.444857,1,1


In [3]:
standardize = lambda x: (x - x.mean()) / x.std()

BASIN_TO_INDEX = {
    'AUS': 0,
    'SI' : 1,
    'SP' : 2,
    'WP' : 3,
}

## FOR NUMERICAL STABILITY
dset["tropical_sst_scaled"] = dset.tropical_sst.pipe(standardize)
dset["local_sst_scaled"] = dset.local_sst.pipe(standardize)
dset["peak_wind_scaled"] = dset.peak_vmax_kt.pipe(lambda x: standardize(x))

dset['basin_numerical'] = dset.basin.apply(lambda x: BASIN_TO_INDEX[x])

In [4]:
NUM_BASINS = 4

def centred_hierarchical(
    tropical_sst=None, local_sst=None,
    peak_wind=None, basin_num=None, q=None):

    ## TODO... choice of prior??
    μ_α = numpyro.sample("μ_α", dist.Normal(0, 10.0))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(10.0))

    μ_β_trop = numpyro.sample("μ_β_trop", dist.Normal(0.0, 10.0))
    σ_β_trop = numpyro.sample("σ_β_trop", dist.HalfNormal(10.0))

    μ_β_local = numpyro.sample("μ_β_local", dist.Normal(0.0, 10.0))
    σ_β_local = numpyro.sample("σ_β_local", dist.HalfNormal(10.0))

    with numpyro.plate("basins", NUM_BASINS):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β_trop = numpyro.sample("β_trop", dist.Normal(μ_β_trop, σ_β_trop))
        β_local = numpyro.sample("β_local", dist.Normal(μ_β_local, σ_β_local))

        assert α.shape == (NUM_BASINS, ), "alpha shape wrong"

    mu = numpyro.deterministic('mu', α[basin_num] +  β_trop[basin_num] * tropical_sst + β_local[basin_num] * local_sst)
    return numpyro.sample('obs', dist.AsymmetricLaplaceQuantile(loc=mu, scale=1.0,quantile=q), obs=peak_wind)


from numpyro.handlers import reparam
from numpyro.infer.reparam import LocScaleReparam

## try the numpyro reparametrisation as well
reparam_config = {
    "α": LocScaleReparam(0),
    "β_trop": LocScaleReparam(0),
    "β_local": LocScaleReparam(0),
}

reparamed_hierarchical = reparam(
    centred_hierarchical, config=reparam_config)


def noncentred_hierarchical(
    tropical_sst=None, local_sst=None,
    peak_wind=None, basin_num=None, q=None):

    ## TODO... choice of prior??
    μ_α = numpyro.sample("μ_α", dist.Normal(0, 10.0))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(10.0))

    μ_β_trop = numpyro.sample("μ_β_trop", dist.Normal(0.0, 10.0))
    σ_β_trop = numpyro.sample("σ_β_trop", dist.HalfNormal(10.0))

    μ_β_local = numpyro.sample("μ_β_local", dist.Normal(0.0, 10.0))
    σ_β_local = numpyro.sample("σ_β_local", dist.HalfNormal(10.0))

    with numpyro.plate("basins", NUM_BASINS):
        α_prime = numpyro.sample("α_prime", dist.Normal(0, 1))
        α = numpyro.deterministic('α', α_prime * σ_α + μ_α)

        β_trop_prime = numpyro.sample("β_trop_prime", dist.Normal(0,1))
        β_trop = numpyro.deterministic('β_trop', β_trop_prime* σ_β_trop + μ_β_trop)


        β_local_prime = numpyro.sample("β_local_prime", dist.Normal(0,1))
        β_local = numpyro.deterministic('β_local', β_local_prime* σ_β_local + μ_β_local)

        assert α.shape == (NUM_BASINS, ), "alpha shape wrong"

    mu = numpyro.deterministic('mu', α[basin_num] +  β_trop[basin_num] * tropical_sst + β_local[basin_num] * local_sst)
    return numpyro.sample('obs', dist.AsymmetricLaplaceQuantile(loc=mu, scale=1.0,quantile=q), obs=peak_wind)


In [None]:
rng_key = random.PRNGKey(0)
qs = np.round(np.arange(0.05, 0.96, 0.1), 2)
params = {}

numpyro.enable_x64()

for q in qs:
    rng_key, rng_key_ = random.split(rng_key)
    kernel = NUTS(noncentred_hierarchical)
    mcmc = MCMC(kernel, num_chains=4, num_warmup=500, num_samples=500, progress_bar=True)
    mcmc.run(
        rng_key_,
        tropical_sst=dset.tropical_sst_scaled.values,
        local_sst=dset.local_sst_scaled.values,
        peak_wind=dset.peak_wind_scaled.values,
        basin_num=dset.basin_numerical.values,
        q=q )
    mcmc_samples = mcmc.get_samples()
    params[q] = mcmc_samples
    print(q)
    mcmc.print_summary()
    print()

    posterior_predictive = Predictive(
        noncentred_hierarchical, posterior_samples=mcmc_samples,
        # return_sites=extract_vars
    )(rng_key_,
      tropical_sst = dset.tropical_sst_scaled.values,
      local_sst    = dset.local_sst_scaled.values,
      peak_wind    = dset.peak_wind_scaled.values,
      basin_num=dset.basin_numerical.values,
      q            = q
    )
    arviz_posterior = az.from_numpyro(
            mcmc,
            posterior_predictive=posterior_predictive,
            # coords=az_coords,
            # dims=az_dims
    )
    params[q] = arviz_posterior
    print(q)
    mcmc.print_summary()
    print()



  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

0.05

                      mean       std    median      5.0%     95.0%     n_eff     r_hat
      α_prime[0]     -0.01      0.88     -0.00     -1.52      1.40    897.16      1.00
      α_prime[1]      0.08      0.91      0.10     -1.36      1.61   1040.90      1.00
      α_prime[2]     -0.05      0.87     -0.04     -1.51      1.29   1139.02      1.00
      α_prime[3]     -0.04      0.85     -0.03     -1.42      1.35   1111.61      1.00
β_local_prime[0]     -0.14      0.85     -0.14     -1.44      1.36   1222.79      1.00
β_local_prime[1]     -0.02      0.88     -0.02     -1.35      1.53   1329.63      1.00
β_local_prime[2]      0.02      0.89     -0.00     -1.44      1.47   1196.70      1.00
β_local_prime[3]      0.07      0.84      0.05     -1.36      1.42    994.81      1.01
 β_trop_prime[0]      0.01      0.85      0.01     -1.25      1.58   1023.92      1.00
 β_trop_prime[1]      0.01      0.89      0.01     -1.40      1.57   1257.17      1.00
 β_trop_prime[2]     -0.03      0.90 

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

0.15

                      mean       std    median      5.0%     95.0%     n_eff     r_hat
      α_prime[0]     -0.17      0.89     -0.19     -1.57      1.32   1616.73      1.00
      α_prime[1]      0.05      0.84      0.04     -1.23      1.50   1641.88      1.00
      α_prime[2]     -0.02      0.88     -0.01     -1.35      1.50   1892.50      1.00
      α_prime[3]      0.08      0.82      0.05     -1.25      1.44   1465.31      1.00
β_local_prime[0]     -0.28      0.85     -0.29     -1.71      1.07   1278.15      1.00
β_local_prime[1]      0.00      0.86      0.01     -1.40      1.41   1568.92      1.00
β_local_prime[2]      0.15      0.90      0.15     -1.24      1.75   1610.61      1.00
β_local_prime[3]      0.15      0.81      0.14     -1.14      1.52   1573.50      1.00
 β_trop_prime[0]      0.11      0.90      0.12     -1.29      1.59   1553.91      1.00
 β_trop_prime[1]     -0.02      0.86     -0.01     -1.41      1.39   1256.48      1.00
 β_trop_prime[2]      0.03      0.88 

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

0.25

                      mean       std    median      5.0%     95.0%     n_eff     r_hat
      α_prime[0]     -0.40      0.82     -0.41     -1.84      0.82    954.42      1.00
      α_prime[1]      0.10      0.84      0.08     -1.27      1.48   1312.90      1.00
      α_prime[2]      0.01      0.87     -0.00     -1.46      1.35   1332.91      1.00
      α_prime[3]      0.25      0.76      0.25     -1.03      1.43   1059.56      1.00
β_local_prime[0]     -0.55      0.78     -0.55     -1.77      0.73   1070.23      1.00
β_local_prime[1]     -0.06      0.81     -0.06     -1.39      1.29   1235.81      1.00
β_local_prime[2]      0.31      0.84      0.31     -1.01      1.75   1693.35      1.00
β_local_prime[3]      0.32      0.77      0.31     -0.97      1.55   1138.27      1.00
 β_trop_prime[0]      0.18      0.83      0.15     -1.19      1.45   1462.82      1.00
 β_trop_prime[1]     -0.07      0.89     -0.07     -1.72      1.26   1640.96      1.00
 β_trop_prime[2]     -0.02      0.88 

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

0.35

                      mean       std    median      5.0%     95.0%     n_eff     r_hat
      α_prime[0]     -0.66      0.70     -0.63     -1.64      0.57    968.47      1.00
      α_prime[1]      0.06      0.72      0.06     -1.10      1.29   1076.72      1.00
      α_prime[2]     -0.06      0.71     -0.07     -1.21      1.06    992.07      1.01
      α_prime[3]      0.68      0.73      0.65     -0.51      1.84    919.40      1.00
β_local_prime[0]     -0.69      0.75     -0.66     -1.99      0.44    949.31      1.00
β_local_prime[1]     -0.14      0.72     -0.14     -1.29      1.05   1081.89      1.00
β_local_prime[2]      0.53      0.75      0.50     -0.64      1.79   1238.13      1.00
β_local_prime[3]      0.30      0.69      0.30     -0.86      1.39    879.97      1.00
 β_trop_prime[0]      0.25      0.85      0.22     -1.20      1.59   1422.90      1.00
 β_trop_prime[1]     -0.14      0.86     -0.13     -1.53      1.32   1485.61      1.00
 β_trop_prime[2]     -0.15      0.90 

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

0.45

                      mean       std    median      5.0%     95.0%     n_eff     r_hat
      α_prime[0]     -0.73      0.68     -0.67     -1.88      0.33    706.26      1.01
      α_prime[1]      0.03      0.66      0.05     -1.02      1.13    768.79      1.00
      α_prime[2]     -0.11      0.66     -0.10     -1.18      0.98    776.71      1.00
      α_prime[3]      0.81      0.68      0.75     -0.27      1.90    779.05      1.00
β_local_prime[0]     -0.77      0.71     -0.77     -1.86      0.40    590.58      1.00
β_local_prime[1]     -0.17      0.74     -0.19     -1.27      1.11    685.91      1.00
β_local_prime[2]      0.59      0.80      0.57     -0.70      1.91    636.13      1.00
β_local_prime[3]      0.27      0.67      0.26     -0.87      1.26    667.20      1.00
 β_trop_prime[0]      0.28      0.83      0.28     -1.08      1.59    782.20      1.00
 β_trop_prime[1]     -0.14      0.87     -0.13     -1.61      1.24    981.28      1.00
 β_trop_prime[2]     -0.23      0.92 

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

0.55

                      mean       std    median      5.0%     95.0%     n_eff     r_hat
      α_prime[0]     -0.74      0.66     -0.70     -1.73      0.39    759.21      1.00
      α_prime[1]      0.08      0.66      0.08     -1.02      1.16    881.54      1.00
      α_prime[2]     -0.17      0.61     -0.16     -1.14      0.86    805.94      1.01
      α_prime[3]      0.83      0.70      0.79     -0.36      1.95    774.15      1.00
β_local_prime[0]     -0.71      0.72     -0.69     -1.81      0.50    726.66      1.00
β_local_prime[1]     -0.08      0.72     -0.07     -1.25      1.09   1091.43      1.00
β_local_prime[2]      0.71      0.72      0.67     -0.36      2.01    805.05      1.00
β_local_prime[3]      0.15      0.68      0.14     -0.96      1.25    872.44      1.00
 β_trop_prime[0]      0.36      0.85      0.37     -1.07      1.78   1160.78      1.00
 β_trop_prime[1]     -0.18      0.85     -0.16     -1.54      1.22   1168.62      1.00
 β_trop_prime[2]     -0.49      0.87 

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

0.65

                      mean       std    median      5.0%     95.0%     n_eff     r_hat
      α_prime[0]     -0.71      0.66     -0.67     -1.70      0.40    820.25      1.00
      α_prime[1]      0.27      0.64      0.26     -0.76      1.30    362.56      1.01
      α_prime[2]     -0.29      0.63     -0.28     -1.31      0.74    777.08      1.00
      α_prime[3]      0.83      0.65      0.82     -0.08      2.02    398.81      1.01
β_local_prime[0]     -0.75      0.65     -0.69     -1.81      0.31    671.71      1.00
β_local_prime[1]     -0.08      0.66     -0.08     -1.20      0.96    795.61      1.01
β_local_prime[2]      0.71      0.70      0.67     -0.43      1.80    787.68      1.00
β_local_prime[3]      0.09      0.64      0.06     -0.95      1.06    788.85      1.01
 β_trop_prime[0]      0.50      0.73      0.50     -0.78      1.58    404.88      1.02
 β_trop_prime[1]     -0.12      0.71     -0.11     -1.26      1.07    582.94      1.02
 β_trop_prime[2]     -0.71      0.73 

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

0.75

                      mean       std    median      5.0%     95.0%     n_eff     r_hat
      α_prime[0]     -0.65      0.65     -0.60     -1.73      0.41    478.95      1.01
      α_prime[1]      0.14      0.63      0.13     -0.83      1.23    509.43      1.00
      α_prime[2]     -0.38      0.62     -0.34     -1.43      0.61    553.02      1.01
      α_prime[3]      0.86      0.71      0.85     -0.31      1.95    332.16      1.01
β_local_prime[0]     -0.81      0.68     -0.78     -1.91      0.35    636.72      1.00
β_local_prime[1]     -0.04      0.64     -0.06     -1.10      1.00    636.73      1.00
β_local_prime[2]      0.66      0.71      0.61     -0.37      1.90    691.31      1.01
β_local_prime[3]      0.08      0.63      0.06     -0.97      1.06    832.68      1.00
 β_trop_prime[0]      0.61      0.75      0.62     -0.63      1.78    772.52      1.01
 β_trop_prime[1]     -0.10      0.68     -0.10     -1.26      0.95    597.03      1.01
 β_trop_prime[2]     -0.72      0.73 

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

0.85

                      mean       std    median      5.0%     95.0%     n_eff     r_hat
      α_prime[0]     -0.49      0.60     -0.45     -1.48      0.42     48.87      1.08
      α_prime[1]     -0.03      0.56      0.05     -1.02      0.83    632.96      1.01
      α_prime[2]     -0.37      0.59     -0.29     -1.46      0.38    108.35      1.05
      α_prime[3]      0.85      0.69      0.78     -0.11      2.06    129.97      1.05
β_local_prime[0]     -0.70      0.70     -0.68     -1.75      0.52    788.22      1.01
β_local_prime[1]     -0.10      0.71     -0.16     -1.32      0.98    307.92      1.03
β_local_prime[2]      0.50      0.82      0.50     -0.70      1.82     44.31      1.09
β_local_prime[3]     -0.07      0.65     -0.13     -0.96      1.04    209.23      1.04
 β_trop_prime[0]      0.41      0.81      0.34     -0.80      1.77    539.41      1.02
 β_trop_prime[1]     -0.18      0.76     -0.18     -1.42      1.18   1262.20      1.01
 β_trop_prime[2]     -0.31      0.81 

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
az.plot_trace(arviz_posterior, var_names=['μ_β_local']);
plt.tight_layout()
plt.show()

## The 'Funnel' distribution and associated problems.

In [None]:
## B_trop_1
x = arviz_posterior['posterior']['β_trop'].sel(β_trop_dim_0=0).to_series()
y = arviz_posterior['posterior']['σ_β_trop'].to_series()
df = pd.DataFrame({'β_local AUS': x, 'σ_β_local': y})
print(df)
sns.jointplot(data = df, x='β_local AUS', y='σ_β_local', ylim=(0, .7));

In [None]:
arviz_posterior