# By-Band _g_-Point Reduction

# Dependencies

`numpy` is installed in the Python environment at NERSC (`module load python`), but `xarray` is not, so the user must install the package on their own. `PIPPATH` is the assumed location. This notebook depends heavily on `xarray`.

In [None]:
import os, sys

# "standard" install
import numpy as np

# directory in which libraries installed with conda are saved
PIPPATH = '{}/.local/'.format(os.path.expanduser('~')) + \
    'cori/3.7-anaconda-2019.10/lib/python3.7/site-packages'
PATHS = ['common', PIPPATH]
for path in PATHS: sys.path.append(path)

# user must do `pip install xarray` on cori (or other NERSC machines)
import xarray as XA

# common submodule
import utils


# Function and Class Definitions

In [None]:
def pathCheck(path, mkdir=False):
    """
    Determine if file exists. If not, throw an Assertion Exception
    """

    if mkdir:
        # mkdir -p -- create dir tree
        if not os.path.exists(path): os.makedirs(path)
    else:
        assert os.path.exists(path), 'Could not find {}'.format(path)
    # endif mkdir
# end pathCheck

def costFuncComp(tst_file, ref_file, levs=[0, 10000, 102000], iRecord=0, 
                 ncVars=['net_flux', 'heating_rate', 'band_flux_net']):
    """
    Calculate same cost functions as `cost_function_components`, but allow 
    for many more terms, yielding a more flexible cost function
    
    No forcing yet

    Inputs
        tst_file -- string, RRTMGP (test model) netCDF file with fluxes
        ref_file -- string, LBLRTM (reference model) netCDF file with fluxes

    Output
        outParams -- list of cost function arrays (RMS test-ref differences 
          averaged over columns); 1 element per input variable (ncVars)

    Keywords
        levs -- list of floats; pressure levels of interest in Pa
        iRecord -- int; index for forcing scenario (default 0 is no forcing)
        ncVars -- list of strings; netCDF variable names of the arrays to 
          include in the cost function
    """

    outParams = []
    with xr.open_dataset(tst_file) as tst, xr.open_dataset(ref_file) as ref:
        # Compute differences in all variables in datasets at levels 
        # closest to user-provided pressure levels
        # TODO: confirm this is doing what we expect it to
        subsetErr = (tst-ref).sel(lev=levs, method='nearest')
        for ncVar in ncVars:
            # pressure dimension will depend on parameter
            # layer for HR, level for everything else
            pStr = 'lay' if 'heating_rate' in ncVar else 'lev'

            # get array for variable, then compute its test-ref RMS 
            # over all columns and given pressure levels for a given 
            # forcing scenario
            ncParam = getattr(subsetErr, ncVar)
            outParams.append(
                (ncParam.isel(record=iRecord)**2).mean(dim=('col', pStr)))
            
    return outParams
# end costFuncComp
    
def normCost(tst_file, ref_file, norm, 
             ncVars=['net_flux', 'heating_rate', 'band_flux_net'], 
             levs=[0, 10000, 102000], ):
    """    
    Returns the summary terms in the cost function
      Each element in each term is normalized (normally by the error at i
      teration 0)

    Inputs
        tst_file -- string, RRTMGP (test model) netCDF file with fluxes
        ref_file -- string, LBLRTM (reference model) netCDF file with fluxes
        norm -- list of floats with RMS error for a given 
          cost function component

    Output
        list of floats that are the RMS error (RRTMGP-LBLRTM)
        for each cost function component normalized by the input 
        `norm` parameter

    Keywords
        levs -- list of floats; pressure levels of interest in Pa
        iRecord -- int; index for whatever the 'record' dimension is in 
          the input netCDF files 
        ncVars -- list of strings; netCDF variable names of the arrays to 
          include in the cost function

    """

    tst_cost = costFuncComp(tst_file, ref_file, ncVars=ncVars, levs=levs)

    # Each scalar term in the cost function is the RMS across the
    #   normalized error in each component. cost_function_components() returns
    #   the squared error
    return [np.sqrt((c/n).mean()) for (c, n) in zip(tst_cost, norm)]
# end normCost

def recordDimRename(inNC, outNC):
    """
    Rename "record" dimension in given netCDF file
    """
    
    outDS = xa.Dataset()

    with xa.open_dataset(inNC) as inObj:
        # save global attributes for later -- will stuff into buffer, unedited
        globalAtt = inObj.attrs

        # write buffer netCDF, complete with global attributes
        ncVars = list(inObj.keys())

        for ncVar in ncVars: 
            ncDat = inObj[ncVar]

            if 'record' in ncDat.dims:
                # which dimension corresponds to `record`?
                dims = list(ncDat.dims)
                iRec = dims.index('record')
                dims[iRec] = 'forcing'

                # save variable with new dimensions
                outDS[ncVar] = xa.DataArray(ncDat, dims=dims)
            else:
                # retain any variables without a record dimension
                outDS[ncVar] = xa.DataArray(ncDat)
            # endif record
        # end ncVar loop
    # endwith

    # stuff the global attributes into the new dataset
    for att in globalAtt: outDS.attrs[att] = globalAtt[att]
    outDS.to_netcdf(outNC, mode='w')
    print('Completed {}'.format(outNC))
# end recordDimRename()

def kDistBandSplit(kFileNC, outDir='band_k_dist'):
    """
    Split a full k-distribution into separate files for each band
    """

    pathCheck(outDir, mkdir=True)

    weights = [
        0.1527534276, 0.1491729617, 0.1420961469, 0.1316886544, 
        0.1181945205, 0.1019300893, 0.0832767040, 0.0626720116, 
        0.0424925000, 0.0046269894, 0.0038279891, 0.0030260086, 
        0.0022199750, 0.0014140010, 0.0005330000, 0.0000750000
    ]
    xaWeights = XA.DataArray(
        weights, dims={'gpt': range(len(weights))}, name='gpt_weights')

    bandFiles = []
    with XA.open_dataset(kFileNC) as kAllObj:
        gLims = kAllObj.bnd_limits_gpt
        ncVars = list(kAllObj.keys())
        dimStr = 'gpt'

        for iBand in kAllObj.bnd.values:
            # make a separate netCDF for each band
            outNC = '{}/coefficients_lw_band{:02d}.nc'.format(outDir, iBand+1)

            # Dataset that will be written to netCDF with new variables and 
            # unedited global attribues
            outDS = XA.Dataset()

            # determine which variables need to be parsed
            for ncVar in ncVars:
                ncDat = kAllObj[ncVar]

                if dimStr in kAllObj[ncVar].dims:
                    # grab only the g-point information for this band
                    # and convert to zero-offset
                    i1, i2 = gLims[iBand].values-1
                    ncDat = ncDat.isel(gpt=slice(i1, i2+1))
                # endif

                # write variable to output dataset
                outDS[ncVar] = XA.DataArray(ncDat)
            # end ncVar loop

            # write weights to output file
            outDS['gpt_weights'] = xaWeights

            outDS.to_netcdf(outNC, mode='w')
            #print('Completed {}'.format(outNC))
            bandFiles.append(outNC)
        # end band loop
    # endwith

    return bandFiles
# end kDistBandSplit()

class kDistOptBand:
    def __init__(self, inFile, band, lw, idxForce):
        """
        - Run a RRTMGP executable that performs computations for a single band
        - Loop over bands and the possible g-point combinations within each 
            band, creating k-distribution and band-wise flux files for each 
            possible combination
        - Compute broadband fluxes and heating rates
        - Compute cost function from broadband parameters and determine 
            optimal combination of g-points

        Input
          inFile -- string, netCF created with kDistBandSplit() method
          band -- int, band number that is being processed with object
          lw -- boolean, do longwave domain (otherwise shortwave)
          idxForce -- int, index of forcing scenario

        Keywords
        """

        self.inNC = str(inFile)
        self.iBand = int(band)
        self.domain = 'LW' if lw else 'SW'
        self.iForce = int(idxForce)

        # directory where model will be run for each g-point 
        # combination
        self.workDir = '{}/workdir_band_{}'.format(os.getcwd(), self.iBand)
        pathCheck(self.workDir, mkdir=True)

        # attributes that will get re-assigned in class
        self.gCombine = []
        self.wCombine = []
        self.trialNC = []
    # end constructor

    def gPointCombine(self):
        """
        Combine g-points in a given band with adjacent g-point

        TODO: will probably have to modify other variables in 
        self.inNC like Ben does in combine_gpoints_fn.py
        """

        with XA.open_dataset(self.inNC) as kDS:
            kVal = kDS.kmajor
            weights = kDS.gpt_weights

            # combine nearest neighbor g-point indices 
            # and associated weights
            nGpt = kDS.dims['gpt']
            self.gCombine = [[x, x+1] for x in range(nGpt-1)]
            self.wCombine = [weights[np.array(gc)] for gc in self.gCombine]

            for gc, wc in zip(self.gCombine, self.wCombine):
                outNC = '{}/coefficients_{}_g{}-{}.nc'.format(
                    self.workDir, self.domain, gc[0], gc[1])
                self.trialNC.append(outNC)
            # end combination loop
        # endwith
    # end gPointCombine()

    def runBandRRTMGP(self):
        """
        Run the RRTMGP executable for a single band
        """
    # end runBandRRTMGP()
    
# end kDistOptBand

def computeBB():
    """
    Compute broadband fluxes after g-points have been combined
    """
# end computeBB()


# Paths

In [None]:
PROJECT = '/global/project/projectdirs/e3sm/' + \
    'pernak18/reference_netCDF/g-point-reduce'
kFullNC = '{}/rrtmgp-data-lw-g256-2018-12-04.nc'.format(PROJECT)
pathCheck(kFullNC)

# Static Inputs

In [None]:
# only do one domain or the other
doLW = True
doSW = False if doLW else True

# forcing scenario (0 is no forcing...need a more comprehensive list)
IFORCING = 0

# Main Driver

In [None]:
# divide full k-distribution into subsets for each band
print('Band splitting commenced')
kFiles = kDistBandSplit(kFullNC)
print('Band splitting completed')

# loop over bands and instantiate a band optimization object
# optimizing each band
for iBand, kFile in enumerate(kFiles):
    kObj = kDistOptBand(kFile, iBand+1, doLW, IFORCING)
    kObj.gPointCombine()
    for outNC in kObj.trialNC: print(outNC)
# end kFile loop

# small edit to flux file -- rename the `record` dimension