In [1]:
import os, sys
from glob import glob
import numpy as np
import dask
import xarray as xr
from scipy.signal import welch
from matplotlib import pyplot as plt
from cmocean import cm
%matplotlib inline

from mitequinox.utils import *
from mitequinox.sigp import *
from mitequinox.plot import *

In [2]:
from dask_jobqueue import PBSCluster
# for heavy processing:
cluster = PBSCluster(cores=6, processes=6, walltime='02:00:00')
w = cluster.scale(2*10)

In [3]:
# get dask handles and check dask server status
from dask.distributed import Client
client = Client(cluster)

In [4]:
client

0,1
Client  Scheduler: tcp://10.135.36.179:45367  Dashboard: /user/yux/proxy/8787/status,Cluster  Workers: 24  Cores: 24  Memory: 400.08 GB


___________
# Useful info.

In [5]:
#grd = load_grd(ftype='nc').reset_coords()
grd = load_grd().reset_coords()
mask = ((grd.hFacW.rename({'i_g': 'i'}) == 1) &
        (grd.hFacS.rename({'j_g': 'j'}) == 1) 
       ).rename('mask').reset_coords(drop=True)
grd_rspec = xr.merge([mask, grd.XC, grd.YC, grd.Depth])

# coriolis term
lat = grd_rspec['YC']
omega = 7.3/100000
f_ij = 2*omega*np.sin(np.deg2rad(lat))

In [6]:
ds_ice = xr.open_zarr(work_data_dir+'xy/sea_ice_mask.zarr')
dij=4 
ice = ds_ice.AREA.isel(i=slice(0,None,dij), j=slice(0,None,dij)) 

_________________

## Part 1. global KE comparisons (among ageostrophic, geostrophic and total uv)

(lat,lon,KE)

In [7]:
face_all = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

D = [xr.open_zarr(work_data_dir+'xy/total_uv/E_band_integral_f%02d.zarr'%(face)) for face in face_all] 
ds_total = xr.concat(D, dim='face')
E_total = ds_total.E_all
E_total_low = ds_total.E_low

D = [xr.open_zarr(work_data_dir+'xy/geo_uv/E_band_integral_f%02d.zarr'%(face)) for face in face_all] 
ds_geo = xr.concat(D, dim='face')
E_geo = ds_geo.E_all

D = [xr.open_zarr(work_data_dir+'xy/Ageo_uv/E_band_integral_f%02d.zarr'%(face)) for face in face_all] 
ds_ageo = xr.concat(D, dim='face')
E_ageo = ds_ageo.E_all
E_ageo_low = ds_ageo.E_low
E_ageo_high = ds_ageo.E_high

In [8]:
face_all = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

D = [xr.open_zarr(work_data_dir+'xy/total_uv_all/E_band_integral_f%02d.zarr'%(face)) for face in face_all] 
ds_total_all = xr.concat(D, dim='face')
E_total_all = ds_total_all.E_all
E_total_all_low = ds_total_all.E_low

D = [xr.open_zarr(work_data_dir+'xy/geo_uv_all/E_band_integral_f%02d.zarr'%(face)) for face in face_all] 
ds_geo_all = xr.concat(D, dim='face')
E_geo_all = ds_geo_all.E_all
E_geo_low_all = ds_geo_all.E_low

D = [xr.open_zarr(work_data_dir+'xy/Ageo_uv_all/E_band_integral_f%02d.zarr'%(face)) for face in face_all] 
ds_ageo_all = xr.concat(D, dim='face')
E_ageo_all = ds_ageo_all.E_all
E_ageo_all_low = ds_ageo_all.E_low
E_ageo_all_high = ds_ageo_all.E_high

In [9]:
depth = grd_rspec.Depth.isel(i=slice(0,None,dij), j=slice(0,None,dij)) 
depth

Unnamed: 0,Array,Chunk
Bytes,60.65 MB,4.67 MB
Shape,"(13, 1080, 1080)","(1, 1080, 1080)"
Count,27 Tasks,13 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 60.65 MB 4.67 MB Shape (13, 1080, 1080) (1, 1080, 1080) Count 27 Tasks 13 Chunks Type float32 numpy.ndarray",1080  1080  13,

Unnamed: 0,Array,Chunk
Bytes,60.65 MB,4.67 MB
Shape,"(13, 1080, 1080)","(1, 1080, 1080)"
Count,27 Tasks,13 Chunks
Type,float32,numpy.ndarray


In [14]:
font_size = 24

def plot_pretty(v, colorbar=False, title=None, label=None, vmin=None, vmax=None, savefig=None, 
                offline=False, figsize=(20,12), cmmap='thermal', coast_resolution='110m',
                ignore_face=[]):
    
    if vmin is None:
        vmin = v.min()
    if vmax is None:
        vmax = v.max()
    #
    MPL_LOCK = threading.Lock()
    with MPL_LOCK:
        if offline:
            plt.switch_backend('agg')
        #
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=0))
        cmap = getattr(cm, cmmap)
        gen = (face for face in v.face.values if face not in ignore_face)
        for face in gen:
            vplt = v.sel(face=face)
            if face in [6,7,8,9]:
                # this deals with dateline crossing areas
                im = vplt.where( (vplt.XC<=179.9) & (vplt.XC>=0.)).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
                im = vplt.where( (vplt.XC>-179.9) & (vplt.XC<=0) ).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)
            else:
                im = vplt.plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap)       
                
        cax = plt.axes([0.91, 0.2875, 0.02, 0.432])
        cb = plt.colorbar(im, cax=cax, ticks=[-2, -1, 0, 1])
        cb.ax.set_yticklabels(['10$^{-2}$','10$^{-1}$','10$^0$','10$^1$'])
        cb.ax.tick_params(labelsize=font_size)
        ax.set_title('',fontsize=font_size)
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$'])
        ax.set_ylim(-60., 60.)
        ax.set_yticks([-60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60])
        ax.set_yticklabels(['$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N'])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)

        if coast_resolution is not None:
            ax.coastlines(resolution=coast_resolution, color='k')
        ax.add_feature(cfeature.LAND)        
        
        if title is not None:
            ax.set_title(title)
        if label is not None:
            cb.set_label(label=label, size=font_size)     #
        if savefig is not None:
            fig.savefig(savefig, dpi=180)
            plt.close(fig)
        #
        if not offline:
            plt.show()

In [33]:
# KE
plot_pretty( np.log10(  E_total_all.isel(i=slice(1,-1),j=slice(1,-1)) ).where( (ice>0) & (depth>500) ), label='(m$^2$ s$^{-2}$)', vmin=-2, vmax=0, cmmap='speed', ignore_face=[6], savefig='/home/uz/yux/mit_equinox_backup/hal/Geostrophy_assessment/Figures/Figure_03a_dpi180.png')

  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,


In [12]:
# KEg
plot_pretty( np.log10(  E_geo_all.isel(i=slice(1,-1),j=slice(1,-1)) ).where( (ice>0) & (depth>500) & (np.abs(E_geo_all.YC)>10) ), label='(m$^2$ s$^{-2}$)', vmin=-2, vmax=0, cmmap='speed', ignore_face=[6], savefig='/home/uz/yux/mit_equinox_backup/hal/Geostrophy_assessment/Figures/Figure_03b_dpi180.png')

  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,


In [13]:
# KEa
plot_pretty( np.log10(  E_ageo_all.isel(i=slice(1,-1),j=slice(1,-1)) ).where( (ice>0) & (depth>500) & (np.abs(E_ageo_all.YC)>10) ), label='(m$^2$ s$^{-2}$)', vmin=-2, vmax=0, cmmap='speed', ignore_face=[6], savefig='/home/uz/yux/mit_equinox_backup/hal/Geostrophy_assessment/Figures/Figure_03c_dpi180.png')

  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,


In [15]:
# KE_{a,low}
plot_pretty( np.log10(ds_ageo_all.E_low.isel(i=slice(1,-1),j=slice(1,-1))).where( (ice>0) & (depth>500) & (np.abs(E_geo_all.YC)>10) ), label='(m$^2$ s$^{-2}$)', vmin=-2, vmax=0, cmmap='speed', ignore_face=[6], savefig='/home/uz/yux/mit_equinox_backup/hal/Geostrophy_assessment/Figures/Figure_S2_low_KEa.png')

  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,


In [16]:
# KE_{a,high}
plot_pretty( np.log10(ds_ageo_all.E_high.isel(i=slice(1,-1),j=slice(1,-1))).where( (ice>0) & (depth>500) & (np.abs(E_geo_all.YC)>10) ), label='(m$^2$ s$^{-2}$)', vmin=-2, vmax=0, cmmap='speed', ignore_face=[6], savefig='/home/uz/yux/mit_equinox_backup/hal/Geostrophy_assessment/Figures/Figure_S2_high_KEa.png')

  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,


In [18]:
# KE_{g,low}
plot_pretty( np.log10(E_geo_low_all.isel(i=slice(1,-1),j=slice(1,-1))).where( (ice>0) & (depth>500) & (np.abs(E_ageo_high.YC)>10)), label='(m$^2$ s$^{-2}$)', vmin=-2, vmax=0, cmmap='speed', ignore_face=[6], savefig='/home/uz/yux/mit_equinox_backup/hal/Geostrophy_assessment/Figures/Figure_S2_low_KEg.png')

  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,


______________
# Ratio

In [10]:
import matplotlib.colors as colors

bounds = np.array([0, 0.2, 0.4, 0.6, 0.8, 1])
norm = colors.BoundaryNorm(boundaries=bounds, ncolors=5)
print(norm)

<matplotlib.colors.BoundaryNorm object at 0x2acb29130310>


In [19]:
font_size = 28

def plot_pretty(v, colorbar=False, title=None, label=None, vmin=None, vmax=None, savefig=None, 
                offline=False, figsize=(20,12), cmmap='thermal', coast_resolution='110m',
                ignore_face=[]):
    
    if vmin is None:
        vmin = v.min()
    if vmax is None:
        vmax = v.max()
    #
    MPL_LOCK = threading.Lock()
    with MPL_LOCK:
        if offline:
            plt.switch_backend('agg')
        #
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=0))
        cmap=plt.cm.get_cmap('RdBu_r', 6)
        gen = (face for face in v.face.values if face not in ignore_face)
        for face in gen:
            vplt = v.sel(face=face)
            if face in [6,7,8,9]:
                # this deals with dateline crossing areas
                im = vplt.where( (vplt.XC<=179.9) & (vplt.XC>=0.)).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap, norm=norm)
                im = vplt.where( (vplt.XC>-179.9) & (vplt.XC<=0) ).plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap, norm=norm)
            else:
                im = vplt.plot.pcolormesh(ax=ax,                   
                                transform=ccrs.PlateCarree(), vmin=vmin, vmax=vmax,
                                x='XC', y='YC', add_colorbar=colorbar, cmap=cmap, norm=norm)       
                
        cax = plt.axes([0.91, 0.2875, 0.02, 0.432])
#        cb = plt.colorbar(im, cax=cax, ticks=[0, 0.2, 0.4, 0.6, 0.8, 1],extend='max')
        cb = plt.colorbar(im, cax=cax, ticks=[0, 0.4, 0.8, 1.2, 1.6, 2],extend='max')

        cb.ax.tick_params(labelsize=font_size)
        ax.set_title('',fontsize=font_size)
        ax.set_xticks([-180, -135, -90, -45, 0, 45,  90, 135, 180])
        ax.set_xticklabels(['$180\degree$','$135\degree$W','$90\degree$W','$45\degree$W','$0\degree$','$45\degree$E','$90\degree$E','$135\degree$E','$180\degree$'])
        ax.set_ylim(-60., 60.)
        ax.set_yticks([-60, -50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50, 60])
        ax.set_yticklabels(['$60\degree$S','','$40\degree$S','','$20\degree$S','','$0\degree$','','$20\degree$N','','$40\degree$N','','$60\degree$N'])
        ax.tick_params(direction='out', length=6, width=2)
        ax.set_ylabel('',fontsize=font_size)
        ax.set_xlabel('',fontsize=font_size)
        ax.tick_params(labelsize=font_size)

        if coast_resolution is not None:
            ax.coastlines(resolution=coast_resolution, color='k')
        ax.add_feature(cfeature.LAND)        
        
        if title is not None:
            ax.set_title(title)
        if label is not None:
            cb.set_label(label=label, size=font_size)     #
        if savefig is not None:
            fig.savefig(savefig, dpi=180)
            plt.close(fig)
        #
        if not offline:
            plt.show()

In [22]:
# Figure 6a (KEa/KE)
plot_pretty( (E_ageo_all/E_total_all)
              .isel(i=slice(1,-1),j=slice(1,-1))
              .where( (ice>0) & (depth>500) & (E_ageo_all.j!=40) & (np.abs(E_ageo_all.YC)>10) ), 
              label='', vmin=0, vmax=1, cmmap='RdBu_r', ignore_face=[6],
              savefig='/home/uz/yux/mit_equinox_backup/hal/Geostrophy_assessment/Figures/Figure_06a_dpi180_0504.png'
           )

  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,


In [23]:
# Figure 6b (KE_{a,low}/KE_{low})
plot_pretty( (E_ageo_all_low/E_total_all_low)
              .isel(i=slice(1,-1),j=slice(1,-1))
              .where( (ice>0) & (depth>500) & (E_ageo_all.j!=40) & (np.abs(E_ageo_all.YC)>10) ), 
              label='', vmin=0, vmax=1, cmmap='RdBu_r', ignore_face=[6],
              savefig='/home/uz/yux/mit_equinox_backup/hal/Geostrophy_assessment/Figures/Figure_06b_dpi180_0504.png'
           )

  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
  X, Y, C, shading = self._pcolorargs('pcolormesh', *args,


In [None]:
cluster.close()