In [None]:
%reload_ext autoreload
%autoreload 2
# %matplotlib widget

In [None]:
import pystac_client
import planetary_computer
import azure.storage.blob
import rioxarray
import rasterio
import stackstac
import xarray as xr
import dask
import dask.array as da
import dask_image.ndmorph as ndmorph
from dataclasses import dataclass


import satio_pc
from satio_pc import parallelize
from satio_pc.clouds import scl_to_mask


import skimage.transform


def resize_chunk(chunk, scale=2, order=0):
    c = skimage.transform.rescale(chunk, scale=scale, order=order)
    return c


def resize_chunk_scl(chunk):
    return resize_chunk(chunk, scale=2, order=0)
    
    
def rescale_da(darr, chunks, scale=2, order=0):
    """Upscale dask array"""

    darr_scaled = da.map_blocks(
        resize_chunk,
        darr,
        dtype=darr.dtype,
        chunks=chunks,
        scale=scale,
        order=order)
    
    return darr_scaled


def mask_clouds(darr, mask):
    """darr has dims (time, band, y, x),
    mask has dims (time, y, x)"""

    mask_bands = da.expand_dims(mask, 1)
    mask_bands = da.broadcast_to(mask_bands, darr.shape)

    darr_masked = da.where(~mask_bands, 0, darr)
    return darr_masked
# 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

In [None]:
catalog = pystac_client.Client.open(
    "https://planetarycomputer.microsoft.com/api/stac/v1",
    modifier=planetary_computer.sign_inplace,
)


with open('../../token') as f:
    sas_token = f.read()
    
container_client = azure.storage.blob.ContainerClient(
    "https://dza2.blob.core.windows.net",
    container_name="feats",
    credential=sas_token,
)

In [None]:
year = 2020
tile_id = '31UFS'

In [None]:
time_range = f"{year}-01-01/{year + 1}-01-01"

query_params = {"eo:cloud_cover": {"lt": 90},
                "s2:mgrs_tile": {"eq": "31UFS"}}

search = catalog.search(collections=["sentinel-2-l2a"],
                        datetime=time_range,
                        query=query_params)
items = search.get_all_items()


bands = ['B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08',
         'B09', 'B11', 'B12', 'B8A', 'SCL']

assets_10m = ['B02', 'B03', 'B04', 'B08']
assets_20m = ['B05', 'B06', 'B07', 'B8A', 'B09',
              'B11', 'B12']
scl = 'SCL'

ds = {}
assets = {10: assets_10m,
          20: assets_20m,
         'scl': [scl]}

for res in assets.keys():
    ds[res] = stackstac.stack(items, assets=assets[res])  

In [None]:
erode_r = 3
dilate_r = 12

max_invalid_ratio = 0.9
mask_values = [0, 1, 3, 8, 9, 10, 11]

scl_data = ds['scl'].sel(band='SCL')
scl_mask = scl_to_mask(scl_data,
                       mask_values=mask_values,
                       erode_r=erode_r,
                       dilate_r=dilate_r,
                       max_invalid_ratio=max_invalid_ratio)

In [None]:
c0, _, c1, c2 = ds[20].chunks

scale = 2
c1 = tuple(map(lambda x: x * scale, c1))
c2 = tuple(map(lambda x: x * scale, c2))

In [None]:
ds20 = mask_clouds(ds[20], scl_mask.mask)

mask10 = rescale_da(scl_mask.mask,
                    chunks=(c0, c1, c2),
                    scale=scale).rechunk(chunks=(1, 1024, 1024))
ds10 = mask_clouds(ds[10], mask10)

In [None]:
ds[10]

In [None]:
%%time

ts0 = ds[10][:, 1, :256, :256].compute()
ts1 = ds10[:, 1, :256, :256].compute()

In [None]:
# for some reason ts1 is a numpy array
ts1 = ts0.copy(data=ts1)

In [None]:
import ipywidgets as ipw
import hvplot.xarray # noqa
import hvplot.pandas # noqa
import panel as pn
import pandas as pd
import panel.widgets as pnw
import xarray as xr

In [None]:
ts0.interactive.sel(time=pnw.DiscreteSlider).plot(vmin=0, vmax=2500)

In [None]:
ts1.interactive.sel(time=pnw.DiscreteSlider).plot(vmin=0, vmax=2500)

In [None]:
darr = scl_mask.mask

In [None]:
mask20

In [None]:
scale=2,
order=0,
chunks=(1, 1024, 1024)):

"""Upscale dask array"""
shape = darr.shape
dtype=darr.dtype
new_shape = shape[0], shape[1] * scale, shape[2] * scale

darr_scaled = da.map_blocks(
    resize_chunk,
    darr,
    kwargs={'scale': 2, 'order': 0},
    template=xr.DataArray(
        da.zeros(new_shape, chunks=chunks, dtype=bool)))

In [None]:
x = ds[20].x.values
y = ds[20].y.values
time = ds[20].time.values

scl20_da = xr.DataArray(scl_mask.mask, 
                         dims=('time', 'y', 'x'),
                         coords={'time': time, 'y': y, 'x': x},
                         name='SCL 10m')

In [None]:
import numpy as np

In [None]:
mask_10 = scl20_da.astype(np.uint8)

In [None]:
mask_10 = mask_10.interp(x=ds[10].x, y=ds[10].y, method="nearest")

In [None]:
scl20_da

In [None]:
resized_scl

In [None]:
ds10_masked = da.where(

