In [1]:
import argparse
import pkg_resources as pkgr
import io
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from om4labs import m6plot
import palettable
import xarray as xr
import warnings

from xwmt.preprocessing import preprocessing
from xwmt.swmt import swmt

from om4labs.om4common import horizontal_grid
from om4labs.om4common import image_handler
from om4labs.om4common import date_range
from om4labs.om4common import open_intake_catalog
from om4labs.om4parser import default_diag_parser

warnings.filterwarnings("ignore", message=".*csr_matrix.*")
warnings.filterwarnings("ignore", message=".*dates out of range.*")


def calculate(ds, bins, group_tend):
    """Calculates watermass transformation from surface fluxes"""
    
    G = swmt(ds).G('sigma0', bins=bins, group_tend=group_tend)
    
    # If tendencies were grouped then G is a DataArray
    # For consistency in plotting function, convert it to a dataset
    if group_tend:
        G = G.to_dataset()
        
    return G

def read(dictArgs, heatflux_varname="hfds", saltflux_varname="sfdsi", 
         fwflux_varname="wfo", sst_varname="tos", sss_varname="sos"):
    """Read in surface flux data"""

    infile = dictArgs["infile"]
    ds = xr.open_mfdataset(infile, combine="by_coords", use_cftime=True)
    
    ### NEED TO IMPOSE CHECK TO MAKE SURE THIS IS NOT ANNUAL DATA
    
    # Check that all required variables are here
    check_vars=[heatflux_varname,saltflux_varname,fwflux_varname,
               sst_varname,sss_varname]
    check = all(item in ds.data_vars for item in check_vars)
    if not check:
        missing = set(check_vars)-set(ds.data_vars)
        raise RuntimeError("Necessary variable {} not present in dataset".format(missing))
    
    ds["areacello"] = xr.open_mfdataset(dictArgs["static"])["areacello"]
    ds["deptho"] = xr.open_mfdataset(dictArgs["static"])["deptho"]
    ds["geolat"] = xr.open_mfdataset(dictArgs["static"])["geolat"]
    ds["geolon"] = xr.open_mfdataset(dictArgs["static"])["geolon"]
    
    ### WMT preprocessing step
    # Perhaps we should pull out some of what happens in here ?
    ds = preprocessing(ds, grid=ds, decode_times=False, verbose=False)
    
    if "bins" in dictArgs:
        bins_args = dictArgs["bins"]
        bins_args = tuple([float(x) for x in bins_args.split(",")])
        bins = np.arange(*bins_args)
    else:
        # Default bins
        bins = np.arange(20,30,0.1)
    
    # Retrieve group_tend boolean
    group_tend=dictArgs["group_tend"]

    return (
        ds,
        bins,
        group_tend
    )

def plot(G):

    # Don't plot first or last bin (expanded to capture full range)
    G = G.isel(sigma0=slice(1,-1))
    levs = G['sigma0'].values
    
    # Take annual mean and load
    G = G.mean('time').load()
    # Get terms in dataset
    terms = list(G.data_vars)
    
    fig,ax = plt.subplots()
    # Plot each term
    for term in terms:
        if term =='heat':
            color='tab:red'
        elif term =='salt':
            color='tab:blue'
        else:
            color='k'
        ax.plot(levs,G[term],label=term,color=color)
        
    # If terms were not grouped then sum them up to get total
    if len(terms)>1:
        total = xr.zeros_like(G[terms[0]])
        for term in terms:
            total += G[term]
        ax.plot(levs,total,label='total',color='k')
        
    ax.legend()
    ax.set_xlabel('SIGMA0')
    ax.set_ylabel('TRANSFORMATION ($m^3s^{-1}$)')
    ax.autoscale(enable=True, axis='x', tight=True)

    return fig


def run(dictArgs):
    """Function to call read, calc, and plot in sequence"""

    # --- the main show ---
    (
        ds,
        bins,
        group_tend
    ) = read(dictArgs)

    G = calculate(ds,bins,group_tend)

    fig = plot(G)

    filename = f"{dictArgs['outdir']}/surface_wmt"
    imgbufs = image_handler([fig], dictArgs, filename=filename)

    return imgbufs

In [3]:
pp = ('/archive/oar.gfdl.cmip6/CM4/'+
          'warsaw_201710_om4_v1.0.1/CM4_piControl_C/'+
          'gfdl.ncrc4-intel16-prod-openmp/pp/ocean_monthly')

# pp = ('/archive/Raphael.Dussin/'+
#       'FMS2019.01.03_devgfdl_20210706/CM4_piControl_c192_OM4p125_v5/'+
#       'gfdl.ncrc4-intel18-prod-openmp/pp/ocean_monthly/')
static = pp+'ocean_monthly.static.nc'
outdir = '.'
infile = pp+'/av/monthly_5yr/ocean_monthly.0001-0005.*.nc'
bins = '20,30,0.1'
group_tend = True
dictArgs = {'infile':infile,'bins':bins,
            'group_tend':group_tend,'static':static,
            'outdir':outdir,'interactive':False,'format':'stream'}

In [None]:
%time fig = run(dictArgs)