In [None]:
import sys
import os
import matplotlib.dates as mdates
sys.path.append(os.environ['GOTMWORK_ROOT']+'/tools', )
from gotmanalysis import *
%matplotlib inline

In [None]:
# list of turbulent methods
turbmethod_list = ['KPP-CVMix',
                   'KPP-ROMS',
                   'KPPLT-EFACTOR',
                   'KPPLT-ENTR',
                   'KPPLT-RWHGK',
                   'EPBL',
                   'EPBL-LT',
                   'SMC',
                   'SMCLT',
                   'K-EPSILON-SG',
                   'OSMOSIS']
legend_list = ['KPP-CVMix',
               'KPP-ROMS',
               'KPPLT-VR12',
               'KPPLT-LF17',
               'KPPLT-RWHGK16',
               'ePBL',
               'ePBL-LT',
               'SMC-KC94',
               'SMCLT-H15',
               'k-epsilon',
               'OSMOSIS']
dzdt_list = ['VR1m_DT60s',
             'VR1m_DT600s',
             'VR1m_DT1800s',
             'VR1m_DT3600s',
             'VR5m_DT60s',
             'VR5m_DT600s',
             'VR5m_DT1800s',
             'VR5m_DT3600s',
             'VR10m_DT60s',
             'VR10m_DT600s',
             'VR10m_DT1800s',
             'VR10m_DT3600s']
tm_color = ['black',
            'blue',
            'red',
            'orange',
            'purple',
            'skyblue',
            'steelblue',
            'limegreen',
            'green',
            'mediumvioletred',
            'darkgoldenrod']
dir_in = os.environ['GOTMRUN_ROOT']+'/TEST_RES'
dir_out = os.environ['GOTMRUN_ROOT']+'/TEST_RES'
# list of location
irow_2col = [1, 2, 0, 1, 2, 3, 3, 4, 4, 5, 5]
icol_2col = [0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1]
labels_2col = ['(b)', '(c)', '(g)', '(h)', '(i)', '(d)', '(j)', '(e)', '(k)','(f)','(l)']
# case = 'OSMOSIS_spring'
# depth = -480
# ylim = 0.05
# case = 'OSMOSIS_winter'
# depth = -200
# ylim = 0.1
case = 'OCSPapa_20130621-20131201'
depth = -100
ylim = 0.2
var = 'PE'

nm = len(turbmethod_list)
nzt = len(dzdt_list)
dz = np.zeros(nzt)
dt = np.zeros(nzt)
dz_str, dt_str = dzdt_list[0].split('_')
dz[0] = float(dz_str.replace('VR','').replace('m',''))
dt[0] = float(dt_str.replace('DT','').replace('s',''))

In [None]:
 # input data directory
dataroot = dir_in+'/'+case
# paths of files 
paths = [dataroot+'/'+turbmethod_list[i]+'_VR1m_DT60s/gotm_out.nc' for i in range(nm)]
# initialize dataset
data = GOTMOutputDataSet(paths=paths, keys=turbmethod_list)

In [None]:
# output figure name
figdir = dir_out+'/'+case
os.makedirs(figdir, exist_ok=True)
figname = figdir+'/IPE_cmp_dzdt_'+var+'.png'


In [None]:
# figure
nrow = (nm+2)//2
fig_width = 12
fig_height = 3+2*(nrow-1)

# plot figure
height_ratios = [1]*nrow
height_ratios.append(0.15)
width_ratios = [1, 1, 0.05]
f, axarr = plt.subplots(nrow, 2)
f.set_size_inches(fig_width, fig_height)

# panel a
gotmdata0 = data.cases['KPP-CVMix']
ts  = gotmdata0.read_timeseries(var, depth=depth)
ts0 = ts.data
dttime0 = ts.time
# dfld0 = ts0[-1] - ts0[0]
dfld0 = np.sqrt((ts0**2).mean())
par1 = axarr[0, 0].twinx()
par1.plot(dttime0, ts0, color='lightgray', linewidth=3)
par1.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
par1.set_ylabel('$\Delta PE$ (J m$^{-2}$)', fontsize=12)
for i in np.arange(nm):
    gotmdata1 = data.cases[turbmethod_list[i]]
    ts  = gotmdata1.read_timeseries(var, depth=depth)
    ts1 = ts.data
    dttime1 = ts.time
    axarr[0, 0].plot(dttime1, ts1-ts0, color=tm_color[i], linewidth=1.5)
axarr[0, 0].set_ylabel('$\Delta PE -\Delta PE_r$ (J m$^{-2}$)', fontsize=12)
axarr[0, 0].ticklabel_format(style='sci', axis='y', scilimits=(0,0))
axarr[0, 0].autoscale(enable=True, axis='x', tight=True)
axarr[0, 0].set_zorder(par1.get_zorder()+1)
axarr[0, 0].patch.set_visible(False)
axarr[0, 0].text(0.04, 0.92, '(a)', transform=axarr[0, 0].transAxes, fontsize=16,
                 fontweight='bold', va='top')
axarr[0, 0].xaxis.set_major_formatter(mdates.DateFormatter('%y-%m'))

# panel b-l
# loop over other turbmethods
for i in np.arange(nm):
    n = icol_2col[i]
    m = irow_2col[i]
    # paths of files
    tm_paths = [dataroot+'/'+turbmethod_list[i]+'_'+dzdt_list[k]+'/gotm_out.nc'
                for k in range(nzt)]
    # initialize dataset
    tm_data = GOTMOutputDataSet(paths=tm_paths, keys=dzdt_list)
    
    # base case
    gotmdata0 = tm_data.cases['VR1m_DT60s']
    fld0 = gotmdata0.read_timeseries(var, depth=depth).data
#     dfld0 = np.max(fld0) - np.min(fld0)
    error_dzdt = np.zeros(nzt)
    # loop over other cases
    for ii in np.arange(nzt-1):
        j = ii+1 
        gotmdata1 = tm_data.cases[dzdt_list[j]]
        fld1 = gotmdata1.read_timeseries(var, depth=depth).data
        dttime1 = num2date(gotmdata1.time, units=gotmdata1.time_units,
                           calendar=gotmdata1.time_calendar)
        # compute percentage error
        error_dzdt[j] = np.sqrt(((fld1-fld0)**2).mean())/abs(dfld0)
        # get coordinate
        dz_str, dt_str = dzdt_list[j].split('_')
        dz[j] = float(dz_str.replace('VR','').replace('m',''))
        dt[j] = float(dt_str.replace('DT','').replace('s',''))
    

    # plt.plot(dz[0], error3_dzdt[0], 'ko')
    axarr[m, n].plot(dz[0:9:4], error_dzdt[0:9:4], ':kx', linewidth=2)
    axarr[m, n].plot(dz[1:10:4], error_dzdt[1:10:4], ':rx', linewidth=2)
    axarr[m, n].plot(dz[2:11:4], error_dzdt[2:11:4], ':bx', linewidth=2)
    axarr[m, n].plot(dz[3:12:4], error_dzdt[3:12:4], ':gx', linewidth=2)
    axarr[m, n].axhline(0, color='black', linewidth=0.75)
    axarr[m, n].set_xlabel('$\Delta z$ (m)', fontsize=12)
    axarr[m, n].set_ylabel('NRMSE in $PE$', fontsize=12)
    axarr[m, n].set_xlim(0,11)
    axarr[m, n].text(0.04, 0.92, labels_2col[i], transform=axarr[m, n].transAxes,
                     fontsize=16, fontweight='bold', va='top')
    axarr[m, n].text(0.98, 1.15, legend_list[i], color=tm_color[i],
                     transform=axarr[m, n].transAxes, fontsize=13, fontweight='bold',
                     va='top', ha='right')
#     axarr[m, n].ticklabel_format(style='sci', axis='y', scilimits=(2,2))
    axarr[m, n].set_ylim(0, ylim)
    
 # reduce margin
plt.tight_layout()

# # save figure
plt.savefig(figname, dpi = 300)

# # close figure
# plt.close()