In [1]:
import os
import pathlib

# Application packages
import datetime as dt
import glob
import logging
import json
import os
import sys
import shutil
import ray
import numpy as np
from osgeo import gdal
import hytools as ht
from PIL import Image
import matplotlib.pyplot as plt
import pystac
import subprocess

# stage_in packages
from unity_sds_client.resources.collection import Collection

# stage_out packages
from datetime import datetime, timezone
from unity_sds_client.resources.dataset import Dataset
from unity_sds_client.resources.data_file import DataFile

In [2]:
input_stac_collection_file = '/unity/ads/input_collections/TRAIT_MERGE/catalog.json' # type: stage-in
output_stac_catalog_dir    = '/unity/ads/outputs/SBG-L2B_VEGBIOCHEM'                    # type: stage-out

experimental = False
crid = "000"
veg_cover = 0.5
gdal_dir = "/home/jovyan/conda-envs/sister-trait/bin"

#For eventual catalogging of this file in the unity environment
output_collection="urn:nasa:unity:unity:dev:SBG-L2B_VEGBIOCHEM___1"

#optional variables
temp_work_dir = "/unity/ads/temp/SBG-L2B_VEGBIOCHEM"


# Import Files from STAC Item Collection

Load filenames from the stage_in STAC item collection file

In [3]:
out_collection = Collection(output_collection)
inp_collection = Collection.from_stac(input_stac_collection_file)
data_filenames = inp_collection.data_locations(["data"])

data_filenames

if2


['/unity/ads/input_collections/TRAIT_MERGE/./SISTER_EMIT_L2A_CORFL_20230807T182755_001.bin',
 '/unity/ads/input_collections/TRAIT_MERGE/./SISTER_EMIT_L2A_CORFL_20230807T182755_001.hdr',
 '/unity/ads/input_collections/TRAIT_MERGE/./SISTER_EMIT_L2B_FRCOV_20230807T182755_001.tif']

In [4]:
for df in data_filenames:
    if ".hdr" in df:
        reflectance_hdr_file = df
    elif ".bin" in df:
        reflectance_file = df
    elif ".tif" in df:
        frcov_file = df


In [5]:
if not os.path.exists(output_stac_catalog_dir):
    os.mkdir(output_stac_catalog_dir)

if not os.path.exists(temp_work_dir):
    os.mkdir(temp_work_dir)

# Set up console logging using root logger
logging.basicConfig(format="%(asctime)s %(levelname)s: %(message)s", level=logging.INFO)
logger = logging.getLogger("sister-trait-estimate")
# Set up file handler logging
handler = logging.FileHandler(f"{output_stac_catalog_dir}/pge_run.log")
handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s %(levelname)s [%(module)s]: %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.info("Starting trait_estimate.py")

if not os.path.exists("models"):
    logger.error("Can't find modles directory!")

2024-03-21 16:57:19,166 INFO: Starting trait_estimate.py


# Helper Functions

In [6]:
def get_description_from_trait(trait, model_jsons):
    for model in model_jsons:
        if trait == model["short_name"].upper():
            return model["full_name"]
    return None


def generate_stac_metadata(basename, trait, description, in_meta):

    out_meta = {}
    out_meta['id'] = basename
    out_meta['start_datetime'] = dt.datetime.strptime(in_meta['start_datetime'], "%Y-%m-%dT%H:%M:%SZ")
    out_meta['end_datetime'] = dt.datetime.strptime(in_meta['end_datetime'], "%Y-%m-%dT%H:%M:%SZ")
    out_meta['geometry'] = in_meta['geometry']
    base_tokens = basename.split('_')
    out_meta['collection'] = f"SISTER_{base_tokens[1]}_{base_tokens[2]}_{base_tokens[3]}_{base_tokens[5]}"
    product = base_tokens[3]
    if trait is not None:
        product += f"_{trait}"
    out_meta['properties'] = {
        'sensor': in_meta['sensor'],
        'description': description,
        'product': product,
        'processing_level': base_tokens[2]
    }
    return out_meta


def create_item(metadata, assets):
    item = pystac.Item(
        id=metadata['id'],
        datetime=metadata['start_datetime'],
        start_datetime=metadata['start_datetime'],
        end_datetime=metadata['end_datetime'],
        geometry=metadata['geometry'],
        collection=metadata['collection'],
        bbox=None,
        properties=metadata['properties']
    )
    # Add assets
    for key, href in assets.items():
        item.add_asset(key=key, asset=pystac.Asset(href=href))
    return item


def apply_trait_model(hy_obj, args):
    '''Apply trait model(s) to image and export to file.

    '''

    logger.info("Applying Trait Model")
    print("Applying Trait Model")
    json_file,crid,disclaimer =args

    with open(json_file, 'r') as json_obj:
        trait_model = json.load(json_obj)
        coeffs = np.array(trait_model['model']['coefficients']).T
        intercept = np.array(trait_model['model']['intercepts'])
        model_waves = np.array(trait_model['wavelengths'])

    if (hy_obj.wavelengths.min() > model_waves.min()) |  (hy_obj.wavelengths.max() < model_waves.max()):
        print('%s model wavelengths outside of image wavelength range, skipping....' % trait_model["full_name"])
        return

    hy_obj.create_bad_bands([[300,400],[1337,1430],[1800,1960],[2450,2600]])
    hy_obj.resampler['type'] = 'cubic'

    #Check if wavelengths match
    resample = not all(x in hy_obj.wavelengths for x in model_waves)
    if resample:
        print('Spectral resampling required')
        hy_obj.resampler['out_waves'] = model_waves
    else:
        wave_mask = [np.argwhere(x==hy_obj.wavelengths)[0][0] for x in model_waves]

    iterator = hy_obj.iterate(by = 'line',
                  resample=resample)

    trait_array = np.zeros((3,hy_obj.lines,
                            hy_obj.columns))

    while not iterator.complete:
        chunk = iterator.read_next()
        if not resample:
            chunk = chunk[:,wave_mask]

        # Apply spectrum transforms
        for transform in  trait_model['model']["transform"]:
            if  transform== "vector":
                norm = np.linalg.norm(chunk,axis=1)
                chunk = chunk/norm[:,np.newaxis]
            if transform == "absorb":
                chunk = np.log(1/chunk)
            if transform == "mean":
                mean = chunk.mean(axis=1)
                chunk = chunk/mean[:,np.newaxis]

        trait_pred = np.dot(chunk,coeffs)
        trait_pred = trait_pred + intercept
        trait_mean = trait_pred.mean(axis=1)
        qa = (trait_mean > trait_model['model_diagnostics']['min']) & (trait_mean < trait_model['model_diagnostics']['max'])

        trait_array[0,iterator.current_line,:] = trait_mean
        trait_array[1,iterator.current_line,:] = trait_pred.std(ddof=1,axis=1)
        trait_array[2,iterator.current_line,:] = qa.astype(int)

        nd_mask = hy_obj.mask['no_data'][iterator.current_line] & hy_obj.mask['veg'][iterator.current_line]
        trait_array[:,iterator.current_line,~nd_mask] = -9999

    trait_abbrv = trait_model["short_name"].upper()
    sister,sensor,level,product,datetime_var,in_crid =  hy_obj.base_name.split('_')

    temp_file =  f'{temp_work_dir}/SISTER_{sensor}_L2B_VEGBIOCHEM_{datetime_var}_{crid}_{trait_abbrv}.tif'
    out_file =  f'{output_stac_catalog_dir}/SISTER_{sensor}_L2B_VEGBIOCHEM_{datetime_var}_{crid}_{trait_abbrv}.tif'
    

    
    logger.info(temp_file)
    logger.info(out_file)

    band_names = ["%s_mean" % trait_model["short_name"].lower(),
                                 "%s_std_dev" % trait_model["short_name"].lower(),
                                 "%s_qa_mask" % trait_model["short_name"].lower()]

    units= [trait_model["full_units"].upper(),
            trait_model["full_units"].upper(),
            "NA"]

    descriptions= ["%s MEAN" % trait_model["full_name"].upper(),
                  "%s STANDARD DEVIATION" % trait_model["full_name"].upper(),
                  "QUALITY ASSURANCE MASK"]


    in_file = gdal.Open(hy_obj.file_name)

    # Set the output raster transform and projection properties
    driver = gdal.GetDriverByName("GTIFF")
    tiff = driver.Create(temp_file,
                         hy_obj.columns,
                         hy_obj.lines,
                         3,
                         gdal.GDT_Float32)

    tiff.SetGeoTransform(in_file.GetGeoTransform())
    tiff.SetProjection(in_file.GetProjection())
    tiff.SetMetadataItem("DESCRIPTION",f"{disclaimer}L2B VEGETATION BIOCHEMISTRY %s" % trait_model["full_name"].upper())

    # Write bands to file
    for i,band_name in enumerate(band_names,start=1):
        band = tiff.GetRasterBand(i)
        band.WriteArray(trait_array[i-1])
        band.SetDescription(band_name)
        band.SetNoDataValue(hy_obj.no_data)
        band.SetMetadataItem("UNITS",units[i-1])
        band.SetMetadataItem("DESCRIPTION",descriptions[i-1])
    del tiff, driver

    print("running system gdal commands")

    subprocess.run([f'{gdal_dir}/gdaladdo', "-minsize", "900", temp_file]) 
    #os.system(f"gdaladdo -minsize 900 {temp_file}")
    subprocess.run([f'{gdal_dir}/gdal_translate', temp_file, out_file, "-co", "COMPRESS=LZW", "-co", "TILED=YES","-co", "COPY_SRC_OVERVIEWS=YES"]) 
    #os.system(f"gdal_translate {temp_file} {out_file} -co COMPRESS=LZW -co TILED=YES -co COPY_SRC_OVERVIEWS=YES")



# Process Data

In [7]:
from pathlib import Path


if experimental:
    logger.info("Turning on experimental flags")
    disclaimer = "(DISCLAIMER: THIS DATA IS EXPERIMENTAL AND NOT INTENDED FOR SCIENTIFIC USE) "
else:
    disclaimer = ""


rfl_base_name = Path(reflectance_file).stem
sister,sensor,level,product,datetime_var,in_crid = rfl_base_name.split('_')

rfl_file = reflectance_file
fc_file = frcov_file

qlook_file = f'{output_stac_catalog_dir}/SISTER_{sensor}_L2B_VEGBIOCHEM_{datetime_var}_{crid}.png'
qlook_met = qlook_file.replace('.png','.met.json')

models = glob.glob('models/PLSR*.json')

In [8]:
if ray.is_initialized():
    ray.shutdown()
ray.init(num_cpus = len(models))

HyTools = ray.remote(ht.HyTools)
actors = [HyTools.remote() for rfl_file in models]

# Load data
logger.info("Loading data")
_ = ray.get([a.read_file.remote(rfl_file,'envi') for a,b in zip(actors,models)])

# Set fractional cover mask
logger.info("Setting fractional cover mask")
fc_obj = gdal.Open(fc_file)
veg_mask = fc_obj.GetRasterBand(2).ReadAsArray() >= veg_cover

_ = ray.get([a.set_mask.remote(veg_mask,'veg') for a,b in zip(actors,models)])

_ = ray.get([a.do.remote(apply_trait_model,[json_file,crid,disclaimer]) for a,json_file in zip(actors,models)])

ray.shutdown()

bands = []

2024-03-21 16:57:35,979	INFO worker.py:1538 -- Started a local Ray instance.
2024-03-21 16:57:37,203 INFO: Loading data
2024-03-21 16:58:02,894 INFO: Setting fractional cover mask


[2m[36m(HyTools pid=1065)[0m Applying Trait Model
[2m[36m(HyTools pid=1066)[0m Applying Trait Model
[2m[36m(HyTools pid=1064)[0m Applying Trait Model
[2m[36m(HyTools pid=1066)[0m running system gdal commands
[2m[36m(HyTools pid=1066)[0m 0
[2m[36m(HyTools pid=1066)[0m ...10
[2m[36m(HyTools pid=1066)[0m ...20...30
[2m[36m(HyTools pid=1066)[0m ...40...50...60...70.
[2m[36m(HyTools pid=1066)[0m ..80...90
[2m[36m(HyTools pid=1066)[0m ..
[2m[36m(HyTools pid=1066)[0m .100 - done.
[2m[36m(HyTools pid=1066)[0m Input file size is 2003, 1935
[2m[36m(HyTools pid=1066)[0m 0
[2m[36m(HyTools pid=1066)[0m .
[2m[36m(HyTools pid=1066)[0m ..
[2m[36m(HyTools pid=1066)[0m 10
[2m[36m(HyTools pid=1066)[0m ..
[2m[36m(HyTools pid=1066)[0m .
[2m[36m(HyTools pid=1066)[0m 20.
[2m[36m(HyTools pid=1066)[0m .
[2m[36m(HyTools pid=1066)[0m .30..
[2m[36m(HyTools pid=1066)[0m .
[2m[36m(HyTools pid=1066)[0m 40
[2m[36m(HyTools pid=1066)[0m ..
[2m[36

In [9]:

if sensor != 'DESIS':
    for trait_abbrv in ['NIT','CHL','LMA']:
        
        tif_file = f'{output_stac_catalog_dir}/SISTER_{sensor}_L2B_VEGBIOCHEM_{datetime_var}_{crid}_{trait_abbrv}.tif'
        print(tif_file)
        gdal_obj = gdal.Open(tif_file)
        band = gdal_obj.GetRasterBand(1)
        band_arr = np.copy(band.ReadAsArray())
        bands.append(band_arr)

    rgb=  np.array(bands)
    rgb[rgb == band.GetNoDataValue()] = np.nan

    rgb = np.moveaxis(rgb,0,-1).astype(float)
    bottom = np.nanpercentile(rgb,5,axis = (0,1))
    top = np.nanpercentile(rgb,95,axis = (0,1))
    rgb = np.clip(rgb,bottom,top)
    rgb = (rgb-np.nanmin(rgb,axis=(0,1)))/(np.nanmax(rgb,axis= (0,1))-np.nanmin(rgb,axis= (0,1)))
    rgb = (rgb*255).astype(np.uint8)
    im = Image.fromarray(rgb)
    description = f'{disclaimer}Vegetation biochemistry RGB quicklook. R: Nitrogen, G: Chlorophyll, B: Leaf Mass ' \
                  f'per Area'


/unity/ads/outputs/SBG-L2B_VEGBIOCHEM/SISTER_EMIT_L2B_VEGBIOCHEM_20230807T182755_000_NIT.tif
/unity/ads/outputs/SBG-L2B_VEGBIOCHEM/SISTER_EMIT_L2B_VEGBIOCHEM_20230807T182755_000_CHL.tif
/unity/ads/outputs/SBG-L2B_VEGBIOCHEM/SISTER_EMIT_L2B_VEGBIOCHEM_20230807T182755_000_LMA.tif


In [10]:
im.save(qlook_file)

# If experimental, prefix filenames with "EXPERIMENTAL-"
if experimental:
    for file in glob.glob(f"{output_stac_catalog_dir}/SISTER*"):
        shutil.move(file, f"{output_stac_catalog_dir}/EXPERIMENTAL-{os.path.basename(file)}")

# Create stage-out item catalog

In [11]:
orig_dataset = inp_collection.datasets[0]

data_files = glob.glob(output_stac_catalog_dir+"/*SISTER*.tif") 
# hack to get the radiance file
data_file = os.path.basename(data_files[0].replace("_UNC",""))
name=os.path.splitext(data_file)[0]
name = ("_").join(name.split("_")[0:-1])

In [12]:
dataset = Dataset(
    name=name, 
    collection_id=output_collection, 
    start_time=orig_dataset.data_begin_time, 
    end_time=orig_dataset.data_end_time,
    creation_time=datetime.utcnow().replace(tzinfo=timezone.utc).isoformat(),
)

# Add output file(s) to the dataset
for file in glob.glob(output_stac_catalog_dir+"/*SISTER*"):
    #type, location, roles = [], title = "", description = "" 
    if file.endswith(".tif"):
        dataset.add_data_file(DataFile("COG",file, ["data"]))
    elif file.endswith(".png"):
        dataset.add_data_file(DataFile("image/png",file, ["browse"]))
    else:
        dataset.add_data_file(DataFile(None,file, ["metadata"]))
        
#Add the STAC file we are creating
# the future metadata file needs to be added to the STAC as well
    # will eventually be moved into the to_stac() function
dataset.add_data_file(DataFile("text/json",os.path.join(output_stac_catalog_dir, name + ".json"), ["metadata"]))

# Add the dataset to the collection
#out_collection.add_dataset(dataset)
out_collection._datasets.append(dataset)

Collection.to_stac(out_collection, output_stac_catalog_dir)