In [None]:
import numpy as np
import xarray as xr
import pandas as pd
from xgcm.grid import Grid
from dask.diagnostics import ProgressBar
import xrft
import gcsfs
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import intake
cat = intake.open_catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean/channel.yaml")
dsmon  = cat["MITgcm_channel_flatbottom_02km_run01_phys-mon"].to_dask()
dsmon

In [None]:
ds15d = cat.MITgcm_channel_flatbottom_02km_run01_phys_snap15D.to_dask()
ds15d

# Surface temperature field

**Data starts from July 1**

In [None]:
date = 4

In [None]:
(ds15d['T'][date,0]-ds15d['T'][date,0].mean()).plot(figsize=(10,12), vmax=4.)

# Relative vorticity ($\zeta$)

In [None]:
grid = Grid(ds15d, periodic=['X'])

In [None]:
vx = grid.diff(ds15d.V*ds15d.dyC,'X') / ds15d.rAz
uy = grid.diff(ds15d.U*ds15d.dxC,'Y',boundary='fill') / ds15d.rAz

print(uy.coords, vx.coords)
with ProgressBar():
    zeta = (vx-uy)[date,0].compute()
zeta

In [None]:
f0 = -1.1e-4
beta = 1.4e-11
cori = f0 + beta*(ds15d.YG-1e6)
(zeta/cori).plot(figsize=(10,12), vmax=1., vmin=-1., cmap='RdBu_r')

# KE zonal wavenumber spectra

In [None]:
dsmon.coords['seas'] = ('time',np.repeat(np.roll(range(4),-2),3))
ds15d.coords['seas'] = ('time',np.append(np.repeat(np.roll(range(4),-2),6),1))
dmsea = dsmon.groupby('seas').mean(['time','XC','XG'])

up = (ds15d.U.groupby('seas') - dmsea.uVeltave)
vp = (ds15d.V.groupby('seas') - dmsea.vVeltave)

In [None]:
uk2 = xrft.power_spectrum(grid.interp(up,'X').sel(YC=slice(6e5,14e5)).chunk(chunks={'XC':500}),
                         dim=['XC'], detrend='constant')
vk2 = xrft.power_spectrum(grid.interp(vp,'Y',boundary='fill').sel(YC=slice(6e5,14e5)
                                                                 ).chunk(chunks={'XC':500}),
                         dim=['XC'], detrend='constant')

with ProgressBar():
    kek = (uk2 + vk2)
    kek.coords['seas'] = ('time',np.append(np.repeat(np.roll(range(4),-2),6),1))
    kek_seas = kek.groupby('seas').mean('time').compute()

In [None]:
import matplotlib.ticker as tick

fig, ax = plt.subplots(figsize=(6,5))
fig.set_tight_layout(True)
ax2 = ax.twinx()
ax3 = ax.twiny()
ax.plot(kek_seas.freq_XC[251:]*1e3, kek_seas[2,0,:,251:].mean(['YC']), 'b')
ax.plot(kek_seas.freq_XC[251:]*1e3, kek_seas[0,0,:,251:].mean(['YC']), 'r')
im, = ax2.plot(kek_seas.freq_XC[251:]*1e3, (kek_seas[2,0,:,251:]
                                            /kek_seas[0,0,:,251:]).mean(['YC']), 
              'g')
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlim([1e-3,2.5e-1])
ax.set_xticks([1e-3,1e-2,1e-1])
ax.set_xticklabels([1e-3,1e-2,1e-1], fontsize=12)
ax.set_ylim([4e-4,4e4])
ax.set_yticks([1e-3,1e-1,1e1,1e3])
ax.set_yticklabels([1e-3,1e-1,1e1,1e3], fontsize=12)
ax.get_xaxis().set_major_formatter(tick.LogFormatterSciNotation())
ax.get_yaxis().set_major_formatter(tick.LogFormatterSciNotation())
ax2.set_ylim([0.8,5.])
ax2.set_yticks([1.,2,3.,4,5.])
ax2.set_yticklabels([1.,2.,3.,4.,5.], fontsize=14)
ax2.spines["right"].set_edgecolor(im.get_color())
ax2.yaxis.label.set_color(im.get_color())
ax2.tick_params(axis='y', colors=im.get_color())
ax3.set_xscale('log')
ax3.set_xlim([1e-3,2.5e-1])
ax3.set_xticks([1e-3,1e-2,1e-1])
ax3.set_xticklabels(np.asarray([1e3,1e2,1e1], dtype=int), fontsize=12)
ax.set_xlabel(r'k [cpkm]', fontsize=14)
ax.set_ylabel(r'[m$^3$ s$^{-2}$]', fontsize=14)
ax2.set_ylabel(r'JAS/JFM', fontsize=14)
ax3.set_xlabel(r'[km]', fontsize=14)