In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
path = '/data/dkamm/nemo_output/NEVERWORLD/'

### Need to prepare nc files to be read by xarray
ncrename -v y,nav_lat_grid_T FILE_NAME

ncrename -v x,nav_lon_grid_T FILE_NAME

### Loading low/high resolution restart files

In [None]:
hr = xr.open_dataset(path + 'high_res_short1/NEVERWORLD_00000032_restart.nc').rename_vars({'x':'lon', 'nav_lat_grid_t':'lat'}).rename_dims({'nav_lon_grid_t':'x'})

In [None]:
lr = xr.open_dataset(path + '1_deg_GM/NEVERWORLD_05760000_restart.nc')#.rename_vars({'x':'lon', 'nav_lat_grid_T':'lat'}).rename_dims({'nav_lon_grid_T':'x'})

In [None]:
hr.avm_k.isel(x=100).plot()

In [None]:
lr

### Create dataset to extrapolate to

In [None]:
low_res=1.0

temp    = lr.tn.values
temp_p  = np.concatenate((temp, temp[:,:,:,-2:]+2*low_res), axis=3)

lon     = lr.lon.values
lon_p   = np.concatenate((lon, lon[:,-2:]+2*low_res), axis=1)

lat     = lr.lat.values
lat_p   = np.concatenate((lat, lat[:,-2:]+2*low_res), axis=1)

nav_lev = lr.nav_lev.values

time    = lr.time_counter.values

# define data with variable attributes
data_vars = {'temperature':(['time_counter', 'nav_lev', 'y', 'x'], temp_p, 
                         {'units': 'C'}
                         )}

# define coordinates
coords = {  'time_counter': ('time_counter', time),
            'nav_lev': ('nav_lev', nav_lev),
            'lat': (['y', 'x'], lat_p),
            'lon': (['y', 'x'], lon_p)
          }

# create dataset
ds_lr = xr.Dataset(data_vars=data_vars, 
                coords=coords, 
)

In [None]:
lr_nomask = lr.isel(x=slice(1,-1), y=slice(1,-1))

### Extrapolating low resolution on landpoints

In [None]:
import xesmf as xe

In [None]:
extrapolator = xe.Regridder(lr.isel(x=slice(1,-1), y=slice(1,-1)), ds_lr, "nearest_s2d")

In [None]:
extrapolator = xe.Regridder(lr_nomask, lr, method="bilinear", extrap_method="nearest_s2d")

In [None]:
extrapolated_lr = extrapolator(lr_nomask)

### Regridding extrapolated low resolution on high resoltution

In [None]:
regridder = xe.Regridder(extrapolated_lr, hr, "bilinear",  extrap_method="nearest_s2d", ignore_degenerate=True)

In [None]:
restart_regrid = regridder(extrapolated_lr)

### Apply high resolution mask

In [None]:
mask = xr.where(hr.tn.isel(nav_lev=0, time_counter=0)==0.0, 0, 1)

In [None]:
#mask = mask.reindex_like(restart_regrid, method='ffill', tolerance=0.01, fill_value=mask)

In [None]:
restart_regrid = restart_regrid * mask

### Set velocities (and ssh) to zero

In [None]:
vars_to_zero = ['vb', 'ub', 'vn', 'un']
#vars_to_zero = ['sshb', 'vb', 'ub', 'sshn', 'vn', 'un']

In [None]:
for var in vars_to_zero:
    restart_regrid[var].loc[:] = 0.0

### Merge with missing values from low resolution dataset and reorder variables (maybe not necessary)

In [None]:
restart_regrid['lon'] = hr.lon
restart_regrid['lat'] = hr.lat
restart_regrid['kt'] = lr.kt
restart_regrid['ndastp'] = lr.ndastp
restart_regrid['adatrj'] = lr.adatrj
restart_regrid['ntime'] = lr.ntime
restart_regrid['rdt'] = 720.

In [None]:
restart_regrid = restart_regrid[list(hr.keys())]

### Writing the nc file

In [None]:
restart_regrid.to_netcdf(path + '1_deg_GM/restart_wssh.nc', unlimited_dims='time_counter')