In [1]:
import datetime as dt
from functools import partial
import logging
import multiprocessing as mp
import sys

import numpy as np
import rasterio as rio
from rasterio.windows import Window
import fsspec

# from lcnext import lsc2ard as c2ard
# from lcnext.lsc2ard import find_observations
# from lcnext.lsc2ard import LandsatARDObservation
from lcnext.static import LANDSAT_SENSORS
from fsspec import AbstractFileSystem

from typing import List

In [2]:
"""
Functionality specifically targetted towards Landsat Collection 2 ARD data, stored with
specific pathing.
"""

import datetime as dt
from dataclasses import dataclass
import re

from rasterio.transform import rowcol

from lcnext.static import LS_ARD_C2_CU_TileAff
from lcnext.static import LS_ARD_C2_HI_TileAff
from lcnext.static import LS_ARD_C2_AK_TileAff
from lcnext.static import LS_ARD_C1_SR_MULT
from lcnext.static import LS_ARD_C1_SR_ADD
from lcnext.static import LS_ARD_C1_SR_NODATA
from lcnext.static import LS_ARD_C1_BT_MULT
from lcnext.static import LS_ARD_C1_BT_ADD
from lcnext.static import LS_ARD_C1_BT_NODATA
from lcnext.static import LS_ARD_C1_TOA_MULT
from lcnext.static import LS_ARD_C1_TOA_ADD
from lcnext.static import LS_ARD_C1_TOA_NODATA
from lcnext.static import LS_ARD_C2_SR_MULT
from lcnext.static import LS_ARD_C2_SR_ADD
from lcnext.static import LS_ARD_C2_SR_NODATA
from lcnext.static import LS_ARD_C2_BT_MULT
from lcnext.static import LS_ARD_C2_BT_ADD
from lcnext.static import LS_ARD_C2_BT_NODATA
from lcnext.static import LS_ARD_C2_TOA_MULT
from lcnext.static import LS_ARD_C2_TOA_ADD
from lcnext.static import LS_ARD_C2_TOA_NODATA

# Typing imports
from typing import List
from typing import Tuple

from fsspec import AbstractFileSystem
import numpy as np
from rasterio import Affine


C2ARD_PATTERN = re.compile(r"(?P<platform>L[CEOT]0[45789])_"
                           r"(?P<region>[A-Z]{2})_"
                           r"(?P<horiz>[0-9]{3})(?P<vert>[0-9]{3})_"
                           r"(?P<acquired>[0-9]{8})_"
                           r"(?P<prod_date>[0-9]{8})_02")


@dataclass(frozen=True, slots=True)
class LandsatARDObservation:
    """
    Keep track of the details for a particular Landsat ARD observation stored somewhere.
    """
    root_id: str
    std_path: str
    acquired: dt.datetime.date
    region: str
    platform: str
    sensor: str
    horiz: int
    vert: int
    tileid: str
    prod_date: dt.datetime.date


def id_search(path_or_id: str) -> dict:
    """
    Apply the C2 ARD base id regex to search for and identify characteristics from the name
    """
    match = C2ARD_PATTERN.search(path_or_id)

    if not match:
        raise ValueError(f'No matching C2 ARD string in {path_or_id}')

    info = match.groupdict()
    info['root_id'] = match.group()
    info['sensor'] = platform_to_sensor(info['platform'])

    return info

def platform_to_sensor(platform: str) -> str:
    """
    Match the platform ID to the onboard sensor ID
    LC09, LC08, LT05, LE07, LT04
    """
    match platform:
        case 'LC09' | 'LC08':
            return 'oli-tirs'
        case 'LE07':
            return 'etm'
        case 'LT05' | 'LT04':
            return 'tm'
        case _:
            raise ValueError(f'Platform not recognized: {platform}')


def std_obs_path(sensor: str, year: str | int, region: str, horiz: str, vert: str, root_id: str):
    """
    Build the standard C2 ARD root pathing based on the given observation information
    """
    return f'usgs-landsat-ard/collection02/{sensor}/{year}/{region}/{horiz}/{vert}/{root_id}'


def obs_deets(path_or_id: str) -> LandsatARDObservation:
    """
    Build the LandsatARDObservation object based around how Landsat ARD is stored on S3.
    """
    details = id_search(path_or_id)
    
    print(f'  Running obs_deets:{path_or_id}')

    return LandsatARDObservation(root_id=details['root_id'],
                                 std_path=std_obs_path(details['sensor'],
                                                       details['acquired'][:4],
                                                       details['region'],
                                                       details['horiz'],
                                                       details['vert'],
                                                       details['root_id']),
                                 acquired=dt.datetime.strptime(details['acquired'], '%Y%m%d').date(),
                                 region=details['region'],
                                 platform=details['platform'],
                                 sensor=details['sensor'],
                                 horiz=int(details['horiz']),
                                 vert=int(details['vert']),
                                 tileid=''.join([details['horiz'], details['vert']]),
                                 prod_date=dt.datetime.strptime(details['prod_date'], '%Y%m%d').date())

def year_deets(fs: AbstractFileSystem, year: int, sensor: str, region: str, horiz: int, vert: int) -> List[LandsatARDObservation]:
    """
    Build out the details for each observation for a given year/sensor
    """
    print('  Running year_deets')
    
    return [obs_deets(p)
            for p in
            fs.ls(f'usgs-landsat-ard/collection02/{sensor}/{year}/{region}/{horiz:03}/{vert:03}') if not p.endswith('.json')] 
                                                                                ##since each dir has a JSON file causing errors

def find_observations(fs: AbstractFileSystem, start_date: dt.datetime.date, end_date: dt.datetime.date, sensor: str, region: str, horiz: int, vert: int):
    """
    Find all observations to fit within the specified start_date/end_date for a given sensor
    """
    print('Running find_observations')
    ret = []
    for year in range(start_date.year, end_date.year + 1):
        ret.extend(filter(lambda x: start_date <= x.acquired <= end_date, year_deets(fs, year, sensor, region, horiz, vert)))

    return ret

def sr_bandnumbers(sensor: str) -> List[int]:
    """
    SR band numbers for the given sensor
    oli-tirs, tm, etm
    """
    if sensor == 'oli-tirs':
        return [2, 3, 4, 5, 6, 7]
    elif (sensor == 'tm') | (sensor == 'etm'):
        return [1, 2, 3, 4, 5, 7]

    raise ValueError

def bt_bandnumbers(sensor: str) -> List[int]:
    """
    BT band numbers for the given sensor
    oli-tirs, tm, etm
    """
    if sensor == 'oli-tirs':
        return [10, 11]
    elif (sensor == 'tm') | (sensor == 'etm'):
        return [6]

    raise ValueError


def regional_affine(region: str) -> Affine:
    """
    Pair up the region string with the defined spatial Affine
    """
    if region == 'CU':
        return LS_ARD_C2_CU_TileAff
    elif region == 'AK':
        return LS_ARD_C2_AK_TileAff
    elif region == 'HI':
        return LS_ARD_C2_HI_TileAff

    raise ValueError

def contains_hv(region: str, xs: List[float] | float, ys: List[float] | float) -> Tuple[List[int]] | Tuple[int]:
    """
    Determine the H/V that the point(s) fall in
    """
    rows, cols = rowcol(regional_affine(region),
                        xs,
                        ys)

    return cols, rows

########################################################
# Helper functions for scaling and descaling L1/2 products such as SR/BT/TOA
########################################################
def find_nodata(arr: np.ndarray, nodata: float = np.nan) -> np.ndarray:
    """
    Helper function for finding the nodata values within the given array due to the
    special nature of nan's

    If precision becomes an issue, then logic can be added to take advantage of numpy isclose
    https://numpy.org/doc/stable/reference/generated/numpy.isclose.html
    """
    if np.isnan(nodata):
        return np.isnan(arr)
    else:
        return arr == nodata

def descale(data: np.ndarray,
            mult: float,
            add: float,
            in_nodata: float = np.nan,
            out_nodata: float = np.nan) -> np.ndarray:
    """
    Descale some array of values (usually SR/TOA/BT)
    f(x) = x * mult + add
    """
    out = (data.astype(float) * mult + add)
    out[find_nodata(data, in_nodata)] = out_nodata

    return out

def scale(data: np.ndarray,
          mult: float,
          add: float,
          in_nodata: float = np.nan,
          out_nodata: float = np.nan) -> np.ndarray:
    """
    Scale some array of values (usually SR/TOA/BT)
    f(x) = (x + add) / mult
    """
    out = (data.astype(float) + add) / mult
    out[find_nodata(data, in_nodata)] = out_nodata

    return out

def sr_to_c1(sr_data: np.ndarray) -> np.ndarray:
    """
    Rescale the Collection 2 surface reflectance values to match Collection 1 scaling
    """
    sr = descale(sr_data,
                 mult=LS_ARD_C2_SR_MULT,
                 add=LS_ARD_C2_SR_ADD,
                 in_nodata=LS_ARD_C2_SR_NODATA,
                 out_nodata=LS_ARD_C1_SR_NODATA)

    return scale(sr,
                 mult=LS_ARD_C1_SR_MULT,
                 add=LS_ARD_C1_SR_ADD,
                 in_nodata=LS_ARD_C1_SR_NODATA,
                 out_nodata=LS_ARD_C1_SR_NODATA).astype(np.int16)

def bt_to_c1(bt_data: np.ndarray) -> np.ndarray:
    """
    Rescale the Collection 2 brightness temperature values to match Collection 1 scaling
    """
    bt = descale(bt_data,
                 mult=LS_ARD_C2_BT_MULT,
                 add=LS_ARD_C2_BT_ADD,
                 in_nodata=LS_ARD_C2_BT_NODATA,
                 out_nodata=LS_ARD_C1_BT_NODATA)

    return scale(bt,
                 mult=LS_ARD_C1_BT_MULT,
                 add=LS_ARD_C1_BT_ADD,
                 in_nodata=LS_ARD_C1_BT_NODATA,
                 out_nodata=LS_ARD_C1_BT_NODATA).astype(np.int16)

################################################
# QA related functions
################################################
def qa_bitmask(qa_arr: np.ndarray, bit: int) -> np.ndarray:
    """
    Create a boolean mask for where the bit is set in the given array
    """
    return (qa_arr & 1 << bit) > 0

def qa_fill(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for where the fill QA bit is set
    """
    return qa_bitmask(qa_arr, 0)

def qa_cl_dilated(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for where the dilated cloud QA bit is set
    """
    return qa_bitmask(qa_arr, 1)

def qa_cirrus(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for where the cirrus QA bit is set
    """
    return qa_bitmask(qa_arr, 2)

def qa_cloud(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for where the cloud QA bit is set
    """
    return qa_bitmask(qa_arr, 3)

def qa_cl_shadow(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for where the cloud shadow QA bit is set
    """
    return qa_bitmask(qa_arr, 4)

def qa_snow(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for where the snow QA bit is set
    """
    return qa_bitmask(qa_arr, 5)

def qa_clear(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for where the clear QA bit is set
    """
    return qa_bitmask(qa_arr, 6)

def qa_water(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for where the water QA bit is set
    """
    return qa_bitmask(qa_arr, 7)

def qa_cl_lconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for low confidence clouds
    """
    return qa_bitmask(qa_arr, 8) & ~qa_bitmask(qa_arr, 9)

def qa_cl_mconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for medium confidence clouds
    """
    return ~qa_bitmask(qa_arr, 8) & qa_bitmask(qa_arr, 9)

def qa_cl_hconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for high confidence clouds
    """
    return qa_bitmask(qa_arr, 8) & qa_bitmask(qa_arr, 9)

def qa_clsh_lconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for low confidence cloud shadow
    """
    return qa_bitmask(qa_arr, 10) & ~qa_bitmask(qa_arr, 11)

def qa_clsh_mconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for medium confidence cloud shadow
    """
    return ~qa_bitmask(qa_arr, 10) & qa_bitmask(qa_arr, 11)

def qa_clsh_hconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for high confidence cloud shadow
    """
    return qa_bitmask(qa_arr, 10) & qa_bitmask(qa_arr, 11)

def qa_snice_lconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for low confidence snow/ice
    """
    return qa_bitmask(qa_arr, 12) & ~qa_bitmask(qa_arr, 13)

def qa_snice_mconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for medium confidence snow/ice
    """
    return ~qa_bitmask(qa_arr, 12) & qa_bitmask(qa_arr, 13)

def qa_snice_hconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for high confidence snow/ice
    """
    return qa_bitmask(qa_arr, 12) & qa_bitmask(qa_arr, 13)

def qa_cirrus_lconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for low confidence cirrus
    """
    return qa_bitmask(qa_arr, 14) & ~qa_bitmask(qa_arr, 15)

def qa_cirrus_mconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for medium confidence cirrus
    """
    return ~qa_bitmask(qa_arr, 14) & qa_bitmask(qa_arr, 15)

def qa_cirrus_hconf(qa_arr: np.ndarray) -> np.ndarray:
    """
    Check pixels for high confidence cirrus
    """
    return qa_bitmask(qa_arr, 14) & qa_bitmask(qa_arr, 15)


In [3]:
# ENDPOINT = 'http://10.165.226.118:7848'
#ENDPOINT = 'http://nlilb.cr.usgs.gov:7848'


logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    stream=sys.stdout,
)

logging.basicConfig(
    level=logging.ERROR,
    format="%(asctime)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    stream=sys.stderr,
)

log = logging.getLogger()

def read_raster(path: str, fs: fsspec.AbstractFileSystem, index=1, window=None, boundless=False, fill_value=None):
    try:
        with fs.open(path) as f:
            with rio.open(f) as ds:
                return ds.read(index, window=window, boundless=boundless, fill_value=fill_value)
    except:
        print(path)
        raise

def sr_bandnumbers(sensor: str) -> List[int]:
    """
    SR band numbers for the given sensor
    oli-tirs, tm, etm
    """
    if sensor == 'oli-tirs':
        return [2, 3, 4, 5, 6, 7]
    elif (sensor == 'tm') | (sensor == 'etm'):
        return [1, 2, 3, 4, 5, 7]
    else:
        raise ValueError

def qa_std_layers(deets: LandsatARDObservation) -> List[str]:
    """
    Standard list of QA bands associated with the given observation
    """
    return [f'{deets.root_id}_QA_PIXEL.TIF']


def sr_std_layers(deets: LandsatARDObservation) -> List[str]:
    """
    Standard list of needed bands associated with the given observation
    """
    return [f'{deets.root_id}_SR_B{b}.TIF'
            for b in sr_bandnumbers(deets.sensor)]
    
def dstack_idx(idxs: np.ndarray):
    """
    Takes 2d index returns from numpy.argmin or numpy.argmax on a 3d array where axis=0 and turns it into
    a tuple of tuples for indexing back into the 3d array
    """
    rows, cols = idxs.shape
    return (idxs,
            np.repeat(np.arange(rows).reshape(-1, 1), repeats=rows, axis=1),
            np.repeat(np.arange(cols).reshape(1, -1), repeats=cols, axis=0))


def difference_absolute(arr1: np.ndarray, arr2: np.ndarray) -> np.ndarray:
    """
    Calculate the absolute distance between the values in the two different arrays
    """
    return np.abs(difference(arr1, arr2))


def difference(arr1: np.ndarray, arr2: np.ndarray) -> np.ndarray:
    """
    Difference the two given arrays
    """
    return arr1 - arr2


def difference_median(arr):
    """
    Calculate the absolute difference each value is from the median
    """
    median = np.ma.median(arr, axis=0)
    return difference_absolute(arr, median)


def sum_squares(arrs):
    """
    Square and then sum all the values (element wise) in the given arrays
    """
    return np.ma.sum([np.ma.power(a, 2) for a in arrs], axis=0)


def distance_overall(spectral):
    """
    Return the per-pixel index location for observations that come closest to the overall median value
    """
    log.info(f'Running distance_overall')
    euc_dist = np.ma.sqrt(sum_squares([difference_median(spectral[b])
                                       for b in range(spectral.shape[0])]))
    
    idxs = dstack_idx(np.ma.argmin(euc_dist, axis=0))
    
    return np.array([spectral[b][idxs]
                     for b in range(spectral.shape[0])])
    
def std_mask(qa_arr: np.ndarray) -> np.ndarray:
    """
    Standard mask for compositing stuff
    """
    return (c2ard.qa_fill(qa_arr) |
            c2ard.qa_cirrus(qa_arr) |
            c2ard.qa_cloud(qa_arr) |
            c2ard.qa_cl_shadow(qa_arr) |
            c2ard.qa_snow(qa_arr))

def comp_worker(window: Window, ids: List[LandsatARDObservation]) -> np.ndarray:
    log.info(f'Processing {window}')
    fs = fsspec.filesystem("s3", requester_pays=True)

    big_arr = np.zeros(shape=(6, len(ids), window.height, window.width), dtype=int)
    qas = np.zeros(shape=(len(ids), window.height, window.width), dtype=int)

    for idx1, d in enumerate(ids):
        log.info(f'Pulling data for {d}')
        layers = sr_std_layers(d)
        qlayer = qa_std_layers(d)[0]
        
        for idx2, b in enumerate(layers):
            log.info(f'Contd pulling data for {b}')
            print(['/'.join([d.std_path, b])])
            big_arr[idx2, idx1] = read_raster('/'.join([d.std_path, b]), fs, window=window, boundless=False)
            
        print(['/'.join([d.std_path, qlayer])])
        qas[idx1] = read_raster('/'.join([d.std_path, qlayer]), fs, window=window, boundless=False)
        
        
    return window, distance_overall(np.ma.masked_array(big_arr,
                                               dtype=big_arr.dtype,
                                               mask=np.repeat(np.expand_dims(std_mask(qas), axis=0),
                                                              repeats=6,
                                                              axis=0)))

# def main():
#     log.info('Initializing filesystem')
#     storage_options = {'profile': "ceph",
#                        'client_kwargs': {"endpoint_url": ENDPOINT}}
    
#     fs = fsspec.filesystem("s3", **storage_options)
        
#     start = dt.datetime.strptime('20130501', '%Y%m%d').date()
#     end = dt.datetime.strptime('20130901', '%Y%m%d').date()

#     log.info('Identifying inputs')
#     deets = []
#     for sensor in LANDSAT_SENSORS:
#         print(sensor)
#         deets.extend(find_observations(fs, start, end, sensor, 'CU', 3, 10))
#         print(deets)
        
#     windows = [Window(col_off=x, row_off=y, width=100, height=100)
#                for x in range(0, 5000, 100)
#                for y in range(0, 5000, 100)]
    
#     func = partial(comp_worker,
#                    ids=deets,
#                    storage_options=storage_options)

#     with mp.Pool() as pool:
#         log.info(f'Begin processing')
#         for window, arrs in pool.imap_unordered(func, windows):
#             pass
    
# if __name__ == '__main__':
#     mp.set_start_method('forkserver')
#     #main()


In [4]:
log.info('Initializing filesystem')
# storage_options = {'profile': "ceph",
#                    'client_kwargs': {"endpoint_url": ENDPOINT}}

storage_options = None

fs = fsspec.filesystem("s3", requester_pays=True)

start = dt.datetime.strptime('20130501', '%Y%m%d').date()
end = dt.datetime.strptime('20130901', '%Y%m%d').date()

log.info('Identifying inputs')
deets = []

# List of sensors and their data collection periods
sensors = {
    "tm": (1982, 2012),
    "oli-tirs": (2013, 2023),
    "etm": (1999, 2022)
}

# Function to determine sensor availability
def check_sensor_availability(sensor_name, year):
    if sensor_name in sensors:
        start_year, end_year = sensors[sensor_name]
        if start_year <= year <= end_year:
            return "available"
    return "unavailable"

for sensor in LANDSAT_SENSORS:  
    print('\n\n '+sensor)
    
    if check_sensor_availability(sensor,start.year) == 'available':
        deets.extend(find_observations(fs, start, end, sensor, 'CU', 3, 10))
    else:
        continue

windows = [Window(col_off=x, row_off=y, width=10, height=10)
           for x in range(0, 10, 10)
           for y in range(0, 10, 10)]

func = partial(comp_worker,
               ids=deets)

with mp.Pool(4) as pool:
    log.info(f'Begin processing')
    for window, arrs in pool.imap_unordered(func, windows):
        pass #Kelcy didnt finish?

2023-10-25 21:24:13 Initializing filesystem
2023-10-25 21:24:14 Identifying inputs


 oli-tirs
Running find_observations
  Running year_deets
  Running obs_deets:usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/LC08_CU_003010_20130322_20210501_02
  Running obs_deets:usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/LC08_CU_003010_20130327_20210501_02
  Running obs_deets:usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/LC08_CU_003010_20130401_20210501_02
  Running obs_deets:usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/LC08_CU_003010_20130411_20210501_02
  Running obs_deets:usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/LC08_CU_003010_20130418_20210501_02
  Running obs_deets:usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/LC08_CU_003010_20130425_20210501_02
  Running obs_deets:usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/LC08_CU_003010_20130427_20210501_02
  Running obs_deets:usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/LC08_C

KeyboardInterrupt: 

In [None]:
window = [Window(col_off=x, row_off=y, width=10, height=10)
           for x in range(0, 10, 10)
           for y in range(0, 10, 10)]
window
read_raster('usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/LC08_CU_003010_20130504_20210501_02/LC08_CU_003010_20130504_20210501_02_SR_B2.TIF', 
            fs, window=windows, boundless=True)

In [None]:
list = fs.ls('usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/LC08_CU_003010_20130504_20210501_02/')
list

 https://landsatlook.usgs.gov/data/collection02/level-2/standard/oli-tirs/2013/CU/003/010/catalog.json

In [None]:
# ! aws s3 ls s3://usgs-landsat-ard/collection02/oli-tirs/2013/CU/003/010/ --request-payer requester

In [None]:
# ! aws s3api get-object --bucket usgs-landsat-ard --key collection02/oli-tirs/2013/CU/003/010/catalog.json  --request-payer requester catalog.json