In [None]:
import sys
sys.path.append('../')
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pygmt
import pandas as pd
import pyshtools
from remit.utils.grid import make_dataarray
from remit.utils.region import load_mask_features, feature3mask

from remit.data.models import load_ocean_age_model, load_vis_model, create_vim, load_lcs

import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

import seaborn as sns

from basis_models import load_vim_models

%load_ext autoreload
%autoreload 2

mpl.rcParams['font.family'] = "sans-serif"
mpl.rcParams['font.sans-serif'] = "arial"
mpl.rcParams['text.usetex'] == True



def plot_polygon_on_map(ax, feature):
        
    mask_dh = feature3mask(lcs_rad.to_xarray(), feature)

    ax.imshow(mask_dh, transform=ccrs.PlateCarree(),
          cmap='Blues',
          extent=(0, 360, -90, 90))

    ax.add_feature(cfeature.LAND, zorder=2, facecolor='gray')


def place_inset_map(feature, parent_ax, loc='lower right'):  
    projection = ccrs.Orthographic(central_longitude=feature.LabelX, central_latitude=feature.LabelY)
    if loc=='lower right':
        bbox_to_anchor=(0.78,-0.03,0.3,0.5)
    elif loc=='upper right':
        bbox_to_anchor=(0.78,0.72,0.3,0.5)
    axins = inset_axes(parent_ax,
                       bbox_to_anchor=bbox_to_anchor, 
                       width="100%", height="100%", #bottom right
                       bbox_transform=parent_ax.transAxes,
                       axes_class=cartopy.mpl.geoaxes.GeoAxes, 
                       axes_kwargs=dict(projection=projection))
    
    plot_polygon_on_map(axins, feature)
    

In [None]:
altitude = 0.
lmax = 142
lmin = 16


lcs = load_lcs()
lcs_rad = lcs.expand(a=lcs.r0+altitude).rad

model_dict = load_vim_models(
    ['DAH981', 
     'HM05',
     'GK07', 
     'DAH982', 
     'DAH983', 
     'VIS'], 
    altitude=altitude)


In [None]:
def region_spectrum(coeffs, tapers, eigvals, altitude=0, lmin=0):
    
    nmax = tapers.shape[0]
    ShannonNumber = int(eigvals.sum())

    flm = coeffs.expand(a=coeffs.r0+altitude).rad.expand().coeffs
    falpha = pyshtools.spectralanalysis.SlepianCoeffs(tapers, flm, nmax)
    localised_coeffs = pyshtools.spectralanalysis.SlepianCoeffsToSH(falpha, tapers, ShannonNumber)
    
    return localised_coeffs



region_list = ['CentralAtlantic',
               'SouthAtlanticN',
               'SouthAtlanticS',
               'SouthwestIndianOcean',
               'Wharton-BayofBengal',
               'SoutheastIndianOcean',
               'PacificAntarcticRidge',
               'EastPacificRidgeS',
               'EastPacificRidgeN',
               'NEPacific',
               'PacificTriangleS',
               'PacificTriangleN']




localised_coeffs_dict = {}

for region in region_list:
    tmp = {}
    try:
        res = np.load('../../vh0/notebooks/tapers/{:s}_tapers_{:d}.npz'.format(region,lmax))
        tapers = res['tapers']
        eigvals = res['eigvals']

    except:
        if region=='EastPacificRidgeN':
            print('TODO Check rotated maps')
            rotation = [90, 45, 90]
            res = np.load('../../vh0/notebooks/tapers/{:s}_tapers_142_0_0_45.npz'.format(region))

            
    tmp['lcs'] = region_spectrum(lcs, tapers, eigvals, altitude=altitude)

    for model_name in model_dict.keys():
        tmp[model_name] = region_spectrum(model_dict[model_name]['coeffs'], tapers, eigvals, altitude=altitude)

    localised_coeffs_dict[region] = tmp
    
    

In [None]:
mask_features = load_mask_features('./gis/OceanMasks.geojson')


In [None]:
fig = plt.figure(figsize=(10,11), constrained_layout=True)
spec = fig.add_gridspec(ncols=3, nrows=4)

normalize_to_area = True
yscale = 'log'  # 'linear' or 'log'

for i,region in enumerate(region_list):
    
    ax = fig.add_subplot(spec[i])
    
    if normalize_to_area:
        region_area = mask_features[mask_features.NAME==region].iloc[0].area
    else:
        region_area = 1.
    
    lcs_sp = pyshtools.SHCoeffs.from_array(localised_coeffs_dict[region]['lcs']).spectrum()
    lcs_sp[:lmin] = np.nan
    ax.plot(lcs_sp/region_area, label='LCS', linewidth=4, color='k', alpha=0.4)
    
    linestyle='-'
    for model_name in model_dict.keys():
        sp = pyshtools.SHCoeffs.from_array(localised_coeffs_dict[region][model_name]).spectrum()
        sp[:lmin] = np.nan
        if model_name=='VIS':
            linestyle='k--'
        ax.plot(sp/region_area, linestyle, label=model_name, linewidth=1.2)
    

    ax.set_xlim(0,lmax)
    ax.set_yscale(yscale)
    if altitude==0:
        ax.set_ylim(0.02,30)
        place_inset_map(mask_features.query('NAME == @region').iloc[0], ax, loc='lower right') 
    else:
        ax.set_ylim(0.0002,2)
        place_inset_map(mask_features.query('NAME == @region').iloc[0], ax, loc='upper right') 
        
    ax.grid()
        
    if i==3:
        ax.legend(ncols=2, fontsize=9)
    if i>8:
        ax.set_xlabel('degree l')
    else:
        ax.set_xticklabels([])
    if i in [0,3,6,9]:
        ax.set_ylabel(r'power [$nT^2$]')
    else:
        ax.set_yticklabels([])
    ax.set_title(region, loc='left')

fig.show()
plt.savefig('./figures/ocean_spectra_{:d}km.png'.format(int(altitude/1000)), 
            dpi=300, bbox_inches='tight')


In [None]:
fig = plt.figure(figsize=(9,9), constrained_layout=True)
spec = fig.add_gridspec(ncols=3, nrows=4)


results = []

for i,region in enumerate(region_list):

    m1 = localised_coeffs_dict[region]['lcs']
    
    ax = fig.add_subplot(spec[i])
    
    linestyle='-'
    for model_name in model_dict.keys():
        m2 = localised_coeffs_dict[region][model_name]
        _,_,corr = pyshtools.spectralanalysis.SHAdmitCorr(m1, m2)
        corr[:lmin] = np.nan
        if model_name=='VIS':
            linestyle='k--'
        ax.plot(corr, linestyle, label=model_name)
        results.append([model_name, region, np.nanmean(corr)])

    ax.set_ylim(-1,1)
    ax.set_xlim(0,lmax)
    ax.set_title(region, loc='left')
    ax.grid()
      
    place_inset_map(mask_features.query('NAME == @region').iloc[0], ax)   
    
    if i>8:
        ax.set_xlabel('degree l')
    else:
        ax.set_xticklabels([])
        
    if i in [0,3,6,9]:
        ax.set_ylabel('correlation')
    else:
        ax.set_yticklabels([])
        
    if i==3:
        ax.legend()    
    
fig.show()
plt.savefig('./figures/ocean_correlation_{:d}km.png'.format(int(altitude/1000)), 
            dpi=300, bbox_inches='tight')


In [None]:
result_df = pd.DataFrame(data=results, columns=['model', 'region', 'mean correlation'])

#fig, ax = plt.subplots()
sns.set_theme(style="whitegrid")

g = sns.relplot(data=result_df,
            x="model", y="region", hue="mean correlation", size="mean correlation",
            palette="magma_r", hue_norm=(0.1, 0.7), edgecolor=".3",
            #height=5, 
                aspect=0.9, sizes=(25, 350), size_norm=(.2, 0.7),)
#g.despine()
g.set_xticklabels(rotation=45., fontsize=9)
g.set_yticklabels(fontsize=9)
#g.ax.spines.left.set_visible(False)
#g.ax.spines.bottom.set_visible(False)
g.ax.set_xlim(g.ax.get_xlim()[0]-0.25, g.ax.get_xlim()[1]+0.25)

plt.savefig('./figures/model_correlation_comparison_{:d}km.png'.format(int(altitude/1000)), 
            bbox_inches='tight')


In [None]:

results = []

for i,region in enumerate(region_list):

    for model_name1 in model_dict.keys():
        m1 = localised_coeffs_dict[region][model_name1]

        linestyle='-'
        for model_name in model_dict.keys():
            m2 = localised_coeffs_dict[region][model_name]
            _,_,corr = pyshtools.spectralanalysis.SHAdmitCorr(m1, m2)
            corr[:lmin] = np.nan

            results.append([model_name1, model_name, region, np.nanmean(corr)])


result_df = pd.DataFrame(data=results, columns=['model1', 'model2', 'region', 'mean correlation'])

result_df


In [None]:
fig = plt.figure(figsize=(11,12), constrained_layout=True)
spec = fig.add_gridspec(ncols=3, nrows=4, hspace=0.1, wspace=0.1)


for i,region in enumerate(region_list):
    
    
    ax = fig.add_subplot(spec[i])
    
    subset = result_df.query('`region` == @region')

    g = sns.scatterplot(data=subset,
                    x="model1", y="model2", 
                    hue="mean correlation", 
                    size="mean correlation",
                    palette="magma_r", hue_norm=(0.8, 1.), edgecolor=".3",
                    sizes=(50, 450), size_norm=(.6, 1.), ax=ax, legend=False)
    
    g.set_xlabel(None)
    g.set_ylabel(None)
    
    ax.set_title(region)
    ax.set_xticks(ax.get_xticks())
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45., fontsize=9)

    ax.set_xlim(ax.get_xlim()[0]-0.25, ax.get_xlim()[1]+0.25)
    ax.set_ylim(ax.get_ylim()[0]+0.25, ax.get_ylim()[1]-0.25)
    

plt.savefig('./figures/model2model_correlation_matrix.png', dpi=600)
plt.show()
