In [None]:
import sys
import os
sys.path.append(os.environ['GOTMWORK_ROOT']+'/tools', )
from gotmanalysis import *
%matplotlib inline

In [None]:
timetag = '20080701-20080731'
casename = 'COREII_Global'
s1data_root = '/Users/qingli/work/gotm/gotmrun/'+casename+'/VR1m_DT600s_'+timetag
s2data_root = './data/'+casename+'/VR1m_DT600s_'+timetag
fig_root = './fig/'+casename+'/VR1m_DT600s_'+timetag
os.makedirs(s2data_root, exist_ok=True)
os.makedirs(fig_root, exist_ok=True)
update_data = False

In [None]:
tmname = 'KPP-CVMix'
basepath = s1data_root+'/'+tmname
s2data_name = s2data_root+'/data_forcing_regime_'+tmname+'.npz'
s2data1_name = s2data_root+'/data_stable_'+tmname+'.npz'
figname = fig_root+'/fig_forcing_regime.png'
loclist = sorted(os.listdir(basepath))
if update_data or not os.path.isfile(s2data_name):
    # save data
    pathlist = [basepath+'/'+x+'/gotm_out_s1.nc' for x in loclist]
    godmobj = GOTMOutputDataMap(pathlist)
    forcing_regime = np.zeros(godmobj.ncase)
    unstable = np.zeros(godmobj.ncase)
    for i in np.arange(godmobj.ncase):
        if np.mod(i, 100) == 0:
            print('{:6.2f} %'.format(i/godmobj.ncase*100.0))
        tmp = GOTMOutputData(godmobj._paths[i], init_time_location=False)
        ts_laturb = tmp.read_timeseries('La_Turb', ignore_time=True).data[1:]
        ts_ustar = tmp.read_timeseries('u_taus', ignore_time=True).data[1:]
        ts_hbl = tmp.read_timeseries('mld_deltaR', ignore_time=True).data[1:]
        ts_obj = tmp.read_timeseries('bflux', ignore_time=True)
        ts_bflux = ts_obj.data[1:]
        m_bflux = ts_obj.data_mean
        # stable or unstable on average
        if m_bflux > 0:
            unstable[i] = 0
        else:
            unstable[i] = 1
        # forcing regime in unstable condition 
        bmask = ts_bflux < 0
        comp_ST = 2.0*(1.0-np.exp(-0.5*ts_laturb[bmask]))
        comp_LT = 0.22/ts_laturb[bmask]**2
        comp_CT = -0.3*ts_bflux[bmask]*ts_hbl[bmask]/ts_ustar[bmask]**3
        comp_total = comp_ST + comp_LT + comp_CT
        frac_ST = comp_ST/comp_total
        frac_LT = comp_LT/comp_total
        frac_CT = comp_CT/comp_total
        mfrac_ST = np.mean(frac_ST)
        mfrac_LT = np.mean(frac_LT)
        mfrac_CT = np.mean(frac_CT)
        if mfrac_LT < 0.25 and mfrac_CT < 0.25:
            # ST dominant
            forcing_regime[i] = 1
        elif mfrac_ST < 0.25 and mfrac_CT < 0.25:
            # LT dominant
            forcing_regime[i] = 2
        elif mfrac_ST < 0.25 and mfrac_LT < 0.25:
            # CT dominant
            forcing_regime[i] = 3
        elif mfrac_ST >= 0.25 and mfrac_LT >= 0.25 and mfrac_CT < 0.25:
            # combined ST and LT
            forcing_regime[i] = 4
        elif mfrac_ST >= 0.25 and mfrac_CT >= 0.25 and mfrac_LT < 0.25:
            # combined ST and CT
            forcing_regime[i] = 5
        elif mfrac_LT >= 0.25 and mfrac_CT >= 0.25 and mfrac_ST < 0.25:
            # combined LT and CT
            forcing_regime[i] = 6
        else:
            # combined ST, LT and CT 
            forcing_regime[i] = 7

    gmobj = GOTMMap(data=forcing_regime, lon=godmobj.lon, lat=godmobj.lat, name='forcing_regime')
    gmobj.save(s2data_name)
    gmobj1 = GOTMMap(data=unstable, lon=godmobj.lon, lat=godmobj.lat, name='unstable')
    gmobj1.save(s2data1_name)
else:
    # read data
    gmobj = GOTMMap().load(s2data_name)
    gmobj1 = GOTMMap().load(s2data1_name)
    lon = gmobj.lon
    lat = gmobj.lat
    name = gmobj.name
    units = gmobj.units

In [None]:
def plot_forcing_regime(fregime, unstable, axis=None, add_colorbar=True, **kwargs):
    if not axis:
        axis = plt.gca()
    # plot map
    m = Basemap(projection='cyl', llcrnrlat=-72, urcrnrlat=72, llcrnrlon=20, urcrnrlon=380, ax=axis)
    # plot coastlines, draw label meridians and parallels.
    m.drawcoastlines()
    m.drawmapboundary(fill_color='lightgray')
    m.fillcontinents(color='gray',lake_color='lightgray')
    m.drawparallels(np.arange(-90.,91.,30.), labels=[1,0,0,1])
    m.drawmeridians(np.arange(-180.,181.,60.), labels=[1,0,0,1])
    data = fregime.data
    lat = fregime.lat
    lon = fregime.lon
    # shift longitude
    lon = np.where(lon < 20., lon+360., lon)
    x, y = m(lon, lat)
    # levels
    levels = [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5]
    cb_ticks = [1, 2, 3, 4, 5, 6, 7]
    cb_ticks_labels = ['S', 'L', 'C', 'SL', 'SC', 'LC', 'SLC']
    color_list = ['firebrick','forestgreen','dodgerblue','gold','orchid','turquoise','w']
    cmap = colors.LinearSegmentedColormap.from_list('rgb', color_list, 7)
#     cmap = 'rainbow'
    bounds = np.array(levels)
    norm = colors.BoundaryNorm(boundaries=bounds, ncolors=7)
    fig = m.scatter(x, y, marker='.', s=32, c=data, norm=norm, cmap=cmap, **kwargs)
    
    tmp = unstable.data
    smask = tmp==0
    data1 = tmp[smask]
    lat1 = lat[smask]
    lon1 = lon[smask]
    # shift longitude
    lon1 = np.where(lon1 < 20., lon1+360., lon1)
    x1, y1 = m(lon1, lat1)
    fig1 = m.scatter(x1, y1, marker='*', s=6, c='black', linewidth=0.1, alpha=1, **kwargs)
    # add colorbar
    if add_colorbar:
        cb = m.colorbar(fig, ax=axis, ticks=cb_ticks)
        cb.ax.set_yticklabels(cb_ticks_labels)
    return fig

In [None]:
fig = plt.figure()
fig.set_size_inches(6, 2.2)
plot_forcing_regime(gmobj, gmobj1)
plt.tight_layout()
plt.savefig(figname, dpi = 300)