# Try out Bayesian update to environmental estimate

In [None]:
#%%
import pymc as pm
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import pandas as pd
import scipy.stats as stats
import seaborn as sns
import xarray as xr
import arviz as az
from hierarchical_normal_belk import hierarchical_normal
import itertools
#!! conda install -c conda-forge flox
import flox
from flox.xarray import xarray_reduce # useful in doing multiple coord groupings

#%%

In [None]:
rng=np.random.Generator(np.random.PCG64(1234))
#%%
size = 160
mean_tempC_Km = 6.5/1000
max_alt_Km = 13
#keep lat and long square for ease of matrixing
horz_offest = 10
lat = np.arange(horz_offest, size)
long = np.arange(0, size - horz_offest)
alt = np.arange(0, max_alt_Km)*1000
#%%

In [None]:
def sample_AR_signal(n_samples, corr, mu=0, sigma=1):
    assert 0 < corr < 1, "Auto-correlation must be between 0 and 1"

    # Find out the offset `c` and the std of the white noise `sigma_e`
    # that produce a signal with the desired mean and variance.
    # See https://en.wikipedia.org/wiki/Autoregressive_model
    # under section "Example: An AR(1) process".
    c = mu * (1 - corr)
    sigma_e = np.sqrt((sigma ** 2) * (1 - corr ** 2))

    # Sample the auto-regressive process.
    signal = [c + np.random.normal(0, sigma_e)]
    for _ in range(1, n_samples):
        signal.append(c + corr * signal[-1] + np.random.normal(0, sigma_e))

    return np.array(signal)

def compute_corr_lag_1(signal):
    return np.corrcoef(signal[:-1], signal[1:])[0][1]
#%%


Baseline thermal along latitude

In [None]:
base_sigma = .05
samp_lat= pd.DataFrame(sample_AR_signal(size-horz_offest, 0.5, mu=2, sigma=base_sigma))
# %%

Extend along longitude

In [None]:
samp = sample_AR_signal(size-horz_offest, 0.5, mu=samp_lat, sigma=base_sigma)
samp = pd.DataFrame(samp[:, :, 0])
# %%

In [None]:
def plot_temperature_env(samp):
    x2, y2 = np.meshgrid(samp.index.values, samp.columns.values)
    plt.figure(figsize=(6,5))
    axes = plt.axes(projection='3d')
    axes.plot_surface(x2, y2,samp.values,cmap=cm.coolwarm,
                          linewidth=0, antialiased=False)
    axes.set_ylabel('Longitude')
    axes.set_xlabel('Latitude')
    axes.set_zlabel('Temperature')
    # keeps padding between figure elements
    plt.tight_layout()
    plt.show()

plot_temperature_env(samp)
# %%

Add trend on top of the AR variation -- to baseline thermal

In [None]:
lat_inc_max = 5
long_inc_mu, long_inc_std = .01, .1

def add_inc_MA(size, horz_offest, sample_AR_signal, samp_lat, lat_inc_max, long_inc_mu, long_inc_std):
    lat_inc = np.linspace(0,lat_inc_max, len(samp_lat))
    sample_lat_inc = samp_lat[0] + lat_inc
    sample_lat_inc = pd.DataFrame(sample_lat_inc)
#sample_lat_inc.plot()

    samp_inc = sample_AR_signal(size-horz_offest, corr=0.5, mu=sample_lat_inc)
    long_inc = stats.norm.rvs(loc=long_inc_mu, scale=long_inc_std, size=(size-horz_offest,size-horz_offest), random_state=None)
    long_inc = np.cumsum(long_inc, axis=0)
    samp_inc = pd.DataFrame(samp_inc[:, :, 0]+long_inc)
    return samp_inc

samp_inc = add_inc_MA(size, horz_offest, sample_AR_signal, samp_lat, lat_inc_max, long_inc_mu, long_inc_std)


plot_temperature_env(samp_inc)
# %%

Extend into atmosphere

In [None]:
#allow for inversion by having random lapse rate at diff altitudes
def add_altitude_effects(rng, samp_inc, mean_tempC_Km, max_alt_Km):
    tempC_Km = rng.normal(loc=mean_tempC_Km, scale=mean_tempC_Km/10, size=max_alt_Km)
# Temp at altitude = base temp - tempC_km * altitude
    temperature = ( [np.array(samp_inc) 
                 for _ in np.arange(max_alt_Km)]
               -np.broadcast_to(
    tempC_Km * np.arange(max_alt_Km)*1000, (size-horz_offest,size-horz_offest,max_alt_Km)
    ).T
)
    temperature = temperature.T
    return temperature

temp_3D = add_altitude_effects(rng, samp_inc, mean_tempC_Km, max_alt_Km)
# %%

In [None]:
xr_temp_3D = xr.DataArray(temp_3D, dims=['lat', 'long', 'alt'], coords={'lat': lat, 'long': long, 'alt': alt})
fig = xr_temp_3D.plot.contourf(x='lat',y='long',col='alt', col_wrap=4,
                         robust=True, vmin=-90, vmax=32, levels=20)
plt.suptitle('Temperature at different altitudes', fontsize = 'xx-large',
             weight = 'extra bold')
plt.subplots_adjust(top=.92, right=.8, left=.05, bottom=.05)

xr_tempC_Km=  xr.DataArray(mean_tempC_Km, dims=['alt'], coords={'alt': alt})

Calculate pressure based on baseline temp field and assumed L; 


In [None]:
# %%
#barometric formula
def add_barometric_effects(T = 288.15-273.15, L = 0.0065, H = 0,  P0 = 101_325.00, g0 = 9.80665, M = 0.0289644, R = 8.3144598):
    #barometric formula
    #P = P0 * (1 - L * H / T0) ^ (g0 * M / (R * L))
    #P = pressure
    #P0 = pressure at sea level = 101_325.00 Pa
    #L = temperature lapse rate = temperature lapse rate (K/m) in
    #H = altitude (m)
    #T0 = temperature at sea level = reference temperature (K)
    #g0 = gravitational acceleration = gravitational acceleration: 9.80665 m/s2
    #M = molar mass of air = molar mass of Earth's air: 0.0289644 kg/mol
    #R = gas constant = universal gas constant: 8.3144598 J/(mol·K)
    #L = temperature lapse rate
    #T = temperature
    T = T +273.15
    if isinstance(T, xr.core.dataarray.DataArray):
        T0 = T.sel(alt=0)
        
    else:
        T0 = T[0]
        print('used t[0]')
        print(type(T))
    #return P0 * (1 - L * H / (T0+273.15)) ** (g0 * M / (R * L))
    return P0 * (T / T0) ** (g0 * M / (R * L.mean()))


pressure = add_barometric_effects(T = xr_temp_3D, 
                                 L = xr_tempC_Km, 
                                 H = xr_temp_3D.alt,  P0 = 101_325.00, g0 = 9.80665, M = 0.0289644, R = 8.3144598)
   

In [None]:
pressure

In [None]:
# %%
xr_temp_pres = xr.merge(
    [xr_temp_3D.rename("Temperature"), 
     pressure.rename("Pressure")]
     )
# %%
xr_temp_pres.Pressure.plot.contourf(x='lat',y='long', col='alt', col_wrap=4,
                         robust=True, levels=20)
plt.suptitle('Pressure at different altitudes', fontsize = 'xx-large',
             weight = 'extra bold')
plt.subplots_adjust(top=.92, right=.8, left=.05, bottom=.05)
# %%

# make trajectory and get corresponding temp and pres

In [None]:
# %%
# make Z = a function of time and  X = sin of time and y = cos of time
time = np.arange(0, 100, 0.1)
release_alt = 12_000 #Troposphere goes to about 12Km, thermal is about linear there
step_alt = 1
x = (np.sin(time) +1) * size/2 +30
y = (np.cos(time) +1 ) * size/2
#create samples from normal distribution and sort them
samples = stats.weibull_max.rvs(1, loc=0, scale=1, size=len(time), random_state=None)
samples.sort()
steps = samples/(samples.max()-samples.min()) /1.3  #normalize and shrink
steps = steps - steps.min() #shift to 0
 #smaller step per time
z = release_alt * (1- steps)

plt.plot(time, z)
plt.xlabel('Time')
plt.ylabel('Altitude')
plt.title('Altitude vs Time')
ax = plt.gca()
ax.set_ylim(0, 12000)
plt.show()
#plot 3d trajectory of z by x and y
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(x, y, z)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

ax.set_zlim(0, 12000)
plt.title('3D Trajectory')
plt.show()
# %%

In [None]:
#select from xarray the temperature at the pressure of the trajectory
xr_x = xr.DataArray(x, dims=['time'], coords={'time': time})
xr_y = xr.DataArray(y, dims=['time'], coords={'time': time})
xr_z = xr.DataArray(z, dims=['time'], coords={'time': time})

xr_traj_env = xr_temp_pres.interp(lat=xr_x,long=xr_y,alt=xr_z)#, method='nearest')
xr_traj_env = xr_traj_env.interpolate_na(dim='time', method='linear', fill_value="extrapolate")

In [None]:
xr_traj_env.Temperature.plot()
plt.suptitle('Temperature over time', fontsize = 'xx-large')
plt.show()
xr_traj_env.Pressure.plot()
plt.suptitle('pressure over time', fontsize = 'xx-large')
plt.show()


In [None]:
#add wind direction and speed then its velocity relvant to the trajectory
wind_direction = 180 #degrees from north is 0, from east is 90, from south is 180, from west is 270
wind_speed = 10 #m/s #TODO: add wind speed as a function of altitude and lat long
wind_speed_long = wind_speed * np.cos(np.deg2rad(wind_direction))
wind_speed_lat = wind_speed * np.sin(np.deg2rad(wind_direction))
wind_speed_z = 0
display('wind_speed_long',wind_speed_long, 
        'wind_speed_lat',wind_speed_lat, 
        'wind_speed_z',wind_speed_z)
# %%

In [None]:
#add wind velocity relevant to the trajectory; dont add wind in
if False:
    xr_traj_env['wind_speed_long'] = wind_speed_long
    xr_traj_env['wind_speed_lat'] = wind_speed_lat
    xr_traj_env['wind_speed_z'] = wind_speed_z
    xr_traj_env['wind_speed'] = np.sqrt(wind_speed_long**2 + wind_speed_lat**2 + wind_speed_z**2)
    xr_traj_env['wind_direction'] = wind_direction
    xr_traj_env=xr_traj_env.interpolate_na(dim='time', method='linear', limit=None, use_coordinate=True, fill_value='extrapolate')
    xr_traj_env
# %%

# Using average values per Km; TODO: find more principled way to remove autocorrelation 

In [None]:
bins_alt = np.linspace(alt.min(), alt.max(), 11)
bins_lat =[lat.min(), lat.mean(),lat.max()] #quadrents
bins_long = [long.min(), long.mean(), long.max()]# quadrents
bins_time = np.arange(time.min(), time.max()+1, 10)

In [None]:
#grouping lat long and alt
if False:
  grp_traj_env=xarray_reduce(xr_traj_env.drop_vars(['wind_speed_long', 'wind_speed_lat', 'wind_speed_z', 'wind_speed', 'wind_direction']),
               'alt', 'lat', 'long',
                 func='mean',
                 expected_groups=(
                            pd.IntervalIndex.from_breaks(bins_alt, closed='left'),
                            pd.IntervalIndex.from_breaks(bins_lat, closed='left'),
                            pd.IntervalIndex.from_breaks(bins_long, closed='left')
                        ))
  grp_traj_env

In [None]:
#grouping lat long and alt and time
grp_traj_env=xarray_reduce(xr_traj_env, #.drop_vars(['wind_speed_long', 'wind_speed_lat', 'wind_speed_z', 'wind_speed', 'wind_direction']),
               'alt', 'lat', 'long', 'time',
                 func='mean',
                 expected_groups=(
                            pd.IntervalIndex.from_breaks(bins_alt, closed='left'),
                            pd.IntervalIndex.from_breaks(bins_lat, closed='left'),
                            pd.IntervalIndex.from_breaks(bins_long, closed='left'),
                            pd.IntervalIndex.from_breaks(bins_time, closed='left')
                        ))

grp_traj_env = grp_traj_env.stack(alt_lat_long_time=(
    'alt_bins', 
    'lat_bins', 
    'long_bins',
    'time_bins')).dropna(dim='alt_lat_long_time')

grp_traj_env.coords

# Model temp and pressure varying by altitude, lat, & long

In [None]:

coords={'alt_lat_long_time':
                      np.arange(grp_traj_env.sizes['alt_lat_long_time'], dtype=int)
                      }
coords

In [None]:
with pm.Model(coords=coords) as thermal_pres:
    #Temp is in celcius
    
    Alt_ = pm.ConstantData('Altitude_m', [bin_item.mid 
                                          for bin_item in grp_traj_env.alt_bins.values], 
                                          dims='alt_lat_long_time' )
    Lat_ = pm.ConstantData('Latitude', [bin_item.mid
                                        for bin_item in grp_traj_env.lat_bins.values],
                                        dims='alt_lat_long_time' )
    Long_ = pm.ConstantData('Longitude', [bin_item.mid
                                          for bin_item in grp_traj_env.long_bins.values],
                                          dims='alt_lat_long_time' )
    Temp_ = pm.ConstantData('Temperature_Samples', grp_traj_env.Temperature.values, dims='alt_lat_long_time' )
    Pres_ = pm.ConstantData('Pressure_Samples', grp_traj_env.Pressure.values, dims='alt_lat_long_time' )
    #prior on effect on temp (degC) of altitude and lat, long
    baseline_temp = pm.Normal('baseline_temp', mu=0, sigma=20) #'L'
    Alt_effect_temp = hierarchical_normal('Alt_effect_temp_Km', mu=-6, sigma=2)
    Lat_effect_temp = hierarchical_normal('Lat_effect_temp', mu=0, sigma=1)
    Long_effect_temp = hierarchical_normal('Long_effect_temp', mu=0, sigma=1)
    #prior on temp and pressure
    #TODO: PULL FROM DATABASE into a pm.Interpolated...maybe not: need relationship between data spreads?
    mu_t = pm.Deterministic('mu_t',
                               baseline_temp + Alt_effect_temp/1000 * Alt_ + Lat_effect_temp * Lat_ + Long_effect_temp * Long_, 
                               dims='alt_lat_long_time')
    #mu_t = hierarchical_normal('temperature_mean', mu= mu_mu_t, sigma = 2, dims='alt_lat_long_time')
    #mu_p = hierarchical_normal('pressure_mean', 
    mu_p= pm.Deterministic('mu_p',add_barometric_effects(T = Temp_, 
                                 L = Alt_effect_temp/1000, H = Alt_,  
                                 P0 = 101_325.00, g0 = 9.80665, M = 0.0289644, R = 8.3144598),
                                 dims='alt_lat_long_time')
    #add_barometric_effects = P0 * (T/T0) ** (g0 * M / (R * L))
    #prior on error variation
    sigma_t=pm.Exponential('model_error_t', 1/0.025)
    sigma_p=pm.Exponential('model_error_p', 10/1)
    #adjusted temp - normal dist error term
    obs_t = pm.Normal('obs_t', mu=mu_t, sigma=sigma_t, 
                    observed = Temp_, dims='alt_lat_long_time')
    obs_p = pm.Normal('obs_p', mu=mu_p, sigma=sigma_p, 
                    observed = Pres_, dims='alt_lat_long_time')
    
pm.model_to_graphviz(thermal_pres)

In [None]:
with thermal_pres:
    idata2 = pm.sample_prior_predictive(1000)
az.plot_ppc(idata2, group='prior', kind='cumulative')

In [None]:
with thermal_pres:
    idata2.extend(pm.sample(1000, tune=1000,  nuts=dict(max_treedepth=15, target_accept=0.9)))
    az.plot_trace(idata2)
    plt.subplots_adjust (hspace=0.4)#, wspace=0.4) 
    

In [None]:
idata2

In [None]:
def reset_multi_dim(need_it_expanded, have_it, dims):
    """Reset the dims of a xarray to match another xarray"""    
    for dim in dims:
        if dim in have_it.dims:
            try:
                need_it_expanded[dim] = have_it[dim]
                need_it_expanded= need_it_expanded.unstack(dim)
            except:
                pass#need_it_expanded = need_it_expanded.assign_coords({dim: have_it[dim]})
            
    return need_it_expanded

idata2.map(reset_multi_dim, have_it=grp_traj_env, dims=['alt_lat_long_time'])
display(idata2)

In [None]:

#figures with lat in coulmns and long in rows
fig, ax = plt.subplots(2, 2, figsize=(10, 10))
ax = ax.flatten()
for i, (lat, long) in enumerate([[i,j] 
                                  for i in np.unique(grp_traj_env.lat_bins.values) 
                                  for j in np.unique(grp_traj_env.long_bins.values)]):
    print(lat,long)
    ax[i].set_title(f'Lat: {lat.mid} Long: {long.mid}')
    az.plot_forest(idata2_post.temperature_mean.sel(lat_bins=lat,long_bins=long), 
                   var_names=['temperature_mean'],
                   kind='ridgeplot', 
                   combined=True, ax= ax[i]
                   )
    ax[i].grid()


In [None]:
az.plot_forest(idata2.posterior.unstack(), var_names=['temperature_mean'],kind='ridgeplot', combined=True,combine_dims='time_bins')
az.plot_forest(idata2, var_names=['pressure_mean'],kind='ridgeplot', combined=True)

In [None]:
with thermal_pres:
    # pymc sample posterior predictive check
    pm.sample_posterior_predictive(idata2, extend_inferencedata=True)
    az.plot_ppc(idata2, group='posterior', kind='cumulative')


In [None]:
az.plot_dist_comparison(idata2, kind='observed')
