# WOA 2018 T/S: preprocessing for MOM6 ocean model initial conditions and bias analysis

## Blends monthly with seasonal to get full depth data
## Uses gsw 

In [None]:
import xarray as xr
from vcr import utils, conserve
import matplotlib.pyplot as plt
import numpy as np
import gsw
import time
import pydap

%matplotlib inline

**URLs for the WOA18 dataset:**

In [None]:
url_temp_1deg = 'https://www.ncei.noaa.gov/thredds-ocean/dodsC/ncei/woa/temperature/decav/1.00/'
url_salt_1deg = 'https://www.ncei.noaa.gov/thredds-ocean/dodsC/ncei/woa/salinity/decav/1.00/'

url_temp_025deg = 'https://www.ncei.noaa.gov/thredds-ocean/dodsC/ncei/woa/temperature/decav/0.25/'
url_salt_025deg = 'https://www.ncei.noaa.gov/thredds-ocean/dodsC/ncei/woa/salinity/decav/0.25/'

## Functions

In [None]:
def roll_lon_to_0360(da, axis=0, res=1):
    data = da.values
    data = np.roll(data, -180 * res, axis=axis)
    data = np.mod(data+360, 360)
    if len(data.shape) == 2:
        data[-1,-1] += 360.
    out = xr.DataArray(data=data, dims=da.dims, attrs=da.attrs)
    return out

def roll_to_0360(da, axis=-1, res=1):
    data = da.values
    data = np.roll(data, -180 * res, axis=axis)
    out = xr.DataArray(data=data, dims=da.dims, attrs=da.attrs)
    return out

In [None]:
def build_woa_TPOT_S_dataset(period=0,
                             resolution=1,
                             url_temp=url_temp_1deg,
                             url_salt=url_salt_1deg):
    """ read in the WOA data from opendap, compute potential temperature
    and roll to 0-360 degrees East"""
    cperiod = str(period).zfill(2)
    cres = str(resolution).zfill(2)
    # load the original data
    kwargs = dict(decode_times=False, engine='pydap')
    woa18_t = xr.open_dataset(f'{url_temp}/woa18_decav_t{cperiod}_{cres}.nc', **kwargs)
    woa18_s = xr.open_dataset(f'{url_salt}/woa18_decav_s{cperiod}_{cres}.nc', **kwargs)
    
    # Compute potential temperature in following three steps:
    # 1. Pressure from depth and lat. Note `depth` convention: positive UP. 
    p = xr.apply_ufunc(gsw.p_from_z, -woa18_t.depth, woa18_t.lat, dask='parallelized',
                       output_dtypes=[woa18_t.t_an.dtype])
    # 2. SA from practical salinity; assume WOA s_an is SP! Can this be done better?
    # This approach is better than following because I'm using SA (absolute salinity), instead of SR (ref salinity).
    # https://github.com/NCAR/WOA_MOM6/blob/a7131c2d9b89bb032e006d27db97679f6cd1f42c/create_filled_ic.py#L171-L172
    sa = xr.apply_ufunc(gsw.SA_from_SP, woa18_s.s_an, p, woa18_s.lon, woa18_s.lat, dask='parallelized',
                    output_dtypes=[woa18_t.t_an.dtype])
    # 3. Potential temperature from SA (instead of woa18_s.s_an), in-situ temperature and pressure (db)
    ptemp = xr.apply_ufunc(gsw.pt0_from_t, sa, woa18_t.t_an, p, dask='parallelized',
                           output_dtypes=[woa18_t.t_an.dtype]) # gsw.pt0_from_t: ASSUME 0 db ref pressure
    
    # roll the data and write into new dataset
    woa18_ts = xr.Dataset()
    woa18_ts['lon'] = roll_lon_to_0360(woa18_t['lon'], res=resolution)
    woa18_ts['lon_bnds'] = roll_lon_to_0360(woa18_t['lon_bnds'], res=resolution)
    woa18_ts['ptemp'] = roll_to_0360(ptemp, res=resolution).fillna(1.0e+20)
    woa18_ts['salt'] = roll_to_0360(woa18_s['s_an'], res=resolution).fillna(1.0e+20)
    for var in ['time', 'lat_bnds', 'depth_bnds', 'nbounds']:
        woa18_ts[var] = woa18_t[var]
    woa18_ts.set_coords(['time', 'lon', 'lat', 'depth', 'nbounds'])
    woa18_ts.attrs = woa18_t.attrs
    return woa18_ts

In [None]:
def vertical_blend_seas(ds_seasonal, ds_monthly):
    """ add the deep ocean from seasonal to monthly"""

    ds_seasonal["time"] = ds_monthly["time"]
    ds_blend = xr.Dataset()
    ds_blend = xr.concat([ds_monthly, ds_seasonal.isel(depth=slice(57,102))], dim='depth')
    # fix lon/lat bounds
    for var in ["lon_bnds", "lat_bnds"]:
        ds_blend[var] = ds_seasonal[var]
    return ds_blend

In [None]:
def vertical_remap(ds, depth_tgt, depth_bnds_tgt, spval=1.0e+20):
    """ remap ptemp and salt on the new vertical grid """
    # re-arange depth bounds for WOA18
    depth_bnds_src = utils.bounds_2d_to_1d(ds['depth_bnds'])
    
    # create remapping weights
    remapping = conserve.create_remapping_matrix(depth_bnds_src, depth_bnds_tgt, strict=False)
    
    # Remap the data (needs to squeeze time then re-expand array)
    ptemp_array = ds['ptemp'].where(ds['ptemp'] != spval).squeeze(dim='time').values
    ptemp_remapped = conserve.vertical_remap_z2z(ptemp_array, remapping)
    ptemp_remapped = np.expand_dims(ptemp_remapped, axis=0)
    
    salt_array = ds['salt'].where(ds['salt'] != spval).squeeze(dim='time').values
    salt_remapped = conserve.vertical_remap_z2z(salt_array, remapping)
    salt_remapped = np.expand_dims(salt_remapped, axis=0)
    
    ds_remapped = xr.Dataset()
    for var in ['time', 'lon', 'lat', 'lon_bnds', 'lat_bnds']:
        ds_remapped[var] = ds[var]
        
    ds_remapped['z_l'] = xr.DataArray(data=depth_tgt, dims=('z_l'))
    ds_remapped['z_i'] = xr.DataArray(data=depth_bnds_tgt, dims=('z_i'))
    
    ds_remapped['ptemp'] = xr.DataArray(data=ptemp_remapped, dims=('time', 'z_l', 'lat', 'lon'))
    ds_remapped['salt'] = xr.DataArray(data=salt_remapped, dims=('time', 'z_l', 'lat', 'lon'))
    ds_remapped.set_coords(["time", "lon", "lat", "z_l"])
    ds_remapped.attrs = ds.attrs
    return ds_remapped

In [None]:
def concat_monthly(filepattern, fileout, format='NETCDF3_64BIT'):
    """ take a list of file and write a concatenated file """
    import glob
    ncfiles = glob.glob(filepattern)
    dsm = xr.open_mfdataset(ncfiles, decode_times=False)
    for var in ["lon_bnds", "lat_bnds", "depth_bnds"]:
        if var in dsm.variables:
            dsm[var] = dsm[var].isel(time=0)
    encoding = dict()
    ncvars = list(dsm.variables)
    ncvars.remove("nbounds")
    for var in ncvars:
        encoding[var] = dict(_FillValue=1.0e+20)
    nckwargs = dict(format=format, encoding=encoding)
    dsm.to_netcdf(fileout, **nckwargs)

## Pre-processing of WOA18
### Note: this blends monthly with seasonal data
See [this](https://www.ncei.noaa.gov/access/world-ocean-atlas-2018/bin/woa18.pl) on how WOA _seasons_ are defined, they range from 13- 17.

For both the 1-deg and 1/4-deg resolutions, the notebook will:

* Compute the potential temperature and write the data on the original grid


* Blend the monthly data with seasonal under 1500 meters to obtain full depth arrays


* Remap to the 35 z-levels from WOA05 for model-obs comparison, using the target depth array:

In [None]:
depth_tgt = np.array([2.5000e+00, 1.0000e+01, 2.0000e+01, 3.2500e+01, 5.1250e+01,
                      7.5000e+01, 1.0000e+02, 1.2500e+02, 1.5625e+02, 2.0000e+02,
                      2.5000e+02, 3.1250e+02, 4.0000e+02, 5.0000e+02, 6.0000e+02,
                      7.0000e+02, 8.0000e+02, 9.0000e+02, 1.0000e+03, 1.1000e+03,
                      1.2000e+03, 1.3000e+03, 1.4000e+03, 1.5375e+03, 1.7500e+03,
                      2.0625e+03, 2.5000e+03, 3.0000e+03, 3.5000e+03, 4.0000e+03,
                      4.5000e+03, 5.0000e+03, 5.5000e+03, 6.0000e+03, 6.5000e+03])


depth_bnds_tgt = np.array([0.000e+00, 5.000e+00, 1.500e+01, 2.500e+01, 4.000e+01, 6.250e+01,
                           8.750e+01, 1.125e+02, 1.375e+02, 1.750e+02, 2.250e+02, 2.750e+02,
                           3.500e+02, 4.500e+02, 5.500e+02, 6.500e+02, 7.500e+02, 8.500e+02,
                           9.500e+02, 1.050e+03, 1.150e+03, 1.250e+03, 1.350e+03, 1.450e+03,
                           1.625e+03, 1.875e+03, 2.250e+03, 2.750e+03, 3.250e+03, 3.750e+03,
                           4.250e+03, 4.750e+03, 5.250e+03, 5.750e+03, 6.250e+03, 6.750e+03])

In [None]:
bottom_fill = False # Fill nan in the water column or not?

### 1 degree grid

nckwargs = dict(format='NETCDF3_64BIT',
                encoding={"ptemp": {"_FillValue": 1.0e+20},
                          "salt": {"_FillValue": 1.0e+20}})

#--- Monthly
for period in range(1,12+1):
    # original data with potential temperature
    ds_monthly = build_woa_TPOT_S_dataset(period=period, resolution=1)
    ds_monthly.to_netcdf(f'WOA18_decav_TPOTS_m{str(period).zfill(2)}_01.nc', **nckwargs)
    time.sleep(2)
    
    # seasonal data that goes to _full_ depth
    int_seas = int(13+np.floor((period-1)/3))
    ds_seas = build_woa_TPOT_S_dataset(period=int_seas, resolution=1)
    ds_seas.to_netcdf(f'WOA18_decav_TPOTS_m{str(int_seas).zfill(2)}_01.nc', **nckwargs)
    time.sleep(2)
    
    # Fill column: (bottom) nan with nearest values just-before nan.
    if bottom_fill == True:
        ds_seas = ds_seas.ffill(dim='depth')
    
    print("\nBlending month [%i] with season [%i] data.\n"%(period, int_seas))
    
    # blend with seasonal under 1500 meters, i.e., depth[56]
    ds_monthly_blend = vertical_blend_seas(ds_seas, ds_monthly)
    ds_monthly_blend.to_netcdf(f'WOA18_decav_TPOTS_m{str(period).zfill(2)}_fulldepth_01.nc', **nckwargs)
    time.sleep(2)

    # vertical remap to 35 levels
    ds_monthly_remap = vertical_remap(ds_monthly_blend, depth_tgt, depth_bnds_tgt)
    ds_monthly_remap.to_netcdf(f'WOA18_decav_TPOTS_m{str(period).zfill(2)}_35lev_01.nc', **nckwargs)
    time.sleep(2)

In [None]:
## concatenate the files

concat_monthly("WOA18_decav_TPOTS_m??_01.nc", "WOA18_decav_TPOTS_monthly_01.nc")
concat_monthly("WOA18_decav_TPOTS_m??_fulldepth_01.nc", "WOA18_decav_TPOTS_monthly_fulldepth_01.nc")
concat_monthly("WOA18_decav_TPOTS_m??_35lev_01.nc", "WOA18_decav_TPOTS_monthly_35lev_01.nc")

In [None]:
bottom_fill = False # Fill nan in the water column or not?

### 0.25 degree grid

nckwargs = dict(format='NETCDF3_64BIT',
                encoding={"ptemp": {"_FillValue": 1.0e+20},
                          "salt": {"_FillValue": 1.0e+20}})

#--- Monthly
for period in range(1,12+1):
    # original data with potential temperature
    ds_monthly = build_woa_TPOT_S_dataset(period=period, resolution=4,
                                          url_temp=url_temp_025deg,
                                          url_salt=url_salt_025deg)
    ds_monthly.to_netcdf(f'WOA18_decav_TPOTS_m{str(period).zfill(2)}_025.nc', **nckwargs)
    time.sleep(2)
    
    # seasonal data that goes to _full_ depth
    int_seas = int(13+np.floor((period-1)/3))
    ds_seas = build_woa_TPOT_S_dataset(period=int_seas, resolution=4,
                                       url_temp=url_temp_025deg,
                                       url_salt=url_salt_025deg)
    #ds_seas.to_netcdf(f'WOA18_decav_TPOTS_m{str(int_seas).zfill(2)}_025.nc', **nckwargs)
    time.sleep(2)
    
    # Fill column: (bottom) nan with nearest values just-before nan.
    if bottom_fill == True:
        ds_seas = ds_seas.ffill(dim='depth')
        
    print("\nBlending month [%i] with season [%i] data.\n"%(period, int_seas))
    
    # blend with seasonal under 1500 meters, i.e., depth[56]
    ds_monthly_blend = vertical_blend_seas(ds_seas, ds_monthly)
    ds_monthly_blend.to_netcdf(f'WOA18_decav_TPOTS_m{str(period).zfill(2)}_fulldepth_025.nc')
    time.sleep(2)
    
    # vertical remap to 35 levels
    ds_monthly_remap = vertical_remap(ds_monthly_blend, depth_tgt, depth_bnds_tgt)
    ds_monthly_remap.to_netcdf(f'WOA18_decav_TPOTS_m{str(period).zfill(2)}_35lev_025.nc', **nckwargs)
    time.sleep(2)

In [None]:
## concatenate the files

concat_monthly("WOA18_decav_TPOTS_m??_025.nc", "WOA18_decav_TPOTS_monthly_025.nc", format='NETCDF4')
concat_monthly("WOA18_decav_TPOTS_m??_fulldepth_025.nc", "WOA18_decav_TPOTS_monthly_fulldepth_025.nc", format='NETCDF4')
concat_monthly("WOA18_decav_TPOTS_m??_35lev_025.nc", "WOA18_decav_TPOTS_monthly_35lev_025.nc", format='NETCDF4')

### ------ The End ------