<a href="https://colab.research.google.com/github/sydney-machine-learning/cyclonegenesis_seasurfacetemperature/blob/albert/%5BJOSH%5Dfuture_cyclones.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro

  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
import os
# cpu cores available for sampling (we want this to equal num_chains)
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

from IPython.display import set_matplotlib_formats
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import arviz as az
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS, Predictive

def get_saffir_simpson_category(wind_kn):
    # NOTE: observations are rounded to nearest 5 so is this being a bit cheeky??
    if wind_kn <= 82:
        return 1
    if wind_kn <= 95:
        return 2
    if wind_kn <= 112:
        return 3
    if wind_kn <= 136:
        return 4
    return 5

DATASET_URL = "https://raw.githubusercontent.com/sydney-machine-learning/cyclonegenesis_seasurfacetemperature/albert/Climate%20Project/albert/cyclone_data/jtwc/cleaned/full_cleaned.csv"
dset = pd.read_csv(DATASET_URL)
# rename columns
dset.columns = dset.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()
dset['category'] = dset['peak_vmax_kt'].apply(get_saffir_simpson_category)
dset

  dset.columns = dset.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()
  dset.columns = dset.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()


Unnamed: 0,timestamp,storm_id,basin,season,season_tc_number,latitude_degrees,longitude_degrees,vmax_kt,peak_vmax_kt,ace,maximum_24h_intensification,month,tropical_sst,tropical_anomaly,local_sst,relative_sst,saffir-simpson_category,category
0,1982-03-14 06:00:00,1982-N-1,WP,1982.0,1.0,7.1,153.0,15.0,60.0,0.73250,15.0,3,27.237669,-0.270336,28.082220,0.844551,1,1
1,1981-10-21 06:00:00,1982-S-2,SI,1982.0,2.0,-8.0,84.6,40.0,85.0,1.18725,25.0,10,26.766030,-0.146004,28.012896,1.246866,2,2
2,1982-03-18 06:00:00,1982-N-2,WP,1982.0,2.0,3.8,160.7,25.0,105.0,2.02800,25.0,3,27.237669,-0.270336,29.003502,1.765833,3,3
3,1981-11-03 18:00:00,1982-S-3,SI,1982.0,3.0,-8.6,92.9,55.0,80.0,1.08875,15.0,11,26.829243,-0.228115,27.334639,0.505396,1,1
4,1982-03-28 06:00:00,1982-N-3,WP,1982.0,3.0,3.5,156.6,20.0,75.0,0.70650,15.0,3,27.237669,-0.270336,28.872086,1.634417,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1854,2022-10-14 06:00:00,2022-N-24,WP,2022.0,24.0,26.4,154.4,25.0,40.0,0.19675,15.0,10,26.962795,0.050762,28.137896,1.175100,1,1
1855,2022-10-26 00:00:00,2022-N-26,WP,2022.0,26.0,12.2,133.8,20.0,75.0,0.67825,25.0,10,26.962795,0.050762,29.645842,2.683046,1,1
1856,2022-10-28 12:00:00,2022-N-27,WP,2022.0,27.0,8.0,140.4,25.0,40.0,0.10875,15.0,10,26.962795,0.050762,29.821840,2.859045,1,1
1857,2022-11-11 18:00:00,2022-N-28,WP,2022.0,28.0,20.2,166.3,20.0,40.0,0.10600,10.0,11,27.142609,0.085251,28.148760,1.006151,1,1


In [3]:
standardize = lambda x: (x - x.mean()) / x.std()

BASIN_TO_INDEX = {
    'SI' : 0,
    'SP' : 1,
    'WP' : 2,
}

INDEX_TO_BASIN = {
     0 : 'SI',
     1 : 'SP',
     2 : 'WP',
}

NUM_BASINS = 3

## FOR NUMERICAL STABILITY

s_wind = dset.peak_vmax_kt.std()
wind_bar = dset.peak_vmax_kt.mean()

sst_trop_bar = dset.tropical_sst.mean()
s_sst_trop = dset.tropical_sst.std()

sst_local_bar = dset.local_sst.mean()
s_sst_local = dset.local_sst.std()

dset["tropical_sst_scaled"] = dset.tropical_sst.pipe(standardize)
dset["local_sst_scaled"] = dset.local_sst.pipe(standardize)
dset["peak_wind_scaled"] = dset.peak_vmax_kt.pipe(standardize)
dset['basin_numerical'] = dset.basin.apply(lambda x: BASIN_TO_INDEX[x])
dset

Unnamed: 0,timestamp,storm_id,basin,season,season_tc_number,latitude_degrees,longitude_degrees,vmax_kt,peak_vmax_kt,ace,...,tropical_sst,tropical_anomaly,local_sst,relative_sst,saffir-simpson_category,category,tropical_sst_scaled,local_sst_scaled,peak_wind_scaled,basin_numerical
0,1982-03-14 06:00:00,1982-N-1,WP,1982.0,1.0,7.1,153.0,15.0,60.0,0.73250,...,27.237669,-0.270336,28.082220,0.844551,1,1,0.480337,-1.249390,-0.689819,2
1,1981-10-21 06:00:00,1982-S-2,SI,1982.0,2.0,-8.0,84.6,40.0,85.0,1.18725,...,26.766030,-0.146004,28.012896,1.246866,2,2,-0.644209,-1.351880,0.057589,0
2,1982-03-18 06:00:00,1982-N-2,WP,1982.0,2.0,3.8,160.7,25.0,105.0,2.02800,...,27.237669,-0.270336,29.003502,1.765833,3,3,0.480337,0.112647,0.655517,2
3,1981-11-03 18:00:00,1982-S-3,SI,1982.0,3.0,-8.6,92.9,55.0,80.0,1.08875,...,26.829243,-0.228115,27.334639,0.505396,1,1,-0.493488,-2.354626,-0.091892,0
4,1982-03-28 06:00:00,1982-N-3,WP,1982.0,3.0,3.5,156.6,20.0,75.0,0.70650,...,27.237669,-0.270336,28.872086,1.634417,1,1,0.480337,-0.081640,-0.241374,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1854,2022-10-14 06:00:00,2022-N-24,WP,2022.0,24.0,26.4,154.4,25.0,40.0,0.19675,...,26.962795,0.050762,28.137896,1.175100,1,1,-0.175055,-1.167078,-1.287747,2
1855,2022-10-26 00:00:00,2022-N-26,WP,2022.0,26.0,12.2,133.8,20.0,75.0,0.67825,...,26.962795,0.050762,29.645842,2.683046,1,1,-0.175055,1.062293,-0.241374,2
1856,2022-10-28 12:00:00,2022-N-27,WP,2022.0,27.0,8.0,140.4,25.0,40.0,0.10875,...,26.962795,0.050762,29.821840,2.859045,1,1,-0.175055,1.322491,-1.287747,2
1857,2022-11-11 18:00:00,2022-N-28,WP,2022.0,28.0,20.2,166.3,20.0,40.0,0.10600,...,27.142609,0.085251,28.148760,1.006151,1,1,0.253682,-1.151017,-1.287747,2


In [4]:
def noncentred_hierarchical(
    tropical_sst=None, local_sst=None,
    peak_wind=None, basin_num=None, q=None):

    ## TODO... choice of prior for hierarchical model??
    μ_α = numpyro.sample("μ_α", dist.Normal(dist.Normal(0,1).icdf(q), 0.1))
    σ_α = numpyro.sample("σ_α", dist.HalfCauchy(0.1))

    μ_β_trop = numpyro.sample("μ_β_trop", dist.Normal(0.0, 0.1))
    σ_β_trop = numpyro.sample("σ_β_trop", dist.HalfCauchy(0.3))

    μ_β_local = numpyro.sample("μ_β_local", dist.Normal(0.0, 0.1))
    σ_β_local = numpyro.sample("σ_β_local", dist.HalfCauchy(0.3))

    with numpyro.plate("basins", NUM_BASINS):
        α_prime = numpyro.sample("α_prime", dist.Normal(0, 1))
        α = numpyro.deterministic('α', α_prime * σ_α + μ_α)

        β_trop_prime = numpyro.sample("β_trop_prime", dist.Normal(0,1))
        β_trop = numpyro.deterministic('β_trop', β_trop_prime* σ_β_trop + μ_β_trop)


        β_local_prime = numpyro.sample("β_local_prime", dist.Normal(0,1))
        β_local = numpyro.deterministic('β_local', β_local_prime* σ_β_local + μ_β_local)

        assert α.shape == (NUM_BASINS, ), "alpha shape wrong"

    mu = numpyro.deterministic('mu', α[basin_num] +  β_trop[basin_num] * tropical_sst + β_local[basin_num] * local_sst)
    ### OBS IS ACTUALLY THE QUANTILE LOSS... so it's only pseudo-bayesian as it's a bit of a hack..
    with numpyro.plate("data", tropical_sst.shape[0], dim = -1):
    numpyro.sample('obs', dist.AsymmetricLaplaceQuantile(loc=mu, scale=1.0, quantile=q), obs=peak_wind)

In [5]:
FUTURE_TC_URL = "https://raw.githubusercontent.com/sydney-machine-learning/cyclonegenesis_seasurfacetemperature/albert/Climate%20Project/albert/cyclone_data/resampled/future_cyclones_cmip.csv"
future_tc_dset = pd.read_csv(FUTURE_TC_URL)
# rename columns
future_tc_dset.columns = future_tc_dset.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()

scenarios = [126, 245, 370, 585]
for scenario in scenarios:
  future_tc_dset[f"tropical_sst_scaled_{scenario}"] = future_tc_dset[f'tropical_sst_{scenario}'].pipe(standardize)
  future_tc_dset[f"local_sst_scaled_{scenario}"] = future_tc_dset[f'local_sst_{scenario}'].pipe(standardize)

future_tc_dset['basin_numerical'] = future_tc_dset.basin.apply(lambda x: BASIN_TO_INDEX[x])
future_tc_dset

  future_tc_dset.columns = future_tc_dset.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()
  future_tc_dset.columns = future_tc_dset.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()


Unnamed: 0,timestamp,lat,lon,basin,tropical_sst_126,local_sst_126,relative_sst_126,tropical_sst_245,local_sst_245,relative_sst_245,...,relative_sst_585,tropical_sst_scaled_126,local_sst_scaled_126,tropical_sst_scaled_245,local_sst_scaled_245,tropical_sst_scaled_370,local_sst_scaled_370,tropical_sst_scaled_585,local_sst_scaled_585,basin_numerical
0,2015-02-28 00:00:00,7.3,149.8,WP,27.640710,28.875692,1.234981,27.698423,28.873789,1.175365,...,1.231745,-0.704961,-1.413644,-0.593272,-1.448255,-0.671059,-1.485915,-0.815066,-1.495362,2
1,2015-06-01 00:00:00,4.4,156.1,WP,27.765518,30.120071,2.354553,27.838305,30.184315,2.346010,...,2.334053,-0.357780,0.509027,-0.224746,0.552949,-0.337344,0.612467,-0.421283,0.376866,2
2,2015-07-08 00:00:00,7.5,145.3,WP,27.341799,30.098497,2.756699,27.401316,30.130117,2.728802,...,2.753607,-1.536447,0.475693,-1.376014,0.470187,-1.434295,0.496964,-1.470389,0.347199,2
3,2015-07-13 00:00:00,21.5,126.5,WP,27.341799,29.402730,2.060932,27.401316,29.424799,2.023483,...,1.900936,-1.536447,-0.599326,-1.376014,-0.606850,-1.434295,-0.834877,-1.470389,-0.912458,2
4,2015-07-17 00:00:00,20.1,150.8,WP,27.341799,28.793873,1.452074,27.401316,28.877990,1.476675,...,1.405762,-1.536447,-1.540061,-1.376014,-1.441840,-1.434295,-1.450143,-1.470389,-1.643982,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1471,2049-03-10 00:00:00,-8.7,157.2,SP,28.585459,30.772000,2.186541,28.796192,30.783005,1.986813,...,2.103435,1.923064,1.516312,2.298852,1.467163,2.321450,1.643522,2.345746,1.749180,1
1472,2049-04-10 00:00:00,-5.3,188.7,SP,28.849090,30.586481,1.737391,29.043240,30.740833,1.697594,...,1.524124,2.656411,1.229670,2.949711,1.402765,2.948347,1.409789,2.963455,1.275768,1
1473,2049-11-19 00:00:00,-11.8,171.5,SP,28.196543,29.662104,1.465561,28.308370,29.750036,1.441666,...,1.184885,0.841210,-0.198571,1.013662,-0.110205,1.284371,-0.113894,1.444019,-0.166041,1
1474,2049-12-09 00:00:00,-12.6,151.2,SP,28.211971,29.230350,1.018379,28.327753,29.466457,1.138704,...,0.929159,0.884126,-0.865668,1.064727,-0.543237,1.329460,-0.425632,1.514625,-0.500115,1


In [6]:
rng_key = random.PRNGKey(0)

lower = np.arange(0.05, 0.96, 0.1)
# upper = np.arange(0.65, 0.96, 0.1)

# full = np.concatenate((lower))
full = lower
qs   = np.round(full, 2)

NUM_SAMPLES = 10000
NUM_WARMUP = 1000

predictions = {}

numpyro.enable_x64()

### ADAPT DELTA FOR A HIERARCHICAL MODEL????

for q in qs:
    rng_key, rng_key_ = random.split(rng_key)
    kernel = NUTS(noncentred_hierarchical)
    mcmc = MCMC(kernel, num_chains=4, num_warmup=1000, num_samples=10000, progress_bar=True)
    mcmc.run(
        rng_key_,
        tropical_sst=dset.tropical_sst_scaled.values,
        local_sst=dset.local_sst_scaled.values,
        peak_wind=dset.peak_wind_scaled.values,
        basin_num=dset.basin_numerical.values,
        q=q )
    mcmc_samples = mcmc.get_samples()

    predictive_generator = Predictive(
        noncentred_hierarchical, posterior_samples=mcmc_samples,
        return_sites=["mu", "obs"]
    )

    predictions[q] = {}

    for scenario in scenarios:
      print(scenario)

      predictive_samples = predictive_generator(
        rng_key_,
        tropical_sst = future_tc_dset[f'tropical_sst_scaled_{scenario}'].values,
        local_sst    = future_tc_dset[f'local_sst_scaled_{scenario}'].values,
        basin_num    = future_tc_dset.basin_numerical.values,
        peak_wind    = None,
        q            = q,
      )

      print(predictive_samples)

      predictions[q][scenario] = az.from_numpyro(
        posterior_predictive = predictive_samples
      )['posterior_predictive']['mu']

      print(predictions[q][scenario])

      print('mean wind speed at this scenario:')
      print(predictions[q][scenario].mean())


    print(q)
    mcmc.print_summary()
    print()



  0%|          | 0/11000 [00:00<?, ?it/s]

  0%|          | 0/11000 [00:00<?, ?it/s]

  0%|          | 0/11000 [00:00<?, ?it/s]

  0%|          | 0/11000 [00:00<?, ?it/s]

126


ValueError: ignored

In [None]:
## Resampling
for scenario in scenarios:
  future_tc_dset[f'peak_wind_{scenario}'] = 0

future_tc_dset['quantile'] = 0

for i in range(future_tc_dset.shape[0]):
  print(i)

  cyclone_quantile = np.random.choice(qs, 1)[0]


  future_tc_dset.loc[i, 'quantile'] = cyclone_quantile

  # get the standardised wind prediction
  for scenario in scenarios:
    cyclone_draw = np.random.choice(40000, 1)[0]
    print(scenario)
    mu = predictions[cyclone_quantile][scenario].sel(draw=cyclone_draw, mu_dim_0 = i).values
    max_predicted = mu*s_wind + wind_bar
    print(max_predicted)
    future_tc_dset.loc[i, f'peak_wind_{scenario}'] = max_predicted

future_tc_dset