In [None]:
import os
import glob
import pandas as pd
import numpy as np

from smrt import make_snowpack, sensor_list, make_model
from smrt.core.interface import make_interface

import matplotlib.pyplot as plt
%matplotlib widget

#intake.output_notebook()
%load_ext autoreload
%autoreload 2

from mw_antarctica.antarctic_snowpacks import expand_profile, compute_thickness
from mw_antarctica.antarctic_snowpacks import AntarcticSnowpacks

In [None]:
sim = AntarcticSnowpacks()
sim.database

In [None]:
freqs = [19, 37]

In [None]:
def run_model_amsr2(snowpacks, freqs):

    model = make_model("iba", "dort", rtsolver_options=dict(prune_deep_snowpack=8))

    channels = list(map(str, freqs))
    sensor = sensor_list.amsr2(channels)
    results = model.run(sensor, snowpacks, parallel_computation=True)
    results = results.to_dataframe(channel_axis="column")
    
    results = results.unstack()
    results.index = ["_".join(ind) for ind in results.index]  # collapse multiindex by joining names with _
    return results

def run_all_site(freqs, snowpack_params={}, microstructure=['TS']):
    
    model_results = []
    for station in sim.sites:
        
        res = run_model_amsr2(sim.prepare_snowpacks(station,
                                                 params=snowpack_params,
                                                 microstructure=microstructure), set(freqs)-set([89]))
        if 89 in freqs:
            res2 = run_model_amsr2(sim.prepare_snowpacks(station,
                                                                     params=snowpack_params,
                                                                     microstructure=microstructure,
                                                                     summer_simulation=True), [89])
            res = pd.concat((res, res2), axis=1)
            
        model_results.append(res)
    
    res = pd.DataFrame(model_results, index=sim.sites)
    return sim.database.join(res)

In [None]:

def compute_error(microstructure, polydispersity):

    errors = []
    for p in polydispersity:
        database_with_result = run_all_site(freqs, microstructure=microstructure, snowpack_params={'polydispersity': p})

        d = {'polydispersity': p}
        
        for freq in freqs:
            d[f'rmse_{freq}'] = np.sqrt(((database_with_result[f'{freq}V'] - database_with_result[f'{freq}V_{microstructure}'])**2).mean())
            d[f'abse_{freq}'] = (database_with_result[f'{freq}V'] - database_with_result[f'{freq}V_{microstructure}']).abs().mean()
            d[f'bias_{freq}'] = (database_with_result[f'{freq}V'] - database_with_result[f'{freq}V_{microstructure}']).mean()
            d[f'count_{freq}'] = database_with_result[f'{freq}V_{microstructure}'].dropna().count()
        errors.append(d)

    return pd.DataFrame(errors).set_index('polydispersity')

polydispersity = np.arange(0.4, 0.95, 0.025)
# polydispersity = np.arange(0.6, 0.65, 0.01) # to refine
        
lazzy = True
update = False

errors = {}
for m in ['UTS', 'USEXP', 'USHS']:
    
    filename = f"results/simulations-errors-antarctica-{m.lower()}.csv"
    file_exists = os.path.exists(filename)
    
    if (lazzy or update) and file_exists:
        errors[m] = pd.read_csv(filename).set_index('polydispersity').drop_duplicates()
    
    if update or not file_exists:
        err = compute_error(m, polydispersity)
        if update and file_exists:
            err = pd.concat((err, errors[m])).sort_index().drop_duplicates()
        errors[m] = err

        errors[m].to_csv(filename)
        
    errors[m]['rmse'] = np.sqrt((errors[m]['rmse_19']**2 + errors[m]['rmse_37']**2) / 2)
    errors[m]['bias'] = (errors[m]['bias_19'] + errors[m]['bias_37']) / 2
    
    imin = errors[m]['rmse'].argmin()
    print(m, errors[m].index[imin])
    
#errors

In [None]:
# Lowest RMSE

In [None]:
for m in errors:
    Koptimal = errors[m]['rmse'].idxmin()
    print(m, Koptimal, errors[m].loc[Koptimal]['rmse'])

In [None]:
f, axs = plt.subplots(1, 3, figsize=(6, 3), sharey=True)

polydispersity = errors['UTS'].index

axs[2].plot(polydispersity, errors['UTS']['rmse'], '-',)
axs[2].plot(polydispersity, np.abs(errors['UTS']['bias']), '-',)
#axs[2].set_xlabel('Repeat coefficient $q$')  # = d / 2\\pi \\xi$')
#axs[2].set_ylabel('RMSE 19 and 37 GHz, V-pol (K)')
axs[2].set_ylabel('Error (K) at 19 and 37 GHz ')
axs[2].set_ylim((0, 25))
axs[2].set_title('Teubner-Strey')

axs[0].plot(polydispersity, errors['USEXP']['rmse'], '-', label='RMSE')
axs[0].plot(polydispersity, np.abs(errors['USEXP']['bias']), '-', label='|bias|')
axs[0].set_title('Scaled exponential')
#axs[0].set_xlabel('Scaling coefficient $\\phi$')
axs[0].legend()

#axs[1].plot(errors['SHS'].index, errors['SHS']['rmse'], '-')
#axs[1].plot(errors['SHS'].index, abs(errors['SHS']['bias']), '-')
axs[1].plot(polydispersity, errors['USHS']['rmse'], '-')
axs[1].plot(polydispersity, np.abs(errors['USHS']['bias']), '-')

#axs[1].set_xlabel('Stickiness $\\tau$')
axs[1].set_title('Sticky Hard Sphere')

for i in [0, 1, 2]:
    axs[i].grid(alpha=0.2)
    axs[i].set_xlabel('Polydispersity $K$')


f.tight_layout()
plt.savefig("fig-global-optimisation-coefs.pdf")