In [40]:
import numpy as np
import pandas as pd
from utils.jtwc_cyclone_utils import get_all_cyclones
cyclone_df = get_all_cyclones(one_per_id=True)
cyclone_df

Unnamed: 0,timestamp,Storm ID,BASIN,Season,SEASON TC NUMBER,STORMNAME,Latitude (degrees),Longitude (degrees),VMAX (kt),Peak VMAX (kt),ACE,Maximum 24h Intensification
911,1982-03-14 06:00:00,1982-N-1,WP,1982.0,1.0,,7.1,153.0,15.0,60.0,0.73250,15.0
0,1981-10-21 06:00:00,1982-S-2,SI,1982.0,2.0,,-8.0,84.6,40.0,85.0,1.18725,25.0
912,1982-03-18 06:00:00,1982-N-2,WP,1982.0,2.0,,3.8,160.7,25.0,105.0,2.02800,25.0
1,1981-11-03 18:00:00,1982-S-3,AUS,1982.0,3.0,,-8.6,92.9,55.0,80.0,1.08875,15.0
913,1982-03-28 06:00:00,1982-N-3,WP,1982.0,3.0,,3.5,156.6,20.0,75.0,0.70650,15.0
...,...,...,...,...,...,...,...,...,...,...,...,...
1877,2022-10-14 06:00:00,2022-N-24,WP,2022.0,24.0,INVEST,26.4,154.4,25.0,40.0,0.19675,15.0
1878,2022-10-26 00:00:00,2022-N-26,WP,2022.0,26.0,INVEST,12.2,133.8,20.0,75.0,0.67825,25.0
1879,2022-10-28 12:00:00,2022-N-27,WP,2022.0,27.0,INVEST,8.0,140.4,25.0,40.0,0.10875,15.0
1880,2022-11-11 18:00:00,2022-N-28,WP,2022.0,28.0,INVEST,20.2,166.3,20.0,40.0,0.10600,10.0


In [41]:
from utils.globals import SI_MIN, SI_MAX, SP_MIN, SP_MAX
from utils.SST_utils import get_historical_ssts, get_tropical_avg
import seaborn as sns
import matplotlib.pyplot as plt

historical_ssts = get_historical_ssts()
# Remove seasonality if possible...
tropical_avg_ssts = get_tropical_avg(historical_ssts).rolling(time=12).mean().dropna(dim='time')
tropical_avg_ssts
# historical_ssts # 0.2 degree grid
# historical_ssts['rSST'] = historical_ssts['sst'] - tropical_avg_ssts['sst'] don't use rSST

In [42]:
from utils.cyclone_utils import get_datetime
from utils.SST_utils import get_local_smooth_at_time, sel_mm_yyyyy, get_local_mean



def filter_nan_from_values(list_obj):
    if len(list_obj) == 0:
        return np.nan
    
    return list_obj[0]


## EFFICIENCY CONCERNS???

def curried_local_smooth(lat, lon, time):
    return get_local_smooth_at_time(historical_ssts['sst'], lat,lon, time) 

def curried_local_mean_for_month(lat, lon, time):
    return get_local_mean(historical_ssts['sst'], lat,lon, time) 

# def curried_local_smooth(lat, lon, time):
#     return get_local_smooth_at_time(historical_ssts['sst'], lat,lon, time) 

local_smooth_vectorized = np.vectorize(curried_local_smooth)
local_mean_for_month_vectorized = np.vectorize(curried_local_mean_for_month)


cyclone_df.loc[:,'Tropical SST'] = cyclone_df.apply(
    lambda row: filter_nan_from_values(tropical_avg_ssts['sst'].sel(time=sel_mm_yyyyy(tropical_avg_ssts, get_datetime(row['timestamp']).month, get_datetime(row['timestamp']).year)).values), axis=1)
cyclone_df['Local SST'] = local_smooth_vectorized(cyclone_df['Latitude (degrees)'], cyclone_df['Longitude (degrees)'], cyclone_df['timestamp'])
cyclone_df['Local Month Mean'] = local_mean_for_month_vectorized(cyclone_df['Latitude (degrees)'], cyclone_df['Longitude (degrees)'], cyclone_df['timestamp'])



In [43]:
# ONLY GET TROPICAL LOCATIONS
cyclone_df = cyclone_df.loc[(cyclone_df.loc[:,'Local Month Mean'] >= 27.5) & (cyclone_df.loc[:,'Local SST'] >= 27.5)]
cyclone_df = cyclone_df.dropna(subset=['Tropical SST'])
cyclone_df

Unnamed: 0,timestamp,Storm ID,BASIN,Season,SEASON TC NUMBER,STORMNAME,Latitude (degrees),Longitude (degrees),VMAX (kt),Peak VMAX (kt),ACE,Maximum 24h Intensification,Tropical SST,Local SST,Local Month Mean
921,1982-08-04 12:00:00,1982-N-12,WP,1982.0,12.0,,19.8,130.9,20.0,125.0,2.14875,45.0,27.565964,28.316607,29.196020
922,1982-08-08 00:00:00,1982-N-13,WP,1982.0,13.0,,7.7,153.9,20.0,80.0,0.86525,25.0,27.565964,28.282774,29.311720
923,1982-08-17 06:00:00,1982-N-14,WP,1982.0,14.0,,8.2,154.2,20.0,125.0,2.72850,25.0,27.565964,28.308462,29.301773
924,1982-08-20 00:00:00,1982-N-15,WP,1982.0,15.0,,11.3,124.8,20.0,90.0,1.47000,35.0,27.565964,28.343359,29.167967
925,1982-08-27 00:00:00,1982-N-16,WP,1982.0,16.0,,14.6,153.8,30.0,100.0,2.33900,30.0,27.565964,28.092667,29.231501
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1877,2022-10-14 06:00:00,2022-N-24,WP,2022.0,24.0,INVEST,26.4,154.4,25.0,40.0,0.19675,15.0,27.681841,28.137896,27.757618
1878,2022-10-26 00:00:00,2022-N-26,WP,2022.0,26.0,INVEST,12.2,133.8,20.0,75.0,0.67825,25.0,27.681841,29.645842,29.207081
1879,2022-10-28 12:00:00,2022-N-27,WP,2022.0,27.0,INVEST,8.0,140.4,25.0,40.0,0.10875,15.0,27.681841,29.821840,29.312254
1880,2022-11-11 18:00:00,2022-N-28,WP,2022.0,28.0,INVEST,20.2,166.3,20.0,40.0,0.10600,10.0,27.664591,28.148760,27.703903


In [44]:
renamed_df = cyclone_df.copy()
renamed_df.columns = renamed_df.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()
renamed_df

  renamed_df.columns = renamed_df.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()
  renamed_df.columns = renamed_df.columns.str.replace(' ', '_').str.replace('(','').str.replace(')','').str.lower()


Unnamed: 0,timestamp,storm_id,basin,season,season_tc_number,stormname,latitude_degrees,longitude_degrees,vmax_kt,peak_vmax_kt,ace,maximum_24h_intensification,tropical_sst,local_sst,local_month_mean
921,1982-08-04 12:00:00,1982-N-12,WP,1982.0,12.0,,19.8,130.9,20.0,125.0,2.14875,45.0,27.565964,28.316607,29.196020
922,1982-08-08 00:00:00,1982-N-13,WP,1982.0,13.0,,7.7,153.9,20.0,80.0,0.86525,25.0,27.565964,28.282774,29.311720
923,1982-08-17 06:00:00,1982-N-14,WP,1982.0,14.0,,8.2,154.2,20.0,125.0,2.72850,25.0,27.565964,28.308462,29.301773
924,1982-08-20 00:00:00,1982-N-15,WP,1982.0,15.0,,11.3,124.8,20.0,90.0,1.47000,35.0,27.565964,28.343359,29.167967
925,1982-08-27 00:00:00,1982-N-16,WP,1982.0,16.0,,14.6,153.8,30.0,100.0,2.33900,30.0,27.565964,28.092667,29.231501
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1877,2022-10-14 06:00:00,2022-N-24,WP,2022.0,24.0,INVEST,26.4,154.4,25.0,40.0,0.19675,15.0,27.681841,28.137896,27.757618
1878,2022-10-26 00:00:00,2022-N-26,WP,2022.0,26.0,INVEST,12.2,133.8,20.0,75.0,0.67825,25.0,27.681841,29.645842,29.207081
1879,2022-10-28 12:00:00,2022-N-27,WP,2022.0,27.0,INVEST,8.0,140.4,25.0,40.0,0.10875,15.0,27.681841,29.821840,29.312254
1880,2022-11-11 18:00:00,2022-N-28,WP,2022.0,28.0,INVEST,20.2,166.3,20.0,40.0,0.10600,10.0,27.664591,28.148760,27.703903


In [45]:
from IPython.display import set_matplotlib_formats
import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS

assert numpyro.__version__.startswith("0.13.2")



In [46]:
## DISTRIBUTIONAL ASSUMPTIONS

In [47]:
DATASET_URL = "https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/WaffleDivorce.csv"
dset = pd.read_csv(DATASET_URL, sep=";")
dset

Unnamed: 0,Location,Loc,Population,MedianAgeMarriage,Marriage,Marriage SE,Divorce,Divorce SE,WaffleHouses,South,Slaves1860,Population1860,PropSlaves1860
0,Alabama,AL,4.78,25.3,20.2,1.27,12.7,0.79,128,1,435080,964201,0.45
1,Alaska,AK,0.71,25.2,26.0,2.93,12.5,2.05,0,0,0,0,0.0
2,Arizona,AZ,6.33,25.8,20.3,0.98,10.8,0.74,18,0,0,0,0.0
3,Arkansas,AR,2.92,24.3,26.4,1.7,13.5,1.22,41,1,111115,435450,0.26
4,California,CA,37.25,26.8,19.1,0.39,8.0,0.24,0,0,0,379994,0.0
5,Colorado,CO,5.03,25.7,23.5,1.24,11.6,0.94,11,0,0,34277,0.0
6,Connecticut,CT,3.57,27.6,17.1,1.06,6.7,0.77,0,0,0,460147,0.0
7,Delaware,DE,0.9,26.6,23.1,2.89,8.9,1.39,3,0,1798,112216,0.016
8,District of Columbia,DC,0.6,29.7,17.7,2.53,6.3,1.89,0,0,0,75080,0.0
9,Florida,FL,18.8,26.4,17.0,0.58,8.5,0.32,133,1,61745,140424,0.44


In [48]:
standardize = lambda x: (x - x.mean()) / x.std()

dset["AgeScaled"] = dset.MedianAgeMarriage.pipe(standardize)
dset["MarriageScaled"] = dset.Marriage.pipe(standardize)
dset["DivorceScaled"] = dset.Divorce.pipe(standardize)

In [49]:
def model(marriage=None, age=None, divorce=None):
    a = numpyro.sample("a", dist.Normal(0.0, 0.2))
    M, A = 0.0, 0.0
    if marriage is not None:
        bM = numpyro.sample("bM", dist.Normal(0.0, 0.5))
        M = bM * marriage
    if age is not None:
        bA = numpyro.sample("bA", dist.Normal(0.0, 0.5))
        A = bA * age
    sigma = numpyro.sample("sigma", dist.Exponential(1.0))
    mu = a + M + A
    numpyro.sample("obs", dist.Normal(mu, sigma), obs=divorce)

In [50]:
# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
    rng_key_, marriage=dset.MarriageScaled.values, divorce=dset.DivorceScaled.values
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()

TypeError: JAX encountered invalid PRNG key data: expected key_data.ndim >= 1; got ndim=0