# Regridding by multiprocessing

Based on [chapter04](https://github.com/zxdawn/GEOSChem-python-tutorial/blob/master/Chapter04_regridding_WRFChem_part1.ipynb), we can regrid many files at the same time by `multiprocessing`.

Because there're so many files, it's better to comment the reminding of using `reuse_weights`.

Edit `site-packages/xesmf/frontend.py` Line 189 - 204 like this:

```    
    def _write_weight_file(self):

        if os.path.exists(self.filename):
            if self.reuse_weights:
                #print('Reuse existing file: {}'.format(self.filename))
                return  # do not compute it again, just read it
            else:
                #print('Overwrite existing file: {} \n'.format(self.filename),
                #      'You can set reuse_weights=True to save computing time.')
                os.remove(self.filename)
        #else:
        #    print('Create weight file: {}'.format(self.filename))

        regrid = esmf_regrid_build(self._grid_in, self._grid_out, self.method,
                                   filename=self.filename)
        esmf_regrid_finalize(regrid)  # only need weights, not regrid object
```       

In [None]:
import os, fnmatch
import multiprocessing
from multiprocessing import Pool
import numpy as np
import xarray as xr
import xesmf as xe
from xgcm.autogenerate import generate_grid_ds
from xgcm import Grid

Prepare global variables for multiprocessing

In [None]:
# Set range of longitude/latitude and the rsolution of grid we want
resolution = 0.1 # degree
Lat_min = 20; Lat_max = 50
Lon_min = -115; Lon_max = -65

# The nested grid we want
nested_grid = xe.util.grid_2d(Lon_min-resolution, Lon_max+resolution, resolution,  # longitude boundary range and resolution
              Lat_min-resolution, Lat_max+resolution, resolution)  # latitude boundary range and resolution

wrf_dir  = '/chenq3/zhangxin/BEHR/data/wrf_profiles/us/2014/all/' # the directory of wrfout* files
save_dir = '/chenq3/zhangxin/BEHR/data/wrf_profiles/us/2014/regrid/'
vnames   = ['IC_FLASHCOUNT','CG_FLASHCOUNT','no2','lno','lno2'] # extracted variables

Create several functions as [chapter04](https://github.com/zxdawn/GEOSChem-python-tutorial/blob/master/Chapter04_regridding_WRFChem_part1.ipynb):

In [None]:
def add_attrs(result):
    result['no2'].attrs['description']  = 'NO2 mixing ratio'
    result['lno'].attrs['description']  = 'LNO mixing ratio'
    result['lno2'].attrs['description'] = 'LNO2 mixing ratio'
    result['no2'].attrs['units']  = 'ppmv'
    result['lno'].attrs['units']  = 'ppmv'
    result['lno2'].attrs['units'] = 'ppmv'
    result['IC_FLASHCOUNT'].attrs['description'] = 'Accumulated IC flash count'
    result['CG_FLASHCOUNT'].attrs['description'] = 'Accumulated CG flash count'
    result['TL_FLASHCOUNT'].attrs['description'] = 'Accumulated Total flash count'


def bilinear_regridding(ds,nested_grid):
    regridder_bilinear = xe.Regridder(ds, nested_grid, method='bilinear', reuse_weights=True)

    bilinear_list = [] # an emtpy list to hold regridding result

    for varname, dr in ds.data_vars.items():
      # Olny pick variables we need
      if varname in vnames:
        dr_temp = regridder_bilinear(dr)  # temporary variable for the current tracer
        bilinear_list.append(dr_temp)

    bilinear_result = xr.merge(bilinear_list)  # merge a list of DataArray to a single Dataset
    # NOTE: The next version of xESMF (v0.2) will be able to directly regrid a Dataset, 
    # so you will not need those additional code. But it is a nice coding exercise anyway.

    bilinear_result['TL_FLASHCOUNT'] =  bilinear_result.CG_FLASHCOUNT + bilinear_result.IC_FLASHCOUNT #TL  = IC +CG

    # Add attributes
    add_attrs(bilinear_result)

    return regridder_bilinear,bilinear_result


def conservative_regridding(ds,nested_grid):
    # To use conservative regridding, you need to specify the cell boundaries,
    # since the boundary information is needed for calculating overlapping areas 
    # (thus ensures total mass conservation).

    #Converting 2D coordinates from cell center to bound with xgcm.autogenerate
    #https://gist.github.com/jbusecke/175d72d81e13f7f8d4dcf26aace511bd

    # Generate outer dimensions from center input
    ds_post = generate_grid_ds(ds, {'X':'west_east', 'Y':'south_north'}, position=('center', 'outer'))

    # Generate the 2d coordinates for the cell boundaries
    # with the "standard" xgcm interpolation for each logical dimension
    grid_ds = Grid(ds_post, periodic=False)
    bnd='extrapolate'
    ds_post.coords['xb'] = grid_ds.interp(grid_ds.interp(ds_post['lon'], 'X', boundary=bnd, fill_value=np.nan),
                                              'Y', boundary=bnd, fill_value=np.nan)
    ds_post.coords['yb'] = grid_ds.interp(grid_ds.interp(ds_post['lat'], 'X', boundary=bnd, fill_value=np.nan),
                                              'Y', boundary=bnd, fill_value=np.nan)

    wrf_grid_with_bounds = {'lon': ds['lon'].values,
                             'lat': ds['lat'].values,
                             'lon_b': ds_post.xb.data,
                             'lat_b': ds_post.yb.data,
                           }

    regridder_conserve = xe.Regridder(wrf_grid_with_bounds, nested_grid, method='conservative', reuse_weights=True)

    conservative_list = []

    for varname, dr in ds.data_vars.items():
      if varname in vnames:
        dr_temp = regridder_conserve(dr)
        conservative_list.append(dr_temp)

    conservative_result = xr.merge(conservative_list)

    conservative_result['TL_FLASHCOUNT'] =  conservative_result.CG_FLASHCOUNT + conservative_result.IC_FLASHCOUNT

    # Set attributes
    add_attrs(conservative_result)

    return regridder_conserve,conservative_result

def regrid(filename):
      # Read file, drop time dimension and rename lon/lat
      ds = xr.open_dataset(wrf_dir+filename).isel(Time=0)
      ds.rename({'XLONG':'lon','XLAT':'lat'}, inplace=True)

      regridder_bilinear, bilinear_result     = bilinear_regridding(ds,nested_grid)
      regridder_conserve, conservative_result = conservative_regridding(ds,nested_grid)

      # Save to netcdf file
      conservative_result.to_netcdf(save_dir+filename+'_conservative_regridding.nc')
      bilinear_result.to_netcdf(save_dir+filename+'_bilinear_regridding.nc')

This is the main function:

In [None]:
def main():
  filenames = [filename for filename in os.listdir(wrf_dir) if fnmatch.fnmatch(filename, 'wrfout*')]
  
  # Build Regridder
  regrid(filenames[0])

  # Multiprocessing using created regridder
  p = Pool(multiprocessing.cpu_count())
  p.map(regrid,filenames[1:])

Let's check the time used to finish the main function.
Without mu;tiprocessing, it takes about `50 minutes` to deal with `1720` wrfout* files (12 km resolution, 430 * 345)

In [None]:
# %%time
# if __name__ == '__main__':
#     main()

Since the coumpting node has `24 cores`, I decide to run the script directly and print the time used here.

But time used is almost as same as before ...

I'm trying to figure out why ...