In [None]:
%reload_ext autoreload
%autoreload 2
# %matplotlib widget

In [None]:
import matplotlib.pyplot as plt
# import plotly.express as px

In [None]:
import warnings
import atexit

import pystac_client
import planetary_computer
import azure.storage.blob
import rioxarray
import rasterio
import stackstac
import numpy as np
import xarray as xr
import dask
import dask.array as da
import dask_image.ndmorph as ndmorph
from dataclasses import dataclass

import skimage
from sklearn.cluster import KMeans
from sklearn.preprocessing import PowerTransformer, StandardScaler, RobustScaler

import satio_pc
from satio_pc import parallelize
from satio_pc.clouds import scl_to_mask


# compositing code
import pandas as pd
from datetime import datetime, timedelta
from dateutil.parser import parse as parse_date


def _get_date_range(start, end, freq, window):
    
    before = round(window / 2)
    start, end = parse_date(start), parse_date(end)

    date_range = pd.date_range(start=start + timedelta(days=before),
                               end=end,
                               freq=f'{freq}D')
    return date_range


def calculate_moving_composite(darr: xr.DataArray,
                               freq=7,
                               window=None,
                               start=None,
                               end=None,
                               use_all_obs=False):
    """
    Calculate moving median composite of an hyper cube with shape [bands, time,
    rows, cols]

    Parameters
    ----------
    arrs : Numpy 4d array [bands, time, rows, cols]

    start : str or datetime
        start date for the new composites dates
    end : str or datetime
        end date for the new composites dates
    freq : int
        days interval for new dates array
    window : int
        moving window size in days on which the compositing function is applied
    mode : string
        compositing mode. Should be one of 'median', 'mean', 'sum',
        'min' or 'max'
    use_all_obs : bool
        When compositing, the last window might be less than window/2 days
        (or freq/2 if window is None). In this case, some observations might
        get discarded from the compositing function, as the window length
        would be too short. Setting this `True` will include the last
        observations in the last available window, which will then span more
        days than the `window` value. This would avoid discarding observations
        which would be used to increase the SNR of the last window but losing
        temporal resolution.

    Return
    ----------
    Tuple of time_vector and composite 4d array
    """
    window = window or freq  # if window is None, default to `freq` days

    if window < freq:
        raise ValueError('`window` value should be equal or greater than '
                         '`freq` value.')

    before, after = _get_before_after(window)
    date_range = _get_date_range(start, end, freq, before)

    comp_shape = (darr.shape[0], len(date_range),
                  darr.shape[2], darr.shape[3])

    comp = da.zeros(comp_shape,
                    chunks=(1, 1, comp_shape[2], comp_shape[3]),
                    dtype=darr.dtype)
    # comp = da.zeros(comp_shape,
    #                 dtype=darr.dtype)
    time = darr.time.values
    
    start = str(time[0])[:10] if start is None else start
    end = str(time[-1])[:10] if end is None else end

    date_range = _get_date_range(start, end, freq, before)

    comp_shape = (len(date_range), darr.shape[1],
                  darr.shape[2], darr.shape[3])

    comp = da.zeros(comp_shape,
                    chunks=(1, 1, darr.shape[2], darr.shape[3]),
                    dtype=darr.dtype)

    time = darr.time.values
    intervals_flags = _get_invervals_flags(date_range,
                                           time,
                                           before,
                                           after,
                                           use_all_obs)
    
    # comp = []
    for i, d in enumerate(date_range):
        flag = intervals_flags[i]
        idxs = np.where(flag)[0]
        
        band_arrs = []
        for band_idx in range(comp.shape[1]):
            comp[i, band_idx, ...] = nanmedian(darr.isel(time=idxs,
                                                         band=band_idx))
        
    darr_out = xr.DataArray(comp,
                            dims=darr.dims,
                            coords={'time': date_range.values,
                                    'band': darr.band,
                                    'y': darr.y,
                                    'x': darr.x},
                            attrs=darr.attrs)
                                               
    return darr_out


def _include_last_obs(idxs):

    # check that all obs are used on last interval
    true_flags = np.where(idxs)[0]
    if true_flags.size:
        last_true_idx = np.where(idxs)[0][-1]
        if last_true_idx != idxs.size - 1:
            idxs[last_true_idx:] = True

    return idxs


def _get_invervals_flags(date_range,
                         time_vector,
                         before,
                         after,
                         use_all_obs):

    intervals_flags = []
    for i, d in enumerate(date_range):
        idxs = interval_flag(
            pd.to_datetime(time_vector),
            d,
            before=before,
            after=after)

        if (i == len(date_range) - 1) and use_all_obs:
            idxs = _include_last_obs(idxs)

        intervals_flags.append(idxs)

    return intervals_flags


def nanmedian(arr):
    """arr should be an xarray with dims (time, y, x)"""
    
    start_dtype = arr.dtype
    if start_dtype not in (np.float32, np.float64):
        arr = arr.astype(np.float32)
        arr = arr.where(arr != 0, np.nan)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        res = da.nanmedian(arr.data, axis=0)

    # res comes out as float. nans will be casted to 0s when returning to int
    return res.astype(start_dtype)


def _get_before_after(window: int):
    """
    Returns values for before and after in number of days, given a
    window length.
    """
    half, mod = window // 2, window % 2

    before = after = half

    if mod == 0:  # even window size
        after = max(0, after - 1)  # after >= 0

    return before, after


def _check_window_settings(freq,
                           before,
                           after,
                           mode,
                           supported_modes=None):
    """
    Perform a check on the values of before and after against freq, mode
    and supported_modes.

    If mode is sum before and after will be computed to give non overlapping
    windows.

    If one of before or after are None, the other is set to the values of the
    valid one.

    If before and after are both None,
    """
    if supported_modes is None:
        supported_modes = SUPPORTED_MODES

    if mode not in supported_modes:
        raise ValueError(('Compositing mode should be one of '
                          f'{supported_modes}, but got: `{mode}`'))

    no_ov_modes = ['sum', 'min', 'max']
    if mode in no_ov_modes:
        if (before is not None or after is not None):
            logger.warning(('`before` and `after` arguments are ignored for '
                            f'compositing mode `{no_ov_modes}`.'))
        # For these modes window overlap is not allowed in the time subsets
        # to avoid double counting of values
        before = after = None

    # if one of before or after is None, set it simmetrically to the valid one
    before = before or after
    after = after or before

    if (before is None and after is None):
        before, after = _get_default_before_after(freq)

    return before, after


def _get_default_before_after(freq):
    """
    Based on the freq, return before and after for non overlapping windows
    """
    if freq == 1:
        before = 0
        after = 0
    elif freq % 2 == 0:
        before = freq / 2 - 1
        after = freq / 2
    else:
        before = int(np.floor(freq / 2))
        after = int(np.floor(freq / 2))

    return before, after


def interval_flag(time_vector,
                  date: datetime,
                  before: int,
                  after: int):
    """
    Returns a boolean array True where the dates in time_vector fall in an
    interval between data - before and date + after

    Parameters
    ----------
    time_vector : datetime/np.datetime64 array
        time_vector of the source.
    date : datetime/np.datetime64
        target date from which we want to get the neighboring dates from
        time_vector
    """
    midnight = datetime(date.year, date.month, date.day)
    return ((time_vector >= midnight + timedelta(days=-before))
            & (time_vector < midnight + timedelta(days=after + 1)))


def mask_clouds(darr, scl_mask):
    """darr has dims (time, band, y, x),
    mask has dims (time, y, x)"""
    mask = da.broadcast_to(scl_mask.data, darr.shape)
    darr_masked = da.where(~mask, 0, darr)
    return darr_masked


def force_unique_time(darr):
    """Add microseconds to time vars which repeats in order to make the
    time index of the DataArray unique, as sometimes observations from the same
    day can be split in multiple obs"""
    unique_ts, counts_ts = np.unique(darr.time, return_counts=True)
    double_ts = unique_ts[np.where(counts_ts > 1)]

    new_time = []
    c = 0
    for i in range(darr.time.size):
        v = darr.time[i].values
        if v in double_ts:
            v = v + c
            c += 1
        new_time.append(v)
    new_time = np.array(new_time)
    darr['time'] = new_time
    return darr


def harmonize(data):
    """
    Harmonize new Sentinel-2 data to the old baseline.
    
    https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a#Baseline-Change
    https://github.com/microsoft/PlanetaryComputer/issues/134
    
    Parameters
    ----------
    data: xarray.DataArray
        A DataArray with four dimensions: time, band, y, x

    Returns
    -------
    harmonized: xarray.DataArray
        A DataArray with all values harmonized to the old
        processing baseline.
    """
    baseline = data.coords['s2:processing_baseline'].astype(float)
    baseline_flag = baseline < 4
    
    if all(baseline_flag):
        return data
    
    offset = 1000
    bands = ["B01", "B02", "B03", "B04",
             "B05", "B06", "B07", "B08",
             "B8A", "B09", "B10", "B11", "B12"]

    old = data.isel(time=baseline_flag)
    to_process = list(set(bands) & set(data.band.data.tolist()))
    
    new = data.sel(time=~baseline_flag).drop_sel(band=to_process)

    new_harmonized = data.sel(time=~baseline_flag, band=to_process).copy()
    
    new_harmonized = new_harmonized.clip(offset)
    new_harmonized -= offset

    new = xr.concat([new, new_harmonized], "band").sel(band=data.band.data.tolist())
    return xr.concat([old, new], dim="time")


import warnings

def _rescale_ts(ts,
                scale=2,
                order=1,
                preserve_range=True,
                nodata_val=0,
                sigma=0.5):
    
    if order > 1:
        raise ValueError('Skimage giving issues with nans and cubic interp')
    
    ts_dtype = ts.dtype
    
    if order > 0:    
        new_dtype = np.float32
        ts = ts.astype(new_dtype)
        ts[ts == nodata_val] = np.nan
    else:
        new_dtype = ts_dtype

    shape = ts.shape
    new_shape = shape[0], shape[1], int(shape[2] * scale), int(shape[3] * scale)
    new = np.empty(new_shape, dtype=new_dtype)
    
    anti_aliasing = None
    anti_aliasing_sigma = None
    
    if scale < 1:
        anti_aliasing = True,
        anti_aliasing_sigma = sigma
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        for t in range(shape[0]):
            new[t, :, :, :] = skimage.transform.rescale(ts[t, ...],
                                                        scale=scale,
                                                        order=order,
                                                        preserve_range=preserve_range,
                                                        channel_axis=0,
                                                        anti_aliasing=anti_aliasing,
                                                        anti_aliasing_sigma=anti_aliasing_sigma)

    new[da.isnan(new)] = 0
    new = new.astype(ts_dtype)

    return new


def rescale_ts(ds20,
               scale=2,
               order=1,
               preserve_range=True,
               nodata_val=0,
               sigma=0.5):
    
    chunks = list(ds20.chunks)
    
    for i in -1, -2:
        chunks[i] = tuple(map(lambda x: x * scale, chunks[i]))
    
    darr_scaled = da.map_blocks(
        _rescale_ts,
        ds20.data,
        dtype=ds20.dtype,
        chunks=chunks,
        scale=scale,
        order=order,
        preserve_range=preserve_range,
        nodata_val=nodata_val,
        sigma=sigma)

    xmin, ymin, xmax, ymax = ds20.satio.bounds
    new_res = ds20.attrs.get('resolution',
                             ds20.x[1] - ds20.x[0]) / scale
    new_res_half = new_res / 2
    
    new_x = np.linspace(xmin + new_res_half,
                        xmax - new_res_half,
                        int(ds20.shape[-2] * scale))
    
    new_y = np.linspace(ymax - new_res_half,
                        ymin + new_res_half,
                        int(ds20.shape[-1] * scale))

    ds20u = xr.DataArray(darr_scaled,
                         dims=ds20.dims,
                         coords={'time': ds20.time,
                                 'band': ds20.band,
                                 'y': new_y,
                                 'x': new_x},
                         attrs=ds20.attrs)
    
    ds20u.attrs['resolution'] = new_res
    
    return ds20u


def load_sentinel2_tile(tile,
                        start_date,
                        end_date,
                        max_cloud_cover=90):
    
    catalog = pystac_client.Client.open(
        "https://planetarycomputer.microsoft.com/api/stac/v1",
        modifier=planetary_computer.sign_inplace,
    )

    time_range = f"{start_date}/{end_date}"

    query_params = {"eo:cloud_cover": {"lt": max_cloud_cover},
                    "s2:mgrs_tile": {"eq": tile}}

    search = catalog.search(collections=["sentinel-2-l2a"],
                            datetime=time_range,
                            query=query_params)
    items = search.get_all_items()

    bands = ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08',
             'B09', 'B11', 'B12', 'B8A', 'SCL']

    assets_10m = ['B02', 'B03', 'B04', 'B08']
    assets_20m = ['B05', 'B06', 'B07', 'B8A', 'B09',
                  'B11', 'B12']
    assets_60m = ['B01', 'B09']
    scl = 'SCL'

    ds = {}
    assets = {10: assets_10m,
              20: assets_20m,
              60: assets_60m,
             'scl': [scl]}

    chunksize = {10: 1024,
                 20: 512,
                 60: 512,
                 'scl': 512}

    dtype = {10: np.uint16,
             20: np.uint16,
             60: np.uint16,
             'scl': np.uint8}

    keep_vars = ['time', 'band', 'y', 'x', 'id', 's2:processing_baseline']
    for res in assets.keys():
        ds[res] = stackstac.stack(items,
                                  assets=assets[res],
                                  chunksize=chunksize[res],
                                  xy_coords='center',
                                  rescale=False,
                                  dtype=dtype[res],
                                  fill_value=0)
        del ds[res].attrs['spec']
        ds_vars = list(ds[res].coords.keys())
        drop_vars = [v for v in ds_vars if v not in keep_vars]
        ds[res] = ds[res].drop_vars(drop_vars)
        ds[res] = force_unique_time(ds[res])
        
        # coerce dtypes
        for v in ['id', 'band', 's2:processing_baseline']:
            ds[res][v] = ds[res][v].astype(str)
            
    return ds[10], ds[20], ds[60], ds['scl']


def get_clusters(dsm, bands=None, n_clusters=20, coarsen=4):
    
    if bands is None:
        bands = ['B02', 'B03', 'B04', 'B08', 'B11', 'B12']

    dsk = dsm.sel(band=bands).coarsen(y=coarsen,
                                      x=coarsen).mean().persist()

    dsk = dsk.chunk((-1, -1, -1, -1))

    data = dsk.data.compute()
    data = data / 10000
    shape = data.shape
    data = np.transpose(data, [0, 2, 3, 1]).reshape((-1, shape[1]))

    scaler = StandardScaler().fit(data[data[:, 0] > 0, :])
    data = scaler.transform(data)

    data.shape

    n_clusters = 20

    k = KMeans(n_clusters).fit(data).labels_
    
    kim = k.reshape((shape[0], shape[2], shape[3])).astype(np.uint8)
    kim = np.expand_dims(kim, 1)

    kimd = xr.DataArray(kim, 
                        dims=('time', 'band', 'y', 'x'),
                        coords={'time': dsk.time,
                                'band': ['K'],
                                'y': dsk.y,
                                'x': dsk.x},
                        name='kclusters').chunk((-1, -1, -1, -1))
    
    kimd20 = kimd.satio.rescale_ts(scale=coarsen, order=0)
    kimd20 = kimd20.persist()
    
    return kimd20


def get_dark_bright_map(kclusters_block, scl_block):
    
    n_clusters = kclusters_block.max().values
    
    k = kclusters_block.isel(band=0).data.compute()
    s = scl_block.isel(band=0).data.compute()

    s2 = np.ones(s.shape)

    dark = 0, 2, 3
    bright = 8, 9, 10, 11

    s2[np.isin(s, dark)] = 0
    s2[np.isin(s, bright)] = 2
    s2 = s2.astype(np.uint8)

    c = []
    for kv in range(n_clusters):
        sk = s2[k == kv]
        v_sk, n_sk = np.unique(sk, return_counts=True)
        c.append(dict(zip(v_sk, n_sk)))

    df = pd.DataFrame(c)
    df = df.fillna(0)
    dfp = df.divide(df.sum(axis=1), axis=0)

    dfp.iloc[:, 1] -= 0.2

    dark_bright_mapping = dfp.idxmax(axis=1).to_dict()
    kdb = np.zeros(s2.shape, dtype=np.uint8)
    for i, v in dark_bright_mapping.items():
        kdb[k == i] = v

    kdb = np.expand_dims(kdb, 1)
    kdb = xr.DataArray(kdb, 
                       dims=('time', 'band', 'y', 'x'),
                       coords={'time': scl_block.time,
                               'band': ['K2'],
                               'y': scl_block.y,
                               'x': scl_block.x},
                       name='darkbright').chunk((-1, -1, -1, -1))
    
    return kdb


from scipy.ndimage import uniform_filter, gaussian_filter
from skimage.filters import sobel
from tqdm.auto import tqdm

# CDI is not working well with cirrus clouds. Perhaps the cirrus band could help but is
# not available, so the CID algorithm can't be fully used. Artefacts are also present.

def focal_variance(img, window_size=(1, 7, 7)):
    """
    Calculate the focal variance of the given 2-d image, over a moving window of
    size winSize pixels.
    """
    img32 = img.astype(np.float32)
    focal_mean = uniform_filter(img32, size=window_size)
    mean_sq = uniform_filter(img32 ** 2, size=window_size)
    # focal_mean = gaussian_filter(img32, sigma=sigma)
    # mean_sq = gaussian_filter(img32 ** 2, sigma=sigma)
    variance = mean_sq - focal_mean ** 2
    return variance



def calc_cdi(darr, window_size):
    """
    Calculate the Cloud Displacement Index, as per Frantz et al (2018).
    """
    # Equations 5 & 6
    b07 = np.squeeze(darr.sel(band='B07').data).compute()
    b08 = np.squeeze(darr.sel(band='B08').data).compute()
    b8a = np.squeeze(darr.sel(band='B8A').data).compute()
    
    r8 = b08 / b8a
    r7 = b07 / b8a
    
    r8[~np.isfinite(r8)] = 0
    r7[~np.isfinite(r7)] = 0
    
    # Equation 7
    v8 = focal_variance(r8, window_size)
    v7 = focal_variance(r7, window_size)
  
    # Mask out where we would divide by zero
    cdi = np.zeros(v7.shape, dtype=np.float32)
    valid = ((v7 + v8) != 0)
    cdi[valid] = (v7[valid] - v8[valid]) / (v7[valid] + v8[valid])

    return (r8, r7, v8, v7, cdi)


def filter_cdi(bright_mask, cdi):
    
    selection = bright_mask & (cdi < -0.5)
    
    # erode selection with 1 px
    selection = scipy.ndimage.binary_erosion(selection)
    
    # region grow within (cdi < -0.25)
    rg_mask = bright_mask & (cdi < -0.25)
    selection = scipy.ndimage.binary_dilation(selection,
                                              mask=rg_mask,
                                              iterations=0)
    bright_mask[~selection] = False
    
    return bright_mask


def ram_usage():
    import psutil
    # Getting % usage of virtual_memory ( 3rd field)
    print('RAM memory % used:', psutil.virtual_memory()[2])
    # Getting usage of virtual_memory in GB ( 4th field)
    print('RAM Used (GB):', psutil.virtual_memory()[3]/1000000000)
    

# worldcover settings
settings = {
    
    "l2a": {
        "bands": ["B02", "B03", "B04", "B05",
                  "B06", "B07", "B08", "B11", "B12"],
        "rsis": ["ndvi", "nbr", "nbr2", "evi",
                 "ndmi", "ndwi", "ndgi"],
        "composite": {"freq": 10, "window": 20},
        "mask": {"erode_r": 3,
                 "dilate_r": 13,
                 "mask_values": [1, 3, 8, 9, 10, 11],
                 "max_invalid_ratio": 1}},
    
    "gamma0": {
        "bands": ["VV", "VH"],
        "rsis": ["vh_vv"],
        "composite": {"freq": 10, "window": 20}}
}



# with open('../../sastoken') as f:
#     sas_token = f.read()
    
# container_client = azure.storage.blob.ContainerClient(
#     "https://dza2.blob.core.windows.net",
#     container_name="feats",
#     credential=sas_token,
# ) 

# """Algo:
# - compute dark and bright features (use blue band for bright and swir for dark)
# - join dark and bright pixels with SCL dark and bright + cirrus clouds
# - use MNDWI to remove dark water
# - use CDI to remove bright buildings
# - do erosion - dilation to remove single pixels
# - remove outliers based on frequency of dark and bright
#     - set threshold on image quality and probability that pixel is dark
#       e.g. if image is cloudy, discard pixel, if image is less than x cloud
#       and frequency of pixel dark/bright is above thresh vs block mean
#       then keep pixel (remove from mask)

# - buffer mask


# Then we can run the block median value and keep it as ref. for each band?
# plus we start rolling and we want to select:
#     1. best acquisition closer to window center
#         - acquisition score based on block invalid pixels
#         - if there is no clear image (cloud less than th=80)
#           then check which pixel/acquisition is closest to the median value for the
#           target bands, this gives the selected acquisition.
#         - store distance from median and cloud % of the acquisition as
#           quality metric
#     2. when using short windows, we can keep more good obs without cloud cover
#        but we will have more noise from artefacts and cloud residuals
#        After the timeseries is built, we can filter out outliers based on some
#        thresholds of the quality.
#        Once these are thresholded out along the time dimension,
#        we can run the interpolation to fill in missing values
       
    
# """


In [None]:
import tempfile
import ipywidgets as ipw
import hvplot.xarray # noqa
import hvplot.pandas # noqa
import panel as pn
import pandas as pd
import panel.widgets as pnw
import xarray as xr


@xr.register_dataarray_accessor("satio")
class ESAWorldCoverTimeSeries:
    def __init__(self, xarray_obj):
        self._obj = xarray_obj
        self._obj.attrs['bounds'] = self.bounds
        # run check that we have a timeseries
        # assert xarray_obj.dims == ('time', 'band', 'y', 'x')
        
    def rescale_ts(self,
                   scale=2,
                   order=1,
                   preserve_range=True,
                   nodata_val=0):
        return rescale_ts(self._obj,
                          scale=scale,
                          order=order,
                          preserve_range=preserve_range,
                          nodata_val=nodata_val)
    
    def composite(self,
                  freq=7,
                  window=None,
                  start=None,
                  end=None,
                  use_all_obs=False):
        return calculate_moving_composite(self._obj,
                                          freq,
                                          window,
                                          start,
                                          end,
                                          use_all_obs)
    
    @property
    def bounds(self):
        
        darr = self._obj
        
        res = darr.x[1] - darr.x[0]
        hres = res / 2
        
        xmin = (darr.x[0] - hres).values.tolist()
        xmax = (darr.x[-1] + hres).values.tolist()
        
        ymin = (darr.y[-1] - hres).values.tolist()
        ymax = (darr.y[0] + hres).values.tolist()
        
        return xmin, ymin, xmax, ymax
    
    def harmonize(self):
        return harmonize(self._obj)
    
    def mask(self, mask):
        return mask_clouds(self._obj, mask)
     
    def cache(self, tempdir=None, chunks=None):
        tmpfile = tempfile.NamedTemporaryFile(suffix='.nc',
                                              prefix='satio-',
                                              dir=tempdir) 
        
        chunks = self._obj.chunks if chunks is None else chunks

        self._obj.to_netcdf(tmpfile.name)
        darr = xr.open_dataarray(tmpfile.name).chunk(chunks)

        atexit.register(tmpfile.close)
        return darr
    
    def rgb(self, bands=None, vmin=0, vmax=1000):
        bands = ['B04', 'B03', 'B02'] if bands is None else bands

        im = self._obj.sel(band=bands).clip(vmin, vmax) / (vmax - vmin)
        
        return im.interactive.sel(time=pnw.DiscreteSlider).hvplot.rgb(
                x='x', y='y', bands='band', data_aspect=1,
                flip_yaxis=False, xaxis=False, yaxis=None)
    
    def plot(self, band=None, vmin=None, vmax=None, colormap='plasma'):
        im = self._obj
        band = im.band[0] if band is None else band
        im = im.sel(band=band)
        return im.interactive.sel(time=pnw.DiscreteSlider).plot(vmin=vmin,
                                                                vmax=vmax,
                                                                colormap=colormap)

In [None]:
# # get bright and dark masks based on clusters
# from skimage.morphology import binary_erosion, binary_dilation, disk

# def binary_erosion_ts(ts, r, band=0):
#     d = disk(r)
#     for t in range(ts.shape[0]):
#         ts[t, band, ...] = binary_erosion(ts[t, band, ...], d)
#     return ts

# def binary_dilation_ts(ts, r, band=0):
#     d = disk(r)
#     for t in range(ts.shape[0]):
#         ts[t, band, ...] = binary_dilation(ts[t, band, ...], d)
#     return ts


# n_clusters = 20

# # bright map
# kim = get_clusters(dsm20, bands=['B02', 'B03'],
#                    n_clusters=n_clusters)
# bright = get_dark_bright_map(kim, scl_block)
# bright = bright == 2

# n_clusters = 20

# # bright map
# kim = get_clusters(dsm20, bands=['B03', 'B08', 'B12'], n_clusters=n_clusters)
# dark = get_dark_bright_map(kim, scl_block)
# dark = dark == 0

# bright['band'] = ['bright']
# dark['band'] = ['dark']

# dark_vals = 2, 3
# bright_vals = 8, 9, 10, 11

# scl_nodata = (scl_block.data == 0).compute()
# scl_bright = np.isin(scl_block.data, bright_vals).compute()
# scl_dark = np.isin(scl_block.data, dark_vals).compute()

# merged_dark = dark | scl_dark
# merged_bright = bright | scl_bright

# merged_dark = binary_erosion_ts(merged_dark, 2)
# merged_dark = binary_dilation_ts(merged_dark, 5)

# merged_bright = binary_erosion_ts(merged_bright, 1)
# merged_bright = binary_dilation_ts(merged_bright, 10)

# merged_mask = merged_dark.satio.rescale_ts(scale=2, order=0).data | merged_bright.satio.rescale_ts(scale=2, order=0).data

# mask = np.broadcast_to(merged_mask.compute(), im2.shape)

In [None]:
ram_usage()

In [None]:
from satio_pc.grid import get_blocks_gdf

year = 2022
tile_id = '31UFS'

start_date = f'{year}-01-01'
end_date = f'{year}-05-01'
max_cloud_cover = 90

blocks_gdf = get_blocks_gdf([tile_id])
block = blocks_gdf.iloc[0]
block

ds10, ds20, ds60, scl = load_sentinel2_tile(tile_id,
                                        start_date,
                                        end_date,
                                        max_cloud_cover)

In [None]:
ds10

In [None]:
scl_settings = settings['l2a']['mask']

keys = "erode_r, dilate_r, max_invalid_ratio, mask_values".split(', ')
erode_r, dilate_r, max_invalid_ratio, mask_values = [scl_settings[k]
                                                     for k in keys]

scl_post = scl_to_mask(scl,
                       mask_values=mask_values,
                       erode_r=erode_r,
                       dilate_r=dilate_r,
                       max_invalid_ratio=max_invalid_ratio)
scl20 = scl_post.mask

### Block processing 

In [None]:
ds10_block = ds10.rio.clip_box(*block.bounds).satio.harmonize()
ds20_block = ds20.rio.clip_box(*block.bounds).satio.harmonize()
scl20_block = scl20.rio.clip_box(*block.bounds)

# ds60_block = ds60.rio.clip_box(*block.bounds).satio.harmonize()
# ds60_block = ds60_block.satio.cache('.', (-1, -1, 256, 256))

In [None]:
ds10_block = ds10_block.satio.cache('.', (-1, -1, 256, 256))
ds20_block = ds20_block.satio.cache('.', (-1, -1, 256, 256))
scl20_block = scl20_block.satio.cache('.', (-1, -1, 256, 256))

scl_block = scl.rio.clip_box(*block.bounds).satio.cache('.', (-1, -1, 256, 256))
scl10_block = scl20_block.satio.rescale_ts(scale=2, order=0)
scl10_block = scl10_block.satio.cache('.', (-1, -1, 256, 256))

In [None]:
ds20_block_10m = ds20_block.satio.rescale_ts(scale=2,
                                             order=1)
dsm10 = xr.concat([ds10_block, ds20_block_10m], dim='band').satio.cache('.', (-1, -1, 128, 128))

In [None]:
ds10_block_20m = ds10_block.satio.rescale_ts(scale=0.5, order=1)

dsm20 = xr.concat([ds10_block_20m, ds20_block], dim='band').satio.cache('.', (-1, -1, 128, 128))

In [None]:
ram_usage()

In [None]:
dsm10.satio.rgb()

In [None]:
# im3 = im2.where(np.broadcast_to(kdb10, im2.shape) == 1, 0)

In [None]:
# # MNDWI

# bands = ['B04', 'B03', 'B02']

# vmin = 0
# vmax = 1500

# b03 = dsm.sel(band=['B02']).data / 10000
# b12 = dsm.sel(band=['B12']).data / 10000

# mndwi = (b03 - b12) / (b03 + b12)

# mndwi = dsm.sel(band=['B03']).copy(data=mndwi).persist()

# mndwi.isel(band=0).interactive.sel(time=pnw.DiscreteSlider).hvplot(aspect=1,
#                                                                    colormap='PiYG',
#                                                                    clim=(-1, 1))
  

In [None]:
# mask clouds

In [None]:
dsm10 = dsm10.satio.mask(scl10_block)

In [None]:
dsm10

In [None]:
dsm10.satio.rgb()

In [None]:
ds10_block = ds10_block.satio.mask(scl10_block.isel(band=0).data)
ds20_block = ds20_block.satio.mask(scl20_block.isel(band=0).data)

In [None]:
%%time
ds10_block = ds10_block.satio.composite(
    freq=10,
    window=20,
    start='2022-01-01',
    end='2022-04-01').satio.cache('.', (-1, -1, 256, 256))

In [None]:
%%time
ds20_block = ds20_block.satio.composite(
    freq=10,
    window=20,
    start='2022-01-01',
    end='2022-04-01')

In [None]:
ds20_block = ds20_block.satio.rescale_ts(scale=2, order=1).satio.cache('.', (-1, -1, 256, 256))

In [None]:
ram_usage()

### Interpolation

In [None]:
import numpy as np
from numba import guvectorize


def _interp1d(xnew, xvals, yvals, ynew):
    i = 0
    N = len(xvals)
    if xnew[0] < xvals[0]:
        # x_a = 0.0
        # y_a = 0.0
        # x_b = xvals[0]
        # y_b = yvals[0]
        ynew[0] = yvals[0]
        return
    elif xnew[-1] > xvals[-1]:
        ynew[-1] = yvals[-1]
        return
    else:
        while xnew[0] >= xvals[i] and i < N:
            i += 1
        if xnew[0] == xvals[i]:
            ynew[0] = yvals[i]
            return
        if i == N:
            i = N-1
        x_a = xvals[i-1]
        y_a = yvals[i-1]
        x_b = xvals[i]
        y_b = yvals[i]
    slope = (xnew[0] - x_a)/(x_b - x_a)
    ynew[0] = slope * (y_b-y_a) + y_a
    return


interp1d = guvectorize(
    ['int64[:], int64[:], float32[:], float32[:]'],
    "(),(n),(n) -> ()", nopython=True)(_interp1d)

interp1d_uint16 = guvectorize(
    ['int64[:], int64[:], uint16[:], uint16[:]'],
    "(),(n),(n) -> ()", nopython=True)(_interp1d)


def interpolate_fast(arrs):

    for band in range(arrs.shape[0]):
        for px in range(arrs.shape[2]):
            y = arrs[band, :, px]

            nans_ids = np.isnan(y)
            xvals = np.where(~nans_ids)[0]
            yvals = y[xvals]
            xnew = np.where(nans_ids)[0]
            y[xnew] = interp1d(xnew, xvals, yvals)

    return arrs


def _interpolate_4d_float32(arrs):

    for band in range(arrs.shape[0]):
        for px in range(arrs.shape[2]):
            for py in range(arrs.shape[3]):
                y = arrs[band, :, px, py]

                nans_ids = np.isnan(y)
                xvals = np.where(~nans_ids)[0]
                yvals = y[xvals]
                xnew = np.where(nans_ids)[0]
                y[xnew] = interp1d(xnew, xvals, yvals)

    return arrs


def _interpolate_4d_uint16(darr):

    for band in range(darr.shape[1]):
        for py in range(darr.shape[2]):
            for px in range(darr.shape[3]):
                # t = darr.isel(band=band,
                #               y=px,
                #               x=py).data
                t = darr[:, band, py, px]

                nans_ids = (t == 0)
                xvals = np.where(~nans_ids)[0]
                yvals = t[xvals]
                xnew = np.where(nans_ids)[0]
                t[xnew] = interp1d(xnew, xvals, yvals)

    return darr


def interpolate_4d(arrs):

    if arrs.dtype == np.float32:
        return _interpolate_4d_float32(arrs)
    elif arrs.dtype == np.uint16:
        return _interpolate_4d_uint16(arrs)
    else:
        raise ValueError("Interpolate function is only available for "
                         "arrays of type float32 or uint16")

In [None]:
ds10_block.data

In [None]:
%%time
i = interpolate_4d(ds10_block.compute().data)

In [None]:
ds20_block_i = ds20_block.copy(data=i)

In [None]:
%%time
ds20_block_10 = ds20_block.satio.to_10m(bounds=block.bounds,
                                        order=3).persist()


In [None]:
%%time
ds20_block_masked_comp = ds20_block_masked_comp.persist()

In [None]:
import ipywidgets as ipw
import hvplot.xarray # noqa
import hvplot.pandas # noqa
import panel as pn
import pandas as pd
import panel.widgets as pnw
import xarray as xr

rgb_bands = ['B04', 'B03', 'B02']

im = ds10_block.sel(band=rgb_bands).clip(0, lim) / lim

im.interactive.sel(time=pnw.DiscreteSlider).hvplot.rgb(
    x='x', y='y', bands='band', data_aspect=1,
    flip_yaxis=True, xaxis=False, yaxis=None)
# ds10_block.sel(band=rgb_bands).interactive.sel(time=pnw.DiscreteSlider).plot(vmin=0, vmax=2500)

# ds10_block.sel(band=rgb_bands).isel(time=0).hvplot.rgb(x='x', y='y', bands='band')



In [None]:
lim = 1000

In [None]:
im = ds10_block.sel(band=rgb_bands).clip(0, lim) / lim

im.interactive.sel(time=pnw.DiscreteSlider).hvplot.rgb(
    x='x', y='y', bands='band', data_aspect=1,
    flip_yaxis=True, xaxis=False, yaxis=None)

