Grid format and type for the prototype
=======================

### Basic Imports

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import textwrap
import numpy as np
import xarray as xr
import xesmf as xe
import cf_xarray as cfxr
import cartopy.crs as ccrs

print("Using xESMF in version %s" % xe.__version__)

xr.set_options(display_style='html');

import warnings
#warnings.simplefilter("ignore")

Using xESMF in version 0.5.3.dev11+gcb46501


### Define function to detect the grid format

In [2]:
def detect_format(ds):
    #TODO: Extend for formats CF, xESMF, ESMF, UGRID, SCRIP
    SCRIP_vars=["grid_center_lat", "grid_center_lon",
                "grid_corner_lat", "grid_corner_lon",
                "grid_dims",
                #"grid_imask", "grid_area"
                ]
    SCRIP_dims=["grid_corners", "grid_size", "grid_rank"]
    
    xESMF_vars=["lat", "lon", 
                "lat_b", "lon_b",
                #"mask",               
               ]
    xESMF_dims=["x", "y", "x_b", "y_b"]

    # Test if SCRIP
    if all([var in ds.data_vars for var in SCRIP_vars]) and \
    all([dim in ds.dims for dim in SCRIP_dims]):
        print("SCRIP")
    # Test if xESMF
    elif all([var in ds.coords for var in xESMF_vars]) and \
    all([dim in ds.dims for dim in xESMF_dims]):
        print("xESMF")
    # Test if CF standard_names latitude and longitude can be found
    elif cfxr.accessor._get_with_standard_name(ds, "latitude")!=[] and \
    cfxr.accessor._get_with_standard_name(ds, "longitude")!=[]:
        print("CF")
    else:
        print("unsupported")

In [3]:
ds=xr.open_dataset("../../target_grids/cmip6_361x576_scrip.20181001.nc")
#ds.dims.grid_corners
#ds.data_vars
ds

In [4]:
detect_format(ds)

SCRIP


In [5]:
ds=xe.util.grid_global(1.,1.)
ds

In [6]:
detect_format(ds)

xESMF


In [7]:
ds=xr.open_dataset("../../target_grids/land_sea_mask_05degree.nc4")
#ds.dims
#ds.data_vars
ds

In [8]:
detect_format(ds)

CF


### Define functions to detect the grid type

In [9]:
def detect_type(ds, grid_format="CF"):
    if grid_format=="CF":
        lat_var=cfxr.accessor._get_with_standard_name(ds, "latitude")[0]
        lon_var=cfxr.accessor._get_with_standard_name(ds, "longitude")[0]
        try:
            lat_bnds=ds[lat_var].attrs["bounds"]
        except KeyError:
            lat_bnds=""
            warnings.warn("The latitude variable does not have bounds.")       
        try:
            lon_bnds=ds[lon_var].attrs["bounds"]
        except KeyError:
            lon_bnds=""
            warnings.warn("The longitude variable does not have bounds.")   
        if len(ds[lat_var].dims)==1 and len(ds[lon_var].dims)==1:
            lat_1D=ds[lat_var].dims[0]
            lon_1D=ds[lon_var].dims[0]            
            #if lat_1D in ds[var].dims and lon_1D in ds[var].dims:
            if lat_bnds=="" or lon_bnds=="":
                if lat_1D==lon_1D:
                    return "irregular"
                else:
                    return "regular_lat_lon"
            else:
                if lat_1D==lon_1D and \
                all([len(ds[bnds].dims)==2 for bnds in [lon_bnds, lat_bnds]]) and \
                all([ds.dims[dim]>2 for dim in [ds[lon_bnds].dims[-1], ds[lat_bnds].dims[-1]]]):
                    return "irregular"
                elif all([len(ds[bnds].dims)==2 for bnds in [lon_bnds, lat_bnds]]) and \
                all([ds.dims[dim]==2 for dim in [ds[lon_bnds].dims[-1], ds[lat_bnds].dims[-1]]]):
                    return "regular_lat_lon"                
                else:
                    raise Exception("The grid type is not supported.")
            #else:
            #    raise Exception("The grid type is not supported.")
        elif len(ds[lat_var].dims)==2 and len(ds[lon_var].dims)==2:
            # Test for curvilinear or restructure lat/lon coordinate variables            
            # ToDo: Check if regular_lat_lon despite 2D 
            #  - requires additional function checking 
            #      lat[:,i]==lat[:,j] for all i,j
            #      lon[i,:]==lon[j,:] for all i,j
            #  - and if that is the case to extract lat/lon and *_bnds
            #      lat[:]=lat[:,j], lon[:]=lon[j,:]
            #      lat_bnds[:, 2]=[min(lat_bnds[:,j, :]), max(lat_bnds[:,j, :])]
            #      lon_bnds similar
            if not ds[lat_var].shape==ds[lon_var].shape:
                raise Exception("The grid type is not supported.")
            else:
                if lat_bnds=="" or lon_bnds=="": 
                    return "curvilinear"
                else:
                    print(list(ds[lat_var].shape), list(ds[lat_bnds].shape), [si+1 for si in ds[lat_var].shape], list(ds[lat_var].shape)+[4])
                    # Shape of curvilinear bounds either [nlat, nlon, 4] or [nlat+1, nlon+1]
                    if list(ds[lat_var].shape)+[4]==list(ds[lat_bnds].shape) and \
                    list(ds[lon_var].shape)+[4]==list(ds[lon_bnds].shape):
                        return "curvilinear"
                    elif [si+1 for si in ds[lat_var].shape]==list(ds[lat_bnds].shape) and \
                    [si+1 for si in ds[lon_var].shape]==list(ds[lon_bnds].shape):
                        return "curvilinear"
                    else:
                        raise Exception("The grid type is not supported.")            
        else: 
            raise Exception("The grid type is not supported.")      
    else:
        raise Exception("Grid type can only be determined for datasets following the CF conventions.")

In [10]:
detect_type(ds, "CF")



'regular_lat_lon'

In [11]:
with open("/home/dkrz/k204212/git/find_result.txt") as f:
    ifiles = [l.strip() for l in f.readlines()]

In [12]:
i=0
for ifile in ifiles:
    i+=1
    ds = xr.open_dataset(ifile).isel(time=0)
    print("\n"+20*"-"+"\n"+"%2d/%2d - " %(i, len(ifiles)) + ds.attrs["source_id"]+"\n"+20*"-" )
    print(detect_type(ds, "CF"))


--------------------
 1/52 - ACCESS-CM2
--------------------
[300, 360] [300, 360, 4] [301, 361] [300, 360, 4]
curvilinear

--------------------
 2/52 - ACCESS-ESM1-5
--------------------
[300, 360] [300, 360, 4] [301, 361] [300, 360, 4]
curvilinear

--------------------
 3/52 - AWI-CM-1-1-MR
--------------------
irregular

--------------------
 4/52 - AWI-ESM-1-1-LR
--------------------
irregular

--------------------
 5/52 - BCC-CSM2-MR
--------------------
regular_lat_lon

--------------------
 6/52 - BCC-ESM1
--------------------
regular_lat_lon

--------------------
 7/52 - CAMS-CSM1-0
--------------------
[200, 360] [200, 360, 4] [201, 361] [200, 360, 4]
curvilinear

--------------------
 8/52 - CanESM5-CanOE
--------------------
[291, 360] [291, 360, 4] [292, 361] [291, 360, 4]
curvilinear

--------------------
 9/52 - CanESM5
--------------------
[291, 360] [291, 360, 4] [292, 361] [291, 360, 4]
curvilinear

--------------------
10/52 - CAS-ESM2-0
--------------------
regular_

  new_vars[k] = decode_cf_variable(



--------------------
12/52 - CESM2
--------------------
[384, 320] [384, 320, 4] [385, 321] [384, 320, 4]
curvilinear

--------------------
13/52 - CESM2-WACCM-FV2
--------------------
[384, 320] [384, 320, 4] [385, 321] [384, 320, 4]
curvilinear

--------------------
14/52 - CESM2-WACCM
--------------------
[384, 320] [384, 320, 4] [385, 321] [384, 320, 4]
curvilinear

--------------------
15/52 - CIESM
--------------------
[384, 320] [384, 320, 4] [385, 321] [384, 320, 4]
curvilinear

--------------------
16/52 - CMCC-CM2-HR4
--------------------
[1051, 1442] [1051, 1442, 4] [1052, 1443] [1051, 1442, 4]
curvilinear

--------------------
17/52 - CMCC-CM2-SR5
--------------------
[292, 362] [292, 362, 4] [293, 363] [292, 362, 4]
curvilinear

--------------------
18/52 - CMCC-ESM2
--------------------
[292, 362] [292, 362, 4] [293, 363] [292, 362, 4]
curvilinear

--------------------
19/52 - CNRM-CM6-1
--------------------
[294, 362] [294, 362, 4] [295, 363] [294, 362, 4]
curvilinear






--------------------
28/52 - FGOALS-g3
--------------------
curvilinear

--------------------
29/52 - FIO-ESM-2-0
--------------------
[384, 320] [384, 320, 4] [385, 321] [384, 320, 4]
curvilinear

--------------------
30/52 - GFDL-CM4
--------------------
[1080, 1440] [1080, 1440, 4] [1081, 1441] [1080, 1440, 4]
curvilinear

--------------------
31/52 - GFDL-ESM4
--------------------
[576, 720] [576, 720, 4] [577, 721] [576, 720, 4]
curvilinear

--------------------
32/52 - GISS-E2-1-G-CC
--------------------
regular_lat_lon

--------------------
33/52 - GISS-E2-1-G
--------------------
regular_lat_lon

--------------------
34/52 - GISS-E2-1-H
--------------------
regular_lat_lon

--------------------
35/52 - HadGEM3-GC31-LL
--------------------
[330, 360] [330, 360, 4] [331, 361] [330, 360, 4]
curvilinear

--------------------
36/52 - HadGEM3-GC31-MM
--------------------
[1205, 1440] [1205, 1440, 4] [1206, 1441] [1205, 1440, 4]
curvilinear

--------------------
37/52 - IPSL-CM5A2-IN

### Define function to reformat grid

In [13]:
def reformat(ds, format_from, format_to="CF"):
    SCRIP_vars=["grid_center_lat", "grid_center_lon",
                "grid_corner_lat", "grid_corner_lon",
                "grid_dims", "grid_area", "grid_imask"]    
    if format_from=="SCRIP":
        if not (all([var in SCRIP_vars for var in ds.data_vars]) and 
                all([coord in SCRIP_vars for coord in ds.coords])):
            raise Exception("Converting the grid format from %s to %s is not yet possible for data variables."
                            %(format_from, format_to))
        if format_to=="CF":
            lat=ds.grid_center_lat.values.reshape((ds.grid_dims.values[1], ds.grid_dims.values[0]))
            lon=ds.grid_center_lon.values.reshape((ds.grid_dims.values[1], ds.grid_dims.values[0]))
            
            if (all([np.array_equal(lat[:,i], lat[:,i+1], equal_nan=True) for i in range(ds.grid_dims.values[0]-1)]) and
                all([np.array_equal(lon[i,:], lon[i+1,:], equal_nan=True) for i in range(ds.grid_dims.values[1]-1)])):
                # regular_lat_lon grid type:
                # Reshape vertices from (n,2) to (n+1) for lat and lon axis
                lat=lat[:,0]
                lon=lon[0,:]
                lat_b=ds.grid_corner_lat.values.reshape((ds.grid_dims.values[1], ds.grid_dims.values[0], ds.dims["grid_corners"]))
                lon_b=ds.grid_corner_lon.values.reshape((ds.grid_dims.values[1], ds.grid_dims.values[0], ds.dims["grid_corners"]))
                lat_bnds=np.zeros((ds.grid_dims.values[1], 2), dtype="double")
                lon_bnds=np.zeros((ds.grid_dims.values[0], 2), dtype="double")
                lat_bnds[:, 0]=np.min(lat_b[:,0,:], axis=1)
                lat_bnds[:, 1]=np.max(lat_b[:,0,:], axis=1)
                lon_bnds[:, 0]=np.min(lon_b[0,:,:], axis=1)
                lon_bnds[:, 1]=np.max(lon_b[0,:,:], axis=1)                
                ds_ref=xr.Dataset(data_vars={},
                                  coords={"lat":(["lat"], lat),
                                          "lon":(["lon"], lon),
                                          "lat_bnds":(["lat","bnds"], lat_bnds),
                                          "lon_bnds":(["lon","bnds"], lon_bnds)})
                # ToDo: Case of other units (rad), Case of "degrees_south/west"?!
                # ToDo: Reformat data variables if in ds, apply imask on data variables
                # ToDo: vertical axis, time axis, ... ?!
                ds_ref["lat"].attrs={"bounds":"lat_bnds",
                                     "units":"degrees_north",
                                     "long_name":"latitude",
                                     "standard_name":"latitude",
                                     "axis":"Y"}
                ds_ref["lon"].attrs={"bounds":"lon_bnds",
                                     "units":"degrees_east",
                                     "long_name":"longitude",
                                     "standard_name":"longitude",
                                     "axis":"X"}
                ds_ref["lat_bnds"].attrs={"long_name":"latitude_bounds",
                                       "units":"degrees_north"}
                ds_ref["lon_bnds"].attrs={"long_name":"longitude_bounds",
                                       "units":"degrees_east"}
                return ds_ref
                
            else:
                raise Exception("Converting the grid format from %s to %s is yet only possible for regular latitude longitude grids."
                                %(format_from, format_to))
                
        else:
            raise Exception("Converting the grid format from %s to %s is not yet supported." %(format_from, format_to))
    elif format_from=="xESMF":
        if format_to=="CF":
            lat=ds.lat[:, 0]
            lon=ds.lon[0, :]
            vertex_lat=np.zeros((lat.shape[0], lon.shape[0], 4), dtype="double")
            vertex_lon=np.zeros((lat.shape[0], lon.shape[0], 4), dtype="double")
            lat_bnds=np.zeros((lat.shape[0], 2), dtype="double")
            lon_bnds=np.zeros((lon.shape[0], 2), dtype="double")
            vertex_lat=_reravel(vertex_lat, ds.lat_b, lon.shape[0], lat.shape[0])
            vertex_lon=_reravel(vertex_lon, ds.lon_b, lon.shape[0], lat.shape[0])
            lat_bnds[:, 0]=np.min(vertex_lat[:,0,:], axis=1)
            lat_bnds[:, 1]=np.max(vertex_lat[:,0,:], axis=1)
            lon_bnds[:, 0]=np.min(vertex_lon[0,:,:], axis=1)
            lon_bnds[:, 1]=np.max(vertex_lon[0,:,:], axis=1) 
            ds_ref=xr.Dataset(data_vars={},
                              coords={"lat":(["lat"], lat),
                                      "lon":(["lon"], lon),
                                      "lat_bnds":(["lat","bnds"], lat_bnds),
                                      "lon_bnds":(["lon","bnds"], lon_bnds)})
            # ToDo: Case of other units (rad), Case of "degrees_south/west"?!
            # ToDo: Reformat data variables if in ds, apply imask on data variables
            # ToDo: vertical axis, time axis, ... ?!
            ds_ref["lat"].attrs={"bounds":"lat_bnds",
                                 "units":"degrees_north",
                                 "long_name":"latitude",
                                 "standard_name":"latitude",
                                 "axis":"Y"}
            ds_ref["lon"].attrs={"bounds":"lon_bnds",
                                 "units":"degrees_east",
                                 "long_name":"longitude",
                                 "standard_name":"longitude",
                                 "axis":"X"}
            ds_ref["lat_bnds"].attrs={"long_name":"latitude_bounds",
                                      "units":"degrees_north"}
            ds_ref["lon_bnds"].attrs={"long_name":"longitude_bounds",
                                      "units":"degrees_east"}
            return ds_ref
        else:
            raise Exception("Converting the grid format from %s to %s is not yet supported." %(format_from, format_to))
    else:
        raise Exception("Converting the grid format from %s to %s is not yet supported." %(format_from, format_to))


def _unravel(new_bounds, vertex_bounds, M, N):
    """
    Helper function to go from the vertex style to
    the M+1, N+1 style of lat/lon bounds.
    Taken from https://nbviewer.jupyter.org/gist/bradyrx/421627385666eefdb0a20567c2da9976
    """
    new_bounds[0:N, 0:M] = vertex_bounds[:, :, 0]

    # fill in missing row
    new_bounds[N, 0:M] = vertex_bounds[N-1, :, 1]
    # fill in missing column
    new_bounds[0:N, M] = vertex_bounds[:, M-1, 2]
    # fill in remaining element
    new_bounds[N, M] = vertex_bounds[N-1, M-1, 3]
    return new_bounds

def _reravel(vertex_bounds, bounds, M, N):
    """
    Helper function to go from the M+1, N+1 style to 
    the vertex style M, N, 4 of lat/lon bounds.
    
    Basically inverted _unravel.
    """    
    vertex_bounds[:, :, 0] = bounds[0:N, 0:M]  
    
    # fill in missing row
    vertex_bounds[N-1, :, 1] = bounds[N, 0:M]
    # fill in missing column
    vertex_bounds[:, M-1, 2] = bounds[0:N, M] 
    # fill in remaining element
    vertex_bounds[N-1, M-1, 3] = bounds[N, M]
    return vertex_bounds

"""
From NCL
------------

curvilinear to SCRIP

grid_size    = nlat*nlon   ; This is number of data points (grid nodes)
grid_corners = 4
grid_rank    = 2

DummyAtt1@units = "degrees"
DummyAtt2@units = "unitless"

FileAtt@Conventions  = "SCRIP"


FDimNames = (/ "grid_size","grid_corners","grid_rank" /)
FDimSizes = (/ grid_size,grid_corners,grid_rank /)
FDimUnlim = (/ False,False,False /)

    filevardef(fid,"grid_dims","integer","grid_rank")
    filevardef(fid,"grid_center_lat","double","grid_size")
    filevardef(fid,"grid_center_lon","double","grid_size")
    filevardef(fid,"grid_imask","integer","grid_size")
    filevardef(fid,"grid_corner_lat","double",(/ "grid_size", "grid_corners" /) )
    filevardef(fid,"grid_corner_lon","double",(/ "grid_size", "grid_corners" /) )

fid->grid_center_lat = (/ndtooned(lat2d)/)
fid->grid_center_lon = (/ndtooned(lon2d)/)
fid->grid_imask=(/ tointeger(ndtooned(grid_mask)) /)

grid_corner_lat = reshape( GridCornerLat,(/ grid_size, grid_corners /))
grid_corner_lon = reshape( GridCornerLon,(/ grid_size, grid_corners /))

fid->grid_corner_lat = (/ todouble(grid_corner_lat) /)
fid->grid_corner_lon = (/ todouble(grid_corner_lon) /)


"""
"""
rectilinear to SCRIP


nlat = dimsizes(lat)
nlon = dimsizes(lon)

Conform the lat/lon to 2D arrays
grid_center_lat = conform_dims((/nlat, nlon/),lat,0)
grid_center_lon = conform_dims((/nlat, nlon/),lon,1)


-Generate the mask
 grid_mask_name = get_mask_name(Opt2)
 if(grid_mask_name.eq."") then
   Opt2@GridMask = onedtond(1,(/nlat,nlon/))
 end if


rest as above
"""

'\nrectilinear to SCRIP\n\n\nnlat = dimsizes(lat)\nnlon = dimsizes(lon)\n\nConform the lat/lon to 2D arrays\ngrid_center_lat = conform_dims((/nlat, nlon/),lat,0)\ngrid_center_lon = conform_dims((/nlat, nlon/),lon,1)\n\n\n-Generate the mask\n grid_mask_name = get_mask_name(Opt2)\n if(grid_mask_name.eq."") then\n   Opt2@GridMask = onedtond(1,(/nlat,nlon/))\n end if\n\n\nrest as above\n'

In [14]:
ds=xr.open_dataset("../../target_grids/cmip6_361x576_scrip.20181001.nc")
#ds.dims
#ds.data_vars
ds

In [15]:
ds=reformat(ds, "SCRIP", "CF")
ds

In [16]:
ds=xe.util.grid_global(1.,1.)
ds

In [17]:
ds=reformat(ds, "xESMF", "CF")
ds

In [18]:
detect_format(ds)
detect_type(ds)

CF


'regular_lat_lon'