In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
import xarray as xr

import hdstats
import odc.algo

## Setup local dask cluster

In [None]:
from datacube.utils.rio import configure_s3_access
from datacube.utils.dask import start_local_dask
import os
import dask
from dask.utils import parse_bytes

# configure dashboard link to go over proxy
dask.config.set({"distributed.dashboard.link":
                 os.environ.get('JUPYTERHUB_SERVICE_PREFIX', '/')+"proxy/{port}/status"});

# Figure out how much memory/cpu we really have (those are set by jupyterhub)
mem_limit = int(os.environ.get('MEM_LIMIT', '0'))
cpu_limit = float(os.environ.get('CPU_LIMIT', '0'))
cpu_limit = int(cpu_limit) if cpu_limit > 0 else 4
mem_limit = mem_limit if mem_limit > 0 else parse_bytes('8Gb')

# leave 4Gb for notebook itself
mem_limit -= parse_bytes('4Gb')

# close previos client if any, so that one can re-run this cell without issues
client = locals().get('client', None)
if client is not None:
    client.close()
    del client
    
client = start_local_dask(n_workers=1,
                          threads_per_worker=cpu_limit, 
                          memory_limit=mem_limit)
display(client)

# Configure GDAL for s3 access 
configure_s3_access(aws_unsigned=True,  # works only when reading public resources
                    client=client);

In [None]:
from datacube import Datacube
from odc.algo import fmask_to_bool, to_f32, from_float, xr_geomedian

product = 'ga_s2a_ard_nbar_granule'
product = (product, product.replace('s2a', 's2b'))

dc = Datacube()

In [None]:
region_code, time = '56HLK', ('2019-06-01', '2019-08-31') #('2019-06', '2019-11')

dss = []
for p in product:
    dss += dc.find_datasets(product=p, 
                            region_code=region_code, 
                            time=time)


tsm_dss = dc.find_datasets(product='s2_tsmask', 
                           time=time,
                           region_code=region_code
                           )
len(dss), len(tsm_dss)

## Do native load (lazy version with Dask)

In [None]:
data_bands = [
 #'nbar_coastal_aerosol',
 'nbar_blue',
 'nbar_green',
 'nbar_red',
 #'nbar_red_edge_1',
 #'nbar_red_edge_2',
 #'nbar_red_edge_3',
 #'nbar_nir_1',
 #'nbar_nir_2',
 #'nbar_swir_2',
 #'nbar_swir_3',
]

mask_bands = ['fmask']

xx = dc.load(product=dss[0].type.name,
             output_crs=dss[0].crs,
             resolution=(-10, 10),
             align=(0, 0),
             measurements=data_bands + mask_bands,
             group_by='solar_day',
             datasets=dss, 
             dask_chunks=dict(
                 x=1000, 
                 y=1000)
            )

In [None]:
xx

In [None]:
tsm = dc.load(product='s2_tsmask',
              like=xx.geobox,
              datasets=tsm_dss, 
              dask_chunks=dict(
                 x=1000,
                 y=1000)
             )
tsm

In [None]:
# Select a 3k by 3k subsection, to speed up testing
if True:
    _roi = dict(x=np.s_[0:3000], y=np.s_[-3000:])
    xx = xx.isel(**_roi)
    tsm = tsm.isel(**_roi)

## Compute geomedian on data_bands
1. Convert fmask to boolean: `True` - use, `False` - do not use
2. Apply masking in native dtype for data bands only
3. Convert to `float32` with scaling
4. Reduce time dimension with geometric median
5. Convert back to native dtype with scaling

All steps are dask operations, so no actuall computation is done until `.compute()` is called.

In [None]:
fm_nocloud = fmask_to_bool(xx.fmask, ('water', 'snow', 'valid'))
tsm_nocloud = fmask_to_bool(tsm.classification, ('valid',))

In [None]:
nocloud = tsm_nocloud

In [None]:
scale, offset = (1/10_000, 0)  # differs per product, aim for 0-1 values in float32

xx_data = xx[data_bands]
xx_clean = odc.algo.keep_good_only(xx_data, where=nocloud)
xx_clean = to_f32(xx_clean, scale=scale, offset=offset)
yy = xr_geomedian(xx_clean, 
                  num_threads=1,  # disable internal threading, dask will run several concurrently
                  eps=0.2*scale,  # 1/5 pixel value
                  nocheck=True)   # disable some checks inside geomedian library that use too much ram

yy = from_float(yy, 
                dtype='int16', 
                nodata=-999, 
                scale=1/scale, 
                offset=-offset/scale)

## Now we can run the computation

In [None]:
%%time
yy = yy.compute()

## Convert to RGBA and display

In [None]:
from odc.ui import to_rgba, to_png_data
from IPython.display import Image

rgba = to_rgba(yy, clamp=3000)

In [None]:
%%time
png_data = to_png_data(rgba.data)

In [None]:
if max(rgba.shape) < 4000:
    display(Image(png_data))
else:
    print('image too large to show')

In [None]:
with open(f'rgba-{region_code}-s2ab-jja-xxx.png', 'wb') as f:
    f.write(png_data)

------------------------------------------------------------------