# TimeSync Image Chip and Spectral Value Time Series Extraction Script
## This workflow demonstrates how to extract Landsat C2 ARD image and spectral data time series in AWS over reference plots using CHS Pangeo.
### Last Updated: 07-28-2023
#### Authors: Cole Krehbiel and Kelcy Smith (original), D. Wellington (revised)

The TimeSync software tool relies on pre-formatted image and spectral data loaded onto the appropriate location on the TimeSync server. This notebook generates these data files for subsequent download to the TimeSync server.

The data that must be generated for each TimeSync project are as follows:

1) A set of CSV files that contains the sensor ID, project ID, plot ID, tile ID, year, day-of-year, datetime, and individual pixel values derived from the blue, green, red, NIR, SWIR1, SWIR2, and QA bands. The QA value is represented as either 0 (clear) or 1 (cloudy). These csv files are output individually for each observation and combined in subsequent processing.

2) Three image composites, output as three-channel, 8 bits-per-channel PNGs, as follows:
      * A composite of NIR/red/green bands (equivalent to TM/ETM+ bands 4/3/2)
      * A composite of SWIR2/NIR/red bands (equivalent to TM/ETM+ bands 7/4/3)
      * A composite of tasseled cap brightness/greenness/wetness bands

Plots with fill at the plot pixel location are excluded from both csv and image chip generation. Fill outside the center pixel location in image chips is allowed.

The three composites (or "image chips") cover an area of 255x255 pixels, centered on the plot location. These locations must be provided in a CSV file with coordinates specified in the Landsat ARD coordinate grid. In all composites, fill values are set to black (0/0/0) in the output PNG.

## Instructions for Use

### 1. Update the parameters dictionary

The following parameter dictionary MUST be updated before running. The parameters are as follows:
- **project_dir**: The s3 bucket location and path-like prefix for objects created in this project. Make sure this is different for every project.
- **plot_file**: A csv file that contains these headers: project_id, plot_id, x, y. This is assumed to be a local file. Use the file browser in the JuypterLab environment to upload the file, and specify the path using this parameter.
- **region**: Only 'CU' (for CONUS) is accepted by default, but other regions could be enabled by adding them to the `tile_grid_affine()` function
- **chip_size**: For default TimeSync, leave this as [255, 255] (width, height in pixels)
- **process_on**: Options are 'kube_cluster' or 'local'. The 'local' option is for debugging, you generally want 'kube_cluster'. Subsequent instructions assume you are operating on the cluster; otherwise, skip step 3 and do not try to pass a client object at execution.
- **store_on**: Options are 'aws_s3' or 'local'. The 'local' option is for debugging, you generally want 'aws_s3'.
- **profile**: This must be a the name of an AWS profile with write access to the project_dir bucket, if writing to s3 (if not writing to an s3 bucket, this parameter is optional and ignored). Profiles with access credentials are listed in ~/.aws/credentials. if you have never written to an AWS S3 bucket, you may need to request access from CHS and follow AWS configuration instructions.

<div class="alert-warning">
Update the next cell.
</div>

In [1]:
params = {
    'project_dir': 's3://dev-nlcd-developer/junk/timesync/', 
    'plot_file': './10lines_PlotList.csv', 
    'region': 'CU', 
    'chip_size': [255, 255],
    'process_on': 'local', 
    'store_on': 'aws_s3', 
    'profile': 'default',
}

### some of this AWS authentication stuff can be greatly simplified with %env or os.environment 
- like the requester pays bs

In [2]:
%env AWS_REQUEST_PAYER=requester

env: AWS_REQUEST_PAYER=requester


In [3]:
#! aws s3 ls | grep dev

### 2. Import libraries and define functions

Run the following cell, which contains all library imports and locally defined functions for data extraction.

In [4]:
import os
import csv
import time
import random
import itertools
import configparser
from copy import copy
from dataclasses import dataclass
from datetime import datetime as dt
from functools import partial, reduce, wraps
from typing import List, Tuple, Optional, Any, Callable, Iterable

import s3fs
import tqdm
import boto3
import fsspec
import numpy as np
import pandas as pd
import pystac_client
import rasterio as rio
from dask.distributed import as_completed, worker_client, Client
from dask.distributed.client import Future
from fsspec.implementations.local import LocalFileSystem

Affine = Tuple[float, float, float, float, float, float]


# Constants
QA_FILL = 0
QA_CLEAR = 6
QA_WATER = 7
CONCURRENT_STAC_QUERIES = 2  # Prevents workers from consuming too much memory
LANDSAT_ARD_C2_FILL_VALUE = 0


@dataclass(frozen=True)
class Bounds:
    """
    Class to hold spatial coordinate bounds
    """
    min_x: float
    max_x: float
    min_y: float
    max_y: float


class StacRecord:
    """
    Class with methods for parsing STAC records
    """

    @staticmethod
    def sensor(record: dict) -> str:
        """
        Parse the sensor shorthand abbreviation at the beginning of the observation id
        """
        return record['id'].split('_')[0]

    @staticmethod
    def hv(record: dict) -> str:
        """
        Retrieve the tile coordinates as a string patterned like hhvv
        """
        return record['properties']['landsat:grid_horizontal'] + record['properties']['landsat:grid_vertical']

    @staticmethod
    def date(record: dict) -> str:
        """
        Retrieve the datetime string
        """
        return record['properties']['datetime']

    @staticmethod
    def year(record: dict) -> int:
        """
        Retrieve the observation year
        Note: Do not try to use '%Y-%m-%dT%H:%M:%S.%fZ' as the format code, it will not work
            because the microseconds are not always six digits
        """
        date = record['properties']['datetime']
        return dt.strptime(date.split('.')[0], '%Y-%m-%dT%H:%M:%S').year

    @staticmethod
    def doy(record: dict) -> int:
        """
        Retrieve the observation DOY
        Note: Do not try to use '%Y-%m-%dT%H:%M:%S.%fZ' as the format code, it will not work
            because the microseconds are not always six digits
        """
        date = record['properties']['datetime']
        return dt.strptime(date.split('.')[0], '%Y-%m-%dT%H:%M:%S').timetuple().tm_yday

    @staticmethod
    def year_doy(record: dict) -> str:
        """
        Retrieve a formatted year/doy string
        """
        return f'{StacRecord.year(record)}_{StacRecord.doy(record):03}'

    @staticmethod
    def asset_href(record: dict, band: str) -> str:
        """
        Retrieve the s3 location of a STAC asset
        """
        return record['assets'][band]['alternate']['s3']['href']

    @staticmethod
    def crs(record: dict) -> str:
        """
        Retrieve the coordinate reference system from a STAC record as WKT
        """
        return record['properties']['proj:wkt2']
    
    
def retry(retries: int, jitter: Tuple[int, int] = (1, 15)) -> Callable:
    """
    Simple retry decorator, for retrying any function that may throw an exception
    such as when trying to retrieve network resources
    """
    def retry_dec(func: Callable) -> Callable:
        def wrapper(*args, **kwargs):
            count = 1
            while True:
                try:
                    return func(*args, **kwargs)
                except Exception:
                    count += 1
                    if count > retries:
                        raise
                    time.sleep(random.randint(*jitter))
        return wrapper
    return retry_dec


def timesync_band_list() -> List[str]:
    """
    Get the bands relevant to TimeSync data retrieval
    """
    return ['blue', 'green', 'red', 'nir08', 'swir16', 'swir22', 'qa_pixel']


def landsat_optical_band_list() -> List[str]:
    """
    Get the optical wavelength Landsat bands
    """
    return ['blue', 'green', 'red', 'nir08', 'swir16', 'swir22']


def centered_window(x: float, y: float, width: int, height: int, ds: rio.io.DatasetReader) -> rio.windows.Window:
    """
    Create a window centered on the x, y of a pixel of interest
    Width and height are in pixels
    """
    row_offset, col_offset = ds.index(x, y)
    return rio.windows.Window(
        col_offset - (width // 2),
        row_offset - (height // 2),
        width,
        height)


def centered_bounds(x: float, y: float, width: int, height: int, pixel_size: int = 30):
    """
    Create coordinate bounds centered on an x/y coordinate
    """
    return Bounds(
        min_x=x - ((width // 2) * pixel_size),
        max_x=x + ((width // 2) * pixel_size),
        min_y=y - ((height // 2) * pixel_size),
        max_y=y + ((height // 2) * pixel_size))


def single_pixel_window(x: float, y: float, ds: rio.io.DatasetReader) -> rio.windows.Window:
    """
    Get a window representing a single pixel
    """
    row_offset, col_offset = ds.index(x, y)
    return rio.windows.Window(col_offset, row_offset, 1, 1)


def build_query(h: int, v: int, region: str = 'CU', collection: str = 'landsat-c2ard-sr',
                datetime: str = '1984-01/2022-12-31', limit: Optional[int] = None) -> dict:
    """
    Construct a STAC query based on h/v tile coordinates
    """
    return {
        'collections': collection,
        'datetime': datetime,
        'limit': limit,
        'query': {'landsat:grid_horizontal': {'eq': f'{h:02}'},
                  'landsat:grid_vertical': {'eq': f'{v:02}'},
                  'landsat:grid_region': {'eq': region}}}


def tile_grid_affine(region: str) -> Affine:
    """
    Get the ARD tile grid affine based on the regional code
    """
    return {
        'CU': (-2565585, 150000, 0, 3314805, 0, -150000),  # CONUS
    }[region]


def transform_geo(x: float, y: float, affine: Affine) -> Tuple[int, int]:
    """
    Perform the affine transformation from an x/y coordinate to row/col space.
    """
    col = (x - affine[0] - affine[3] * affine[2]) / affine[1]
    row = (y - affine[3] - affine[0] * affine[4]) / affine[5]
    return int(col), int(row)


def determine_hv(x: float, y: float, region: str) -> Tuple[int, int]:
    """
    Determine the ARD tile (in h/v coordinates) containing the x/y coordinate
    """
    h, v = transform_geo(x, y, tile_grid_affine(region))
    return h, v


def determine_hvs(bbox, region: str) -> itertools.product:
    """
    Determine the h/v coordinates of tiles that intersect a bounding box
    """
    min_h, min_v = determine_hv(bbox.min_x, bbox.max_y, region)
    max_h, max_v = determine_hv(bbox.max_x, bbox.min_y, region)
    return itertools.product(range(min_h, max_h + 1), range(min_v, max_v + 1))


def query_stac(query_params: dict) -> dict:
    """
    Query the STAC catalog using the provided query parameters
    """
    stac = pystac_client.Client.open('https://landsatlook.usgs.gov/stac-server')
    # This returns a dictionary with two keys, 'type' and 'features'
    results = stac.search(**query_params).item_collection_as_dict()
    # 'type' only contains the value 'FeatureCollection'; we care about what is in 'features'
    return results['features']


def group_dicts(records: List[dict], key_func: Callable) -> itertools.groupby:
    """
    Group a list of dictionaries based on key value
    """
    records = sorted(records, key=key_func)
    return itertools.groupby(records, key=key_func)


def convert_sr(data: np.ndarray) -> np.ndarray:
    """
    Re-scale Landsat Collection 2 spectral data values back to the Collection 1 range
    """
    return ((data.astype(float) * 0.0000275 - 0.2) * 10000).astype(np.int16)


def data_to_collection_1(old_dict: dict, bands: List[str]) -> dict:
    """
    Convert a dictionary of surface reflectance bands to the Landsat Collection 1 numerical range
    """
    new_dict = old_dict.copy()
    for key in bands:
        new_dict[key] = convert_sr(old_dict[key])
    return new_dict


def read_bands(record: dict, bands: List[str], plot: Tuple[Any, ...], width: int, height: int) -> dict:
    """
    Read in an ROI for an observation in the STAC record for all bands
    """
    out = {}
    for band in bands:
        with rio.open(StacRecord.asset_href(record, band)) as ds:
            window = centered_window(plot.x, plot.y, width, height, ds)
            # Get a masked array and fill it to avoid a bug with gdal/rasterio
            out[band] = ds.read(1, window=window, boundless=True, fill_value=0, masked=True).filled()
    return out


def read_qa_at_plot(record: dict, plot: Tuple[Any, ...]) -> int:
    """
    Read a single pixel at the plot location
    """
    with rio.open(StacRecord.asset_href(record, 'qa_pixel')) as ds:
        window = single_pixel_window(plot.x, plot.y, ds)
        out = ds.read(1, window=window, boundless=False)
    if out.size == 0:
        return 1  # Treat values outside the spatial extent as fill
    return out.item()


def add_bands(dict_a: dict, dict_b: dict) -> dict:
    """
    Combine two observation dictionaries by adding the values for each band
    """
    out = {}
    for band in dict_a:
        out[band] = dict_a[band] + dict_b[band]
    return out


def composite(data: dict, bands: List[str], axis: int = 0) -> np.ndarray:
    """
    Create a multi-band ndarray from band names
    """
    return np.stack([data[band] for band in bands], axis=axis)


def tasseled_cap(data: dict) -> np.ndarray:
    """
    Create a composite of tasseled cap values
    """
    band_order = ['blue', 'green', 'red', 'nir08', 'swir16', 'swir22']  # Must match coefficient order below
    arr = np.stack([data[band] for band in band_order], axis=2)
    b = np.tensordot(arr, [0.2043, 0.4158, 0.5524, 0.5741, 0.3124, 0.2303], axes=1)  # Brightness
    g = np.tensordot(arr, [-0.1603, -0.2819, -0.4934, 0.7940, -0.0002, -0.1446], axes=1)  # Greenness
    w = np.tensordot(arr, [0.0315, 0.2021, 0.3102, 0.1594, -0.6806, -0.6109], axes=1)  # Wetness
    return np.stack([b, g, w])


def build_affine(x_off: float, y_off: float, x_size: float = 30, y_size: float = 30, x_shear: float = 0,
                 y_shear: float = 0) -> rio.Affine:
    """
    Build the affine tuple in the rasterio format (different from GDAL)
    """
    return rio.Affine(x_size, x_shear, x_off, y_shear, -y_size, y_off)


def write_to_png(file: str, array: np.ndarray, crs: str, transform: rio.Affine) -> None:
    """
    Write a PNG file as three 8-bit channels
    """
    profile = {
        'driver': 'PNG',
        'count': 3,
        'nodata': None,
        'crs': crs,
        'transform': transform,
        'height': array.shape[1],
        'width': array.shape[2],
        'dtype': np.uint8}

    with rio.open(file, mode='w', **profile) as ds:
        ds.write(array)


def array_mask(array: np.ndarray, value_to_mask = None, axis: int = 0) -> np.ndarray:
    """
    Boolean mask where the array matches the provided value anywhere along an axis
    """
    return (array == value_to_mask).any(axis=axis)


def apply_mask(array: np.ndarray, mask_array: np.ndarray, mask_value: float) -> np.ndarray:
    """
    Apply a Boolean mask 
    """
    arr = array.copy()
    arr[mask_array] = mask_value
    return arr


def byte_scale(array: np.ndarray, min_value: float, max_value: float) -> np.ndarray:
    """
    Scale the data between min_value and max_value to 0-255
    """
    out_array = (255 / (max_value - min_value)) * (array - min_value)
    out_array = np.minimum(out_array, 255)
    out_array = np.maximum(out_array, 0)
    return out_array.astype(np.uint8)


def byte_scale_bands(array: np.ndarray, all_bounds: List[Tuple[int, int]], 
                     mask: Optional[np.ndarray] = None, axis: int = 0) -> np.ndarray:
    """
    Convert a multi-band array to scaled 8-bit
    Masked values are set to 0
    """
    out = []
    for i, (min_value, max_value) in enumerate(all_bounds):
        byte_image = byte_scale(array.take(i, axis=axis), min_value, max_value)
        out.append(apply_mask(byte_image, mask, mask_value=0))
    return np.stack(out, axis=axis)


def center(array: np.ndarray) -> Tuple[int, ...]:
    """
    Get the indices for the center of an array
    """
    return tuple(x // 2 for x in array.shape)


def center_value(array: np.ndarray) -> int:
    """
    Get the center value of an array
    """
    return array[center(array)]


def spectral_data(data: dict) -> dict:
    """
    Get the data for the center pixel in the chip
    """
    return {band: center_value(array) for band, array in data.items()}


def df_to_csv(df: pd.DataFrame, params: dict, output: dict) -> None:
    """
    Write the dataframe to the csv file
    """
    with params['fs'].open(output['scsv'], 'w') as f:
        df.to_csv(f, index=False)


def classify_qa(qa: int) -> int:
    """
    Return a value indicating the pixel is clear/water (0) or fill/cloud (1)
    """
    if passes_qa_check(qa, enable_cloud_filtering=True):
        return 0
    return 1


def build_df(pixel_data: dict, record: dict, project_id: str, plot_id: str) -> pd.DataFrame:
    """
    Build the output dataframe
    """
    return pd.DataFrame({
        'sensor': StacRecord.sensor(record),
        'project_id': project_id,
        'plot_id': plot_id,
        'hv': StacRecord.hv(record),
        'year': StacRecord.year(record),
        'doy': StacRecord.doy(record),
        'blue': pixel_data['blue'],
        'green': pixel_data['green'],
        'red': pixel_data['red'],
        'nir': pixel_data['nir08'],
        'swir1': pixel_data['swir16'],
        'swir2': pixel_data['swir22'],
        'qa': classify_qa(pixel_data['qa_pixel']),
        'date': StacRecord.date(record)}, index=[0])


def invalid_pixel() -> dict:
    """
    Get band values to represent and invalid pixel
    """
    return {
        'blue': 0,
        'green': 0,
        'red': 0,
        'nir08': 0,
        'swir16': 0,
        'swir22': 0,
        'qa_pixel': 1}


def adjust_for_s3(in_dict, filesystem) -> dict:
    """
    Adjust the output file names to use the /vsis3/ file system handler for image data
    """
    out_dict = copy(in_dict)
    if isinstance(filesystem, s3fs.core.S3FileSystem):
        for key in in_dict:
            if in_dict[key].endswith('.png') or in_dict[key].endswith('.tif'):
                out_dict[key] = in_dict[key].replace('s3:/', '/vsis3')
    return out_dict


def process_group(group: List[dict], plot: Tuple[Any, ...], params: dict) -> None:
    """
    Process a group of STAC records associated with a plot into output PNGs
    """
    with rio.Env(rio.session.AWSSession(boto3.Session(), requester_pays=True), 
                 GDAL_DISABLE_READDIR_ON_OPEN='EMPTY_DIR', 
                 GDAL_HTTP_MAX_RETRY=10,
                 GDAL_HTTP_RETRY_DELAY=3):

        # Set up the output filenames and (as applicable) directories
        output = adjust_for_s3(
            output_files(params['project_dir'], plot.project_id, plot.plot_id, StacRecord.year_doy(group[0])),
            params['fs'])
        for out in output.values():
            make_dirs(params['fs'], out)

        # Determine if the center pixel is fill
        if not any(passes_qa_check(read_qa_at_plot(record, plot)) for record in group):
            # Optionally, write an entry for an invalid pixel. This had been previous functionality,
            # but because TimeSync does not expect this entry, I am disabling it.
            # df_to_csv(build_df(invalid_pixel(), group[0], plot.project_id, plot.plot_id), params, output)
            return

        # Read and combine the data records
        data_stack = [
            read_bands(observation, timesync_band_list(), plot, params['chip_size'][0], params['chip_size'][1]) for
            observation in group]
        combined_data = data_to_collection_1(reduce(add_bands, data_stack), landsat_optical_band_list())

    with rio.Env(**params['rio_env']):

        # Get geospatial attributes
        bounds = centered_bounds(plot.x, plot.y, params['chip_size'][0], params['chip_size'][1])
        aff = build_affine(bounds.min_x, bounds.max_y)
        crs = StacRecord.crs(group[0])
        
        # Mask for fill values
        fill_value = convert_sr(np.array(LANDSAT_ARD_C2_FILL_VALUE))
        mask = array_mask(composite(combined_data, landsat_optical_band_list()), fill_value)

        # Calculate the output chip data
        output_tcap = tasseled_cap(combined_data)
        output_b743 = composite(combined_data, bands=['swir22', 'nir08', 'red'])
        output_b432 = composite(combined_data, bands=['nir08', 'red', 'green'])

        # Stretch the chip data and write to PNG
        write_to_png(output['tcap'], byte_scale_bands(output_tcap, [(604, 5592), (49, 3147), (-2245, 843)], mask), crs, aff)
        write_to_png(output['b743'], byte_scale_bands(output_b743, [(-904, 3696), (151, 4951), (-300, 2500)], mask), crs, aff)
        write_to_png(output['b432'], byte_scale_bands(output_b432, [(151, 4951), (-300, 2500), (50, 1150)], mask), crs, aff)

        # Define the metadata and spectral data for this observation and export to a csv
        df_to_csv(build_df(spectral_data(combined_data), group[0], plot.project_id, plot.plot_id), params, output)


def group_records(records: List[dict]) -> List[List[dict]]:
    """
    Group records based on the observation ID
    """
    out = []
    for _, group in group_dicts(records, StacRecord.year_doy):
        out.append(list(group))
    return out


def output_files(project_dir: str, project_id: str, plot_id: str, year_doy: str) -> dict:
    """
    Build the file names for the output files for TimeSync
    """
    return {
        'scsv': os.path.join(project_dir, f'prj_{project_id}/{plot_id}_spectral_files_set_{year_doy}.csv'),
        'tcap': os.path.join(project_dir, f'prj_{project_id}/tc/plot_{plot_id}/plot_{plot_id}_{year_doy}.png'),
        'b743': os.path.join(project_dir, f'prj_{project_id}/b743/plot_{plot_id}/plot_{plot_id}_{year_doy}.png'),
        'b432': os.path.join(project_dir, f'prj_{project_id}/b432/plot_{plot_id}/plot_{plot_id}_{year_doy}.png')}


def report_status(func: Callable) -> Callable:
    @wraps(func)
    def wrapper(plot: Tuple[Any, ...], *args, **kwargs) -> Tuple[Tuple[Any, ...], str]:
        """
        Return the plot and any exception raised, or report complete
        """
        try:
            func(plot, *args, **kwargs)
            return plot, 'complete'
        except Exception as error:
            return plot, str(error)

    return wrapper


@report_status
def process_plot(plot: Tuple[Any, ...], params: dict) -> None:
    """
    Process an individual plot
    """
    groups = group_records(stac_records_for_plot(plot, params))
    for group in groups:
        process_group(group, plot, params)


@report_status
def process_plot_dask(plot: Tuple[Any, ...], params: dict) -> None:
    """
    Process an individual plot
    """
    groups = group_records(stac_records_for_plot(plot, params))
    func = partial(process_group, plot=plot, params=params)
    with worker_client() as client:
        futures = client.map(func, groups)
        try:
            client.gather(futures)
        except Exception:
            client.cancel(futures)
            raise


def check_bit(value: int, bit: int) -> bool:
    """
    Check whether a bit is set
    """
    return bool((value & (1 << bit)))


def passes_qa_check(qa: int, enable_cloud_filtering=False) -> bool:
    """
    Make sure the QA value is not indicating fill and (optionally) ensure clear or water bits are set
    """
    if check_bit(qa, QA_FILL):
        return False
    if enable_cloud_filtering and not (check_bit(qa, QA_CLEAR) or check_bit(qa, QA_WATER)):
        return False
    return True

In [5]:
def make_dirs(fs, file: str) -> None:
    """
    Create parent directories if it makes sense to do so
    """
    if isinstance(fs, LocalFileSystem):
        fs.makedirs(os.path.dirname(file), exist_ok=True)


def stac_records_for_plot(plot: Tuple[Any, ...], params: dict) -> List[dict]:
    """
    Retrieve the stac records relevant for the plot
    """
    query_results = []
    roi = centered_bounds(plot.x, plot.y, params['chip_size'][0], params['chip_size'][1])
    for h, v in determine_hvs(roi, region=params['region']):
        query_results.extend(query_stac(build_query(h, v, region=params['region'])))
    return query_results


def format_plot_data(plot_file: str) -> pd.DataFrame:
    """
    Read in the csv file containing geospatial plot data
    """
    return pd.read_csv(
        plot_file,
        usecols=['project_id', 'plot_id', 'x', 'y'],
        dtype={'project_id': str, 'plot_id': str, 'x': int, 'y': int})


def format_log_data(log_file: str) -> pd.DataFrame:
    """
    Read in the csv file containing a record of previous run(s)
    """
    return pd.read_csv(
        log_file,
        usecols=['project_id', 'plot_id', 'time', 'status'],
        dtype={'project_id': str, 'plot_id': str, 'time': str, 'status': str})


def aws_credentials(profile: str) -> Tuple[str, str]:
    """
    Fetch information on AWS credentials
    """
    # parser = configparser.ConfigParser()
    # parser.read(os.path.join(os.environ['HOME'], '.aws', 'credentials'))
    # return parser[profile]['aws_access_key_id'], parser[profile]['aws_secret_access_key']
    return 


def local_setup(*args) -> dict:
    """
    Extra setup for writing output to a local file system
    """
    return {
        'fs': fsspec.filesystem('file'),
        'rio_env': {
            'session': None, 
            'GDAL_PAM_ENABLED': 'NO',  # Set to 'YES' to write XML metadata
        }}


def aws_setup(profile: str, *args) -> dict:
    """
    Extra setup for writing to an S3 bucket
    """
    # key, secret = aws_credentials(profile)
    return {
        # 'fs': fsspec.filesystem('s3', key=key, secret=secret),
        'fs': fsspec.filesystem('s3', anon=False, requester_pays=True),
        'rio_env': {
            'session': rio.session.AWSSession(),
            'GDAL_DISABLE_READDIR_ON_OPEN': 'EMPTY_DIR',
            'GDAL_PAM_ENABLED': 'NO',  # Set to 'YES' to write XML metadata
        }}

In [6]:
def append_to_csv(entry: list, csv_file: str) -> None:
    """
    Append a line to a csv file
    """
    with open(csv_file, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(entry)


def log_plot_status(plot: Tuple[Any, ...], status: str, log_file: str) -> None:
    """
    Write plot status to a log file
    """
    if not os.path.exists(log_file) or not (os.path.getsize(log_file) > 0):
        append_to_csv(['project_id', 'plot_id', 'time', 'status'], log_file)
    append_to_csv([plot.project_id, plot.plot_id, dt.now(), status], log_file)


def data_preparation(plot_file: str, log_file: str) -> Tuple[pd.DataFrame, int, int]:
    """
    Read in the plot geolocation information and prior processing history
    """
    # Read in the plot data
    plots_df = format_plot_data(plot_file)

    if os.path.exists(log_file) and (os.path.getsize(log_file) > 0):
        log_df = format_log_data(log_file)

        # Get the most recent status from any previous processing run
        df = plots_df.merge(
            log_df.drop_duplicates(subset='plot_id', keep='last'),
            how='left', on=['project_id', 'plot_id'])

    else:
        df = plots_df.copy().reindex(columns=plots_df.columns.tolist() + ['status'])

    n_total, n_completed = len(df), len(df[df.status == 'complete'])
    plots_to_process = df.loc[df.status != 'complete', plots_df.columns]

    return plots_to_process, n_completed, n_total


def submit_next_job(client: Client, func: Callable, plots: Iterable) -> Optional[Future]:
    """
    Run the processing of the next plot as a job on the cluster
    Set the priority of the job to favor completing plots in order
    """
    try:
        plot = next(plots)
        future = client.submit(func, plot, priority=-int(plot.plot_id))
    except StopIteration:
        future = None
    return future


def log_file_name(params: dict) -> str:
    """
    Define the output log file
    """
    return os.path.splitext(params['plot_file'])[0] + '.log'


def process_on_kube(params: dict) -> None:
    """
    Process a group of plots on the CHS Pangeo Kubernetes cluster
    """
    # Get input data
    plots_df, n_completed, n_total = data_preparation(params['plot_file'], log_file_name(params))
    if n_completed == n_total:
        print(f'All {n_total} plots processed successfully! Exiting...')

    else:

        # Get the client and dashboard link
        client = params.pop('client')
        print('Dashboard: https://pangeo.chs.usgs.gov' + client.dashboard_link)

        # Define the processing function
        processing_func = partial(process_plot_dask, params=params)

        # Prepare to iterate over plots and hold resulting futures
        plots, futures = plots_df.itertuples(), []

        # Submit the first n tasks; others will be submitted as each task completes
        for _ in range(CONCURRENT_STAC_QUERIES):
            future = submit_next_job(client, processing_func, plots)
            if future is not None:
                futures.append(future)

        # Prepare to process results as they complete
        watch_for_completion = as_completed(futures)

        # Track plot completion and handle subsequent task submissions
        with tqdm.tqdm(desc='Processing plots', initial=n_completed, total=n_total) as pbar:
            for completed in watch_for_completion:

                # Log plot completion (or error message)
                plot, status = completed.result()
                log_plot_status(plot, status, log_file_name(params))
                pbar.update()

                # Submit the next plot to the cluster
                future = submit_next_job(client, processing_func, plots)
                if future is not None:
                    watch_for_completion.add(future)


def process_on_local(params: dict) -> None:
    """
    Local single-threaded processing
    """
    # Get input data
    plots_df, n_completed, n_total = data_preparation(params['plot_file'], log_file_name(params))
    if n_completed == n_total:
        print(f'All {n_total} plots processed successfully! Exiting...')
        return

    # Define the processing function
    processing_func = partial(process_plot, params=params)

    # Iterate over the plots
    for plot in tqdm.tqdm(plots_df.itertuples(), desc='Processing plots', initial=n_completed, total=n_total):
        plot, status = processing_func(plot)
        log_plot_status(plot, status, log_file_name(params))

In [7]:
def timesync_data_extraction(project_dir: str, plot_file: str, region: str, chip_size: List[int], process_on: str,
                             store_on: str, profile: Optional[str] = None, client: Optional[Client] = None) -> None:
    """
    Run TimeSync data extraction
    """
    params = locals()

    storage = {
        'local': local_setup,
        'aws_s3': aws_setup,
    }[store_on]

    process = {
        'kube_cluster': process_on_kube,
        'local': process_on_local,
    }[process_on]

    params.update(storage(profile))
    process(params)

In [None]:
timesync_data_extraction(**params)  # docker and the cluster will not need dask
# timesync_data_extraction(**params, client=client)

Processing plots:   0%|                                 | 0/10 [00:00<?, ?it/s]

In [None]:
#! head -11 TxL2Test_PlotList.csv >10lines_PlotList.csv
! cat 10lines_PlotList.csv

### 3. Start the cluster and get the client object

See the three red rectangles on the left Jupyterlab sidebar? Click that button, and then at the bottom click on +NEW. At the bottom of the sidebar, you will see something like "KubeCluster 1". Click "scale" and choose some number of workers to utilize (max is 200), either by adaptive or manual scaling. The processing will go faster the more workers are assigned. Click the <> button at the bottom to get a code cell with the tcp address of the scheduler. It will appear above the cell you have currently highlighted. Click the cell below before clicking the button if you want it to appear under this text. Run that cell before proceeding.

<div class="alert-warning">
Insert a cell above using the dask interface
</div>

### 4. Run the TimeSync extraction script (!)

At this point, the cluster and client should exist but do not yet have any tasks (there may or may not be workers; if using adaptive scaling, additional workers may appear only after the scheduler has received a sufficient number of tasks to keep them busy. Running the cell below will start submitting tasks to the scheduler. There are two types of tasks. The `process_plot_dask` task handles the processing of a single plot and launches large numbers of `process_group` tasks which write the actual data files for each observation. Additional `process_plot_dask` tasks will appear successively as plots are completed.

In [None]:
timesync_data_extraction(**params, client=client)

### If something goes wrong...

Progress through the plots is tracked by a log file that is written to the same directory as the plot file (it will have the same name as the plot file and end in .log). If the script stops running or you lose the connection to the notebook, simply re-run all cells in the notebook (note: you may have to start a new cluster and retrieve a new client object) and the progress bar should come back to exactly where it left off. Alternatively, if you do want to start over completely, find the .log file in the aforementioned directory and delete/move/rename it. 


### On completion...

Even if everything seems okay, re-run the cell above until it reports that all plots were processed successfully to ensure that no errors occurred. You should get a message like "All n plots processed successfully! Exiting..."