In [None]:
%load_ext autoreload

%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
from skimage.exposure import rescale_intensity

from sentinelhub import CRS, BBox, DataCollection, SHConfig
from sentinelhub import SentinelHubCatalog
from sentinelhub import SHConfig
from oxeo.core.stac import landsat
from oxeo.core.stac.constants import USWEST_URL ,ELEMENT84_URL
uswest_config = SHConfig()
uswest_config.sh_base_url = USWEST_URL
from oxeo.core.stac.constants import USWEST_URL ,ELEMENT84_URL, LANDSATLOOK_URL, LANDSAT_SEARCH_PARAMS

uswest_catalog = SentinelHubCatalog(config=uswest_config)
eu_catalog = SentinelHubCatalog(SHConfig())
from oxeo.core.data import get_aoi_from_stac_catalog
import os
os.environ["AWS_REQUEST_PAYER"] = "requester"

from oxeo.core.utils import get_bounding_box

import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10, 8]

# Example bbox

In [None]:
box = get_bounding_box({"geometry": {
        "type": "Polygon",
        "coordinates": [
          [
            [
              14.911966323852539,
              37.30573714593416
            ],
            [
              14.987583160400392,
              37.30573714593416
            ],
            [
              14.987583160400392,
              37.345050859282736
            ],
            [
              14.911966323852539,
              37.345050859282736
            ],
            [
              14.911966323852539,
              37.30573714593416
            ]
          ]
        ]
      }})

In [None]:
bbox = BBox(box, crs=CRS.WGS84)#BBox([49.9604, 44.7176, 51.0481, 45.2324], crs=CRS.WGS84) #BBox((-71.40254974365233, -46.9537775782648, -71.26213073730467, -46.89163931213445), crs=CRS.WGS84)#
landsat_time_interval = "1999-01-01", "2002-03-01" 
s1_s2_time_interval = "2020-01-01", "2020-03-01" 

# Filter clouds
search_params = {"query":{"eo:cloud_cover": {
                                                              "gte": 0,
                                                              "lte": 20
                                                    },}}


In [None]:
s2_aoi = get_aoi_from_stac_catalog(catalog=ELEMENT84_URL,
                                   data_collection="sentinel-s2-l2a-cogs",
                                   bbox=bbox,
                                   time_interval=s1_s2_time_interval,
                                   search_params=search_params)

In [None]:
s2_aoi

In [None]:
x_min = None
x_max = 3000
y_min = None
y_max = 3000

In [None]:
plt.figure(figsize=(10,10))
img = s2_aoi.sel(band=["B04","B03","B02"])[:5].median(axis=0)[:, x_min:x_max, y_min:y_max].values.transpose(1,2,0)
plt.imshow(rescale_intensity(img,(np.nanmin(img), np.nanmax(img))))

In [None]:
s1_aoi = get_aoi_from_stac_catalog(catalog=eu_catalog,
                                   data_collection=DataCollection.SENTINEL1,
                                   bbox=bbox,
                                   time_interval=s1_s2_time_interval,
                                   search_params={}, resolution=10, orbit_state="descending")


                                        


In [None]:
s1_aoi

In [None]:
vv = s1_aoi.sel(band="vv")[:5].mean(axis=0)[x_min:x_max, y_min:y_max].compute()
vh = s1_aoi.sel(band="vh")[:5].mean(axis=0)[x_min:x_max, y_min:y_max].compute()


In [None]:
vv.shape

In [None]:
plt.figure(figsize=(10,10))
vv_thresh = 80
rgb = np.where(
    vv < vv_thresh,
    [vv, 8 * vv, 0.5 + 3 * vv + 2000 * vh],
    [3 * vv, 1.1 * vv + 8.75 * vh, 1.75 * vh],
).transpose(1, 2, 0)

cutoff = 2000
dis = rescale_intensity(rgb, in_range=(0, cutoff), out_range=(0, 1))

plt.imshow(dis)

In [None]:
plt.figure(figsize=(10,10))
landsat_aoi = get_aoi_from_stac_catalog(catalog="https://landsatlook.usgs.gov/stac-server",
                                   data_collection="landsat-c2l2-sr",
                                   bbox=bbox,
                                   time_interval=landsat_time_interval,
                                   search_params={"query":{"eo:cloud_cover": {
                                                              "gte": 0,
                                                              "lte": 10
                                                    },}},
                                   resolution=10)

img = landsat_aoi.sel(band=["red","green","blue"])[:5].min(axis=0)[:,x_min:x_max, y_min:y_max].values.transpose(1,2,0)
plt.imshow(rescale_intensity(img,(np.nanmin(img), np.nanmax(img))))

# Prediction

In [None]:

import numpy as np


import dask
from distributed import Client


from oxeo.water.models.segmentation import Segmentation2DPredictor, DaskSegmentationPredictor
from oxeo.water.models.segmentation import reconstruct_image_from_patches,stack_preds,reduce_to_timeseries
from oxeo.core.models.tile import load_tile_from_stac_as_dict, load_aoi_from_stac_as_dict, tile_from_id, TilePath, tile_to_geom
from oxeo.core import data
import matplotlib.pyplot as plt
from oxeo.core.constants import BAND_PREDICTOR_ORDER

In [None]:
predictor = DaskSegmentationPredictor(
    ckpt_path="../data/semseg_epoch_012.ckpt",
    fs=None,
)

In [None]:
preds, aoi = predictor.predict_stac_aoi(
    constellation="sentinel-2",
    catalog=ELEMENT84_URL,
    data_collection="sentinel-s2-l2a-cogs",
    bbox=bbox,
    time_interval=s1_s2_time_interval,
    search_params={"query":{"eo:cloud_cover": {
                                                              "gte": 0,
                                                              "lte": 10
                                                    },}},
    resolution=10
)

In [None]:
stack = stack_preds(preds)
revisits, _, target_h, target_w = aoi.shape
mask = reconstruct_image_from_patches(stack, revisits, target_h, target_w, patch_size=250)

In [None]:
#client =  Client(n_workers=4, threads_per_worker=1, memory_limit="16GB") 

mask_out = mask[:4, :, :].compute()
plt.imshow(mask_out[0])

In [None]:
ts = reduce_to_timeseries(mask_out)
ts_out = ts.compute()
print(ts_out)

In [None]:
preds, aoi = predictor.predict_stac_aoi(
    constellation="landsat",
    catalog=LANDSATLOOK_URL,
    data_collection="landsat-c2l2-sr",
    bbox=bbox,
    time_interval=landsat_time_interval,
    search_params={"query":{"platform": {
                                  "in": ["LANDSAT_7"]
                                },

                             
                                "eo:cloud_cover": {
                                                              "gte": 0,
                                                              "lte": 10
                                                    },}},
    resolution=10
)

In [None]:
img = aoi.sel(band=["red","green","blue"])[1][:,x_min:x_max, y_min:y_max].values.transpose(1,2,0)
plt.imshow(rescale_intensity(img,(np.nanmin(img), np.nanmax(img))))

In [None]:
stack = stack_preds(preds)
revisits, _, target_h, target_w = aoi.shape
mask = reconstruct_image_from_patches(stack, revisits, target_h, target_w, patch_size=250)

In [None]:
#client =  Client(n_workers=4, threads_per_worker=1, memory_limit="16GB") 

mask_out = mask[:2].compute()
plt.imshow(mask_out[0])

# NVDI

In [None]:
from oxeo.water.models.ndvi import NDVIPredictor

In [None]:
ndvi_predictor = NDVIPredictor()

aoi = ndvi_predictor.predict_stac_aoi(
    catalog=ELEMENT84_URL,
    data_collection="sentinel-s2-l2a-cogs",
    bbox=BBox(box, crs=CRS.WGS84),
    time_interval=s1_s2_time_interval,
    search_params={},
    resolution=10
)

In [None]:
plt.imshow(aoi[:5].median(axis=0).compute())

In [None]:
aoi = ndvi_predictor.predict_stac_aoi(
    catalog=LANDSATLOOK_URL,
    data_collection="landsat-c2l2-sr",
    bbox=BBox(box, crs=CRS.WGS84),
    time_interval=landsat_time_interval,
    search_params={},
    resolution=10
)

plt.imshow(aoi[:5].median(axis=0).compute())

In [None]:
aoi.shape

# Soil Moisture

In [None]:
from oxeo.water.models.soil_moisture import SoilMoisturePredictor
soil_predictor = SoilMoisturePredictor()
aoi = soil_predictor.predict_stac_aoi(
    catalog=eu_catalog,
    data_collection=DataCollection.SENTINEL1,
    bbox=BBox(box, crs=CRS.WGS84),
    time_interval=s1_s2_time_interval,
    search_params={},
    resolution=10
)

plt.imshow(aoi[:5].median(axis=0).compute(),cmap="jet")