In [None]:
import os
import glob
import itertools
import pandas as pd
import numpy as np

import sys

from smrt import make_snowpack, sensor_list, make_model
from smrt.core.interface import make_interface

import matplotlib.pyplot as plt
%matplotlib widget

#intake.output_notebook()
%load_ext autoreload
%autoreload 2

from mw_antarctica.antarctic_snowpacks import expand_profile, compute_thickness
from mw_antarctica.antarctic_snowpacks import AntarcticSnowpacks

In [None]:
sim = AntarcticSnowpacks()

sim.profiles.groupby('site').min()

In [None]:
freqs = [10, 19, 37, 89]

channels = [f'{freq}{pola}' for freq in freqs for pola in 'HV']

In [None]:
profile = sim.profiles.loc['charcot-asuma']

f, axs = plt.subplots(1, 2, sharey=True)
axs[0].plot(profile.ssa, profile.z)
axs[1].plot(profile.density, profile.z)

profile = sim.profiles.loc['stop1-asuma']

axs[0].plot(profile.ssa, profile.z)
axs[1].plot(profile.density, profile.z)

#profile = expand_profile(compute_thickness(profile))
#axs[0].plot(profile.ssa, profile.z)
#axs[1].plot(profile.density, profile.z)

In [None]:
def run_model_amsr2(snowpacks, freqs):

    model = make_model("iba", "dort", rtsolver_options=dict(prune_deep_snowpack=8))

    channels = list(map(str, freqs))
    sensor = sensor_list.amsr2(channels)
    results = model.run(sensor, snowpacks, parallel_computation=True)
    results = results.to_dataframe()
    
    results = results.unstack()
    results.index = ["_".join(ind) for ind in results.index]  # collapse multiindex by joining names with _
    return results

def run_all_site(freqs, snowpack_params={}, microstructure=['TS']):
    
    model_results = []
    for station in sim.sites:
        
        res = run_model_amsr2(sim.prepare_snowpacks(station,
                                                 params=snowpack_params,
                                                 microstructure=microstructure), set(freqs)-set([89]))
        
        if 89 in freqs:
            res2 = run_model_amsr2(sim.prepare_snowpacks(station,
                                                                     params=snowpack_params,
                                                                     microstructure=microstructure,
                                                                     summer_simulation=True), [89])
            res = pd.concat((res, res2))
            
        model_results.append(res)
    res = pd.DataFrame(model_results, index=sim.sites)
    return sim.database.join(res)

In [None]:
microstructure=['UTS', 'USHS', 'USEXP']

lazy = True
filename = "results/sites-database-with-results-unified-antarctica.csv"

if lazy and os.path.exists(filename):
    database_with_result = pd.read_csv(filename)
else:
    database_with_result = run_all_site(freqs, microstructure=microstructure)
    database_with_result.to_csv(filename)

In [None]:
database_with_result

In [None]:
# Residual

residual = pd.DataFrame([(m, freq, pola, f'{freq}{pola}', database_with_result[f'{freq}{pola}'] - database_with_result[f'{freq}{pola}_{m}']) \
                         for m in microstructure for freq in freqs for pola in 'HV'],
                        columns=('microstructure', 'freq', 'pola', 'channel', 'residual'))
    

def rmse_func(x):
    return np.sqrt(np.mean(x**2))
    
def bias_func(x):
    return np.mean(x)
    
residual['RMSE'] = residual['residual'].apply(rmse_func)
residual['bias'] = residual['residual'].apply(bias_func)

residual['RMSE']

residual

# compare RMSE with obs.
stats_pola = residual.groupby(['microstructure', 'freq', 'pola']).aggregate({'RMSE': rmse_func, 'bias': bias_func })
stats = residual.groupby(['microstructure', 'freq']).aggregate({'RMSE': rmse_func, 'bias': bias_func })

np.sqrt((stats_pola.loc['UTS', 19, 'V'] **2 + stats_pola.loc['UTS', 37, 'V']**2)/2)

#stats
stats.groupby('microstructure')['RMSE'].apply(rmse_func)

stats.groupby('microstructure')['bias'].apply(rmse_func)

(stats_pola.loc['USHS', 19, 'V'], stats_pola.loc['USHS', 37, 'V'])

In [None]:
# Inter microstructure difference

difference_TS = pd.DataFrame([(m, freq, pola, f'{freq}{pola}', database_with_result[f'{freq}{pola}_UTS'] - database_with_result[f'{freq}{pola}_{m}']) \
                         for m in microstructure for freq in freqs for pola in 'HV'],
                        columns=('microstructure', 'freq', 'pola', 'channel', 'difference'))

difference_TS['RMSD'] = difference_TS['difference'].apply(rmse_func)
difference_TS['mean'] = difference_TS['difference'].apply(bias_func)

difference_TS

difference_TS.groupby('microstructure')['RMSD'].apply(rmse_func)
difference_TS.groupby('microstructure')['mean'].apply(bias_func)

In [None]:
def plot_all(emissivity=False):
    f, axs = plt.subplots(4, 1, figsize=(6, 7), gridspec_kw=dict(top=0.98, bottom=0.15, hspace=0.02, wspace=0.1))
    f.canvas.layout.height = '800px'
    
    simulations = {
                    'USHS': ('_USHS', dict(marker='x', linestyle='--', alpha=0.9)),
                    #'USEXP': ('_USEXP', dict(marker='x', linestyle='--', alpha=0.9)),
                    #'UTS': ('_UTS', dict(marker='x', linestyle='--', alpha=0.9)),
                    #'GF': ('_GF', dict(marker='s', linestyle='--', alpha=0.7)),
                    'Obs.': ('', dict(marker='o', alpha=0.5)),
                 }

    #color_V = ("#5E9732", "#277455")
    #color_H = ("#AA5B39", "#993350")

    #color_V = ('#80E135', '#35E69F')
    #color_H = ('#FF753A', '#F7396F')

    color_V = ('#2F71A3', '#4D5CE7')
    color_H = ('#FEC93C', '#FEAB3C')
    
    ylabel = "Emissivity" if emissivity else "Brightness\ntemperature (K)"
    for i in [0, 1, 2, 3]:
        axs[i].set_ylabel(ylabel)

    #axs = {freq: axs[i // 2][i % 2] for i, freq in enumerate(freqs)}
    #axs = {freq: axs[i // 2][i % 2] for i, freq in enumerate(freqs)}
    axs = {freq: axs[i] for i, freq in enumerate(freqs)}
    
    title = [""] * len(axs)
    
    stats = residual.groupby(['microstructure', 'freq']).aggregate({'RMSE': rmse_func, 'bias': bias_func })
    stats_pola = residual.groupby(['microstructure', 'freq', 'pola']).aggregate({'RMSE': rmse_func, 'bias': bias_func })

    for sim in simulations:
        suffix, style = simulations[sim]
        for i, freq in enumerate(freqs):
        
            #icolor = 1 if sim == 'Obs.' else 0
            icolor = 1

            if freq == 89:
                title[i] = "%i GHz summer" % freq
            else:
                title[i] = "%i GHz annual" % freq
            
            if emissivity:
                T = database_with_result['temperature'] + 273.15
            else:
                T = 1
            
            axs[freq].plot(database_with_result.shortname,
                           database_with_result['%iV%s' % (freq, suffix)] / T,
                           **style, lw=1, color=color_V[icolor])
            axs[freq].plot(database_with_result.shortname,
                           database_with_result['%iH%s' % (freq, suffix)] / T,
                           **style, lw=1, color=color_H[icolor])

            if sim != 'Obs.':
                rmse = stats.loc[(sim, freq)]['RMSE']
                rmseV = stats_pola.loc[(sim, freq, 'V')]['RMSE']
                rmseH = stats_pola.loc[(sim, freq, 'H')]['RMSE']
                bias = stats.loc[(sim, freq)]['bias']
                biasV = stats_pola.loc[(sim, freq, 'V')]['bias']
                biasH = stats_pola.loc[(sim, freq, 'H')]['bias']
                print(rmse)
                kwargs = dict(xycoords=('axes fraction', 'axes fraction'), alpha=0.8, fontsize=9.5)
                x0 = 0.40
                axs[freq].annotate(f'RMSE({freq}):   {rmse:.1f}K', xy=(x0, 0.19), **kwargs)
                axs[freq].annotate(f'RMSE({freq}V): {rmseV:.1f}K', xy=(x0, 0.11), **kwargs)
                axs[freq].annotate(f'RMSE({freq}H): {rmseH:.1f}K', xy=(x0, 0.03), **kwargs)
                axs[freq].annotate(f'bias({freq}):   {bias:.1f}K', xy=(x0+0.30, 0.19), **kwargs)
                axs[freq].annotate(f'bias({freq}V): {biasV:.1f}K', xy=(x0+0.30, 0.11), **kwargs)
                axs[freq].annotate(f'bias({freq}H): {biasH:.1f}K', xy=(x0+0.30, 0.03), **kwargs)
            
            axs[freq].xaxis.set_tick_params(rotation=90)
            #if i < 2:
            #    axs[freq].xaxis.set_ticklabels([])

            if emissivity:
                axs[freq].set_title(title[i], y=0.0)
                axs[freq].set_ylim((0.55, 1))
            else:
                axs[freq].set_title(title[i], y=0.8)
                axs[freq].set_ylim((135, 255))
            axs[freq].grid(alpha=0.2)
    
    plt.tight_layout()
    plt.savefig("fig-simulation-unified.pdf")
    
plot_all(emissivity=False)

In [None]:
plt.figure()
plt.plot([120, 260], [120, 260], '0.8')

color = {10: '#03aa87', 19: '#f46817', 37: "#7b7bc6", 89: '#ee188d'}

for freq in freqs:
    for pola in 'HV':
        channel = f"{freq}{pola}"
        symbol = "v" if pola == 'H' else "o"
        plt.plot(database_with_result[channel], database_with_result[channel+"_USHS"], symbol, color=color[freq], alpha=0.5)
