In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from nzdownscale.downscaler.preprocess import PreprocessForDownscaling
from nzdownscale.dataprocess import wrf
import numpy as np
from functools import partial
import xarray as xr
from tqdm import tqdm
from time import time
import matplotlib.pyplot as plt

In [None]:
fpaths = wrf.get_filepaths('2023110100', '2023110200')
len(fpaths)

In [None]:
context_variables = ['temperature',
        'precipitation',
        '10m_u_component_of_wind',
        '10m_v_component_of_wind',
        'surface_pressure',
        'surface_solar_radiation_downwards',
        ]

data = PreprocessForDownscaling(
        variable='temperature',
        base='wrf',
        training_fpaths=fpaths[:-1], #wrf
        validation_fpaths=fpaths[-1:], #wrf
        context_variables=context_variables,
    )

In [None]:
base_ds = data.process_wrf.load_ds(filenames=data.all_paths,
                                    context_variables = data.context_variables)

In [None]:
len(base_ds.south_north)

In [None]:
data.load_topography()
highres_aux_raw_ds, aux_raw_ds = data.preprocess_topography(5, 4)
aux_raw_ds

In [None]:
base_ds

# Method 1: LinearND Interpolation

In [None]:
start = time()
LND = data.process_wrf.regrid_to_topo_old(base_ds, aux_raw_ds)
print(time()-start)

# Method 2: xESMF

In [None]:
start = time()
new_ds = data.process_wrf.regrid_to_topo(base_ds,
                                aux_raw_ds)
print(time()-start)

In [None]:
new_ds

In [None]:
base_ds.T2.isel(Time=0).plot()

In [None]:
new_ds.T2.isel(Time=0).plot()

# New method

In [None]:
import xesmf as xe

In [None]:
ds = base_ds.rename({'XLONG': 'lon', 'XLAT': 'lat'})
ds

In [None]:
ds_out = xr.Dataset({
    'lat': (['lat'], aux_raw_ds.latitude.values),
    'lon': (['lon'], aux_raw_ds.longitude.values),
})

In [None]:
regridder = xe.Regridder(ds.isel(Time=0), ds_out, "bilinear")

In [None]:
# regridder.to_netcdf()

In [None]:
new = regridder(ds)

In [None]:
new

In [None]:
new.isel(Time=0).T2.plot()

In [None]:
new == new_ds

# Greg's ndimage method

In [None]:
from scipy import ndimage
from scipy.interpolate import LinearNDInterpolator

class regridder:
    def __init__(self, from_lats, to_lats, from_lons, to_lons):
        """

        :param from_lats:
        :param to_lats:
        :param from_lons:
        :param to_lons:
        """
        from_lats = np.array(from_lats)
        self.to_lats = np.array(to_lats)
        from_lons = np.array(from_lons)
        self.to_lons = np.array(to_lons)
        assert (np.min(self.to_lats) >= np.min(from_lats)), \
            'The minimum latitude to interpolate to was smaller than the input grid.'
        assert (np.max(self.to_lats) <= np.max(from_lats)), \
            'The maximum latitude to interpolate to was greater than the input grid.'
        assert (np.min(self.to_lons) >= np.min(from_lons)), \
            'The minimum longitude to interpolate to was smaller than the input grid.'
        assert (np.max(self.to_lons) <= np.max(from_lons)), \
            'The maximum longitude to interpolate to was greater than the input grid.'

        # Ensure that the latitudes and longitudes are increasing.
        from_lats_indices = np.arange(len(from_lats))
        if np.any(np.diff(from_lats) < 0):
            if np.any(np.diff(from_lats) > 0): # lats are non-monotonic
                from_lats = np.sort(from_lats)
                from_lats_indices = np.argsort(from_lats)
            else: # lats are decreasing
                from_lats = np.flip(from_lats)
                from_lats_indices = np.flip(from_lats_indices)
        from_lons_indices = np.arange(len(from_lons))
        if ~np.all(np.diff(from_lons) > 0):
            from_lons = np.flip(from_lons)
            from_lons_indices = np.flip(from_lons_indices)
        lat_indices = np.interp(self.to_lats, from_lats, from_lats_indices)
        lon_indices = np.interp(self.to_lons, from_lons, from_lons_indices)
        lon_mesh, lat_mesh = np.meshgrid(lon_indices, lat_indices)
        self.shape = np.shape(lat_mesh)
        self.lat_mesh = lat_mesh.flatten()
        self.lon_mesh = lon_mesh.flatten()

    def regrid(self, data: xr.DataArray, method: int=1) -> xr.DataArray:
        """

        :param data:
        :return:
        """
        # todo: need to figure out how to loop over all dimensions except latitude and longitude.

        if method == 1:
            new_data = ndimage.map_coordinates(data.data, [self.lat_mesh, self.lon_mesh], order=1)
            new_data = new_data.reshape(self.shape)
            result = xr.DataArray(new_data, coords=[self.to_lats, self.to_lons], dims=['latitude', 'longitude'],
                                  name=data.name)
            result.attrs = data.attrs
            return result
        if method == 2:
            vals = data.data.flatten()
            lons, lats = np.meshgrid(data.longitude.values, data.latitude.values)
            lons = lons.flatten()
            lats = lats.flatten()
            interp = LinearNDInterpolator(list(zip(lons, lats)), vals)
            lons, lats = np.meshgrid(self.to_lons, self.to_lats)
            vals_interp = interp(lons, lats)
            result = xr.DataArray(vals_interp, coords=[self.to_lats, self.to_lons], dims=['latitude', 'longitude'],
                                  name=data.name)
            result.attrs = data.attrs
            return result

In [None]:
base_ds

In [None]:
re = regridder(base_ds.XLAT.values.flatten(), 
               aux_raw_ds.latitude.values, 
               base_ds.XLONG.values.flatten(), 
               aux_raw_ds.longitude.values)

In [None]:
from_lats = base_ds.XLAT.values.flatten()
np.any(np.diff(from_lats) < 0)


In [None]:
T2_regrid = re.regrid(base_ds.T2.isel(Time=0))

In [None]:
T2_regrid.plot()