In [5]:
import os
import boto3
import logging
import time
import math
import pandas as pd
import pytz
import subprocess
import re
import requests
import concurrent.futures
from datetime import datetime
from io import BytesIO
from osgeo import gdal
from dask.distributed import Client, get_worker

# dask/parallelization libraries
import coiled
import dask
from dask.distributed import Client, LocalCluster
from dask.distributed import print
import distributed

# scipy basics
import numpy as np
import rasterio
import rasterio.transform
import rasterio.windows
import geopandas as gpd
import pandas as pd
import rioxarray
import xarray as xr

# numba
from numba import jit
from numba.typed import Dict
from numba.core import types

<font size="6">Cluster management</font> 

<font size="5">Creating clusters</font> 

In [42]:
# Full cluster
coiled_cluster = coiled.Cluster(
    n_workers=60,
    use_best_zone=True, 
    compute_purchase_option="spot_with_fallback",
    idle_timeout="10 minutes",
    region="us-east-1",
    name="AFOLU_flux_model", 
    account='wri-forest-research', 
    worker_cpu=4, # Adequate for carbon pool 2000
    worker_memory = "32GiB" # Adequate for carbon pool 2000
    # worker_cpu=8,
    # worker_memory = "64GiB"
)

# Coiled cluster (cloud run)
coiled_client = coiled_cluster.get_client()
coiled_client

Output()

Output()

0,1
Connection method: Cluster object,Cluster type: coiled.Cluster
Dashboard: https://cluster-bqjye.dask.host/YYxOdTWGei0LANoH/status,

0,1
Dashboard: https://cluster-bqjye.dask.host/YYxOdTWGei0LANoH/status,Workers: 5
Total threads: 20,Total memory: 151.06 GiB

0,1
Comm: tls://10.0.40.139:8786,Workers: 5
Dashboard: http://10.0.40.139:8788/status,Total threads: 20
Started: Just now,Total memory: 151.06 GiB

0,1
Comm: tls://10.0.35.173:37175,Total threads: 4
Dashboard: http://10.0.35.173:8787/status,Memory: 30.22 GiB
Nanny: tls://10.0.35.173:40805,
Local directory: /scratch/dask-scratch-space/worker-96aqx07f,Local directory: /scratch/dask-scratch-space/worker-96aqx07f

0,1
Comm: tls://10.0.45.216:42157,Total threads: 4
Dashboard: http://10.0.45.216:8787/status,Memory: 30.21 GiB
Nanny: tls://10.0.45.216:46105,
Local directory: /scratch/dask-scratch-space/worker-i7y_5eau,Local directory: /scratch/dask-scratch-space/worker-i7y_5eau

0,1
Comm: tls://10.0.41.67:37375,Total threads: 4
Dashboard: http://10.0.41.67:8787/status,Memory: 30.21 GiB
Nanny: tls://10.0.41.67:34467,
Local directory: /scratch/dask-scratch-space/worker-jicwcugq,Local directory: /scratch/dask-scratch-space/worker-jicwcugq

0,1
Comm: tls://10.0.41.143:45651,Total threads: 4
Dashboard: http://10.0.41.143:8787/status,Memory: 30.22 GiB
Nanny: tls://10.0.41.143:42393,
Local directory: /scratch/dask-scratch-space/worker-cq5au4_t,Local directory: /scratch/dask-scratch-space/worker-cq5au4_t

0,1
Comm: tls://10.0.32.94:41001,Total threads: 4
Dashboard: http://10.0.32.94:8787/status,Memory: 30.20 GiB
Nanny: tls://10.0.32.94:37763,
Local directory: /scratch/dask-scratch-space/worker-bow1gvqq,Local directory: /scratch/dask-scratch-space/worker-bow1gvqq


In [None]:
# Test cluster
coiled_cluster = coiled.Cluster(
    n_workers=2,
    use_best_zone=True, 
    compute_purchase_option="spot_with_fallback",
    idle_timeout="20 minutes",
    region="us-east-1",
    name="AFOLU_flux_model", 
    account='wri-forest-research', 
    worker_cpu=4,
    worker_memory = "32GiB" # Adequate for carbon pool 2000
    # worker_cpu=8,
    # worker_memory = "64GiB"
)

# Coiled cluster (cloud run)
coiled_client = coiled_cluster.get_client()
coiled_client

In [None]:
# Local single-process cluster (local run). Will run .compute() on just one process, not a whole cluster.
local_client = Client(processes=False)
local_client

In [None]:
local_client = Client()
local_client

In [None]:
# Local cluster with multiple workers
local_cluster = LocalCluster()  
local_client = Client(local_cluster)
local_client

<font size="5">Shutting down cloud and local clusters</font> 

In [None]:
coiled_client.restart() 

In [305]:
coiled_cluster.shutdown()

In [None]:
local_client.shutdown()

<font size="6">Variables and constants</font> 

In [7]:
# General paths and constants

LC_uri = 's3://gfw2-data/landcover'

s3_out_dir = 'climate/AFOLU_flux_model/LULUCF/outputs'

s3 = boto3.resource('s3')
my_bucket = s3.Bucket('gfw2-data')
s3_client = boto3.client("s3")

tile_id_pattern = r"[0-9]{2}[A-Z][_][0-9]{3}[A-Z]"  # Pattern for tile_ids in regex form

IPCC_class_max_val = 6  # Maximum value of IPCC class codes

# IPCC codes
forest = 1
cropland = 2
settlement = 3
wetland = 4
grassland = 5
otherland = 6

first_year = 2000  # First year of model
last_year = 2020   # Last year of model

full_raster_dims = 40000    # Size of a 10x10 deg raster in pixels

interval_years = 5   # Number of years in interval. #TODO: calculate programmatically in numba function rather than coded here-- for greater flexibility.

# Threshold for height loss to be counted as tree loss (meters)
sig_height_loss_threshold = 5 

biomass_to_carbon_non_mangrove = 0.47   # Conversion of biomass to carbon for non-mangrove forests
biomass_to_carbon_mangrove = 0.45   # Conversion of biomass to carbon for mangroves (IPCC wetlands supplement table 4.2)

# Default root:shoot when no Huang et al. 2021 is available. The average slope of the AGB:BGB relationship in Figure 3 of Mokany et al. 2006.
# and is only used where Huang et al. 2021 can't reach (remote Pacific islands).
default_r_s = 0.26   

rate_ratio_spreadsheet = 'http://gfw2-data.s3.amazonaws.com/climate/AFOLU_flux_model/LULUCF/rate_ratio_lookup_tables/rate_and_ratio_lookup_tables_20240718.xlsx'
mangrove_rate_ratio_tab = 'mang gain C ratio, for model'

# Non-mangrove deadwood C:AGC and litter C:AGC constants
# Deadwood and litter carbon as fractions of AGC are from
# https://cdm.unfccc.int/methodologies/ARmethodologies/tools/ar-am-tool-12-v3.0.pdf
# "Clean Development Mechanism A/R Methodological Tool: 
# Estimation of carbon stocks and change in carbon stocks in dead wood and litter in A/R CDM project activities version 03.0"
# Tables on pages 18 (deadwood) and 19 (litter).
# They depend on the climate domain, elevation, and precipitation. 
tropical_low_elev_low_precip_deadwood_c_ratio = 0.02
tropical_low_elev_low_precip_litter_c_ratio = 0.04
tropical_low_elev_med_precip_deadwood_c_ratio = 0.01
tropical_low_elev_med_precip_litter_c_ratio = 0.01
tropical_low_elev_high_precip_deadwood_c_ratio = 0.06
tropical_low_elev_high_precip_litter_c_ratio = 0.01
tropical_high_elev_deadwood_c_ratio = 0.07
tropical_high_elev_litter_c_ratio = 0.01
non_tropical_deadwood_c_ratio = 0.08
non_tropical_litter_c_ratio = 0.04

mang_no_data_val = 255   # NoData value in mangrove AGB raster

In [8]:
model_version = 0.1

In [9]:
# GLCLU codes
cropland = 244
builtup = 250

tree_dry_min_height_code = 27
tree_dry_max_height_code = 48
tree_wet_min_height_code = 127
tree_wet_max_height_code = 148

tree_threshold = 5   # Height minimum for trees (meters)

In [10]:
# File name paths and patterns

log_path = "climate/AFOLU_flux_model/LULUCF/model_logs/"
combined_log = "AFOLU_model_log"

agb_2000_path = "s3://gfw2-data/climate/WHRC_biomass/WHRC_V4/Processed/"
agb_2000_pattern = "t_aboveground_biomass_ha_2000"

mangrove_agb_2000_path = "s3://gfw2-data/climate/carbon_model/mangrove_biomass/processed/standard/20190220/"
mangrove_agb_2000_pattern = "mangrove_agb_t_ha_2000"

elevation_path = "s3://gfw2-data/climate/carbon_model/inputs_for_carbon_pools/processed/elevation/20190418/"
elevation_pattern = "elevation"

climate_domain_path = "s3://gfw2-data/climate/carbon_model/inputs_for_carbon_pools/processed/fao_ecozones_bor_tem_tro/20190418/"
climate_domain_pattern = "fao_ecozones_bor_tem_tro_processed"

precipitation_path = "s3://gfw2-data/climate/carbon_model/inputs_for_carbon_pools/processed/precip/20190418/"
precipitation_pattern = "precip_mm_annual"

r_s_ratio_path = "s3://gfw2-data/climate/carbon_model/BGB_AGB_ratio/processed/20230216/"
r_s_ratio_pattern = "BGB_AGB_ratio"

continent_ecozone_path = "s3://gfw2-data/climate/carbon_model/fao_ecozones/ecozone_continent/20190116/processed/"
continent_ecozone_pattern = "fao_ecozones_continents_processed"


### IPCC classes and change
IPCC_class_path = "IPCC_basic_classes"
IPCC_class_pattern = "IPCC_classes"
IPCC_change_path = "IPCC_basic_change"
IPCC_change_pattern = "IPCC_change"

land_state_pattern = "land_state_node"

agb_dens_pattern = "AGB_density_MgAGB_ha"
agc_dens_pattern = "AGC_density_MgC_ha"
bgc_dens_pattern = "BGC_density_MgC_ha"
deadwood_c_dens_pattern = "deadwood_C_density_MgC_ha"
litter_c_dens_pattern = "litter_C_density_MgC_ha"
agc_flux_pattern = "AGC_flux_MgC_ha"
bgc_flux_pattern = "BGC_flux_MgC_ha"
deadwood_c_flux_pattern = "deadwood_C_flux_MgC_ha"
litter_c_flux_pattern = "litter_C_flux_MgC_ha"

land_cover = "land_cover"
vegetation_height = "vegetation_height"

agb_2000 = "agb_2000"
mangrove_agb_2000 = "mangrove_agb_2000"
agc_2000 = "agc_2000"
bgc_2000 = "bgc_2000"
deadwood_c_2000 = "deadwood_c_2000"
litter_c_2000 = "litter_c_2000"
soil_c_2000 = "soil_c_2000"

r_s_ratio = "r_s_ratio"

burned_area = "burned_area"
forest_disturbance = "forest_disturbance"

planted_forest_type_layer = "planted_forest_type"
planted_forest_tree_crop_layer = "planted_forest_tree_crop"

elevation = "elevation"
climate_domain = "climate_domain"
precipitation = "precipitation"
continent_ecozone = "continent_ecozone"

<font size="6">Logging</font> 

In [11]:
# Log compilation and uploading
# From https://chatgpt.com/share/e/4fe1e9c8-05a0-4e9d-8eee-64168891b5e2
def compile_and_upload_log(logs, stage, chunk_count, chunk_size_deg, start_time_str, end_time_str, log_note):

    log_name = f"logs/{combined_log}_{stage}_{time.strftime('%Y%m%d_%H_%M_%S')}.txt"

    # Converts the start time of the stage run from string to datetime so it can be compared to the log entries' times
    start_time = datetime.strptime(start_time_str, "%Y%m%d_%H_%M_%S")

    # Retrieves the number of workers
    n_workers = len(coiled_client.scheduler_info()['workers'])  # Get the number of connected workers

    # Retrieves scheduler info for other cluster properties
    scheduler_info = coiled_cluster.scheduler_info  # Access scheduler info directly as a dictionary
    
    # Gets memory per worker.
    # Can't get it to report the worker instance type
    try:
        worker_memory_bytes = scheduler_info['workers'][next(iter(scheduler_info['workers']))]['memory_limit']
        worker_memory_gb = worker_memory_bytes / (1024 ** 3)  # Convert bytes to GB
        worker_memory = f"{worker_memory_gb:.2f} GB"  # Format to 2 decimal places
        # worker_type = coiled_cluster.config.get('worker_options', {}).get('instance_type', "Unknown")
    except KeyError:
        worker_memory = "Unknown"
        # worker_type = "Unknown"
    
    # Create header lines
    header_lines = [
        f"Stage: {stage}",
        f"Model version: {model_version}",
        f"Number of workers: {n_workers}",
        f"Memory per worker: {worker_memory}",
        f"Number of chunks: {chunk_count}",
        f"Chunk size (degrees): {chunk_size_deg}",
        # f"Worker Type: {worker_type}",
        f"Log note: {log_note}",
        f"Starting time: {start_time_str}",
        "",
        "Filtered logs:",
        ""
    ]

    # Filter lines containing both 'distributed.worker' and 'flm', 
    # and where the datetime is greater than start_time
    filtered_logs = []
    for worker_id, log in logs.items():
        for line in log.split('\n'):
            if 'distributed.worker' in line and 'flm' in line:
                # Extract the datetime from the end of the log line
                log_time_str = line.split()[-1]
                try:
                    log_time = datetime.strptime(log_time_str, "%Y%m%d_%H_%M_%S")
                    # Include the line only if log_time is greater than start_time
                    if log_time > start_time:
                        filtered_logs.append(line)
                except ValueError:
                    # If the datetime format is incorrect, skip this line
                    continue

    end_time = f"Stage ended at: {end_time_str}"

    # Combine the header and filtered logs into a single string
    combined_filtered_logs = "\n".join(header_lines) + "\n".join(filtered_logs) + "\n".join(end_time)
    
    # Save the filtered logs to a text file
    with open(log_name, "w") as file:
        file.write(combined_filtered_logs)
    
    s3_client = boto3.client("s3") # Needs to be in the same function as the upload_file call
    s3_client.upload_file(log_name, "gfw2-data", Key=f"{log_path}{log_name}")

In [12]:
# Determines whether statement should be printed to the console as well as logged
def print_and_log(text, is_final, logger):

    logger.info(f"flm: {text}")
    if not is_final:
        print(f"flm: {text}")

In [13]:
# Configure logging for the distributed workers
# https://chatgpt.com/share/e/6f80ccde-6a85-4837-94a0-4fcf09b96e43
def setup_logging():
    logger = logging.getLogger('distributed.worker')
    logger.setLevel(logging.INFO)
    if not logger.hasHandlers():
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger

<font size="6">General functions</font> 

In [14]:
# Time in Eastern US timezone as a string
def timestr():
    # return time.strftime("%Y%m%d_%H_%M_%S")

    # Define the Eastern Time timezone
    eastern = pytz.timezone('US/Eastern')
    
    # Get the current time in UTC and convert to Eastern Time
    eastern_time = datetime.now(eastern)
    
    # Format the time as a string
    return eastern_time.strftime("%Y%m%d_%H_%M_%S")

# Chunk bounds as a string
def boundstr(bounds):
    bounds_str = "_".join([str(round(x)) for x in bounds])
    return bounds_str

# Chunk length in pixels
def calc_chunk_length_pixels(bounds):
    chunk_length_pixels = int((bounds[3]-bounds[1]) * (40000/10))
    return chunk_length_pixels

In [15]:
# Maps GDAL data type to the appropriate string value
gdal_dtype_mapping = {
    gdal.GDT_Byte: 'Byte',
    gdal.GDT_UInt16: 'UInt16',
    gdal.GDT_Int16: 'Int16',
    gdal.GDT_UInt32: 'UInt32',
    gdal.GDT_Int32: 'Int32',
    gdal.GDT_Float32: 'Float32',
    gdal.GDT_Float64: 'Float64'
}

In [16]:
# Gets the W, S, E, N bounds of a 10x10 degree tile
def get_10x10_tile_bounds(tile_id):
    
    if "S" in tile_id:
        max_y = -1 * (int(tile_id[:2]))
        min_y = -1 * (int(tile_id[:2])+10)
    else: 
        max_y = (int(tile_id[:2]))
        min_y = (int(tile_id[:2])-10)

    if "W" in tile_id:
        max_x = -1 * (int(tile_id[4:7])-10)
        min_x = -1 * (int(tile_id[4:7]))
    else: 
        max_x = (int(tile_id[4:7])+10)
        min_x = (int(tile_id[4:7]))

    return min_x, min_y, max_x, max_y      # W, S, E, N

In [17]:
# Returns list of all chunk boundaries within a bounding box for chunks of a given size
def get_chunk_bounds(chunk_params):

    min_x = chunk_params[0]
    min_y = chunk_params[1]
    max_x = chunk_params[2]
    max_y = chunk_params[3]
    chunk_size = chunk_params[4]
    
    x, y = (min_x, min_y)
    chunks = []

    # Polygon Size
    while y < max_y:
        while x < max_x:
            bounds = [
                x,
                y,
                x + chunk_size,
                y + chunk_size,
            ]
            chunks.append(bounds)
            x += chunk_size
        x = min_x
        y += chunk_size

    return chunks

In [18]:
# Returns the encompassing tile_id string in the form YYN/S_XXXE/W based on a coordinate
def xy_to_tile_id(top_left_x, top_left_y):

    lat_ceil = math.ceil(top_left_y/10.0) * 10
    lng_floor = math.floor(top_left_x/10.0) * 10
    
    lng: str = f"{str(lng_floor).zfill(3)}E" if (lng_floor >= 0) else f"{str(-lng_floor).zfill(3)}W"
    lat: str = f"{str(lat_ceil).zfill(2)}N" if (lat_ceil >= 0) else f"{str(-lat_ceil).zfill(2)}S"

    return f"{lat}_{lng}"

In [19]:
# Lazily opens tile within provided bounds (i.e. one chunk) and returns as a numpy array
# If it can't open the uri for the chunk (tile does not exist), it returns nothing. 
# Originally, I had it return an array of the NoData value if the chunk didn't exist but that seems inefficient.
def get_tile_dataset_rio(uri, bounds, chunk_length_pixels):

    bounds_str = boundstr(bounds)

    # If the uri exists, the relevant window is opened and returned and returned as an array.
    # Note that this chunk could still just have NoData values, which would be downloaded.
    try:
        with rasterio.open(uri) as ds:
            window = rasterio.windows.from_bounds(*bounds, ds.transform)
            data = ds.read(1, window=window)

        return data
    
    # If the uri does not exist, no array is returned
    except:

        return

In [20]:
# Prepares list of chunks to download.
# Chunks are defined by a bounding box.
def prepare_to_download_chunk(bounds, download_dict, is_final, logger):
 
    futures = {}
    
    bounds_str = boundstr(bounds)
    tile_id = xy_to_tile_id(bounds[0], bounds[3])
    chunk_length_pixels = calc_chunk_length_pixels(bounds)

    # Submit requests to S3 for input chunks but don't actually download them yet. This queueing of the requests before downloading them speeds up the downloading
    # Approach is to download all the input chunks up front for every year to make downloading more efficient, even though it means storing more upfront
    with concurrent.futures.ThreadPoolExecutor() as executor:

        print_and_log(f"Requesting data in chunk {bounds_str} in {tile_id}: {timestr()}", is_final, logger)
            
        for key, value in download_dict.items():
            futures[executor.submit(get_tile_dataset_rio, value, bounds, chunk_length_pixels)] = key

    return futures

In [21]:
# Checks if tiles exist at all
def check_for_tile(download_dict, is_final, logger):

    s3 = boto3.client('s3')

    i=0

    while i < len(list(download_dict.values())):

        s3_key = list(download_dict.values())[i][15:]
        tile_id = re.findall(tile_id_pattern, list(download_dict.values())[i])[0]  # Extracts the tile_id from the s3 path

        # Breaks the loop if the tile exists. No need to keep checking other tiles because one exists.
        try:
            s3.head_object(Bucket='gfw2-data', Key=s3_key)
            
            print_and_log(f"Tile id {tile_id} exists for some inputs. Proceeding: {timestr()} ", is_final, logger)
                           
            return True
        except:
            pass
            
        i+=1

    print_and_log(f"Tile id {tile_id} does not exists. Skipping chunk: {timestr()}", is_final, logger)
    
    return False

In [22]:
# Checks whether a chunk has data in it.
# There are two options for how to assess if a chunk has data (any_or_all argument): if any assessed input has data, or if all assessed inputs have data. 
# Any: To have data, a chunk have have at least one of the assessed inputs (layers).
# All: To have data, a chunk must have all necessary inputs (layers).
# If one or more necessary input is missing, the loop is terminated and the chunk ultimately skipped. 
def check_chunk_for_data(required_layers, item_to_check, bounds_str, tile_id, any_or_all, is_final, logger):

    # Checks if ANY of the assessed inputs are present
    if any_or_all == "any":

        i=0

        while i < len(list(required_layers.values())):
    
            # Checks if all the pixels have the nodata value. 
            # Assume no data in the chunk if the min and max values are the same for EVERY input raster.
            # Can't use np.all because it doesn't work in chunks that are mostly water; says nodata in chunk even if there is land
            # So, instead compare np.min and np.max.
            min = np.min(list(required_layers.values())[i])  
            max = np.max(list(required_layers.values())[i])
            
            # Breaks the loop if there is data in the chunk.
            # Don't need to keep checking chunk for data because the condition has been met
            # (at least one chunk has data).
            # The one print statement regardless of whether the model is full-scale or not.
            if min != max:  # if min and max are different, there must be data in the chunk
                logger.info(f"flm: Data in chunk {bounds_str}. Proceeding: {timestr()}")  
                print(f"flm: Data in chunk {bounds_str}. Proceeding: {timestr()}")
                return True
    
            i+=1

        # The one print statement regardless of whether the model is full-scale or not
        logger.info(f"flm: No data in chunk {bounds_str} for assessed inputs: {timestr()}")   
        print(f"flm: No data in chunk {bounds_str} for assessed inputs: {timestr()}")
        return False

    # Checks if ALL of the assessed inputs are present
    elif any_or_all == "all":
    
        # Iterates through all the required input layers
        for i, (key, value) in enumerate(required_layers.items()):
    
            # Assume no data in the chunk if the min and max values are the same for EVERY input raster.
            # Can't use np.all because it doesn't work in chunks that are mostly water; says nodata in chunk even if there is land
            # So, instead compare np.min and np.max.
            min = np.min(value)  
            max = np.max(value)
    
            # Breaks the loop if min and max couldn't be calculated, i.e. chunk doesn't exist.
            # Don't need to keep checking chunk for data because at least one input doesn't have data,
            # so not ALL of the inputs exist
            if (min == None) and (max == None):

                # The one print statement regardless of whether the model is full-scale or not
                logger.info(f"flm: Chunk {bounds_str} does not exist for {key}. Skipping chunk: {timestr()}")  # The one print statement regardless of whether the model is full-scale or not
                print(f"flm: Chunk {bounds_str} does not exist for {key}. Skipping chunk: {timestr()}")
                return False    

        # If all required inputs are checked (for loop is completed), ALL inputs exist.
        # The one print statement regardless of whether the model is full-scale or not
        logger.info(f"flm: Chunk {bounds_str} has data for all assessed inputs: {timestr()}")   # The one print statement regardless of whether the model is full-scale or not
        print(f"flm: Chunk {bounds_str} has data for all assessed inputs: {timestr()}")
        return True

    else: 

        raise Exception("any_or_all argument not valid")

In [23]:
# Saves array as a raster locally, then uploads it to s3. NoData value for outputs is optional
def save_and_upload_small_raster_set(bounds, chunk_length_pixels, tile_id, bounds_str, output_dict, is_final, logger, no_data_val = None):

    s3_client = boto3.client("s3") # Needs to be in the same function as the upload_file call

    transform = rasterio.transform.from_bounds(*bounds, width=chunk_length_pixels, height=chunk_length_pixels)

    file_info = f'{tile_id}__{bounds_str}'

    # For every output file, saves from array to local raster, then to s3.
    # Can't save directly to s3, unfortunately, so need to save locally first.
    for key, value in output_dict.items():

        data_array = value[0]
        data_type = value[1]
        data_meaning = value[2]
        year_out = value[3]

        if is_final:
            file_name = f"{file_info}__{key}.tif"
        else:
            file_name = f"{file_info}__{key}__{timestr()}.tif"

        print_and_log(f"Saving {bounds_str} in {tile_id} for {year_out}: {timestr()}", is_final, logger)

        # Includes NoData value in output raster
        if no_data_val is not None:
            with rasterio.open(f"/tmp/{file_name}", 'w', driver='GTiff', width=chunk_length_pixels, height=chunk_length_pixels, count=1, 
                               dtype=data_type, crs='EPSG:4326', transform=transform, compress='lzw', blockxsize=400, blockysize=400, nodata=no_data_val) as dst:
                               # dtype=data_type, crs='EPSG:4326', transform=transform, compress='lzw', nodata=no_data_val) as dst:
                dst.write(data_array, 1)

        # No NoData value in output raster
        else:
            with rasterio.open(f"/tmp/{file_name}", 'w', driver='GTiff', width=chunk_length_pixels, height=chunk_length_pixels, count=1, 
                               dtype=data_type, crs='EPSG:4326', transform=transform, compress='lzw', blockxsize=400, blockysize=400) as dst:
                dst.write(data_array, 1)
        
        s3_path = f"{s3_out_dir}/{data_meaning}/{year_out}/{chunk_length_pixels}_pixels/{time.strftime('%Y%m%d')}"

        print_and_log(f"Uploading {bounds_str} in {tile_id} for {year_out} to {s3_path}: {timestr()}", is_final, logger)
        
        s3_client.upload_file(f"/tmp/{file_name}", "gfw2-data", Key=f"{s3_path}/{file_name}")

        # Deletes the local raster
        os.remove(f"/tmp/{file_name}")

In [24]:
# Lists rasters in an s3 folder and returns their names as a list
def list_rasters_in_folder(full_in_folder):

    cmd = ['aws', 's3', 'ls', full_in_folder]
    s3_contents_bytes = subprocess.check_output(cmd)

    # Converts subprocess results to useful string
    s3_contents_str = s3_contents_bytes.decode('utf-8')
    s3_contents_list = s3_contents_str.splitlines()
    rasters = [line.split()[-1] for line in s3_contents_list]
    rasters = [i for i in rasters if "tif" in i]

    return rasters

In [25]:
# Uploads a shapefile to s3
def upload_shp(full_in_folder, in_folder, shp):

    print(f"flm: Uploading to {full_in_folder}{shp}: {timestr()}")

    shp_pattern = shp[:-4]

    s3_client = boto3.client("s3")  # Needs to be in the same function as the upload_file call
    s3_client.upload_file(f"/tmp/{shp}", "gfw2-data", Key=f"{in_folder[10:]}{shp}")
    s3_client.upload_file(f"/tmp/{shp_pattern}.dbf", "gfw2-data", Key=f"{in_folder[10:]}{shp_pattern}.dbf")
    s3_client.upload_file(f"/tmp/{shp_pattern}.prj", "gfw2-data", Key=f"{in_folder[10:]}{shp_pattern}.prj")
    s3_client.upload_file(f"/tmp/{shp_pattern}.shx", "gfw2-data", Key=f"{in_folder[10:]}{shp_pattern}.shx")

    os.remove(f"/tmp/{shp}")
    os.remove(f"/tmp/{shp_pattern}.dbf")
    os.remove(f"/tmp/{shp_pattern}.prj")
    os.remove(f"/tmp/{shp_pattern}.shx")

In [26]:
# Makes a shapefile of the footprints of rasters in a folder, for checking geographical completeness of rasters
def make_tile_footprint_shp(input_dict):

    in_folder = list(input_dict.keys())[0]
    pattern = list(input_dict.values())[0]

    # Task properties
    print(f"flm: Making tile index shapefile for: {in_folder}: {timestr()}")

    # Folder including s3 key
    s3_in_folder = f's3://{in_folder}'
    vsis3_in_folder = f'/vsis3/{in_folder}'

    # List of all the filenames in the folder
    filenames = list_rasters_in_folder(s3_in_folder)

    # List of the tile paths in the folder
    tile_paths = []
    tile_paths = [vsis3_in_folder + filename for filename in filenames]

    file_paths = 's3_paths.txt'

    with open(f"/tmp/{file_paths}", 'w') as file:
        for item in tile_paths:
            file.write(item + '\n')

    # Output shapefile name
    shp = f"raster_footprints_{pattern}.shp"

    cmd = ["gdaltindex", "-t_srs", "EPSG:4326", f"/tmp/{shp}", "--optfile", f"/tmp/{file_paths}"]
    subprocess.check_call(cmd)

    # Uploads shapefile to s3
    upload_shp(s3_in_folder, in_folder, shp)

    return(f"Completed: {timestr()}")

In [27]:
# Saves an xarray data array locally as a raster and then uploads it to s3
def save_and_upload_raster_10x10(**kwargs):

    s3_client = boto3.client("s3") # Needs to be in the same function as the upload_file call

    data_array = kwargs['data']   # The data being saved
    out_file_name = kwargs['out_file_name']   # The output file name
    out_folder = kwargs['out_folder']   # The output folder

    print(f"flm: Saving {out_file_name} locally")

    profile_kwargs = {'compress': 'lzw'}   # Adds attribute to compress the output raster 
    # data_array.rio.to_raster(f"{out_file_name}", **profile_kwargs)
    data_array.rio.to_raster(f"/tmp/{out_file_name}", **profile_kwargs)

    print(f"flm: Saving {out_file_name} to {out_folder[10:]}{out_file_name}")

    s3_client.upload_file(f"/tmp/{out_file_name}", "gfw2-data", Key=f"{out_folder[10:]}{out_file_name}")

    # Deletes the local raster
    os.remove(f"/tmp/{out_file_name}")

In [28]:
# Creates a list of 2x2 deg tiles to aggregate into 10x10 deg tiles, where the list is a list of dictionaries of the form 
# [{'gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/8000_pixels/20240821/': 
# ['00N_110E__AGC_density_MgC_ha_2000.tif', '00N_120E__AGC_density_MgC_ha_2000.tif']}, 
# {'gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/BGC_density_MgC_ha/2000/8000_pixels/20240821/': 
# ['00N_110E__BGC_density_MgC_ha_2000.tif', '00N_120E__BGC_density_MgC_ha_2000.tif']}]
def create_list_for_aggregation(s3_in_folders):

    list_of_s3_names_total = []   # Final list of dictionaries of input s3 paths and output aggregated 10x10 raster names
    
    # Iterates through all the input s3 folders
    for s3_in_folder in s3_in_folders:
    
        simple_file_names = []   # List of output aggregatd output 10x10 rasters
    
        # Raw filenames in an input folder, e.g., ['00N_000E__6_-2_8_0__IPCC_classes_2020.tif', '00N_000E__6_-4_8_-2__IPCC_classes_2020.tif',...]
        filenames = list_rasters_in_folder(f"s3://{s3_in_folder}")
    
        # Iterates through all the files in a folder and converts them to the output names. 
        # Essentially [tile_id]__[pattern].tif. Drops the chunk bounds from the middle.
        for filename in filenames:
        
            result = filename[:10] + filename[filename.rfind("__") + len("__"):]   # Extracts the relevant parts of the raw file names
            simple_file_names.append(result)   # New list of simplified file names used for 10x10 degree outputs

        # Removes duplicate simplified file names.
        # There are duplicates because each 10x10 output raster has many constituent chunks, each of which have the same aggregated, final name
        # e.g., ['00N_000E__IPCC_classes_2020.tif', '00N_010E__IPCC_classes_2020.tif', ...]
        simple_file_names = np.unique(simple_file_names).tolist()

        # Makes nested lists of the file names. Nested for next step.  
        # e.g., [['00N_110E__AGC_density_MgC_ha_2000.tif']]
        simple_file_names = [[item] for item in simple_file_names]
    
        # Makes a list of dictionaries, where the key is the input s3 path and the value is the output aggregated name
        # e.g., [{'gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/8000_pixels/20240821/': ['00N_110E__AGC_density_MgC_ha_2000.tif']}] 
        list_of_s3_name_dicts = [{key: value} for value in simple_file_names for key in [s3_in_folder]]
    
        # Adds the dictionary of s3 paths and output names for this folder to the list for all folders
        list_of_s3_names_total.append(list_of_s3_name_dicts)
    
    # Output of above is a nested list, where each input folder is its own inner list. Need to flatten to a list.
    # e.g., [{'gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/8000_pixels/20240821/': ['00N_110E__AGC_density_MgC_ha_2000.tif', '00N_120E__AGC_density_MgC_ha_2000.tif']}, 
    # {'gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/BGC_density_MgC_ha/2000/8000_pixels/20240821/': ['00N_110E__BGC_density_MgC_ha_2000.tif', '00N_120E__BGC_density_MgC_ha_2000.tif']}]
    list_of_s3_names_total = flatten_list(list_of_s3_names_total)
    
    print(f"flm: There are {len(list_of_s3_names_total)} 10x10 deg rasters to create across {len(s3_in_folders)} input folders.")

    return list_of_s3_names_total

# Flattens a nested list
def flatten_list(nested_list):
    return [x for xs in nested_list for x in xs]

In [29]:
# Merges rasters that are <10x10 degrees into 10x10 degree rasters in the standard grid.
# Approach is to merge rasters with gdal.Warp and then upload them to s3.
def merge_small_tiles_gdal(s3_name_dict):

    in_folder = list(s3_name_dict.keys())[0]   # The input s3 folder for the small rasters
    out_file_name = list(s3_name_dict.values())[0][0]   # The output file name for the combined rasters

    s3_in_folder = f's3://{in_folder}'   # The input s3 folder with s3:// prepended
    vsis3_in_folder = f'/vsis3/{in_folder}'   # The input s3 folder with /vsis3/ prepended

    # Lists all the rasters in the specified s3 folder
    filenames = list_rasters_in_folder(s3_in_folder)

    # Gets the tile_id from the output file name in the standard format
    tile_id = out_file_name[:8]

    # Limits the input rasters to the specified tile_id (the relevant 10x10 area)
    filenames_in_focus_area = [i for i in filenames if tile_id in i]
    
    # Lists the tile paths for the relevant rasters
    tile_paths = []
    tile_paths = [vsis3_in_folder + filename for filename in filenames_in_focus_area]

    print(f"flm: Merging small rasters in {tile_id} in {vsis3_in_folder}")

    # Names the output folder. Same as the input folder but with the dimensions in pixels replaced
    out_folder = re.sub(r'\d+_pixels', f'{full_raster_dims}_pixels', in_folder[10:])   # [10:] to remove the gfw2-data/ at the front

    min_x, min_y, max_x, max_y = get_10x10_tile_bounds(tile_id)

    output_extent = [min_x, min_y, max_x, max_y]  # Specify the extent in the order [xmin, ymin, xmax, ymax]

    # Dynamically sets the datatype for the merged raster based on the input rasters (courtesy of https://chatgpt.com/share/e/a91c4c98-b2b1-4680-a4a7-453f1a878052)
    # Determines the data type of the first raster
    first_raster_path = tile_paths[0]
    ds = gdal.Open(first_raster_path)
    raster_datatype = ds.GetRasterBand(1).DataType
    raster_nodata_value = ds.GetRasterBand(1).GetNoDataValue()
    ds = None

    # Defaults to Float32 if not found
    dtype_str = gdal_dtype_mapping.get(raster_datatype, 'Float32')  

    # Merges the rasters (courtesy of ChatGPT: https://chatgpt.com/share/e/13158ebb-dd0a-41d8-8dfb-9ee12e4c804e)
    # This is the only system I found that maintains the extent of all the constituent rasters and doesn't change their resolution or pixel size or shift them.
    # I also tried various gdal_translate, build_vrt, and numpy padding approaches, none of which worked in all cases.
    merged_file = f"/tmp/merged_{out_file_name}"

    merge_command = [
        'gdal_merge.py',
        '-o', merged_file,
        '-of', 'GTiff',
        '-co', 'COMPRESS=DEFLATE',
        '-co', 'TILED=YES',   # If not included, the size of the merged small rasters can be many times their sum. Answer at https://gis.stackexchange.com/a/258215
        '-co', 'BLOCKXSIZE=400',  # Internal tiling
        '-co', 'BLOCKYSIZE=400',  # Internal tiling
        '-ul_lr', str(min_x), str(max_y), str(max_x), str(min_y),
        '-ot', dtype_str,
        '-a_nodata', str(raster_nodata_value)
    ]

    # Add the input tile paths
    merge_command.extend(tile_paths)

    try:
        subprocess.check_call(merge_command)
        print(f"flm: Successfully merged rasters into {merged_file}")
    except subprocess.CalledProcessError as e:
        print(f"flm: Error merging rasters: {e}")
        return f"failure for {s3_name_dict}"

    s3_client = boto3.client("s3") # Needs to be in the same function as the upload_file call for uploading to work

    print(f"flm: Saving {out_file_name} to s3: {out_folder}{out_file_name}")
    
    try:
        s3_client.upload_file(merged_file, "gfw2-data", Key=f"{out_folder}{out_file_name}")
        print(f"flm: Successfully uploaded {out_file_name} to s3")
    except boto3.exceptions.S3UploadFailedError as e:
        print(f"flm: Error uploading file to s3: {e}")
        return f"failure for {s3_name_dict}"

    # Deletes the local merged raster
    os.remove(merged_file)

    return f"success for {s3_name_dict}"

In [30]:
@jit(nopython=True)
def accrete_node(combo, new):
    combo = combo*10 + new
    return combo

# accrete_node(1, 1)
# accrete_node(13, 1)

In [31]:
### Creates a separate dictionary for each chunk datatype so that they can be passed to Numba as separate arguments.
### Numba functions can accept (and return) dictionaries of arrays as long as each dictionary only has arrays of one data type (e.g., uint8, float32)
### Note: need to add new code if inputs with other data types are added
def create_typed_dicts(layers):

    # Initializes empty dictionaries for each type
    uint8_dict_layers = {}
    int16_dict_layers = {}
    int32_dict_layers = {}
    float32_dict_layers = {}
    
    # Iterates through the downloaded chunk dictionary and distributes arrays to a separate dictionary for each data type
    for key, array in layers.items():

        # Skips the dictionary entry if it has no data (generally because the chunk doesn't exist for that input)
        if array is None:
            continue

        # If there is data, it puts the data in the corresponding dictionary for that datatype
        if array.dtype == np.uint8:
            uint8_dict_layers[key] = array
        elif array.dtype == np.int16:
            int16_dict_layers[key] = array
        elif array.dtype == np.int32:
            int32_dict_layers[key] = array
        elif array.dtype == np.float32:
            float32_dict_layers[key] = array
        else:
            pass
            # raise TypeError(f"{key} dtype not in list")

    # print(f"uint8 datasets: {uint8_dict_layers.keys()}")
    # print(f"int16 datasets: {int16_dict_layers.keys()}")
    # print(f"int32 datasets: {int32_dict_layers.keys()}")
    # print(f"float32 datasets: {float32_dict_layers.keys()}")
    
    # Creates numba-compliant typed dict for each type of array
    typed_dict_uint8 = Dict.empty(
        key_type=types.unicode_type, 
        value_type=types.Array(types.uint8, 2, 'C')  # Assuming 2D arrays of uint8
    )

    typed_dict_int16 = Dict.empty(
        key_type=types.unicode_type, 
        value_type=types.Array(types.int16, 2, 'C')  # Assuming 2D arrays of int16
    )
    
    typed_dict_int32 = Dict.empty(
        key_type=types.unicode_type, 
        value_type=types.Array(types.int32, 2, 'C')  # Assuming 2D arrays of int32
    )
    
    typed_dict_float32 = Dict.empty(
        key_type=types.unicode_type, 
        value_type=types.Array(types.float32, 2, 'C')  # Assuming 2D arrays of float32
    )

    # Populates the numba-compliant typed dicts
    for key, array in uint8_dict_layers.items():
        typed_dict_uint8[key] = array

    for key, array in int16_dict_layers.items():
        typed_dict_int16[key] = array

    for key, array in int32_dict_layers.items():
        typed_dict_int32[key] = array

    for key, array in float32_dict_layers.items():
        typed_dict_float32[key] = array

    return typed_dict_uint8, typed_dict_int16, typed_dict_int32, typed_dict_float32

In [32]:
# Creates numpy array of rates or ratios from a tab in an Excel spreadsheet, e.g., removal factors or carbon pool ratios
def convert_lookup_table_to_array(spreadsheet, sheet_name, fields_to_keep):

    # Fetches the file content. Courtesy of ChatGPT: https://chatgpt.com/share/e/aff31681-c9a7-40fe-85c1-73a1cab62066
    response = requests.get(spreadsheet)
    response.raise_for_status()  # Ensure we notice bad responses

    # Converts to Excel. Courtesy of ChatGPT: https://chatgpt.com/share/e/aff31681-c9a7-40fe-85c1-73a1cab62066
    excel_df = pd.read_excel(BytesIO(response.content), sheet_name=sheet_name)

    # Retains only the relevant columns
    filtered_data = excel_df[fields_to_keep]

    # Converts from dataframe to Numpy array
    filtered_array = filtered_data.to_numpy().astype(float)  # Need to convery Pandas dataframe to numpy array because Numba jit-decorated function can't use dataframes. 
    filtered_array = filtered_array.astype(float)  # Convert from object dtype to float dtype-- necessary for numba to use it
    
    return filtered_array

In [33]:
# Creates arrays of 0s for any missing inputs and puts them in the corresponding typed dictionary
def complete_inputs(existing_input_list, typed_dict, datatype, chunk_length_pixels, bounds_str, tile_id, is_final, logger):
    for dataset_name in existing_input_list:
        if dataset_name not in typed_dict.keys():
            typed_dict[dataset_name] = np.full((chunk_length_pixels, chunk_length_pixels), 0, dtype=datatype)
            print_and_log(f"Created {dataset_name} for chunk {bounds_str} in {tile_id}: {timestr()}", is_final, logger)
    return typed_dict

In [34]:
# Calculates stats for a chunk (numpy array)
# From https://chatgpt.com/share/e/5599b6b0-1aaa-4d54-98d3-c720a436dd9a
def calculate_stats(array, name, bounds_str, tile_id, in_out):
    if array is None or not np.any(array):  # Check if the array is None or empty
        return {
            'chunk_id': bounds_str,
            'tile_id': tile_id,
            'layer_name': name,
            'in_out': in_out,
            'min_value': 'no data',
            'mean_value': 'no data',
            'max_value': 'no data',
            'data_type': 'no data'
        }
    else:    # Only calculates stats if there is data in the array
        return {
            'chunk_id': bounds_str,
            'tile_id': tile_id,
            'layer_name': name,
            'in_out': in_out,
            'min_value': np.min(array),
            'mean_value': np.mean(array),
            'max_value': np.max(array),
            'data_type': array.dtype.name
        }

In [35]:
# Calculates chunk-level stats for all inputs and outputs and saves to Excel spreadsheet
# Also calculates the min and max value for each input and output across all chunks
# From https://chatgpt.com/share/e/5599b6b0-1aaa-4d54-98d3-c720a436dd9a
def calculate_chunk_stats(all_stats, stage):
    
    # Convert accumulated statistics to a DataFrame
    df_all_stats = pd.DataFrame(all_stats)
    sorted_stats = df_all_stats.sort_values(by=['in_out', 'layer_name']).reset_index(drop=True)
    
    # Calculate the min and max values for each layer_name
    min_max_stats = df_all_stats.groupby('layer_name').agg(
        min_value=('min_value', 'min'),
        max_value=('max_value', 'max')
    ).reset_index()
    
    # Write the combined statistics to a single Excel file
    with pd.ExcelWriter(f'chunk_stats/{stage}_chunk_statistics_{timestr()}.xlsx') as writer:
        sorted_stats.to_excel(writer, sheet_name='chunk_stats', index=False)
    
        # Write the min and max statistics to the second sheet
        min_max_stats.to_excel(writer, sheet_name='min_max_for_layers', index=False)

    print(sorted_stats.head())  # Show first few rows of the stats DataFrame for inspection

In [36]:
# get_tile_dataset_rio("s3://gfw2-data/climate/carbon_model/carbon_pools/aboveground_carbon/extent_2000/standard/20230222/70N_010W_Mg_AGC_ha_2000.tif", [-10, 69, -9, 70, 1], 4000, 255)
# get_tile_dataset_rio("s3://gfw2-data/climate/carbon_model/other_emissions_inputs/burn_year/burn_year_10x10_clip/ba_2019_70N_010W.tif", [-10, 69, -9, 70, 1], 4000, 255)

In [37]:
# download_dict = {}
# download_dict["agc_2000"] = "s3://gfw2-data/climate/carbon_model/carbon_pools/aboveground_carbon/extent_2000/standard/20230222/70N_020E_Mg_AGC_ha_2000.tif"
# data = prepare_to_download_chunk([-10, 69, -9, 70, 1], download_dict, 255)
# data

In [38]:
#TODO: Function to track the number of land use changes per pixel