In [None]:
%load_ext autoreload
%autoreload 2

import psutil
import dask.distributed
import rioxarray
import numpy as np
import xarray as xr
from odc.stac import stac_load
from pystac_client import Client

import warnings
warnings.filterwarnings("ignore", category=UserWarning, message="angle from rectified to skew grid parameter lost in conversion to CF")

In [None]:
# Get the available memory in gigabytes (to later adapt chunks size)

available_memory = psutil.virtual_memory().available
available_memory_gb = available_memory / (1024 ** 3)

print(f"Available memory: {available_memory_gb:.2f} GB")

In [None]:
# Initiate Dask Env
client = dask.distributed.Client()
display(client)

In [None]:
catalog = Client.open("https://explorer.swissdatacube.org/stac")

product = 's2_l2'
measurements = ["B02", "B04", "B08", "SCL"]
time = ("2020-06-20", "2020-09-23")

# longitude = (8.270401, 9.270401)
# latitude = (46.538201, 47.538201)
longitude = (7.385559, 7.785187)
latitude = (47.444126, 47.623999)

resolution = -10.0, 10.0
output_crs = 'epsg:2056'

chunks = {"x": 2048, "y": 2048, "time": 1}  # 2048 values are OK with ~21Gb memory available

In [None]:
# Mask function
# See https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/scene-classification/
# for valid_cats
def create_scl_clean_mask(scl, valid_cats = [4, 5, 6, 7, 11]):
    return xr.apply_ufunc(np.isin, scl, valid_cats, dask='allowed')

In [None]:
%%time

# search items in catalog
query = catalog.search(
    collections=[product],
    datetime=f"{time[0]}/{time[1]}", limit=100,
    bbox=(longitude[0], latitude[0],
          longitude[1], latitude[1])
)
items = list(query.items())

# load identified items
lazy_ds = stac_load(
    items,
    lon=longitude,
    lat=latitude,
    bands=measurements,
    crs=output_crs,
    resolution=resolution[1],
    chunks=chunks,
)

# remove nodata
lazy_ds = lazy_ds.where(create_scl_clean_mask(lazy_ds.SCL)).drop_vars('SCL')

# manage tiles daily overlays
lazy_ds['time'] = lazy_ds['time'].to_index().to_period('D').to_timestamp()
if len(list(lazy_ds['time'].values)) != len(set(lazy_ds['time'].values)):
    lazy_ds = lazy_ds.groupby('time').mean(dim='time', skipna=True)

# convert DN to SR (and remove saturated pixels)
lazy_ds = lazy_ds.where(lazy_ds <= 10000) / 10000

# compute daily LAI
evi = 2.5 *((lazy_ds.B08 - lazy_ds.B04)/(lazy_ds.B08 + 6 * lazy_ds.B04 - 7.5 * lazy_ds.B02 + 1))
lai = 3.618 * evi - 0.118

# compute mean and cnt (number of pixels with data)
lai = lai.median(dim='time', skipna=True)

In [None]:
# query.item_collection_as_dict()

In [None]:
%%time

# Perform calculation (you can open the link generated by
# "Initiate Dask Env" cell to monitor how your resources are used
lai_sdc = lai.load()

In [None]:
lai_sdc.plot(cmap='Greens')

In [None]:
lai_sdc.plot(cmap='Greens', robust = True)

In [None]:
# export 
lai_sdc.rio.to_raster(raster_path="lai_sdc.tif", driver="COG")