In [1]:
# jupyteronly
%load_ext autoreload
%autoreload 2
%matplotlib inline
# Supress Warning 
import warnings
warnings.filterwarnings('ignore')
# Set reference for util modules
#import sys
#sys.path.append('/home/jovyan/odc-hub/')
# Generic python
import matplotlib.pyplot as plt
import datacube
from matplotlib.cm import RdYlGn, Greens

import glob

#from datacube_utilities.interactive_maps import display_map

In [2]:
# generic python
import xarray as xr 
import odc.algo
from shapely import wkt
from datetime import datetime
import numpy as np

#from datacube_utilities.createAOI import create_lat_lon
#from datacube_utilities.createindices import NDVI, EVI
from odc.algo import to_f32, from_float, xr_geomedian
#from datacube_utilities.dc_mosaic import create_max_ndvi_mosaic, create_min_ndvi_mosaic, create_median_mosaic, create_mosaic, create_hdmedians_multiple_band_mosaic, create_mean_mosaic
from pyproj import Proj, transform
#from datacube_utilities.dc_fractional_coverage_classifier import frac_coverage_classify 
#from datacube_utilities.fromDCALscripts import threshold_plot
#from datacube_utilities.dc_utilities import write_png_from_xr, write_geotiff_from_xr
from datacube.utils.cog import write_cog

import yaml
import rioxarray as rxr

import dask
import dask.array as da
from dask.distributed import Client
#client = Client('dask-scheduler.dask.svc.cluster.local:8786')
client = Client(n_workers=2, threads_per_worker=4, memory_limit='7GB')

#client.get_versions(check=True)
#client

CMAP = "Blues"

In [3]:
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 2
Total threads: 8,Total memory: 13.04 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:37051,Workers: 2
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 13.04 GiB

0,1
Comm: tcp://127.0.0.1:34583,Total threads: 4
Dashboard: http://127.0.0.1:43989/status,Memory: 6.52 GiB
Nanny: tcp://127.0.0.1:42293,
Local directory: /tmp/dask-scratch-space/worker-hvu79359,Local directory: /tmp/dask-scratch-space/worker-hvu79359

0,1
Comm: tcp://127.0.0.1:35619,Total threads: 4
Dashboard: http://127.0.0.1:43811/status,Memory: 6.52 GiB
Nanny: tcp://127.0.0.1:37457,
Local directory: /tmp/dask-scratch-space/worker-2dhblibp,Local directory: /tmp/dask-scratch-space/worker-2dhblibp


### Parameters for Testing

In [4]:
baseline_product = 'landsat_8'
baseline_measurement = ["green","red","blue","nir","swir1","swir2", "pixel_qa"]

analysis_product = 'landsat_8'
analysis_measurement = ["green","red","blue","nir","swir1","swir2", "pixel_qa"]


### Running with dask chunks 
This process ran in about ~1min and works for both the baseline and analysis datasets.

In [5]:
# Set size of dask chunks to use for the scenes
dask_chunks = dict(
    x = 1000,
    y = 1000
)

In [6]:
def rename_bands(in_xr, des_bands, position):
    """
    (From genprepWater.py)
    """
    in_xr.name = des_bands[position]
    return in_xr

def stack_arrays(array_list, ref_band, time_handling, timestamp=None):
    """Given a list of arrays, stack them to prepare for merging into single dataset 
    or other computations. Note that this function should be used on bands of a single
    scene or scenes from the same tile (i.e. where the geographic area is the same).
    
    Parameters: 
    array_list - list of xarray DataArrays to be stacked
    """
    # Reproject and resample all bands to match the reference band
    array_list = [ array_list[i].rio.reproject_match(ref_band, fill_value=-9999) for i in range(len(array_list)) ]

    # Chunk the dataset with dask 
    array_list = [ array_list[i].chunk({'x': 1000, 'y': 1000}) for i in range(len(array_list)) ]

    # Avoid coord differences due to floats and manage the time dimension
    if time_handling == 'Bands':
        print('Using provided timestamp for each band.')
        array_list = [array_list[i].assign_coords({
        "x": array_list[0].x,
        "y": array_list[0].y,
        "time": xr.DataArray([timestamp], dims='time')}) for i in range(len(array_list))]
    elif time_handling == 'Scenes':
        print('Maintaining existing time dimension for each scene.')
        array_list = [array_list[i].assign_coords({
        "x": array_list[0].x,
        "y": array_list[0].y,
        "time": array_list[i].time}) for i in range(len(array_list))]
    elif time_handling == 'Composites':
        print('Not assigning timestamp or time dimension.')
        array_list = [array_list[i].assign_coords({
        "x": array_list[0].x,
        "y": array_list[0].y}) for i in range(len(array_list))]

    # Align the arrays to the first array in the list
    array_list = [ xr.align(array_list[0], array_list[i], join="override", fill_value=-9999)[1] for i in range(len(array_list)) ]

    return array_list


def stack_bands(bands_data, band_nms, timestamp):
    """
    (Originally from genprepWater.py)

    Use one of the bands as a reference for reprojection and resampling (matching the resolution of
    the reference band), then align all bands before merging into a single xarray dataset. 

    Note the original function used the first band by default as the reference band for reprojecting 
    and resampling, which works for Landsat, but for Sentinel-2, the first band has a 10m resolution 
    which caused capacity issues when running, so switched to using bands_data[6] (20m resolution) as 
    the reference band. 

    Parameters:
    bands_data - list of xarray DataArrays for each band
    band_nms- list of strings of the band names
    satellite -  string denoting the satellite (e.g. LANDSAT_8, SENTINEL_2)

    Returns:
    bands_data - xarray dataset containing all bands
    """
    
    # Name the bands so they appear as named data variables in the xarray dataset
    bands_data = [ rename_bands(band_data, band_nms, i) for i,band_data in enumerate(bands_data) ] 

    # Create new time dimension
    bands_data = [bands_data[i].expand_dims(dim='time') for i in range(len(band_nms))]
   
    # Pick the reference band for each satellite
    ref_band = bands_data[0]

    bands_data = [ bands_data[i].drop('band').squeeze('band') for i in range(len(band_nms))]

    # Add fill value attribute to -9999
    bands_data = [ bands_data[i].assign_attrs({'_FillValue': -9999}) for i in range(len(band_nms)) ]

    # # Reproject and resample all bands to match the reference band
    # bands_data = [ bands_data[i].rio.reproject_match(ref_band, fill_value=-9999) for i in range(len(band_nms)) ]
    # print(f'After Reprojecting (Before Chunking): {bands_data}')

    # # Chunk the dataset with dask 
    # bands_data = [ bands_data[i].chunk({'x': 1000, 'y': 1000}) for i in range(len(band_nms)) ]
    # print(f'After Chunking: {bands_data}')

    # # Avoid coord differences due to floats 
    # bands_data = [bands_data[i].assign_coords({
    # "x": bands_data[0].x,
    # "y": bands_data[0].y,
    # "time": xr.DataArray([timestamp], dims='time')}) for i in range(len(band_nms))]

    # bands_data = [ xr.align(bands_data[0], bands_data[i], join="override", fill_value=-9999)[1] for i in range(len(band_nms)) ]

    print('Stacking the bands for each scene.')

    # Stack the bands
    bands_data = stack_arrays(bands_data, ref_band, time_handling = 'Bands', timestamp=timestamp)

    bands_data = xr.merge(bands_data, fill_value=-9999)
    
    # Add attributes from the original reference band
    attrs = ref_band.attrs
    bands_data = bands_data.assign_attrs(attrs)

    return bands_data


def stack_scenes(array_list):
    """
    Prepare to concatenate all the scenes together into a single xarray dataset. 
    
    The input is a list of xarray arrays, each containing a single scene's band data
    for the same tile. The arrays are first aligned to the first array in the list, 
    then the coordinates are copied across all the scenes so xarray is able to align 
    and concatenate properly without dealing with slight differences in coordinates due
    to floats. 
    """
    # Align time slices with each other before merging into single dataset 

    # array_list_aligned = [ array_list[i].rio.reproject_match(array_list[0], fill_value=-9999) for i in range(len(array_list)) ] 

    # array_list_aligned = [ array_list_aligned[i].chunk({'x': 1000, 'y': 1000}) for i in range(len(array_list_aligned)) ]

    # array_list_aligned = [array_list_aligned[i].assign_coords({
    #     "x": array_list_aligned[0].x,
    #     "y": array_list_aligned[0].y,
    #     "time": array_list_aligned[i].time}) for i in range(len(array_list_aligned))]
    
    # print(f'Dataset assigned coords: {array_list_aligned}')

    # array_list_aligned = [ xr.align(array_list_aligned[0], array_list_aligned[i], join="override", fill_value=-9999)[1] for i in range(len(array_list_aligned)) ]

    print('Stacking the scenes.')

    # Stack the arrays
    array_list_aligned = stack_arrays(array_list, array_list[0], time_handling='Scenes')


    # Turn any 0s (nodata) into -9999s
    array_list_aligned = [ array_list_aligned[i].where(array_list_aligned[i] != 0, -9999) for i in range(len(array_list_aligned)) ]

    # Concatenate the xarray datasets inside array_list into a single xarray dataset
    baseline_ds = xr.concat(array_list_aligned, dim='time', fill_value=-9999)

    baseline_ds = baseline_ds.assign_attrs({'_FillValue': -9999})

    return baseline_ds


def prep_dataset(in_dir, measurement, product):
    """Prepare either the baseline or analysis dataset."""
    scenes = glob.glob(f'{in_dir}/*/')

    array_list = []

    for scene in scenes:

        yml = f'{scene}/datacube-metadata.yaml'
        with open (yml) as stream: yml_meta = yaml.safe_load(stream)

        o_bands_data = [ rxr.open_rasterio(scene + yml_meta['image']['bands'][b]['path'], chunks=dask_chunks) for b in measurement ] 

        # Get the timestamp from the yaml file
        timestamp = datetime.strptime(yml_meta['extent']['center_dt'], '%Y-%m-%d %H:%M:%S')

        # Stack the bands together into a single xarray dataset
        band_data = stack_bands(o_bands_data, measurement, timestamp)

        # Append each stacked scene to a list to be combined later
        array_list.append(band_data)

    # Align and concatenate the data
    ds = stack_scenes(array_list)

    ds = ds.where(ds != -9999)
    print(f'FINAL DATASET {ds}')

    return ds
    

In [7]:
# Create the baseline dataset (this cell takes about a minute to run)
baseline_dir = '/home/spatialdays/Documents/testing-wofs/test_masking/Tile7572/BaselineData/'
baseline_ds = prep_dataset(baseline_dir, baseline_measurement, baseline_product)

Stacking the bands for each scene.
Using provided timestamp for each band.
Stacking the bands for each scene.
Using provided timestamp for each band.
Stacking the scenes.
Maintaining existing time dimension for each scene.
FINAL DATASET <xarray.Dataset>
Dimensions:      (time: 2, y: 7826, x: 8052)
Coordinates:
    spatial_ref  int64 0
  * x            (x) float64 176.2 176.2 176.2 176.2 ... 178.4 178.4 178.4 178.4
  * y            (y) float64 -16.29 -16.29 -16.29 -16.29 ... -18.4 -18.4 -18.4
  * time         (time) datetime64[ns] 2018-01-01 2018-01-01
Data variables:
    green        (time, y, x) float32 dask.array<chunksize=(1, 1000, 1000), meta=np.ndarray>
    red          (time, y, x) float32 dask.array<chunksize=(1, 1000, 1000), meta=np.ndarray>
    blue         (time, y, x) float32 dask.array<chunksize=(1, 1000, 1000), meta=np.ndarray>
    nir          (time, y, x) float32 dask.array<chunksize=(1, 1000, 1000), meta=np.ndarray>
    swir1        (time, y, x) float32 dask.array<chunk

In [8]:
# Create the analysis dataset (this cell takes about a minute to run)
analysis_dir = '/home/spatialdays/Documents/testing-wofs/test_masking/Tile7572/AnalysisData/'
analysis_ds = prep_dataset(analysis_dir, analysis_measurement, analysis_product)

Stacking the bands for each scene.
Using provided timestamp for each band.
Stacking the bands for each scene.
Using provided timestamp for each band.
Stacking the scenes.
Maintaining existing time dimension for each scene.
FINAL DATASET <xarray.Dataset>
Dimensions:      (time: 2, y: 7826, x: 8052)
Coordinates:
    spatial_ref  int64 0
  * x            (x) float64 176.2 176.2 176.2 176.2 ... 178.4 178.4 178.4 178.4
  * y            (y) float64 -16.29 -16.29 -16.29 -16.29 ... -18.4 -18.4 -18.4
  * time         (time) datetime64[ns] 2022-06-21 2022-06-21
Data variables:
    green        (time, y, x) float32 dask.array<chunksize=(1, 1000, 1000), meta=np.ndarray>
    red          (time, y, x) float32 dask.array<chunksize=(1, 1000, 1000), meta=np.ndarray>
    blue         (time, y, x) float32 dask.array<chunksize=(1, 1000, 1000), meta=np.ndarray>
    nir          (time, y, x) float32 dask.array<chunksize=(1, 1000, 1000), meta=np.ndarray>
    swir1        (time, y, x) float32 dask.array<chunk

In [9]:
baseline_ds = baseline_ds.where(baseline_ds != -9999)
analysis_ds = analysis_ds.where(analysis_ds != -9999)

## Cloud and Water Masking

In [None]:
# [TODO - insert masking code here]

## Perform Mosaic

In [10]:
# Copied this function from ard-docker-images - can't import datacube_utilities 

def create_median_mosaic(dataset_in, clean_mask=None, no_data=-9999, dtype=None, **kwargs):
    """
    Method for calculating the median pixel value for a given dataset.

    Parameters
    ----------
    dataset_in: xarray.Dataset
        A dataset retrieved from the Data Cube; should contain:
        coordinates: time, latitude, longitude
        variables: variables to be mosaicked (e.g. red, green, and blue bands)
    clean_mask: np.ndarray
        An ndarray of the same shape as `dataset_in` - specifying which values to mask out.
        If no clean mask is specified, then all values are kept during compositing.
    no_data: int or float
        The no data value.
    dtype: str or numpy.dtype
        A string denoting a Python datatype name (e.g. int, float) or a NumPy dtype (e.g.
        np.int16, np.float32) to convert the data to.

    Returns
    -------
    dataset_out: xarray.Dataset
        Compositited data with the format:
        coordinates: latitude, longitude
        variables: same as dataset_in
    """
    # # Default to masking nothing.
    # if clean_mask is None:
    #     clean_mask = create_default_clean_mask(dataset_in)

    # dataset_in_dtypes = None
    # if dtype is None:
    #     # Save dtypes because masking with Dataset.where() converts to float64.
    #     band_list = list(dataset_in.data_vars)
    #     dataset_in_dtypes = {}
    #     for band in band_list:
    #         dataset_in_dtypes[band] = dataset_in[band].dtype

    # # Mask out clouds and Landsat 7 scan lines.
    # dataset_in = dataset_in.where((dataset_in != no_data) & (clean_mask))
    dataset_out = dataset_in.mean(dim='time', skipna=True, keep_attrs=False)

    # Handle datatype conversions.
    #dataset_out = restore_or_convert_dtypes(dtype, band_list, dataset_in_dtypes, dataset_out, no_data)
    return dataset_out

def create_default_clean_mask(dataset_in):
    """
    Description:
        Creates a data mask that masks nothing.
    -----
    Inputs:
        dataset_in (xarray.Dataset) - dataset retrieved from the Data Cube.
    Throws:
        ValueError - if dataset_in is an empty xarray.Dataset.
    """
    data_vars = dataset_in.data_vars
    if len(data_vars) != 0:
        first_data_var = next(iter(data_vars))
        clean_mask = np.ones(dataset_in[first_data_var].shape).astype(bool)
        return clean_mask
    else:
        raise ValueError('`dataset_in` has no data!')

In [11]:
baseline_composite = create_median_mosaic(baseline_ds, clean_mask=None, no_data=np.nan)
analysis_composite = create_median_mosaic(analysis_ds, clean_mask=None, no_data=np.nan)

In [None]:
baseline_composite.red.compute()

## Spectral Indices

In [12]:
def NDVI(dataset):
    NDVI = (dataset.nir - dataset.red)/(dataset.nir + dataset.red)
    NDVI = NDVI.where(dataset.nir.notnull() & dataset.red.notnull())
    return NDVI

In [13]:
parameter_baseline_composite = NDVI(baseline_composite)
parameter_analysis_composite = NDVI(analysis_composite)

In [None]:
parameter_baseline_composite.compute()

In [None]:
parameter_analysis_composite

In [14]:
# Align and stack the baseline and analysis composites

array_list = [ parameter_baseline_composite, parameter_analysis_composite ]

# array_list = [ array_list[i].assign_attrs({'_FillValue': -9999}) for i in range(len(array_list)) ]

# array_list_aligned = [ array_list[i].rio.reproject_match(array_list[0], fill_value=-9999) for i in range(len(array_list)) ] 


# #array_list_aligned = [ array_list_aligned[i].chunk({'x': 1000, 'y': 1000}) for i in range(len(array_list_aligned)) ]

# array_list_aligned = [array_list_aligned[i].assign_coords({
#     "x": array_list_aligned[0].x,
#     "y": array_list_aligned[0].y}) for i in range(len(array_list_aligned))]

# array_list_aligned = [ xr.align(array_list_aligned[0], array_list_aligned[i], join="override", fill_value=-9999)[1] for i in range(len(array_list_aligned)) ]


array_list_aligned = stack_arrays(array_list, array_list[0], time_handling = '', timestamp=None)

# Turn any 0s (nodata) into -9999s
#array_list_aligned = [ array_list_aligned[i].where(array_list_aligned[i] != 0, -9999) for i in range(len(array_list_aligned)) ]


In [15]:
parameter_baseline_composite = array_list_aligned[0].where(array_list_aligned[0] != -9999)
parameter_analysis_composite = array_list_aligned[1].where(array_list_aligned[1] != -9999)

In [17]:
parameter_anomaly = (parameter_analysis_composite - parameter_baseline_composite).compute()

In [18]:
write_cog(geo_im=parameter_analysis_composite,
          fname='parameter_analysis_composite.tif',
          overwrite=True)

write_cog(geo_im=parameter_baseline_composite,
          fname='parameter_baseline_composite.tif',
          overwrite=True)

write_cog(geo_im=parameter_anomaly,
          fname='parameter_anomaly.tif',
          overwrite=True)

PosixPath('parameter_anomaly.tif')