# Calculation of mixing diagnostics, save yearly mean of diapycnal transport

In [1]:
%matplotlib inline
import xarray as xr
import numpy as np
import cosima_cookbook as cc
from collections import OrderedDict
from dask.distributed import Client
import matplotlib.path as mpath

import cf_xarray
from metpy.interpolate import cross_section
import pyproj

import matplotlib.pyplot as plt
import cmocean.cm as cmo
import matplotlib.colors as col
import cartopy.crs as ccrs
from cartopy.mpl.ticker import LongitudeFormatter
import matplotlib.ticker as mticker

In [2]:
def yearly_mean(var):
    # construct an xarray of days per month
    month_length = var.time.dt.days_in_month
    weights_month = (month_length.groupby('time.year') /
                     month_length.groupby('time.year').sum())
    var = (var * weights_month).groupby('time.year').sum()
    var = var.rename({'year': 'time'})
    var = var.where(var != 0)
    return var

In [3]:
client = Client()
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: /proxy/45839/status,

0,1
Dashboard: /proxy/45839/status,Workers: 7
Total threads: 28,Total memory: 251.20 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:46063,Workers: 7
Dashboard: /proxy/45839/status,Total threads: 28
Started: Just now,Total memory: 251.20 GiB

0,1
Comm: tcp://127.0.0.1:35237,Total threads: 4
Dashboard: /proxy/36371/status,Memory: 35.89 GiB
Nanny: tcp://127.0.0.1:42335,
Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-11rq99ai,Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-11rq99ai

0,1
Comm: tcp://127.0.0.1:42823,Total threads: 4
Dashboard: /proxy/36399/status,Memory: 35.89 GiB
Nanny: tcp://127.0.0.1:41387,
Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-1shpfkyh,Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-1shpfkyh

0,1
Comm: tcp://127.0.0.1:43341,Total threads: 4
Dashboard: /proxy/42709/status,Memory: 35.89 GiB
Nanny: tcp://127.0.0.1:35413,
Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-cthfaec7,Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-cthfaec7

0,1
Comm: tcp://127.0.0.1:34257,Total threads: 4
Dashboard: /proxy/36955/status,Memory: 35.89 GiB
Nanny: tcp://127.0.0.1:39979,
Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-jj_3r3ij,Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-jj_3r3ij

0,1
Comm: tcp://127.0.0.1:37273,Total threads: 4
Dashboard: /proxy/37549/status,Memory: 35.89 GiB
Nanny: tcp://127.0.0.1:43329,
Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-weqri7mo,Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-weqri7mo

0,1
Comm: tcp://127.0.0.1:44703,Total threads: 4
Dashboard: /proxy/34535/status,Memory: 35.89 GiB
Nanny: tcp://127.0.0.1:38755,
Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-vezj4g31,Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-vezj4g31

0,1
Comm: tcp://127.0.0.1:41813,Total threads: 4
Dashboard: /proxy/44081/status,Memory: 35.89 GiB
Nanny: tcp://127.0.0.1:42995,
Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-rd5l4jcv,Local directory: /jobfs/105666458.gadi-pbs/dask-scratch-space/worker-rd5l4jcv


In [4]:
DSW_region = {
    'name': ['Weddell', 'Prydz', 'Adelie', 'Ross'],
    'lon_min_area': [-58, 47, 90-360, 166-360],
    'lon_max_area': [-30, 72, 147-360, -170],
    'lat_min_area': [-75, -68, -67.5, -76.5],
    'lat_max_area': [-59, -64, -61.9, -65.8]}

In [9]:
session = cc.database.create_session()
expt = 'panant-005-zstar-ACCESSyr2'
expt_name = 'panan_005deg_jra55_ryf'
resolution = expt_name.split('_')[1][:-3]

year = '2000'
start_time= year + '-01-01'
end_time= year + '-12-31'

frequency = '1 monthly'
path_output = '/g/data/e14/cs6673/mom6_comparison/data_DSW/'

## Load data to calculate and save age, rho and mixing diagnostics as yearly means

In [None]:
%%time
for year in range(1997, 2003):
    start_time= str(year) + '-01-01'
    end_time= str(year) + '-12-31'
    print(start_time)
    sig_min = 1035
    
    rho = cc.querying.getvar(
        expt, 'rhopot2', session, frequency=frequency,
        start_time=start_time, end_time=end_time,
        chunks={'xh': '200MB', 'yh': '200MB'}).sel(
        time=slice(start_time, end_time), yh=slice(None, -55))
    var1 = yearly_mean(rho.isel(z_l=slice(0, 25))).compute()
    var2 = yearly_mean(rho.isel(z_l=slice(25, 50))).compute()
    var3 = yearly_mean(rho.isel(z_l=slice(50, 75))).compute()
    rho = xr.concat((var1, var2, var3), dim='z_l').squeeze()
    del var1, var2, var3
    print('rho done')
    
    age = cc.querying.getvar(
        expt, 'agessc', session, frequency=frequency,
        start_time=start_time, end_time=end_time,
        chunks={'xh': '200MB', 'yh': '200MB'}).sel(
        time=slice(start_time, end_time), yh=slice(None, -55))
    var1 = yearly_mean(age.isel(z_l=slice(0, 25))).compute()
    var2 = yearly_mean(age.isel(z_l=slice(25, 50))).compute()
    var3 = yearly_mean(age.isel(z_l=slice(50, 75))).compute()
    age = xr.concat((var1, var2, var3), dim='z_l').squeeze()
    del var1, var2, var3
    print('age done')
    
    # thickness of layers
    area = cc.querying.getvar(
        expt, 'areacello', session, n=1,
        chunks={'xh': '200MB', 'yh': '200MB'}).sel(
        yh=slice(None, -55))
    vol = cc.querying.getvar(
        expt, 'volcello', session,
        frequency='1 monthly',
        attrs={'cell_methods': 'area:sum rho2_l:sum yh:sum xh:sum time: mean'} ,
        start_time=start_time, end_time=end_time,
        chunks={'rho2_l': '200MB'}).sel(
        time=slice(start_time, end_time), yh=slice(None, -55),
        rho2_l=slice(sig_min, None))
    hmo = vol/area
    var1 = yearly_mean(hmo.isel(rho2_l=slice(0, 28))).compute()
    var2 = yearly_mean(hmo.isel(rho2_l=slice(28, 56))).compute()
    var3 = yearly_mean(hmo.isel(rho2_l=slice(56, 82))).compute()
    hmo = xr.concat((var1, var2, var3), dim='rho2_l').squeeze()
    hmo = hmo.cumsum('rho2_l')
    del var1, var2, var3
    print('hmo done')

    """save data"""
    rho.name = 'rhopot2'
    ds_z = rho.to_dataset()
    ds_z['agessc'] = age
    ds_z = ds_z.assign_coords(time=[np.int(year)])
    
    hmo.name = 'hmo'
    ds_rho = hmo.to_dataset()
    ds_rho = ds_rho.assign_coords(time=[np.int(year)])
    
    comp = dict(chunksizes=(42, 292, 1200),
                zlib=True, complevel=5, shuffle=True)
    enc_rho = {var: comp for var in ds_rho.data_vars}
    enc_z = {var: comp for var in ds_z.data_vars}
    ds_z.to_netcdf(
        path_output + 'Age_rhopot2_' + expt_name + '_' +
        str(year) + '.nc', encoding=enc_z)
    ds_rho.to_netcdf(
        path_output + 'Layer_thickness_' + expt_name + '_' +
        str(year) + '.nc', encoding=enc_rho)

1998-01-01


In [11]:
%%time
# append layer thickness, what I saved as hmo is the depth of each layer
for year in range(2003, 2006):
    start_time= str(year) + '-01-01'
    end_time= str(year) + '-12-31'
    print(start_time)
    sig_min = 1035
    
    # thickness of layers
    area = cc.querying.getvar(
        expt, 'areacello', session, n=1,
        chunks={'xh': '200MB', 'yh': '200MB'}).sel(
        yh=slice(None, -55))
    vol = cc.querying.getvar(
        expt, 'volcello', session,
        frequency='1 monthly',
        attrs={'cell_methods': 'area:sum rho2_l:sum yh:sum xh:sum time: mean'} ,
        start_time=start_time, end_time=end_time,
        chunks={'rho2_l': '200MB'}).sel(
        time=slice(start_time, end_time), yh=slice(None, -55),
        rho2_l=slice(sig_min, None))
    hmo = vol/area
    var1 = yearly_mean(hmo.isel(rho2_l=slice(0, 28))).compute()
    var2 = yearly_mean(hmo.isel(rho2_l=slice(28, 56))).compute()
    var3 = yearly_mean(hmo.isel(rho2_l=slice(56, 82))).compute()
    hmo = xr.concat((var1, var2, var3), dim='rho2_l').squeeze()
    del var1, var2, var3
    
    """save data"""
    hmo.name = 'dz'
    ds_rho = hmo.to_dataset()
    ds_rho = ds_rho.assign_coords(time=[np.int(year)])
    
    comp = dict(chunksizes=(42, 292, 1200),
                zlib=True, complevel=5, shuffle=True)
    enc_rho = {var: comp for var in ds_rho.data_vars}
    ds_rho.to_netcdf(
        path_output + 'Layer_thickness_' + expt_name + '_' +
        str(year) + '.nc', encoding=enc_rho, mode='a')

2003-01-01
2004-01-01
2005-01-01
CPU times: user 12min 34s, sys: 1min 14s, total: 13min 49s
Wall time: 16min 1s


  ret = callback()


In [8]:
%%time
if resolution != '0025':
    for year in range(2001, 2003):
        start_time= str(year) + '-01-01'
        end_time= str(year) + '-12-31'
        print(start_time)
        sig_min = 1035
        
        rho = cc.querying.getvar(
            expt, 'rhopot2', session, frequency=frequency,
            start_time=start_time, end_time=end_time,
            chunks={'xh': '200MB', 'yh': '200MB'}).sel(
            time=slice(start_time, end_time), yh=slice(None, -55))
        rho = yearly_mean(rho).mean('time').compute()
        
        age = cc.querying.getvar(
            expt, 'agessc', session, frequency=frequency,
            start_time=start_time, end_time=end_time,
            chunks={'xh': '200MB', 'yh': '200MB'}).sel(
            time=slice(start_time, end_time), yh=slice(None, -55))
        age = yearly_mean(age).mean('time').compute()
        print('age, rho done')
        
        # thickness of layers
        area = cc.querying.getvar(
            expt, 'areacello', session, n=1,
            chunks={'xh': '200MB', 'yh': '200MB'}).sel(
            yh=slice(None, -55))
        vol = cc.querying.getvar(
            expt, 'volcello', session,
            frequency='1 monthly',
            attrs={'cell_methods': 'area:sum rho2_l:sum yh:sum xh:sum time: mean'} ,
            start_time=start_time, end_time=end_time,
            chunks={'rho2_l': '200MB'}).sel(
            time=slice(start_time, end_time), yh=slice(None, -55),
            rho2_l=slice(sig_min, None))
        hmo = vol/area
        hmo = yearly_mean(hmo.cumsum('rho2_l')).mean('time').compute()
        print('hmo done')
        
        # mixing diagnostics
        Kd_heat = cc.querying.getvar(
            expt, 'Kd_heat', session, frequency=frequency,
            start_time=start_time, end_time=end_time,
            chunks={'rho2_i': '200MB'}).sel(
            time=slice(start_time, end_time), yh=slice(None, -55),
            rho2_i=slice(sig_min, None))
        Kd_heat = yearly_mean(Kd_heat).mean('time').compute()
        
        Kd_salt = cc.querying.getvar(
            expt, 'Kd_salt', session, frequency=frequency,
            start_time=start_time, end_time=end_time,
            chunks={'rho2_i': '200MB'}).sel(
            time=slice(start_time, end_time), yh=slice(None, -55),
            rho2_i=slice(sig_min, None))
        Kd_salt = yearly_mean(Kd_salt).mean('time').compute()
        print('Kd_heat, Kd_salt done')
        
        Kd_shear = cc.querying.getvar(
            expt, 'Kd_shear', session, frequency=frequency,
            start_time=start_time, end_time=end_time,
            chunks={'rho2_i': '200MB'}).sel(
            time=slice(start_time, end_time), yh=slice(None, -55),
            rho2_i=slice(sig_min, None))
        Kd_shear = yearly_mean(Kd_shear).mean('time').compute()
        
        Kd_BBL = cc.querying.getvar(
            expt, 'Kd_BBL', session, frequency=frequency,
            start_time=start_time, end_time=end_time,
            chunks={'rho2_i': '200MB'}).sel(
            time=slice(start_time, end_time), yh=slice(None, -55),
            rho2_i=slice(sig_min, None))
        Kd_BBL = yearly_mean(Kd_BBL).mean('time').compute()
        
        Kd_ePBL = cc.querying.getvar(
            expt, 'Kd_ePBL', session, frequency=frequency,
            start_time=start_time, end_time=end_time,
            chunks={'rho2_i': '200MB'}).sel(
            time=slice(start_time, end_time), yh=slice(None, -55),
            rho2_i=slice(sig_min, None))
        Kd_ePBL = yearly_mean(Kd_ePBL).mean('time').compute()
        print('Kd_shear, Kd_BBL, Kd_ePBL done')
        
        """save data"""
        rho.name = 'rhopot2'
        ds_z = rho.to_dataset()
        ds_z['agessc'] = age
        ds_z = ds_z.assign_coords(time=[np.int(year)])
        
        Kd_heat.name = 'Kd_heat'
        ds_rho = Kd_heat.to_dataset()
        ds_rho['Kd_salt'] = Kd_salt
        ds_rho['Kd_shear'] = Kd_shear
        ds_rho['Kd_BBL'] = Kd_BBL
        ds_rho['Kd_ePBL'] = Kd_ePBL
        ds_rho['hmo'] = hmo
        ds_rho = ds_rho.assign_coords(time=[np.int(year)])
        
        comp = dict(chunksizes=(42, 292, 1200),
                    zlib=True, complevel=5, shuffle=True)
        enc_rho = {var: comp for var in ds_rho.data_vars}
        enc_z = {var: comp for var in ds_z.data_vars}
        ds_z.to_netcdf(
            path_output + 'Age_rhopot2_' + expt_name + '_' +
            str(year) + '.nc', encoding=enc_z)
        
        ds_rho.to_netcdf(
            path_output + 'mixing_diagnostics_' + expt_name + '_' +
            str(year) + '.nc', encoding=enc_rho)

CPU times: user 5 µs, sys: 0 ns, total: 5 µs
Wall time: 8.34 µs


2024-01-02 09:33:30,694 - distributed.core - INFO - Connection to tcp://127.0.0.1:37795 has been closed.


## Load data to cut out age, rho, layer thickness and diapycnal mixing in DSW regions and save as monthly means

In [None]:
%%time
for year in range(1998, 2001):
    start_time= str(year) + '-01-01'
    end_time= str(year) + '-12-31'
    print(start_time)
    for a, area_text in enumerate(DSW_region['name']):
        print(a)
        rho = cc.querying.getvar(
            expt, 'rhopot2', session, frequency=frequency,
            start_time=start_time, end_time=end_time,
            chunks={'xh': '200MB', 'yh': '200MB'}).sel(
            time=slice(start_time, end_time)).assign_coords(
            {'area': area_text})
        rho = rho.sel(
            xh=slice(DSW_region['lon_min_area'][a],
                     DSW_region['lon_max_area'][a]),
            yh=slice(DSW_region['lat_min_area'][a],
                     DSW_region['lat_max_area'][a])).compute()
        
        age = cc.querying.getvar(
            expt, 'agessc', session, frequency=frequency,
            start_time=start_time, end_time=end_time,
            chunks={'xh': '200MB', 'yh': '200MB'}).sel(
            time=slice(start_time, end_time)).assign_coords(
            {'area': area_text})
        age = age.sel(
            xh=slice(DSW_region['lon_min_area'][a],
                     DSW_region['lon_max_area'][a]),
            yh=slice(DSW_region['lat_min_area'][a],
                     DSW_region['lat_max_area'][a])).compute()
    
        ds_z = rho.to_dataset()
        ds_z['agessc'] = age
    
        """save data"""
        comp = dict(chunksizes=(12, 75, 90, 200),
                    zlib=True, complevel=5, shuffle=True)
        enc_z = {var: comp for var in ds_z.data_vars}
        ds_z.to_netcdf(
            path_output + 'Age_rhopot2_in_' + area_text + '_' +
            expt_name + '_1m_' + str(year) + '.nc', encoding=enc_z)

In [49]:
%%time
for year in range(2003, 2006):
    for a, area_text in enumerate(DSW_region['name']):
        print(a)
        ds = xr.open_mfdataset(
            path_output + 'Diapycnal_transport_at_upper_interface_' +
            expt_name + '_' + frequency[:3:2] + '_' +
            str(year) + '*nc', concat_dim='time', combine='nested')
        ds = ds.sel(
            xh=slice(DSW_region['lon_min_area'][a],
                     DSW_region['lon_max_area'][a]),
            yh=slice(DSW_region['lat_min_area'][a],
                     DSW_region['lat_max_area'][a])).compute()
        enc = {'diapycnal_transport':
               {'chunksizes': (len(ds.time), 99, 90, 200),
                'zlib': True, 'complevel': 5, 'shuffle': True}}
        
        ds.to_netcdf(
            path_output + 'Diapycnal_transport_at_upper_interface_' +
            'in_' + area_text + '_' + expt_name + '_1m_' + str(year) +
            '.nc', encoding=enc)

0
1
2
3
0
1
2
3
0
1
2
3
CPU times: user 8min 30s, sys: 1min 25s, total: 9min 56s
Wall time: 12min 4s


In [12]:
%%time
for year in range(1997, 2001):
    for a, area_text in enumerate(DSW_region['name']):
        print(a)
        start_time = str(year) + '-01-01'
        end_time = str(year) + '-12-31'

        # thickness of layers
        area = cc.querying.getvar(
            expt, 'areacello', session, n=1,
            chunks={'xh': '200MB', 'yh': '200MB'}).sel(
            yh=slice(None, -55))
        vol = cc.querying.getvar(
            expt, 'volcello', session,
            frequency='1 monthly',
            attrs={'cell_methods': 'area:sum rho2_l:sum yh:sum xh:sum time: mean'} ,
            start_time=start_time, end_time=end_time,
            chunks={'rho2_l': '200MB'}).sel(
            time=slice(start_time, end_time), yh=slice(None, -55))
        hmo = vol/area
        hmo = hmo.sel(
            xh=slice(DSW_region['lon_min_area'][a],
                     DSW_region['lon_max_area'][a]),
            yh=slice(DSW_region['lat_min_area'][a],
                     DSW_region['lat_max_area'][a]))
        hmo = hmo.cumsum('rho2_l').compute()
        
        hmo.name = 'hmo'
        enc = {'hmo':
               {'chunksizes': (len(hmo.time), 99, 90, 200),
                'zlib': True, 'complevel': 5, 'shuffle': True}}
        hmo.to_netcdf(
            path_output + 'Layer_thickness_' +
            'in_' + area_text + '_' + expt_name + '_1m_' + str(year) +
            '.nc', encoding=enc)

0
1
2
3
0
1
2
3
0
1
2
3
0
1
2
3
CPU times: user 18min 54s, sys: 3min 20s, total: 22min 14s
Wall time: 23min 55s


In [10]:
%%time
# append layer thickness, what I saved as hmo is the depth of each layer
for year in range(1997, 2001):
    print(year)
    for a, area_text in enumerate(DSW_region['name']):
        print(a)
        start_time = str(year) + '-01-01'
        end_time = str(year) + '-12-31'

        # thickness of layers
        area = cc.querying.getvar(
            expt, 'areacello', session, n=1,
            chunks={'xh': '200MB', 'yh': '200MB'}).sel(
            yh=slice(None, -55))
        vol = cc.querying.getvar(
            expt, 'volcello', session,
            frequency='1 monthly',
            attrs={'cell_methods': 'area:sum rho2_l:sum yh:sum xh:sum time: mean'} ,
            start_time=start_time, end_time=end_time,
            chunks={'rho2_l': '200MB'}).sel(
            time=slice(start_time, end_time), yh=slice(None, -55))
        hmo = vol/area
        hmo = hmo.sel(
            xh=slice(DSW_region['lon_min_area'][a],
                     DSW_region['lon_max_area'][a]),
            yh=slice(DSW_region['lat_min_area'][a],
                     DSW_region['lat_max_area'][a]))
        hmo = hmo.compute()
        
        hmo.name = 'dz'
        enc = {'dz':
               {'chunksizes': (len(hmo.time), 99, 90, 200),
                'zlib': True, 'complevel': 5, 'shuffle': True}}
        hmo.to_netcdf(
            path_output + 'Layer_thickness_' +
            'in_' + area_text + '_' + expt_name + '_1m_' + str(year) +
            '.nc', encoding=enc, mode='a')

1997
0
1
2
3
1998
0
1
2
3
1999
0
1
2
3
2000
0
1
2
3
CPU times: user 18min 45s, sys: 3min 19s, total: 22min 4s
Wall time: 24min


### diapycnal transport: save as yearly means

files with monthly data were calculated using run_diapycnal_transp_calculation.sh

In [7]:
expt = 'panant-005-zstar-ACCESSyr2'
expt_name = 'panan_005deg_jra55_ryf'
resolution = expt.split('-')[1]

frequency = '1 monthly'
path_output = '/g/data/e14/cs6673/mom6_comparison/data_DSW/'

resolution

'005'

In [8]:
%%time
for year in range(2003, 2006):
    ds = xr.open_mfdataset(path_output + 'Diapycnal_transport_at_upper_interface_' +
                expt_name + '_' + frequency[:3:2] + '_' +
                str(year) + '*nc', concat_dim='time', combine='nested')
    ds_mean = yearly_mean(ds).squeeze().compute()
    enc = {'diapycnal_transport':
           {'chunksizes': (50, 292, 1200),
            'zlib': True, 'complevel': 5, 'shuffle': True}}
    
    ds_mean.to_netcdf(path_output + 'Diapycnal_transport_at_upper_interface_' +
                      expt_name + '_1y_' + str(year) + '.nc', encoding=enc)

CPU times: user 7min 31s, sys: 1min 1s, total: 8min 33s
Wall time: 11min 47s


### diapycnal transport: calculate for individual months/years

This can be run in parallel using run_diapycnal_transp_calculation.sh

In [6]:
year = 2000
month = 1

In [7]:
if resolution == '01':
    start_time = str(year) + '-01-01'
    end_time = str(year+1) + '-01-02'
else:
    if month == 12:
        start_time = str(year) + '-' +  str(month).zfill(2) + '-01'
        end_time = str(year+1) + '-01-02'
    else:
        start_time = str(year) + '-' +  str(month).zfill(2) + '-01'
        end_time = str(year) + '-' +  str(month+1).zfill(2) + '-02'

In [32]:
def transport_across_isopycnals_12months(expt, U, V, dvol):
    resolution = expt.split('-')[1]
    if resolution == '01':
        U = U.isel(yh=slice(None, -1))
        dvol = dvol.isel(yh=slice(None, -1))
        if str(U.time[0].values)[:7] == '2003-01':
            U = U[1:, :]
            V = V[1:, :]

    D = 0*dvol 
    k = len(dvol.rho2_l)-1
    D[:, k, :] = (dvol.isel(rho2_l=k) -
                  (U.isel(xq=slice(1, None), rho2_l=k).values -
                   U.isel(xq=slice(None, -1), rho2_l=k).values) -
                  (V.isel(yq=slice(1, None), rho2_l=k).values -
                   V.isel(yq=slice(None, -1), rho2_l=k).values))
    for k in range(len(dvol.rho2_l)-2, -1, -1):
        D[:, k, :] = (dvol.isel(rho2_l=k) + D[:, k+1, :] -
                      (U.isel(xq=slice(1, None), rho2_l=k).values -
                       U.isel(xq=slice(None, -1), rho2_l=k).values) -
                      (V.isel(yq=slice(1, None), rho2_l=k).values -
                       V.isel(yq=slice(None, -1), rho2_l=k).values))
    D['time'] = U.time
    return D

In [33]:
def transport_across_isopycnals_1month(expt, U, V, dvol):
    resolution = expt.split('-')[1]
    if resolution == '01':
        U = U.isel(yh=slice(None, -1))
        dvol = dvol.isel(yh=slice(None, -1))

    D = 0*dvol 
    k = len(dvol.rho2_l)-1
    D[k, :] = (dvol.isel(rho2_l=k) -
                  (U.isel(xq=slice(1, None), rho2_l=k).values -
                   U.isel(xq=slice(None, -1), rho2_l=k).values) -
                  (V.isel(yq=slice(1, None), rho2_l=k).values -
                   V.isel(yq=slice(None, -1), rho2_l=k).values))
    for k in range(len(dvol.rho2_l)-2, -1, -1):
        D[k, :] = (dvol.isel(rho2_l=k) + D[k+1, :] -
                      (U.isel(xq=slice(1, None), rho2_l=k).values -
                       U.isel(xq=slice(None, -1), rho2_l=k).values) -
                      (V.isel(yq=slice(1, None), rho2_l=k).values -
                       V.isel(yq=slice(None, -1), rho2_l=k).values))
    D['time'] = U.time
    return D

In [7]:
if resolution != '0025':
    # UMO and VMO
    U = cc.querying.getvar(
        expt, 'umo', session, frequency='1 monthly',
        start_time=start_time, end_time=end_time,
        chunks={'rho2_l': '200MB'}).sel(
        time=slice(start_time, end_time), yh=slice(None, -55)).squeeze()
    V = cc.querying.getvar(
        expt, 'vmo', session, frequency='1 monthly',
        start_time=start_time, end_time=end_time,
        chunks={'rho2_l': '200MB'}).sel(
        time=slice(start_time, end_time), yq=slice(None, -55)).squeeze()
    
    vol = cc.querying.getvar(
        expt, 'volcello', session,
        attrs={'cell_methods': 'area:sum rho2_l:sum yh:sum xh:sum time: point'} ,
        start_time=start_time, end_time=end_time,
        chunks={'rho2_l': '200MB'}).sel(
        time=slice(start_time, end_time), yh=slice(None, -55))
    # change in volume per second between monthly snapshots * density
    dvol =  vol.diff('time', label='lower')/(
        vol.time.diff('time', label='lower').astype('int')/1e9)*vol.rho2_l
    dvol = dvol.squeeze()

    if resolution == '01':
        D = transport_across_isopycnals_12months(expt, U, V, dvol)
    else:
        D = transport_across_isopycnals_1month(expt, U, V, dvol)

    D.name = 'diapycnal_transport'
    if resolution == '01':
        time_str = str(year)
        enc = {'diapycnal_transport':
               {'chunksizes': (1, 50, 292, 1200),
                'zlib': True, 'complevel': 5, 'shuffle': True}}
    else:
        time_str = str(D.time.values)[:7]
        enc = {'diapycnal_transport':
               {'chunksizes': (50, 292, 1200),
                'zlib': True, 'complevel': 5, 'shuffle': True}}
    
    D.to_netcdf(path_output + 'Diapycnal_transport_at_upper_interface_' +
                expt_name + '_' + frequency[:3:2] + '_' +
                time_str + '_test.nc', encoding=enc)

In [38]:
if resolution == '0025':
    # for 1/40th cut out DSW regions, otherwise it takes forever
    # run this for all regions (best with .py script)
    a = 0
    area_text = 'Weddell'
    
    # UMO and VMO
    U = cc.querying.getvar(
        expt, 'umo', session, frequency='1 monthly',
        start_time=start_time, end_time=end_time,
        chunks={'rho2_l': '200MB'}).sel(
        time=slice(start_time, end_time), 
        xq=slice(DSW_region['lon_min_area'][a],
                 DSW_region['lon_max_area'][a]),
        yh=slice(DSW_region['lat_min_area'][a],
                 DSW_region['lat_max_area'][a])).squeeze()
    # if resolution == '0025':
    #     U = U[:, :-1, :]
    V = cc.querying.getvar(
        expt, 'vmo', session, frequency='1 monthly',
        start_time=start_time, end_time=end_time,
        chunks={'rho2_l': '200MB'}).sel(
        time=slice(start_time, end_time), 
        xh=slice(DSW_region['lon_min_area'][a],
                 DSW_region['lon_max_area'][a]),
        yq=slice(DSW_region['lat_min_area'][a],
                 DSW_region['lat_max_area'][a])).squeeze()
    
    vol = cc.querying.getvar(
        expt, 'volcello', session,
        attrs={'cell_methods': 'area:sum rho2_l:sum yh:sum xh:sum time: point'} ,
        start_time=start_time, end_time=end_time,
        chunks={'rho2_l': '200MB'}).sel(
        time=slice(start_time, end_time), 
        xh=slice(DSW_region['lon_min_area'][a],
                 DSW_region['lon_max_area'][a]),
        yh=slice(DSW_region['lat_min_area'][a],
                 DSW_region['lat_max_area'][a])).squeeze()
    # change in volume per second between monthly snapshots * density
    dvol =  vol.diff('time', label='lower')/(
        vol.time.diff('time', label='lower').astype('int')/1e9)*vol.rho2_l
    dvol = dvol.squeeze()
    
    if U.xq[0] > V.xh[0]:
        V = V.isel(xh=slice(1, None))
        dvol = dvol.isel(xh=slice(1, None))
    if U.xq[-1] < V.xh[-1]:
        V = V.isel(xh=slice(0, -1))
        dvol = dvol.isel(xh=slice(0, -1))
    if V.yq[0] > U.yh[0]:
        U = U.isel(yh=slice(1, None))
        dvol = dvol.isel(yh=slice(1, None))
    if V.yq[-1] < U.yh[-1]:
        U = U.isel(yh=slice(0, -1))
        dvol = dvol.isel(yh=slice(0, -1))
    assert len(U.xq) == (len(V.xh) + 1), 'longitude has wrong dimensions'
    assert len(U.xq) == (len(dvol.xh) + 1), 'longitude of volume has wrong dimensions'
    assert len(V.yq) == (len(U.yh) + 1), 'latitude has wrong dimensions'
    assert len(V.yq) == (len(dvol.yh) + 1), 'latitude of volume has wrong dimensions'

    D = transport_across_isopycnals_1month(expt, U, V, dvol)

    D.name = 'diapycnal_transport'
    time_str = str(D.time.values)[:7]
    enc = {'diapycnal_transport':
           {'chunksizes': (99, 90, 200),
            'zlib': True, 'complevel': 5, 'shuffle': True}}
    
    D.to_netcdf(path_output + 'Diapycnal_transport_at_upper_interface_in_' +
                area_text + '_' + expt_name + '_' + frequency[:3:2] + '_' +
                time_str + '_test.nc', encoding=enc)