In [None]:
import pymc as pm
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import pandas as pd
import scipy.stats as stats
import seaborn as sns
import xarray as xr
import arviz as az
import arviz.labels as azl


In [None]:
# next time save source code as .py file and import it

In [None]:
rng=np.random.Generator(np.random.PCG64(1234))

In [None]:
#Arviz to use bokeh in a notebook
az.style.use("arviz-doc")
# Confgure Bokeh as backend
az.rcParams["plot.backend"] = "bokeh"
az.output_notebook()
#get bokeh to work in vs code
import panel as pn
pn.extension(comms="vscode")
import hvplot.xarray

In [None]:
# save off pymc model
def save_pymc_trace( trace=None, xr_data_source=None, filename='pymc_meodel', multi_index_to_reset =None):
    if xr_data_source:
        if multi_index_to_reset:
            print('adding this', xr_data_source.reset_index(multi_index_to_reset))
            trace.add_groups({'data_source':xr_data_source.reset_index(multi_index_to_reset)})
        else:
            trace.add_groups({'data_source':xr_data_source})

    trace.to_netcdf(f'{filename}.nc')#to_zarr(filename)#
        
    #just rerun the model with no observed data
    #with open(filename, 'wb') as buff:
    #    pickle.dump({'model':model}, buff)
# open pymc model and inference object
def open_pymc_trace(filename='pymc_meodel'):
    idata = az.from_netcdf(f'{filename}.nc')#az.InferenceData.from_zarr(filename)#
    
    return idata  #data['model'],

In [None]:
#standardize the data

def normalization(xarray=None):
    xr_x = xarray
    revert_min, revert_spread = xr_x.min(), xr_x.max()-xr_x.min()
    xr_x = (xr_x - revert_min) / revert_spread
    return xr_x, revert_min, revert_spread
def revert_normalization(xarray=None, revert_min=None, revert_spread=None):
    xr_x = xarray
    xr_x = xr_x * revert_spread + revert_min
    return xr_x
def convert_normalization(xr_old=None, xr_new=None):
    if xr_new.lat.values.size > 0:
        xr_new['lat_norm'] = (xr_new.lat - xr_old.attrs['sdz_lat_min']) / xr_old.attrs['sdz_lat_spread'] 
    if xr_new.long.values.size > 0:
        xr_new['long_norm'] = (xr_new.long - xr_old.attrs['sdz_long_min']) / xr_old.attrs['sdz_long_spread']

xr_traj_env_time['stz_lat'], sdz_min, stz_spread = normalization(xr_traj_env_time.lat)
xr_traj_env_time.attrs['sdz_lat_min'] =sdz_min.values
xr_traj_env_time.attrs['sdz_lat_spread']  =stz_spread.values
xr_traj_env_time['stz_long'], sdz_min, stz_spread = normalization(xr_traj_env_time.long)
xr_traj_env_time.attrs['sdz_long_min'] =sdz_min.values
xr_traj_env_time.attrs['sdz_long_spread']  =stz_spread.values
xr_traj_env_time['stz_alt'], sdz_min, stz_spread = normalization(xr_traj_env_time.alt)
xr_traj_env_time.attrs['sdz_alt_min'] =sdz_min.values
xr_traj_env_time.attrs['sdz_alt_spread']  =stz_spread.values

xr_traj_env_time

# Bring in the old trace

In [None]:
az.rcParams["data.load"] = "eager"
idata3 = open_pymc_trace(filename='thermal_pres')
idata2 = idata3.copy()
xr_traj_env_time= idata2.data_source
xr_traj_env_time

# Run the model with no observed data

In [None]:

'''coords={'alt_lat_long_time':
                      np.arange(xr_traj_env_time.sizes['time'], dtype=int)
                      }'''
with pm.Model() as thermal_pres:
    #set coords in here to be mutable
    thermal_pres.add_coord('alt_lat_long_time', np.arange(xr_traj_env_time.sizes['time'], dtype=int), mutable=True)
    #Temp is in celcius
    
    Alt_ = pm.MutableData('Altitude_m', xr_traj_env_time.alt.values,# use normal un-normalized data
                                          dims='alt_lat_long_time' )
    Lat_ = pm.MutableData('Latitude', xr_traj_env_time.stz_lat.values,
                                        dims='alt_lat_long_time' )
    Long_ = pm.MutableData('Longitude', xr_traj_env_time.stz_long.values,
                                          dims='alt_lat_long_time' )
    Temp_ = pm.MutableData('Temperature_Samples', xr_traj_env_time.Temperature.values, dims='alt_lat_long_time' )
    Pres_ = pm.MutableData('Pressure_Samples', xr_traj_env_time.Pressure.values, dims='alt_lat_long_time' )
    #prior on effect on temp (degC) of altitude and lat, long
    baseline_temp = pm.Normal('baseline_temp', mu=17, sigma=5)
    Alt_effect_temp = pm.Normal('Alt_effect_temp_Km', mu=-6, sigma=0.5)
    Lat_effect_temp = pm.Normal('Lat_effect_temp', mu=0, sigma=10/4)
    Long_effect_temp = pm.Normal('Long_effect_temp', mu=0, sigma=25/4)
    Lat_Long_effect_temp = 0  #pm.Normal('Lat_Long_effect_temp', mu=0, sigma=1)
    #prior on temp and pressure
    #TODO: PULL FROM DATABASE into a pm.Interpolated...maybe not: need relationship between data spreads?
    mu_t = pm.Deterministic('mu_t',
                               baseline_temp + 
                               Alt_effect_temp/1000 * Alt_ + 
                               Lat_effect_temp * Lat_ + 
                               Long_effect_temp * Long_ + 
                               Lat_Long_effect_temp * Lat_ * Long_, 
                               dims='alt_lat_long_time')

    P0 = pm.Normal('P0', mu=1, sigma=.01)*101_325.00 # lat/long influence on ground level temp captured in Temp_0
    g0 = 9.80665
    M = 0.0289644
    R = 8.3144598
    # NOTE: Temp_[0] is not the lowest altitude temperature, but the first temperature in the array
    Temp_0 = baseline_temp+ Lat_effect_temp * Lat_ + Long_effect_temp * Long_ # account for lat/long influence on ground level temp
    mu_p= pm.Deterministic('mu_p',P0 *  ((mu_t+273.15)/(Temp_0+273.15)) ** (-g0 * M / (R * (Alt_effect_temp/1000))), 
                                 dims='alt_lat_long_time')
    #add_barometric_effects = P0 * (T/T0) ** (-g0 * M / (R * L))
    #prior on error variation
    sigma_t=pm.Exponential('model_error_t', 1/5)
    sigma_p=pm.Exponential('model_error_p', 1/500)
    #adjusted temp - normal dist error term
    # For resample, remove observed argument?
    obs_t = pm.Normal('obs_t', mu=mu_t, sigma=sigma_t, dims='alt_lat_long_time')# observed = Temp_, dims='alt_lat_long_time')#
    obs_p = pm.Normal('obs_p', mu=mu_p, sigma=sigma_p, dims='alt_lat_long_time')# observed = Pres_, dims='alt_lat_long_time')
    
pm.model_to_graphviz(thermal_pres)


# Do new predictions

In [None]:
investigate_lat_long

In [None]:
def investigate_dic(lat = [100, 100, 100, 100], long = [250, 100, 250, 100], alt = [5, 5, 20000, 20000]):
    investigate_lat_long = xr.Dataset(data_vars={
    'lat': lat,
    'long': long,
    'Altitude_m': alt,
    'Temperature_Samples': [0, 0, 0,0],
    'Pressure_Samples' :[0,0,0,0],
    })
        
    return investigate_lat_long
investigate_lat_long = investigate_dic()

def make_temp_prediction(investigate_lat_long, trace=idata2):
    convert_normalization(xr_old= xr_traj_env_time,xr_new = investigate_lat_long)
    new_length = len(investigate_lat_long.Temperature_Samples.values)
    new_coord_values = (trace.constant_data.alt_lat_long_time.values.max() +1) + np.arange(new_length)

    with thermal_pres:
        thermal_pres.set_dim(name='alt_lat_long_time', 
                         new_length=new_length,
                         coord_values=new_coord_values)
    
    with thermal_pres:
    # do-operator
        pm.set_data({'Altitude_m': investigate_lat_long.Altitude_m.values,
                 'Latitude': investigate_lat_long.lat_norm.values,
                 'Longitude': investigate_lat_long.long_norm.values,
                 'Temperature_Samples': investigate_lat_long.Temperature_Samples.values,
                 'Pressure_Samples' : investigate_lat_long.Pressure_Samples.values,
                 })
    
    # sample from this out of sample posterior predictive distribution
        counterfactual = pm.sample_posterior_predictive(trace, var_names=["obs_t"], predictions=True, progressbar=False)
    counterfactual.predictions
# make seaborn violin plots along alt_lat_long_time dimension from counterfactual xarray dataset


    counterfactual.predictions_constant_data['Longitude'] = revert_normalization( counterfactual.predictions_constant_data.Longitude, 
                     revert_min=xr_traj_env_time.attrs['sdz_long_min'], 
                     revert_spread=xr_traj_env_time.attrs['sdz_long_spread'])
    counterfactual.predictions_constant_data['Latitude'] = revert_normalization( counterfactual.predictions_constant_data.Latitude,
                        revert_min=xr_traj_env_time.attrs['sdz_lat_min'],
                        revert_spread=xr_traj_env_time.attrs['sdz_lat_spread'])
    counterfactual.predictions=counterfactual.predictions.assign_coords(counterfactual.predictions_constant_data[['Latitude', 'Longitude', 'Altitude_m']])
                            
    return counterfactual

In [None]:

# This code creates a panel that users can interact with to compare the temperature distributions at two sets of lat, long, and altitude values.  The code uses the hvplot.violin method to create the distributions.
pn.extension(sizing_mode = 'stretch_width', template='fast')

slider_lat = pn.widgets.IntSlider(name='lat', start=0, end=200, step=1, value=50) # lat slider
slider_long = pn.widgets.IntSlider(name='long', start=0, end=200, step=1, value=50) # long slider
slider_alt = pn.widgets.IntSlider(name='alt', start=0, end=20_000, step=500, value=500) # alt slider

# second set for comparison
slider_lat2 = pn.widgets.IntSlider(name='lat', start=0, end=200, step=1, value=50) # lat slider
slider_long2 = pn.widgets.IntSlider(name='long', start=0, end=200, step=1, value=50) # long slider
slider_alt2 = pn.widgets.IntSlider(name='alt', start=0, end=20_000, step=500, value=500) # alt slider

no_slider = [0]

#plot_prediction takes in the coordinates of a point, and the number of samples to take from the distribution. It then plots the distribution of temperature at that point.
#az.plot_violin
def plot_prediction(slider_lat=50, slider_long=50, slider_alt=500, no_slider=[0], var_names="obs_t"):
    
    return make_temp_prediction(xr.Dataset(data_vars={
    'lat': [slider_lat],
    'long': [slider_long],
    'Altitude_m': [slider_alt],
    'Temperature_Samples': no_slider,
    'Pressure_Samples' :no_slider,
    })
    ).predictions.swap_dims({'alt_lat_long_time':'Altitude_m'}).hvplot.violin(y='obs_t',ylabel='Temperature (degC)',
                 legend=False, title=f'Temperature Distribution at {slider_lat} lat, {slider_long} long, {slider_alt} alt',
                 width=500, height=500, padding=0.4, shared_axes=True)

def plot_prediction2(slider_lat=50, slider_long=50, slider_alt=500,
                     slider_lat2 = slider_lat2, slider_long2= slider_long2, slider_alt2=slider_alt2,
                       no_slider=[0], var_names="obs_t"):
    both = plot_prediction(slider_lat=slider_lat, 
                           slider_long=slider_long, 
                           slider_alt=slider_alt) + plot_prediction(slider_lat=slider_lat2,
                                                                    slider_long=slider_long2, 
                                                                    slider_alt=slider_alt2)
    return both
     
    
    

#pre_plot = plot_prediction()

display_pn = pn.bind(plot_prediction2,
                     slider_lat=slider_lat, 
                     slider_long=slider_long, 
                     slider_alt=slider_alt,
                     slider_lat2=slider_lat2, 
                     slider_long2=slider_long2, 
                     slider_alt2=slider_alt2)
#display_pn2 = pn.bind(plot_prediction,slider_lat=slider_lat2, slider_long=slider_long2, slider_alt=slider_alt2)

really_display_pn = pn.Column(pn.Row('##Interactive Temperature Comparison'),
                              pn.Row(pn.Column(slider_lat, slider_long, slider_alt),pn.Column(slider_lat2, slider_long2, slider_alt2)),
                              pn.Row(display_pn))
#sync the y axis of both graphs
#really_display_pn[2][0][0].link(really_display_pn[2][1][0], bidirectional=True, links={'ylim':'ylim'}) 
really_display_pn
