In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import numpy as np
from skimage.exposure import rescale_intensity
import matplotlib.pyplot as plt

from sentinelhub import CRS, BBox, DataCollection, SHConfig
from sentinelhub import SentinelHubCatalog
from sentinelhub import SHConfig

import dask
from dask_kubernetes import make_pod_spec, KubeCluster
from distributed import Client

In [None]:
from oxeo.core.utils import get_bounding_box
from oxeo.core.stac import landsat
from oxeo.core.stac.constants import USWEST_URL, ELEMENT84_URL
from oxeo.core.data import get_aoi_from_landsat_shub_catalog, get_aoi_from_stac_catalog
from oxeo.water.models.segmentation import (
    Segmentation2DPredictor,
    DaskSegmentationPredictor,
    reconstruct_image_from_patches,
)
from oxeo.core.models.tile import (
    load_tile_from_stac_as_dict,
    load_aoi_from_stac_as_dict,
    tile_from_id,
    TilePath,
    tile_to_geom,
)
from oxeo.core import data
from oxeo.core.constants import BAND_PREDICTOR_ORDER

In [None]:
os.environ["AWS_REQUEST_PAYER"] = "requester"
uswest_config = SHConfig()
uswest_config.sh_base_url = USWEST_URL

uswest_catalog = SentinelHubCatalog(config=uswest_config)
eu_catalog = SentinelHubCatalog(SHConfig())

# Dask setup

In [None]:
def kube_cluster(workers=3, memory="32G", cpu=4):
    image = "413730540186.dkr.ecr.eu-central-1.amazonaws.com/flows:latest"
    pod_spec = make_pod_spec(
        image=image,
        cpu_request=cpu,
        cpu_limit=cpu,
        memory_request=memory,
        memory_limit=memory,
    )
    root_spec = make_pod_spec(image=image)
    return KubeCluster(
        n_workers=workers,
        pod_template=pod_spec,
        scheduler_pod_template=root_spec,
    )

In [None]:
cluster = kube_cluster()

In [None]:
cluster

In [None]:
# client = Client(n_workers=4, threads_per_worker=1, memory_limit="16GB")
client = Client(cluster)
client

# Create and run predictor

In [None]:
box = (14.9119, 37.3057, 14.9875, 37.3450)
bbox = BBox(box, crs=CRS.WGS84)
time_interval = "2020-12-10", "2021-02-01"

In [None]:
s2_predictor = DaskSegmentationPredictor(
    ckpt_path="../data/semseg_epoch_012.ckpt",
    fs=None,
    bands=BAND_PREDICTOR_ORDER["sentinel-2"],
)

In [None]:
preds, aoi = s2_predictor.predict_stac_aoi(
    constellation="sentinel-2",
    catalog=ELEMENT84_URL,
    data_collection="sentinel-s2-l2a-cogs",
    bbox=bbox,
    time_interval=time_interval,
    search_params={},
)

In [None]:
aoi

In [None]:
res = client.compute(preds)

In [None]:
res

In [None]:
stack = np.vstack([e.result() for e in res])
mask = reconstruct_image_from_patches(
    stack, aoi.shape[0], aoi.shape[-2], aoi.shape[-1], patch_size=250
)

In [None]:
img = aoi[1, [3, 2, 1], :, :].values.transpose(1, 2, 0)
plt.imshow(rescale_intensity(img, (np.nanmin(img), np.nanmax(img))))

In [None]:
plt.imshow(mask[1, :, :])

# Don't forget to CLOSE the cluster!

In [None]:
cluster.close()