# Compare GM outputs from notebook and odc-stats

In [None]:
from odc.stats.plugins.gm import StatsGMS2
from odc.stats.tasks import TaskReader
from odc.stats.model import OutputProduct

In [None]:
# find tile locations using .geojson output from odc-stats save-tasks
# define x, y index and year, which has to be aligned with what's used in odc-stats save-tasks

x, y = 235, -15
year = "2024--P1Y"
cloud_filters= {"cloud shadows":[["dilation", 5]], "cloud medium probability":[["opening", 5], ["dilation", 5]], "cloud high probability":[["opening", 5], ["dilation", 5]], "thin cirrus":[["dilation", 5]]}


## Accessing task database

In [None]:
# Dummy product to pass to taskreader

name, version = 'gm_s2', '0-0-1' # product name and version (appended to results path)
product = OutputProduct(name=name,
                        version=version,
                        short_name=name,
                        location = "",
                        properties = {},
                        measurements = [],
                       )



tidx = (year, x, y)
rdr = TaskReader("s2_l2a_2024--P1Y.db", product=product)
task = rdr.load_task(tidx)
task


## Option - use notebook method

In [None]:
import datacube

# Connect to ODC
dc = datacube.Datacube(app="geomad_s2")

In [None]:
sentinel2_datasets = dc.find_datasets(
    product=["s2_l2a"],
    time="2024",
    like=task.geobox.to_crs("EPSG:4326"), # non-EPSG is not accepted
    cloud_cover=(0, 10),
)

In [None]:
# expect same number of datasets found
len(task.datasets), len(sentinel2_datasets)

In [None]:
# Load available data
ds = dc.load(
    datasets=sentinel2_datasets,
    geopolygon=task.geobox.extent,  
    measurements=["red", "green", "blue", "scl"],
    dask_chunks={"x": 5000, "y": 5000, "time":-1},
    resolution=(-10,10),
    group_by="solar_day",
    output_crs=task.geobox.crs, #"ESRI:54034"
    driver="rio",
)

In [None]:
# expect same number of days for this tile in s2_l2a_2024--P1Y.csv
ds

In [None]:
ds["scl"].attrs['flags_definition'] =dc.list_measurements().loc['s2_l2a','scl']['flags_definition']
ds["scl"].attrs['flags_definition']

In [None]:
from odc.algo import enum_to_bool, erase_bad
from odc.algo import mask_cleanup

# Apply the same cloud filter
# Erase Data Pixels for which mask == nodata
mask = ds["scl"]
bad = enum_to_bool(mask, ("no data",))
for cloud_class, c_filter in cloud_filters.items():
    if not isinstance(cloud_class, tuple):
        cloud_class = (cloud_class,)
        cloud_mask = enum_to_bool(mask, cloud_class)
        cloud_mask_buffered = mask_cleanup(cloud_mask, mask_filters=c_filter)
        bad = cloud_mask_buffered | bad

ds = ds.drop_vars(["scl"])
ds = erase_bad(ds, bad)

In [None]:
from odc.algo import geomedian_with_mads
gm = geomedian_with_mads(ds)

In [None]:
gm

### Option - use GM plugin to process

In [None]:
# Use gm plugin to generate GM
# use same configuration as 
gm = StatsGMS2(resampling="nearest", 
               bands=["blue", "green", "red", "nir", "swir16", "swir22"], 
                cloud_filters=cloud_filters,
               work_chunks=(100,100),
)
xx = gm.input_data(task.datasets, task.geobox)
xx = gm.native_transform(xx)
output = gm.reduce(xx)

In [None]:
output

In [None]:
#compute
#output = output.compute()

### Load odc-stats run output

In [None]:
import xarray as xr

#Assume output save in current location
basepath = f"x{x:03d}/y{y:03d}/{year}/gm_s2_x{x:03d}y{y:03d}_{year}"

bands = ['red', 'green', 'blue']
datasets = [xr.open_dataset(f'{basepath}_{band}.tif',chunks ={}).squeeze() for band in bands]
datasets = [ds.rename_vars({list(ds.data_vars)[0]: name}) for ds, name in zip(datasets, bands)]

combined = xr.merge(datasets)
combined

In [None]:
count = xr.open_dataset(f'{basepath}_COUNT.tif').squeeze().to_array().squeeze()
count.plot.imshow()

In [None]:
combined.red.plot.imshow(robust=True)