## Drift correction of DPLE-ensemble mean files

In [1]:
import xarray as xr 
import numpy as np  
import os
import cftime
import copy
import scipy.stats
from scipy import signal
import cartopy.crs as ccrs
import glob
import dask

### Functions

In [5]:
def nested_file_list_by_year(filetemplate, field, firstyear, lastyear):
    ''' retrieve a nested list of files for these start years'''
    yrs = np.arange(firstyear, lastyear+1)
    files = []    # a list of lists, dim0=start_year, dim1=ens
    ix = np.zeros(yrs.shape)+1
    file0 = ''
    for yy, i in zip(yrs, range(len(yrs))):
        filepaths = file_dict(filetemplate)
        #append file if it is new
        if yy in filepaths.keys():
            file = filepaths[yy]
            if file != file0:
                files.append(file)
                file0 = file
            else:
                ix[i] = 0
    return files, yrs[ix==1]

In [6]:
def file_dict(filetempl):
    ''' returns a dictionary of filepaths keyed by initialization year, 
    for a given experiment, field, ensemble member, and initialization month '''

    filepaths = {}
    #find all the relevant files
    files = glob.glob(filetempl)
    for file in files:
        #isolate initialization year from the file name
        ystr = file.split('.pop.h.')[0]
        y0 = int(ystr[-7:-3])
        filepaths[y0] = file
    return filepaths

## Main processing

In [7]:
#field = 'TEMP'
#field = 'O2'
#field = 'SALT'
field = 'AOU'
datadir = '/glade/scratch/czhuomin/DPLE-ens-mean'
casename = 'b.e11.BDP.f09_g16.????-11'
filetemplate = datadir+'/'+casename+'.pop.h.'+field+'.nc'
firstyear = 1954
lastyear = 2017
# obtain all files for field variable
files, yrs = nested_file_list_by_year(filetemplate, field, firstyear, lastyear)

### combine all datasets

In [8]:
d0 = xr.open_mfdataset(files, combine='nested', parallel=True, concat_dim='Y', data_vars=[field],\
                        chunks={}, compat='override', coords='minimal', join='override') #, preprocess=preprocess)

In [9]:
d0 = d0.assign_coords(Y=("Y", yrs))
leadtimes = np.array(range(122))+1
d0 = d0.assign_coords(L=("time", leadtimes))
d0 = d0.swap_dims({'time': 'L'})
d0 = d0.reset_coords(["time"])
#d0

## Drift Correction

In [10]:
ds = xr.Dataset()
ds[field] = d0[field]
# Load this in memory to speed up later computations
ds = ds.persist()

In [11]:
%%time
climodrift = ds[field].mean('Y')

CPU times: user 3.97 ms, sys: 0 ns, total: 3.97 ms
Wall time: 42.9 ms


In [12]:
anos = xr.Dataset()
#
#ano_tmp = ds[field] - climodrift
#ano_dif = ds[field].isel(L=0) - ano_tmp.isel(L=0)
anos[field] = ds[field] - climodrift + climodrift.isel(L=0)

### write out the data

In [13]:
USER = os.environ['USER']
dout = f'/glade/scratch/{USER}/DPLE-results'
os.makedirs(dout, exist_ok=True)

In [14]:
%%time
anos.load()

CPU times: user 10min 50s, sys: 8min 41s, total: 19min 32s
Wall time: 19min 58s


In [15]:
anos.to_netcdf(f'{dout}/DPLE_driftcorrected_{field}_ens_mean.nc', mode='w')