In [1]:
from dask_gateway import GatewayCluster
import dask.distributed
import dask.utils
import planetary_computer
from pystac_client import Client
import odc.stac
import geopandas
import numpy
import xarray

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import Image

In [3]:
def expand_ls_qa_pixel_msks(scn_xa, qa_pxl_msk="QA_PIXEL"):
    scn_lcl_xa = scn_xa.copy()
    unq_img_vals = numpy.unique(numpy.squeeze(scn_xa["QA_PIXEL"].values))
    
    fill_da = scn_xa["QA_PIXEL"].copy()
    fill_da[...] = 0
    fill_da = fill_da.astype(numpy.uint8)
    
    dilated_clouds_da = scn_xa["QA_PIXEL"].copy()
    dilated_clouds_da[...] = 0
    dilated_clouds_da = dilated_clouds_da.astype(numpy.uint8)
    
    cirrus_da = scn_xa["QA_PIXEL"].copy()
    cirrus_da[...] = 0
    cirrus_da = cirrus_da.astype(numpy.uint8)
    
    clouds_da = scn_xa["QA_PIXEL"].copy()
    clouds_da[...] = 0
    clouds_da = clouds_da.astype(numpy.uint8)
    
    cloud_shadows_da = scn_xa["QA_PIXEL"].copy()
    cloud_shadows_da[...] = 0
    cloud_shadows_da = cloud_shadows_da.astype(numpy.uint8)
    
    snow_da = scn_xa["QA_PIXEL"].copy()
    snow_da[...] = 0
    snow_da = snow_da.astype(numpy.uint8)
    
    clear_da = scn_xa["QA_PIXEL"].copy()
    clear_da[...] = 0
    clear_da = clear_da.astype(numpy.uint8)
    
    water_da = scn_xa["QA_PIXEL"].copy()
    water_da[...] = 0
    water_da = water_da.astype(numpy.uint8)
    
    all_clouds_da = scn_xa["QA_PIXEL"].copy()
    all_clouds_da[...] = 0
    all_clouds_da = all_clouds_da.astype(numpy.uint8)
    
    for val in unq_img_vals:
        val_bin = numpy.flip(numpy.unpackbits(numpy.flip(numpy.array([val]).view(numpy.uint8))))
        #print("{} = {}".format(val, val_bin))
        if val_bin[0] == 1:
            fill_da.values[scn_xa["QA_PIXEL"].values == val] = 1
        if val_bin[1] == 1:
            dilated_clouds_da.values[scn_xa["QA_PIXEL"].values == val] = 1
        if val_bin[2] == 1:
            cirrus_da.values[scn_xa["QA_PIXEL"].values == val] = 1
        if val_bin[3] == 1:
            clouds_da.values[scn_xa["QA_PIXEL"].values == val] = 1
        if val_bin[4] == 1:
            cloud_shadows_da.values[scn_xa["QA_PIXEL"].values == val] = 1
        if val_bin[5] == 1:
            snow_da.values[scn_xa["QA_PIXEL"].values == val] = 1
        if val_bin[6] == 1:
            clear_da.values[scn_xa["QA_PIXEL"].values == val] = 1
        if val_bin[7] == 1:
            water_da.values[scn_xa["QA_PIXEL"].values == val] = 1
        if (val_bin[1] == 1) or (val_bin[2] == 1) or (val_bin[3] == 1) or (val_bin[4] == 1):
            all_clouds_da.values[scn_xa["QA_PIXEL"].values == val] = 1
    
    scn_lcl_xa["FILL"]=fill_da
    scn_lcl_xa["DILATED_CLOUDS"]=dilated_clouds_da
    scn_lcl_xa["CIRRUS"]=cirrus_da
    scn_lcl_xa["CLOUDS"]=clouds_da
    scn_lcl_xa["CLOUD_SHADOWS"]=cloud_shadows_da
    scn_lcl_xa["SNOW"]=snow_da
    scn_lcl_xa["CLEAR"]=clear_da
    scn_lcl_xa["WATER"]=water_da
    scn_lcl_xa["ALL_CLOUDS"]=all_clouds_da
    return scn_lcl_xa

In [4]:
def limit_range_np_arr(
    arr_data: numpy.array,
    min_thres: float = 0,
    min_out_val: float = 0,
    max_thres: float = 1,
    max_out_val: float = 1,
) -> numpy.array:
    """
    A function which can be used to limit the range of the numpy array.
    For example, to mask values less than 0 to 0 and values greater than
    1 to 1.

    :param arr_data: input numpy array.
    :param min_thres: the threshold for the minimum value.
    :param min_out_val: the value assigned to values below the min_thres
    :param max_thres: the threshold for the maximum value.
    :param max_out_val: the value assigned to the values above the max_thres
    :return: numpy array with output values.

    """
    arr_data_out = arr_data.copy()
    arr_data_out[arr_data < min_thres] = min_out_val
    arr_data_out[arr_data > max_thres] = max_out_val
    return arr_data_out


def cumulative_stretch_np_arr(
    arr_data: numpy.array,
    no_data_val: float = None,
    lower: int = 2,
    upper: int = 98,
    out_off: float = 0,
    out_gain: float = 1,
    out_int_type=False,
    min_out_val: float = 0,
    max_out_val: float = 1,
) -> numpy.array:
    """
    A function which performs a cumulative stretch using an upper and lower
    percentile to define the min-max values. This analysis is on a per
    band basis for a numpy array representing an image dataset. This function
    is useful in combination with get_gdal_raster_mpl_imshow for displaying
    raster data from an input image as a plot. By default this function returns
    values in a range 0 - 1 but if you prefer 0 - 255 then set the out_gain to
    255 and the out_int_type to be True to get an 8bit unsigned integer value.

    :param arr_data: The numpy array as either [n,m,b] or [n,m] where n and m are
                     the number of image pixels in the x and y axis' and b is the
                     number of image bands.
    :param no_data_val: the no data value for the input data. If there isn't a no
                        data value then leave as None (default)
    :param lower: lower percentile (default: 2)
    :param upper: upper percentile (default: 98)
    :param out_off: Output offset value (value * gain) + offset. Default: 0
    :param out_gain: Output gain value (value * gain) + offset. Default: 1
    :param out_int_type: False (default) and the output type will be float and
                         True and the output type with be integers.
    :param min_out_val: Minimum output value within the output array (default: 0)
    :param max_out_val: Maximum output value within the output array (default: 1)
    :return: A number array with the rescaled values but same dimensions as the
             input numpy array.

    .. code:: python

        img_sub_bbox = [554756, 577168, 9903924, 9944315]
        input_img = "sen2_img_strch.kea"

        img_data_arr, coords_bbox = get_gdal_raster_mpl_imshow(input_img,
                                                               bands=[8,9,3],
                                                               bbox=img_sub_bbox)

        img_data_arr = cumulative_stretch_np_arr(img_data_arr, no_data_val=0.0)

        import matplotlib.pyplot as plt
        fig, ax = plt.subplots()
        im = ax.imshow(img_data_arr, extent=coords_bbox)
        plt.show()

    """
    arr_shp = arr_data.shape

    if no_data_val is not None:
        arr_data_out = arr_data.astype(float)
        arr_data_out[arr_data == no_data_val] = numpy.nan
    else:
        arr_data_out = arr_data.copy()

    if len(arr_shp) == 2:
        min_val, max_val = numpy.nanpercentile(arr_data_out, [lower, upper])
        range_val = max_val - min_val

        arr_data_out = (((arr_data_out - min_val) / range_val) * out_gain) + out_off
    else:
        n_bands = arr_shp[2]
        for n in range(n_bands):
            min_val, max_val = numpy.nanpercentile(arr_data_out[..., n], [lower, upper])
            range_val = max_val - min_val

            arr_data_out[..., n] = (
                ((arr_data_out[..., n] - min_val) / range_val) * out_gain
            ) + out_off

    arr_data_out = limit_range_np_arr(
        arr_data_out,
        min_thres=min_out_val,
        min_out_val=min_out_val,
        max_thres=max_out_val,
        max_out_val=max_out_val,
    )

    if out_int_type:
        arr_data_out = arr_data_out.astype(int)

    return arr_data_out

In [None]:
cluster = GatewayCluster()  # Creates the Dask Scheduler. Might take a minute.
#client = cluster.get_client()
cluster.adapt(minimum=4, maximum=24)
print(cluster.dashboard_link)

In [None]:
client = dask.distributed.Client(cluster)

#client = dask.distributed.Client()

odc.stac.configure_rio(cloud_defaults=True, client=client)
display(client)

In [5]:
catalog = Client.open("https://planetarycomputer.microsoft.com/api/stac/v1")

In [6]:
time_range = "2018-01-01/2018-12-31"
bbox = [100.59, 4.79, 100.61, 4.81]
#bbox = [100.35, 4.35, 100.85, 5.1]

search = catalog.search(collections=["landsat-8-c2-l2"], bbox=bbox, datetime=time_range)
items = search.get_all_items()

In [7]:
len(items)

23

In [8]:
#items = items[:3]
#len(items)

In [9]:
signed_items = [planetary_computer.sign(item) for item in items]

In [10]:
bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7", "QA_PIXEL"]

ls8_scn_xa = odc.stac.stac_load(
    signed_items,
    bands=bands,
    #chunks={"x": 1028, "y": 1028},
)

In [11]:
# Load into memory to make processing faster!
#ls8_scn_xa = ls8_scn_xa.persist()

In [12]:
# Subset to 6 images for testing
#ls8_scn_xa = ls8_scn_xa.isel(time=numpy.s_[:6])

In [13]:
ls8_scn_xa

In [None]:
ls8_scn_qa_xa = ls8_scn_xa.map_blocks(expand_ls_qa_pixel_msks)

In [None]:
#ls8_scn_qa_xa = expand_ls_qa_pixel_msks(ls8_scn_xa)

In [None]:
ls8_scn_qa_xa.coords["spatial_ref"] = ls8_scn_xa.coords["spatial_ref"]

In [None]:
ls8_scn_qa_xa

# Apply Cloud Masks

In [None]:
def apply_cloud_msk(scns_xa):
    scns_lcl_xa = scns_xa.copy()
    bands = ["SR_B1", "SR_B2", "SR_B3", "SR_B4", "SR_B5", "SR_B6", "SR_B7"]
    for band in bands:
        scns_lcl_xa[band].values[scns_lcl_xa["ALL_CLOUDS"].values == 1] = 0.0
        scns_lcl_xa[band].values[scns_lcl_xa["FILL"].values == 1] = 0.0
    return scns_lcl_xa

In [None]:
ls8_scn_qa_mskd_xa = ls8_scn_qa_xa.map_blocks(apply_cloud_msk)

In [None]:
ls8_scn_qa_mskd_xa.coords["spatial_ref"] = ls8_scn_xa.coords["spatial_ref"]

In [None]:
ls8_scn_qa_mskd_xa

In [None]:
ls8_scn_qa_mskd_xa = ls8_scn_qa_mskd_xa.drop(["QA_PIXEL", "FILL", "DILATED_CLOUDS", "CIRRUS", "CLOUDS", "CLOUD_SHADOWS", "SNOW", "CLEAR", "WATER", "ALL_CLOUDS"])

In [None]:
ls8_scn_qa_mskd_xa

In [None]:
# Debug - view b1 from an individual scene
#ls8_scn_sgl_xa = ls8_scn_qa_mskd_xa.isel(time=numpy.s_[1])
#img_scn_bbox = [float(ls8_scn_sgl_xa.x.min()), float(ls8_scn_sgl_xa.x.max()), float(ls8_scn_sgl_xa.y.min()), float(ls8_scn_sgl_xa.y.max())]
#fig, ax = plt.subplots(figsize=(10, 10))
#ax.imshow(numpy.squeeze(ls8_scn_sgl_xa["SR_B1"]), extent=img_bbox)

## Convert the zero's to NaN so ignored in calculations.

In [None]:
ls8_scn_qa_mskd_xa = ls8_scn_qa_mskd_xa.where(ls8_scn_qa_mskd_xa > 1)

## Calculate Median Image

In [None]:
ls8_scn_med_xa = ls8_scn_qa_mskd_xa.median(dim="time", skipna=True).compute()

In [None]:
ls8_scn_med_xa

## Create Figure

In [None]:
band_stack = numpy.stack([ls8_scn_med_xa["SR_B5"].values, ls8_scn_med_xa["SR_B6"].values, ls8_scn_med_xa["SR_B4"].values])
band_stack = numpy.moveaxis(band_stack, 0, -1)
band_stack_stch = cumulative_stretch_np_arr(band_stack, no_data_val=0.0)
img_bbox = [float(ls8_scn_med_xa.x.min()), float(ls8_scn_med_xa.x.max()), float(ls8_scn_med_xa.y.min()), float(ls8_scn_med_xa.y.max())]
fig, ax = plt.subplots(figsize=(10, 10))
ax.imshow(band_stack_stch, extent=img_bbox)

# Interactive Map

In [None]:
import folium

In [None]:
band_stack_stch_uint = numpy.nan_to_num(band_stack_stch)
band_stack_stch_uint = band_stack_stch_uint*255
band_stack_stch_uint = band_stack_stch_uint.astype(numpy.uint8)

In [None]:
from shapely.geometry import Polygon
x_point_list = [img_bbox[0], img_bbox[1], img_bbox[1], img_bbox[0]]
y_point_list = [img_bbox[3], img_bbox[3], img_bbox[2], img_bbox[2]]

img_bbox_geom = Polygon(zip(x_point_list, y_point_list))
crs = {'init': ls8_scn_xa.crs}
img_bbox_gdf = geopandas.GeoDataFrame(index=[0], crs=crs, geometry=[img_bbox_geom])
img_bbox_wgs84_gdf = img_bbox_gdf.to_crs("EPSG:4326")

img_bounds = numpy.dstack(img_bbox_wgs84_gdf.geometry[0].boundary.coords.xy).tolist()

# switch x/y as lat/lon
img_bounds_latlon = [[]]
for pt in img_bounds[0]:
    n_pt = [pt[1], pt[0]]
    img_bounds_latlon[0].append(n_pt)

# lat / lon to switch
img_cen_srs = img_bbox_wgs84_gdf.centroid
scn_cen_y = img_cen_srs.geometry[0].x
scn_cen_x = img_cen_srs.geometry[0].y
print("{} {}".format(scn_cen_y, scn_cen_x))

In [None]:
m = folium.Map([scn_cen_x, scn_cen_y], zoom_start=8)
folium.raster_layers.ImageOverlay(
    image=band_stack_stch_uint,
    bounds=img_bounds_latlon,
).add_to(m)

m