Alternative approach to running model that uses Dask bags for tasks and maps the operation onto them.
Doesn't seem to be any faster than the dask.delayed approch I'd been working with, and indeed seemed slower.
Many things about this were different, including how data were downloaded. 
It comes from this very long exchange with ChatGPT: https://chatgpt.com/share/e/484b70f8-2179-4a5a-b034-8cd6076a3830.
My original reason for trying this was that the global run using dask.delayed was failing even with 1x1 chunks and I was seeing lots of 
messages in the Coiled dashboard log about GIL being held too long.
So I thought I'd try something different, although it doesn't solve the fundamental issue of the numba function just being
giant and taking a lot of time no matter how the rest of the code works. 
Basically, this helped me conclude that the bottleneck in performance really is the numba function,
although something else about the downloading system could be huring performance.
Still, when I did head-to-head comparisons of just downloading 1x1 chunks for all inputs so far a 10x10 tile (50N_010E or 00N_110E), 
the dask.delayed route I'd been working with was generally faster. 
So, I'm not pursuing this route but I want to keep the code in case something about it useful eventually.

In [None]:
### First pass at reading chunks from rasters using Dask bags and executing on the Dask bags with a lambda function.
### Just reads and prints the chunks.
### The number of bags comes from the number of input rasters

%%time

import dask
import dask.bag as db
import s3fs
import rasterio
from rasterio.windows import from_bounds
import boto3
from concurrent.futures import ThreadPoolExecutor

# Initialize the S3 filesystem
s3 = s3fs.S3FileSystem(anon=False)

# Function to download a chunk from S3 based on bounds
def download_chunk(s3_uri, bounds):
    with rasterio.Env():
        with rasterio.open(s3_uri) as src:
            print(f"Reading from {s3_uri}")
            print(f"Raster bounds: {src.bounds}")
            # Calculate window from bounds
            window = from_bounds(*bounds, transform=src.transform)
            # print(f"Window: {window}")
            # Read the windowed data
            data = src.read(window=window)
            print(f"Data shape: {data.shape}")
    return data

# List of S3 URIs to download
s3_uris = [
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__AGC_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/BGC_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__BGC_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/deadwood_C_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__deadwood_C_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/litter_C_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__litter_C_density_MgC_ha_2000.tif"
    # Add more URIs as needed
]

# Bounds of the chunk to download
# bounds = [50, -10, 120, 60]
# bounds = [110, -10, 120, 0]
bounds = [10, 49, 11, 50]    # 1x1 deg (50N_010E)

# Create a Dask bag from the list of S3 URIs
s3_bag = db.from_sequence(s3_uris, npartitions=len(s3_uris))

# Apply the download function to each URI in parallel
chunks = s3_bag.map(lambda uri: download_chunk(uri, bounds)).compute()

# `chunks` now contains the downloaded data for each raster
print(chunks)
# print(chunks[0].shape)

In [None]:
### Experimenting with performing some operation on the chunks: getting the min and max of each chunk

%time

import dask
import dask.bag as db
import s3fs
import rasterio
from rasterio.windows import from_bounds
import numpy as np

# Initialize the S3 filesystem
s3 = s3fs.S3FileSystem(anon=False)

# Function to download a chunk from S3 based on bounds
def download_chunk(s3_uri, bounds):
    with rasterio.Env():
        with rasterio.open(s3_uri) as src:
            # Calculate window from bounds
            window = from_bounds(*bounds, transform=src.transform)
            # Read the windowed data
            data = src.read(window=window)
    return data

# Function to calculate min and max values of the chunk
def calculate_min_max(data):
    return np.min(data), np.max(data)

# Function to download and calculate min and max values
def download_and_calculate_min_max(s3_uri, bounds):
    data = download_chunk(s3_uri, bounds)
    min_val, max_val = calculate_min_max(data)
    return s3_uri, min_val, max_val

# List of S3 URIs to download
s3_uris = [
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__AGC_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/BGC_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__BGC_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/deadwood_C_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__deadwood_C_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/litter_C_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__litter_C_density_MgC_ha_2000.tif"
    # Add more URIs as needed
]

# Bounds of the chunk to download (left, bottom, right, top)
bounds = [10, 48, 12, 50]    # 1x1 deg (50N_010E)

# Create a Dask bag from the list of S3 URIs
s3_bag = db.from_sequence(s3_uris, npartitions=len(s3_uris))

# Apply the download_and_calculate_min_max function to each URI in parallel
results = s3_bag.map(lambda uri: download_and_calculate_min_max(uri, bounds)).compute()

# `results` now contains the URI, min, and max values for each raster
for uri, min_val, max_val in results:
    print(f"URI: {uri}, Min: {min_val}, Max: {max_val}")

In [None]:
### Example of performing a simple arithmetic operation on the chunks in a numba jit-decorated function (not pixel by pixel but on an entire array).
### Still uses Dask bags (determined by the number of input files) and a lambda function to iterate on them.

%%time

import dask
import dask.bag as db
import s3fs
import rasterio
from rasterio.windows import from_bounds
from numba import jit
import numpy as np

# Initialize the S3 filesystem
s3 = s3fs.S3FileSystem(anon=False)

# Function to download a chunk from S3 based on bounds
def download_chunk(s3_uri, bounds):
    with rasterio.Env():
        with rasterio.open(s3_uri) as src:
            # Calculate window from bounds
            window = from_bounds(*bounds, transform=src.transform)
            # Read the windowed data
            data = src.read(window=window)
    return data

# Function to process each chunk with LULUCF_fluxes
@jit(nopython=True)
def LULUCF_fluxes(arr):
    # Example processing function, simplified
    processed_arr = arr * 1.1  # Example operation
    return processed_arr

def download_and_process_chunk(s3_uri, bounds):
    # Download the chunk
    data = download_chunk(s3_uri, bounds)
    
    # Assuming data is float32
    data = data.astype(np.float32)
    
    # Process the chunk
    processed_data = LULUCF_fluxes(data)
    
    return s3_uri, processed_data

# List of S3 URIs to download
s3_uris = [
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__AGC_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/BGC_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__BGC_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/deadwood_C_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__deadwood_C_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/litter_C_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__litter_C_density_MgC_ha_2000.tif"
    # Add more URIs as needed
]

# Bounds of the chunk to download (left, bottom, right, top)
bounds = [10, 49, 11, 50]  # Example bounds

# Create a Dask bag from the list of S3 URIs
s3_bag = db.from_sequence(s3_uris, npartitions=len(s3_uris))

# Apply the download_and_process_chunk function to each URI in parallel
processed_chunks = s3_bag.map(lambda uri: download_and_process_chunk(uri, bounds)).compute()

# `processed_chunks` now contains the processed data for each raster
print(processed_chunks)

In [None]:
### Uploads the output rasters to s3 after saving them locally first.
### Could not get numpy arrays saved directly to s3-- had to save them locally as rasters first

%%time

import os
import uuid
import dask
import dask.bag as db
import s3fs
import rasterio
from rasterio.windows import from_bounds
from numba import jit
import numpy as np
import boto3

# Initialize the S3 filesystem with appropriate credentials
s3 = s3fs.S3FileSystem(anon=False)  # Set anon=False to use AWS credentials
s3_client = boto3.client('s3')

# Function to download a chunk from S3 based on bounds
def download_chunk(s3_uri, bounds):
    with rasterio.Env():
        with rasterio.open(s3_uri) as src:
            # Calculate window from bounds
            window = from_bounds(*bounds, transform=src.transform)
            # Read the windowed data
            data = src.read(window=window)
            transform = src.window_transform(window)
            metadata = src.meta.copy()
            metadata.update({
                "height": window.height,
                "width": window.width,
                "transform": transform,
                "compress": "lzw"  # Add compression to reduce file size
            })
    return data, metadata

# Function to process each chunk with LULUCF_fluxes
@jit(nopython=True)
def LULUCF_fluxes(arr):
    # Example processing function, simplified
    processed_arr = arr * 1.1  # Example operation
    return processed_arr

# Function to upload processed data to S3
def upload_to_s3(local_path, s3_uri):

    s3_client = boto3.client('s3')
    
    output_uri = s3_uri.replace("outputs", "test_out")
    bucket, key = output_uri.replace("s3://", "").split("/", 1)
    print(local_path, bucket, key)
    s3_client.upload_file(local_path, bucket, key)
    os.remove(local_path)  # Clean up the local file

# Wrapper function to download, process, and upload a chunk
def download_process_upload_chunk(s3_uri, bounds):
   
    # Download the chunk
    data, metadata = download_chunk(s3_uri, bounds)
    
    # Assuming data is float32
    data = data.astype(np.float32)
    
    # Process the chunk
    processed_data = LULUCF_fluxes(data)
    
    # Save to a local file
    local_output_path = f"/tmp/{uuid.uuid4()}.tif"
    with rasterio.open(local_output_path, "w", **metadata) as dst:
        dst.write(processed_data)
    
    # Upload the processed data
    upload_to_s3(local_output_path, s3_uri)

    return s3_uri, processed_data.shape

# List of S3 URIs to download
s3_uris = [
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__AGC_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/BGC_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__BGC_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/deadwood_C_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__deadwood_C_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/litter_C_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__litter_C_density_MgC_ha_2000.tif"
    # Add more URIs as needed
]

# Bounds of the chunk to download (left, bottom, right, top)
bounds = [10, 49, 11, 50]  # Example bounds

# Create a Dask bag from the list of S3 URIs
s3_bag = db.from_sequence(s3_uris, npartitions=len(s3_uris))

# Apply the download_process_upload_chunk function to each URI in parallel
processed_chunks = s3_bag.map(lambda uri: download_process_upload_chunk(uri, bounds)).compute()

# `processed_chunks` now contains the URI and the shape of the processed data for each raster
for uri, shape in processed_chunks:
    print(f"URI: {uri}, Processed data shape: {shape}")


In [None]:
### Changed the tasks in the Dask bag to use chunks within a bounding box rather than the number of input files.
### So now you supply a bounding box and chunk size within and that creates the items in the Dask bag. 
### Also, no longer uses a lambda function (though I don't know why ChatGPT stopped doing that).

%%time

import os
import uuid
import dask
import dask.bag as db
import s3fs
import rasterio
from rasterio.windows import from_bounds
from numba import jit
import numpy as np
import boto3
from datetime import datetime

# Initialize the S3 filesystem with appropriate credentials
s3 = s3fs.S3FileSystem(anon=False)  # Set anon=False to use AWS credentials

# Function to download a chunk from S3 based on bounds
def download_chunk(s3_uri, bounds):
    with rasterio.Env():
        with rasterio.open(s3_uri) as src:
            # Calculate window from bounds
            window = from_bounds(*bounds, transform=src.transform)
            # Read the windowed data
            data = src.read(window=window)
            transform = src.window_transform(window)
            metadata = src.meta.copy()
            metadata.update({
                "height": window.height,
                "width": window.width,
                "transform": transform,
                "compress": "lzw"  # Add compression to reduce file size
            })
    return data, metadata

# Function to process each chunk with LULUCF_fluxes
@jit(nopython=True)
def LULUCF_fluxes(arr):
    # Example processing function, simplified
    processed_arr = arr * 2  # Updated operation
    return processed_arr

# Function to upload processed data to S3
def upload_to_s3(local_path, s3_uri, bounds):
    s3_client = boto3.client('s3')  # Ensure client is created in the worker process
    
    # Extract carbon pool from the URI
    carbon_pool = s3_uri.split('/')[-1].split('__')[1]
    
    # Get today's date
    today = datetime.today().strftime('%Y%m%d')
    
    # Construct the output URI
    west, south, east, north = bounds
    output_uri = s3_uri.replace(
        "outputs",
        "test_outputs"
        # f"test_outputs/{carbon_pool}/2000/{today}/{west}_{south}_{east}_{north}__{carbon_pool}_MgC_ha_2000"
    )
    bucket, key = output_uri.replace("s3://", "").split("/", 1)
    
    # Upload the file
    s3_client.upload_file(local_path, bucket, key)
    os.remove(local_path)  # Clean up the local file

# Wrapper function to download, process, and upload a chunk
def download_process_upload_chunk(s3_uri, bounds):
    # Download the chunk
    data, metadata = download_chunk(s3_uri, bounds)
    
    # Assuming data is float32
    data = data.astype(np.float32)
    
    # Process the chunk
    processed_data = LULUCF_fluxes(data)
    
    # Construct the local output file path with bounding box in the name
    west, south, east, north = bounds
    local_output_path = f"/tmp/{west}_{south}_{east}_{north}.tif"
    
    # Save to a local file
    with rasterio.open(local_output_path, "w", **metadata) as dst:
        dst.write(processed_data)
    
    # Upload the processed data
    upload_to_s3(local_output_path, s3_uri, bounds)

    return s3_uri, processed_data.shape

# Function to generate bounding boxes within a specified bounding box with a specified chunk size
def generate_chunks_within_bounds(west, south, east, north, chunk_size):
    chunks = []
    lat = south
    while lat < north:
        lon = west
        while lon < east:
            chunk_west = lon
            chunk_south = lat
            chunk_east = min(lon + chunk_size, east)
            chunk_north = min(lat + chunk_size, north)
            chunks.append((chunk_west, chunk_south, chunk_east, chunk_north))
            lon += chunk_size
        lat += chunk_size
    return chunks

# Specify the bounding box and chunk size
# bounding_box = [10, 45 15, 50]  # Updated bounding box
bounding_box = [10, 49, 11, 50] 
chunk_size = 1  # 1x1 degree chunks

# Generate chunks within the specified bounding box
chunks = generate_chunks_within_bounds(*bounding_box, chunk_size)

# List of S3 URIs to download
s3_uris = [
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__AGC_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/BGC_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__BGC_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/deadwood_C_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__deadwood_C_density_MgC_ha_2000.tif",
    f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/litter_C_density_MgC_ha/2000/40000_pixels/20240729/50N_010E__litter_C_density_MgC_ha_2000.tif"
]

# Define the function to be used with Dask for each chunk and URI
def process_uri_chunk(args):
    uri, bounds = args
    return download_process_upload_chunk(uri, bounds)

# Create a list of (uri, bounds) tuples for each chunk
tasks = [(uri, bounds) for uri in s3_uris for bounds in chunks]

# Create a Dask bag from the list of tasks
s3_bag = db.from_sequence(tasks, npartitions=len(tasks))

# Apply the process_uri_chunk function to each task in parallel
processed_chunks = s3_bag.map(process_uri_chunk).compute()

# `processed_chunks` now contains the URI and the shape of the processed data for each raster
for uri, shape in processed_chunks:
    print(f"URI: {uri}, Processed data shape: {shape}")


In [None]:
### The list of uris is now specified inside download_process_upload_chunk, so all the tiles for each input dataset can be accessed.
### This means that each chunk has to identify and access the right tile_id.
### Bags are now entirely dependent on the chunks.
### The LULUCF_fluxes function operates pixel by pixel using nested for loops, not on input arrays. 
### Inputs to LULUCF_fluxes are kwargs instead of a fixed list. 

%%time

import os
import math
import dask
import dask.bag as db
import s3fs
import rasterio
from rasterio.windows import from_bounds
from numba import jit
import numpy as np
import boto3
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

# Initialize the S3 filesystem with appropriate credentials
s3 = s3fs.S3FileSystem(anon=False)  # Set anon=False to use AWS credentials

# Function to download a chunk from S3 based on bounds
def download_chunk(s3_uri, bounds):
    with rasterio.Env():
        with rasterio.open(s3_uri) as src:
            # Calculate window from bounds
            window = from_bounds(*bounds, transform=src.transform)
            # Read the windowed data
            data = src.read(window=window)
            transform = src.window_transform(window)
            metadata = src.meta.copy()
            metadata.update({
                "height": window.height,
                "width": window.width,
                "transform": transform,
                "compress": "lzw"  # Add compression to reduce file size
            })
    return data, metadata

# Function to process each chunk with LULUCF_fluxes
@jit(nopython=True)
def LULUCF_fluxes(*arrays):

    processed_arr = np.zeros_like(arrays[0])

    for row in range(arrays[0].shape[0]):
        for col in range(arrays[0].shape[1]):

            agc_cell = arrays[0][row, col]
            bgc_cell = arrays[1][row, col]
            deadwood_cell = arrays[2][row, col]
            litter_cell = arrays[3][row, col]

            total_c = agc_cell + bgc_cell + deadwood_cell + litter_cell
            processed_arr[row, col] = total_c
    
    return processed_arr

# Function to upload processed data to S3
def upload_to_s3(local_path, s3_uri, bounds, tile_id):
    s3_client = boto3.client('s3')  # Ensure client is created in the worker process
    
    # Extract carbon pool from the URI
    carbon_pool = s3_uri.split('/')[-1].split('__')[1]
    
    # Get today's date
    today = datetime.today().strftime('%Y%m%d')
    
    # Construct the output URI
    west, south, east, north = bounds
    output_uri = f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/test_outputs/total_C/2000/40000_pixels/{today}/{tile_id}_{west}_{south}_{east}_{north}__{carbon_pool}"
    bucket, key = output_uri.replace("s3://", "").split("/", 1)
    
    # Upload the file
    s3_client.upload_file(local_path, bucket, key)
    os.remove(local_path)  # Clean up the local file

# Returns the encompassing tile_id string in the form YYN/S_XXXE/W based on a coordinate
def xy_to_tile_id(top_left_x, top_left_y):
    lat_ceil = math.ceil(top_left_y / 10.0) * 10
    lng_floor = math.floor(top_left_x / 10.0) * 10
    
    lng = f"{str(abs(lng_floor)).zfill(3)}E" if lng_floor >= 0 else f"{str(abs(lng_floor)).zfill(3)}W"
    lat = f"{str(abs(lat_ceil)).zfill(2)}N" if lat_ceil >= 0 else f"{str(abs(lat_ceil)).zfill(2)}S"

    return f"{lat}_{lng}"

# Wrapper function to download, process, and upload a chunk
def download_process_upload_chunk(bounds):
    west, south, east, north = bounds

    tile_id = xy_to_tile_id(west, north)
    
    s3_uris = [
        f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__AGC_density_MgC_ha_2000.tif",
        f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/BGC_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__BGC_density_MgC_ha_2000.tif",
        f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/deadwood_C_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__deadwood_C_density_MgC_ha_2000.tif",
        f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/litter_C_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__litter_C_density_MgC_ha_2000.tif"
    ]
    
    # Download each chunk and store them
    data_list = []
    for s3_uri in s3_uris:
        data, metadata = download_chunk(s3_uri, bounds)
        data_list.append(data.astype(np.float32))

    # Sum the arrays using LULUCF_fluxes
    processed_data = LULUCF_fluxes(*data_list)
    
    # Construct the local output file path with bounding box in the name
    local_output_path = f"/tmp/{tile_id}_{west}_{south}_{east}_{north}.tif"
    
    # Save to a local file
    with rasterio.open(local_output_path, "w", **metadata) as dst:
        dst.write(processed_data)
    
    # Upload the processed data
    upload_to_s3(local_output_path, s3_uris[0], bounds, tile_id)  # Use the first URI for constructing the output path

    return bounds, processed_data.shape

# Function to generate bounding boxes within a specified bounding box with a specified chunk size
def generate_chunks_within_bounds(west, south, east, north, chunk_size):
    chunks = []
    lat = south
    while lat < north:
        lon = west
        while lon < east:
            chunk_west = lon
            chunk_south = lat
            chunk_east = min(lon + chunk_size, east)
            chunk_north = min(lat + chunk_size, north)
            chunks.append((chunk_west, chunk_south, chunk_east, chunk_north))
            lon += chunk_size
        lat += chunk_size
    return chunks

# Specify the bounding box and chunk size: west, south, east, north
bounding_box = [10, 40, 20, 50]  # 50N_010E
bounding_box = [110, -10, 120, 0]  # 00N_110E
chunk_size = 1  # 1x1 degree chunks

# Generate chunks within the specified bounding box
chunks = generate_chunks_within_bounds(*bounding_box, chunk_size)

# Define the function to be used with Dask for each chunk
def process_chunk(bounds):
    return download_process_upload_chunk(bounds)

# Create a Dask bag from the list of chunks
s3_bag = db.from_sequence(chunks, npartitions=len(chunks))

# Apply the process_chunk function to each chunk in parallel
processed_chunks = s3_bag.map(process_chunk).compute()

# `processed_chunks` now contains the bounds and the shape of the processed data for each raster
for bounds, shape in processed_chunks:
    print(f"Bounds: {bounds}, Processed data shape: {shape}")


In [None]:
%%time

### Turns downloads into a dictionary where the keys are the inputs' names. 
### Converts the download dictionary to typed dictionaries that Numba will understand and uses them in the Numba function.
### Returns typed dictionaries, which are then saved to raster locally and uploaded to s3

import os
import math
import uuid
import dask
import dask.bag as db
import s3fs
import rasterio
from rasterio.windows import from_bounds
from numba import jit
import numpy as np
import boto3
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

# Initialize the S3 filesystem with appropriate credentials
s3 = s3fs.S3FileSystem(anon=False)  # Set anon=False to use AWS credentials

# Function to download a chunk from S3 based on bounds
def download_chunk(s3_uri, bounds):
    with rasterio.Env():
        with rasterio.open(s3_uri) as src:
            # Calculate window from bounds
            window = from_bounds(*bounds, transform=src.transform)
            # Read the windowed data
            data = src.read(window=window)
            transform = src.window_transform(window)
            metadata = src.meta.copy()
            metadata.update({
                "height": window.height,
                "width": window.width,
                "transform": transform,
                "compress": "lzw"  # Add compression to reduce file size
            })
    return data, metadata

# Function to process each chunk with LULUCF_fluxes
@jit(nopython=True)
def LULUCF_fluxes_bag(in_dict_uint8, in_dict_int16, in_dict_float32):

    # Separate dictionaries for output numpy arrays of each datatype, named by output data type).
    # This is because a dictionary in a Numba function cannot have arrays with multiple data types, so each dictionary has to store only one data type,
    # just like inputs to the function.
    out_dict_float32 = {}

    agc_dens_curr_block = in_dict_float32[agc_2000].astype('float32')
    bgc_dens_curr_block = in_dict_float32[bgc_2000].astype('float32')
    deadwood_c_dens_curr_block = in_dict_float32[deadwood_c_2000].astype('float32')
    litter_c_dens_curr_block = in_dict_float32[litter_c_2000].astype('float32')
    
    processed_arr = np.zeros_like(agc_dens_curr_block).astype('float32')

    for row in range(agc_dens_curr_block.shape[0]):
        for col in range(agc_dens_curr_block.shape[1]):

            agc_cell = agc_dens_curr_block[row, col]
            bgc_cell = bgc_dens_curr_block[row, col]
            deadwood_cell = deadwood_c_dens_curr_block[row, col]
            litter_cell = litter_c_dens_curr_block[row, col]

            total_c = agc_cell + bgc_cell + deadwood_cell + litter_cell
            processed_arr[row, col] = total_c

    out_dict_float32["total_C"] = processed_arr.copy()
    
    return out_dict_float32

# Function to upload processed data to S3
def upload_to_s3(local_path, s3_uri, bounds, tile_id):
    s3_client = boto3.client('s3')  # Ensure client is created in the worker process
    
    # Extract carbon pool from the URI
    carbon_pool = s3_uri.split('/')[-1].split('__')[1]
    
    # Get today's date
    today = datetime.today().strftime('%Y%m%d')
    
    # Construct the output URI
    west, south, east, north = bounds
    output_uri = f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/test_outputs/total_C/2000/40000_pixels/{today}/{tile_id}_{west}_{south}_{east}_{north}__{carbon_pool}"
    bucket, key = output_uri.replace("s3://", "").split("/", 1)
    
    # Upload the file
    s3_client.upload_file(local_path, bucket, key)
    os.remove(local_path)  # Clean up the local file

# Returns the encompassing tile_id string in the form YYN/S_XXXE/W based on a coordinate
def xy_to_tile_id(top_left_x, top_left_y):
    lat_ceil = math.ceil(top_left_y / 10.0) * 10
    lng_floor = math.floor(top_left_x / 10.0) * 10
    
    lng = f"{str(abs(lng_floor)).zfill(3)}E" if lng_floor >= 0 else f"{str(abs(lng_floor)).zfill(3)}W"
    lat = f"{str(abs(lat_ceil)).zfill(2)}N" if lat_ceil >= 0 else f"{str(abs(lat_ceil)).zfill(2)}S"

    return f"{lat}_{lng}"

# Wrapper function to download, process, and upload a chunk
def download_process_upload_chunk(bounds):
    west, south, east, north = bounds

    bounds_str = boundstr(bounds)    # String form of chunk bounds
    tile_id = xy_to_tile_id(west, north)    # tile_id in YYN/S_XXXE/W
    chunk_length_pixels = calc_chunk_length_pixels(bounds)   # Chunk length in pixels (as opposed to decimal degrees)
    
    s3_uris = {
        'agc_2000': f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__AGC_density_MgC_ha_2000.tif",
        'bgc_2000': f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/BGC_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__BGC_density_MgC_ha_2000.tif",
        'deadwood_c_2000': f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/deadwood_C_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__deadwood_C_density_MgC_ha_2000.tif",
        'litter_c_2000': f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/litter_C_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__litter_C_density_MgC_ha_2000.tif"
    }

    def download_and_process(name, uri):
        data, metadata = download_chunk(uri, bounds)
        return name, np.squeeze(data.astype(np.float32)), metadata

    # Download each chunk in parallel and store them
    with ThreadPoolExecutor(max_workers=4) as executor:
        results = list(executor.map(lambda item: download_and_process(*item), s3_uris.items()))

    data_dict = {name: data for name, data, metadata in results}
    metadata = results[0][2]

    typed_dict_uint8, typed_dict_int16, typed_dict_int32, typed_dict_float32 = create_typed_dicts(data_dict)
    
    # Sum the arrays using LULUCF_fluxes
    out_dict_float32 = LULUCF_fluxes_bag(
        typed_dict_uint8, typed_dict_int16, typed_dict_float32 
    )

    out_dict_all_dtypes = {}

    # Transfers the dictionaries of numpy arrays for each data type to a new, Pythonic array
    # for key, value in out_dict_uint32.items():
    #     out_dict_all_dtypes[key] = value

    for key, value in out_dict_float32.items():
        out_dict_all_dtypes[key] = value

    # Clear memory of unneeded arrays
    # del out_dict_uint32
    del out_dict_float32
  
    # Construct the local output file path with bounding box in the name
    local_output_path = f"/tmp/{tile_id}_{west}_{south}_{east}_{north}.tif"
    
    # Save to a local file
    transform = rasterio.transform.from_bounds(*bounds, width=chunk_length_pixels, height=chunk_length_pixels)
    with rasterio.open(local_output_path, "w", driver='GTiff', width=chunk_length_pixels, height=chunk_length_pixels, count=1, 
                               dtype='float32', crs='EPSG:4326', transform=transform, compress='lzw', blockxsize=400, blockysize=400) as dst:
        dst.write(out_dict_all_dtypes["total_C"][np.newaxis, :, :])
    
    # Upload the processed data
    upload_to_s3(local_output_path, s3_uris['agc_2000'], bounds, tile_id)

    return bounds, out_dict_all_dtypes["total_C"].shape

# Function to generate bounding boxes within a specified bounding box with a specified chunk size
def generate_chunks_within_bounds(west, south, east, north, chunk_size):
    chunks = []
    lat = south
    while lat < north:
        lon = west
        while lon < east:
            chunk_west = lon
            chunk_south = lat
            chunk_east = min(lon + chunk_size, east)
            chunk_north = min(lat + chunk_size, north)
            chunks.append((chunk_west, chunk_south, chunk_east, chunk_north))
            lon += chunk_size
        lat += chunk_size
    return chunks

# Specify the bounding box and chunk size: west, south, east, north
# bounding_box = [10, 40, 20, 50]  # 50N_010E
bounding_box = [10, 49, 11, 50]
# bounding_box = [10, 49.75, 10.25, 50]
chunk_size = 1  # 1x1 degree chunks

# Generate chunks within the specified bounding box
chunks = generate_chunks_within_bounds(*bounding_box, chunk_size)

# Define the function to be used with Dask for each chunk
def process_chunk(bounds):
    return download_process_upload_chunk(bounds)

# Create a Dask bag from the list of chunks
s3_bag = db.from_sequence(chunks, npartitions=len(chunks))

# Apply the process_chunk function to each chunk in parallel
processed_chunks = s3_bag.map(process_chunk).compute()

# `processed_chunks` now contains the bounds and the shape of the processed data for each raster
for bounds, shape in processed_chunks:
    print(f"Bounds: {bounds}, Processed data shape: {shape}")


In [None]:
# Function to calculate LULUCF fluxes and carbon densities
# Operates pixel by pixel, so uses numba (Python compiled to C++).
@jit(nopython=True)
def LULUCF_fluxes_bag(in_dict_uint8, in_dict_int16, in_dict_float32):

    # Separate dictionaries for output numpy arrays of each datatype, named by output data type).
    # This is because a dictionary in a Numba function cannot have arrays with multiple data types, so each dictionary has to store only one data type,
    # just like inputs to the function.
    out_dict_uint32 = {}
    out_dict_float32 = {}

    end_years = list(range(first_year, last_year+1, interval_years))[1:]
    # end_years = [2005, 2010]

    agc_dens_curr_block = in_dict_float32[agc_2000].astype('float32')
    bgc_dens_curr_block = in_dict_float32[bgc_2000].astype('float32')
    deadwood_c_dens_curr_block = in_dict_float32[deadwood_c_2000].astype('float32')
    litter_c_dens_curr_block = in_dict_float32[litter_c_2000].astype('float32')
    soil_c_dens_curr_block = in_dict_int16[soil_c_2000].astype('float32')

    r_s_ratio_block = in_dict_float32[r_s_ratio]

    for year in end_years:

        # print(year)

        # Creates array for each input
        LC_prev_block = in_dict_uint8[f"{land_cover}_{year-interval_years}"]
        LC_curr_block = in_dict_uint8[f"{land_cover}_{year}"]
        veg_h_prev_block = in_dict_uint8[f"{vegetation_height}_{year-interval_years}"]
        veg_h_curr_block = in_dict_uint8[f"{vegetation_height}_{year}"]
        planted_forest_type_block = in_dict_uint8[planted_forest_type_layer]
        planted_forest_tree_crop_block = in_dict_uint8[planted_forest_tree_crop_layer]

        burned_area_t_4_block = in_dict_uint8[f"{burned_area}_{year-4}"]
        burned_area_t_3_block = in_dict_uint8[f"{burned_area}_{year-3}"]
        burned_area_t_2_block = in_dict_uint8[f"{burned_area}_{year-2}"]
        burned_area_t_1_block = in_dict_uint8[f"{burned_area}_{year-1}"]
        burned_area_t_block = in_dict_uint8[f"{burned_area}_{year}"]

        forest_dist_t_4_block = in_dict_uint8[f"{forest_disturbance}_{year-4}"]
        forest_dist_t_3_block = in_dict_uint8[f"{forest_disturbance}_{year-3}"]
        forest_dist_t_2_block = in_dict_uint8[f"{forest_disturbance}_{year-2}"]
        forest_dist_t_1_block = in_dict_uint8[f"{forest_disturbance}_{year-1}"]
        forest_dist_t_block = in_dict_uint8[f"{forest_disturbance}_{year}"]

        # Numpy arrays for outputs that don't depend on previous values
        state_out = np.zeros(in_dict_float32[agc_2000].shape).astype('uint32') 
        agc_flux_out_block = np.zeros(in_dict_float32[agc_2000].shape).astype('float32')
        bgc_flux_out_block = np.zeros(in_dict_float32[agc_2000].shape).astype('float32')
        deadwood_c_flux_out_block = np.zeros(in_dict_float32[agc_2000].shape).astype('float32')
        litter_c_flux_out_block = np.zeros(in_dict_float32[agc_2000].shape).astype('float32')

        
        # Iterates through all pixels in the chunk
        for row in range(LC_curr_block.shape[0]):
            for col in range(LC_curr_block.shape[1]):
                
                LC_prev = LC_prev_block[row, col]
                LC_curr = LC_curr_block[row, col]
                veg_h_prev = veg_h_prev_block[row, col]
                veg_h_curr = veg_h_curr_block[row, col]
                planted_forest_type = planted_forest_type_block[row, col]
                planted_forest_tree_crop = planted_forest_tree_crop_block[row, col]

                # Note: Stacking the burned area rasters using ndstack outside the pixel iteration did not work with numba.
                # So just reading each burned area raster separately.
                burned_area_t_4 = burned_area_t_4_block[row, col]
                burned_area_t_3 = burned_area_t_3_block[row, col]
                burned_area_t_2 = burned_area_t_2_block[row, col]
                burned_area_t_1 = burned_area_t_1_block[row, col]
                burned_area_t = burned_area_t_block[row, col]
                burned_area_last = max([burned_area_t_4, burned_area_t_3, burned_area_t_2, burned_area_t_1, burned_area_t])  # Most recent year with burned area during the interval

                forest_dist_t_4 = forest_dist_t_4_block[row, col]
                forest_dist_t_3 = forest_dist_t_3_block[row, col]
                forest_dist_t_2 = forest_dist_t_2_block[row, col]
                forest_dist_t_1 = forest_dist_t_1_block[row, col]
                forest_dist_t = forest_dist_t_block[row, col]
                forest_dist_last = max([forest_dist_t_4, forest_dist_t_3, forest_dist_t_2, forest_dist_t_1, forest_dist_t])  # Most recent year with forest disturbance during the interval   

                agc_dens_curr = agc_dens_curr_block[row, col]
                bgc_dens_curr = bgc_dens_curr_block[row, col]
                deadwood_c_dens_curr = deadwood_c_dens_curr_block[row, col]
                litter_c_dens_curr = litter_c_dens_curr_block[row, col]
                soil_c_dens_curr = soil_c_dens_curr_block[row, col]

                r_s_ratio_cell = r_s_ratio_block[row, col]

                tree_prev = (veg_h_prev >= tree_threshold)
                tree_curr = (veg_h_curr >=  tree_threshold)
                tall_veg_prev = (((LC_prev >= tree_dry_min_height_code) and (LC_prev <= tree_dry_max_height_code)) or
                                        ((LC_prev >= tree_wet_min_height_code) and (LC_prev <= tree_wet_max_height_code)))
                tall_veg_curr = (((LC_curr >= tree_dry_min_height_code) and (LC_curr <= tree_dry_max_height_code)) or
                                       ((LC_curr >= tree_wet_min_height_code) and (LC_curr <= tree_wet_max_height_code)))
                short_med_veg_prev = (((LC_prev >= 2) and (LC_prev <= 26)) or
                                       ((LC_prev >= 102) and (LC_prev <= 126)))
                short_med_veg_curr = (((LC_curr >= 2) and (LC_curr <= 26)) or
                                       ((LC_curr >= 102) and (LC_curr <= 126)))

                sig_height_loss_prev_curr = (veg_h_prev-veg_h_curr >= sig_height_loss_threshold) 

                node = 0
                
                ### Tree gain
                if (not tree_prev) and (tree_curr):                  # Non-tree converted to tree (1)    ##TODO: Include mangrove exception.
                    node = accrete_node(node, 1)
                    if planted_forest_type == 0:                     # New non-SDPT trees (11)
                        node = accrete_node(node, 1)
                        if not tall_veg_curr:                        # New trees outside forests (111)
                            node = accrete_node(node, 1)
                            state_out[row, col] = node
                            agc_rf = 2.8
                            agc_flux_out_block[row, col] = (agc_rf*interval_years)*-1
                            agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]   
                            bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                            bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col]   
                        else:                                        # New terrestrial natural forest (112)
                            node = accrete_node(node, 2)
                            state_out[row, col] = node
                            agc_rf = 5.6
                            agc_flux_out_block[row, col] = (agc_rf*interval_years)*-1
                            agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                            bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                            bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                    else:                                            # New SDPT trees (12)
                        node = accrete_node(node, 2)
                        state_out[row, col] = node
                        agc_rf = 10
                        agc_flux_out_block[row, col] = (agc_rf*interval_years)*-1
                        agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]   
                        bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                        bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                
                ### Tree loss
                elif (tree_prev) and (not tree_curr):                # Tree converted to non-tree (2)    ##TODO: Include forest disturbance condition.  ##TODO: Include mangrove exception.
                    node = 2
                    if planted_forest_type == 0:                     # Full loss of non-SDPT trees (21)
                        node = accrete_node(node, 1)
                        if not tall_veg_prev:                        # Full loss of trees outside forests (211)
                            node = accrete_node(node, 1)
                            if burned_area_last == 0:                # Full loss of trees outside forests without fire (2111)
                                node = accrete_node(node, 1)
                                state_out[row, col] = node
                                agc_ef = 0.8
                                agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                            else:                                    # Full loss of trees outside forests with fire (2112)
                                node = accrete_node(node, 2)
                                state_out[row, col] = node
                                agc_ef = 0.6
                                agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                        else:                                        # Full loss of natural forest (212)
                            node = accrete_node(node, 2)
                            if LC_curr == cropland:                  # Full loss of natural forest converted to cropland (2121)
                                node = accrete_node(node, 1)
                                if burned_area_last == 0:            # Full loss of natural forest converted to cropland, not burned (21211)
                                    node = accrete_node(node, 1)
                                    state_out[row, col] = node
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                                else:                                # Full loss of natural forest converted to cropland, burned (21212)
                                    node = accrete_node(node, 2)
                                    state_out[row, col] = node
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                            elif short_med_veg_curr:                 # Full loss of natural forest converted to short or medium vegetation (2122)
                                node = accrete_node(node, 2)
                                if burned_area_last == 0:            # Full loss of natural forest converted to short or medium vegetation, not burned (21221)
                                    node = accrete_node(node, 1)
                                    state_out[row, col] = node
                                    agc_ef = 0.9
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                                else:                                # Full loss of natural forest converted to short or medium vegetation, burned (21222)
                                    node = accrete_node(node, 2)
                                    state_out[row, col] = node
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                            elif LC_curr == builtup:                 # Full loss of natural forest converted to builtup (2123)
                                node = accrete_node(node, 3)
                                if burned_area_last == 0:            # Full loss of natural forest converted to builtup, not burned (21231)
                                    node = accrete_node(node, 1)
                                    state_out[row, col] = node
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                                else:                                # Full loss of natural forest converted to builtup, burned (21232)
                                    node = accrete_node(node, 2)
                                    state_out[row, col] = node
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                            else:                                    # Full loss of natural forest converted to anything else (2124)
                                node = accrete_node(node, 4)
                                if burned_area_last == 0:            # Full loss of natural forest converted to anything else, not burned (21241)
                                    node = accrete_node(node, 1)
                                    state_out[row, col] = node
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                                else:                                # Full loss of natural forest converted to anything else, burned (21242)
                                    node = accrete_node(node, 2)
                                    state_out[row, col] = node       
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                    else:                                            # Full loss of SDPT trees (22)
                        node = accrete_node(node, 2)
                        if LC_curr == cropland:                      # Full loss of SDPT converted to cropland (221)
                            node = accrete_node(node, 1)
                            if burned_area_last == 0:                # Full loss of SDPT converted to cropland, not burned (2211)
                                node = accrete_node(node, 1)
                                state_out[row, col] = node
                                agc_ef = 0.3
                                agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                            else:                                    # Full loss of SPDPT converted to cropland, burned (2212)
                                node = accrete_node(node, 2)
                                state_out[row, col] = node
                                agc_ef = 0.3
                                agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                        elif short_med_veg_curr:                     # Full loss of SDPT converted to short or medium vegetation (222)
                            node = accrete_node(node, 2)
                            if burned_area_last == 0:                # Full loss of SDPT converted to short or medium vegetation, not burned (2221)
                                node = accrete_node(node, 1)
                                state_out[row, col] = node
                                agc_ef = 0.3
                                agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                            else:                                    # Full loss of SDPT converted to short or medium vegetation, burned (2222)
                                node = accrete_node(node, 2)
                                state_out[row, col] = node
                                agc_ef = 0.3
                                agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                        elif LC_curr == builtup:                     # Full loss of SDPT converted to builtup (223)
                            node = accrete_node(node, 3)
                            if planted_forest_tree_crop == 1:        # Full loss of SDPT planted forest to builtup (2231)
                                node = accrete_node(node, 1)
                                if burned_area_last == 0:            # Full loss of SDPT planted forest converted to builtup, not burned (22311)
                                    node = accrete_node(node, 1)
                                    state_out[row, col] = node
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                                else:                                # Full loss of SDPT planted forest converted to builtup, burned (22312)
                                    node = accrete_node(node, 2)
                                    state_out[row, col] = node
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                            else:                                    # Full loss of SDPT tree crop to builtup (2232)
                                node = accrete_node(node, 2)
                                if burned_area_last == 0:            # Full loss of SDPT tree crop converted to builtup, not burned (22321)
                                    node = accrete_node(node, 1)
                                    state_out[row, col] = node
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                                else:                                # Full loss of SDPT tree crop converted to builtup, burned (22322)
                                    node = accrete_node(node, 2)
                                    state_out[row, col] = node         
                                    agc_ef = 0.3
                                    agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                    agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                    bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                    bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                        else:                                        # Full loss of SDPT converted to anything else (224)
                            node = accrete_node(node, 4)
                            if burned_area_last == 0:                # Full loss of SDPT converted to builtup, not burned (2241)
                                node = accrete_node(node, 1)
                                state_out[row, col] = node
                                agc_ef = 0.3
                                agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                            else:                                    # Full loss of SDPT converted to builtup, burned (2242)
                                node = accrete_node(node, 2)
                                state_out[row, col] = node
                                agc_ef = 0.3
                                agc_flux_out_block[row, col] = (agc_dens_curr * agc_ef)
                                agc_dens_curr_block[row, col] = agc_dens_curr - agc_flux_out_block[row, col]
                                bgc_flux_out_block[row, col] = float(agc_flux_out_block[row, col]) * r_s_ratio_cell
                                bgc_dens_curr_block[row, col] = bgc_dens_curr - bgc_flux_out_block[row, col] 
                
                ### Trees remaining trees
                elif (tree_prev) and (tree_curr):                    # Trees remaining trees (3)    ##TODO: Include mangrove exception.
                    node = accrete_node(node, 3)
                    if forest_dist_last == 0:                        # Trees without stand-replacing disturbances in the last interval (31)
                        node = accrete_node(node, 1)
                        if planted_forest_type == 0:                 # Non-planted trees without stand-replacing disturbance in the last interval (311)
                            node = accrete_node(node, 1)
                            if not tall_veg_curr:                    # Trees outside forests without stand-replacing disturbance in the last interval (3111)
                                node = accrete_node(node, 1)
                                if not sig_height_loss_prev_curr:    # Stable trees outside forests (31111)
                                    node = accrete_node(node, 1)
                                    state_out[row, col] = node
                                    agc_flux_out_block[row, col] = 5.54
                                    agc_dens_curr_block[row, col] = 13.59
                                    bgc_flux_out_block[row, col] = 2.83
                                    bgc_dens_curr_block[row, col] = 7.34
                                else:                                # Partially disturbed trees outside forests (31112)
                                    node = accrete_node(node, 2)
                                    if burned_area_last == 0:        # Partially disturbed trees outside forests without fire (311121)
                                        node = accrete_node(node, 1)
                                        state_out[row, col] = node
                                        agc_flux_out_block[row, col] = 5.54
                                        agc_dens_curr_block[row, col] = 13.59
                                        bgc_flux_out_block[row, col] = 2.83
                                        bgc_dens_curr_block[row, col] = 7.34
                                    else:
                                        node = accrete_node(node, 2)
                                        state_out[row, col] = node
                                        agc_flux_out_block[row, col] = 5.54
                                        agc_dens_curr_block[row, col] = 13.59
                                        bgc_flux_out_block[row, col] = 2.83
                                        bgc_dens_curr_block[row, col] = 7.34
                            else:                                    # Natural forest without stand-replacing disturbance in the last interval (3112)
                                node = accrete_node(node, 2)
                                state_out[row, col] = node
                                agc_flux_out_block[row, col] = 5.54
                                agc_dens_curr_block[row, col] = 13.59
                                bgc_flux_out_block[row, col] = 2.83
                                bgc_dens_curr_block[row, col] = 7.34
                                # if not sig_height_loss_prev_curr:    # Stable natural forest (31121)
                                    
                                
                            
                    else:
                        state_out[row, col] = 32
                        agc_flux_out_block[row, col] = 5.54
                        agc_dens_curr_block[row, col] = 13.59
                        bgc_flux_out_block[row, col] = 2.83
                        bgc_dens_curr_block[row, col] = 7.34
                    
                else:                                                # Not covered in above branches
                    state_out[row, col] = 4000000000                 # High value for uint32
        
        # Adds the output arrays to the dictionary with the appropriate data type
        # Outputs need .copy() so that previous intervals' arrays in dicationary aren't overwritten because arrays in dictionaries are mutable (courtesy of ChatGPT).
        year_range = f"{year-interval_years}_{year}"
        out_dict_uint32[f"{land_state_pattern}_{year_range}"] = state_out.copy()  
        out_dict_float32[f"{agc_dens_pattern}_{year_range}"] = agc_dens_curr_block.copy()
        out_dict_float32[f"{bgc_dens_pattern}_{year_range}"] = bgc_dens_curr_block.copy()
        out_dict_float32[f"{deadwood_c_dens_pattern}_{year_range}"] = deadwood_c_dens_curr_block.copy()
        out_dict_float32[f"{litter_c_dens_pattern}_{year_range}"] = litter_c_dens_curr_block.copy()
        
        out_dict_float32[f"{agc_flux_pattern}_{year_range}"] = agc_flux_out_block.copy()
        out_dict_float32[f"{bgc_flux_pattern}_{year_range}"] = bgc_flux_out_block.copy()
        out_dict_float32[f"{deadwood_c_flux_pattern}_{year_range}"] = deadwood_c_flux_out_block.copy()
        out_dict_float32[f"{litter_c_flux_pattern}_{year_range}"] = litter_c_flux_out_block.copy()

    return out_dict_uint32, out_dict_float32

In [None]:
%%time

### Turns downloads into a dictionary where the keys are the inputs' names. 
### Converts the download dictionary to typed dictionaries that Numba will understand and uses them in the Numba function.
### Returns typed dictionaries, which are then saved to raster locally and uploaded to s3

import os
import math
import uuid
import dask
import dask.bag as db
import s3fs
import rasterio
from rasterio.windows import from_bounds
from numba import jit
import numpy as np
import boto3
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

# Initialize the S3 filesystem with appropriate credentials
s3 = s3fs.S3FileSystem(anon=False)  # Set anon=False to use AWS credentials

# Function to download a chunk from S3 based on bounds
def download_chunk(s3_uri, bounds):
    with rasterio.Env():
        with rasterio.open(s3_uri) as src:
            # Calculate window from bounds
            window = from_bounds(*bounds, transform=src.transform)
            # Read the windowed data
            data = src.read(window=window)
            transform = src.window_transform(window)
            metadata = src.meta.copy()
            metadata.update({
                "height": window.height,
                "width": window.width,
                "transform": transform,
                "compress": "lzw"  # Add compression to reduce file size
            })
    return data, metadata

# # Function to process each chunk with LULUCF_fluxes
# @jit(nopython=True)
# def LULUCF_fluxes_bag(in_dict_uint8, in_dict_int16, in_dict_float32):

#     # Separate dictionaries for output numpy arrays of each datatype, named by output data type).
#     # This is because a dictionary in a Numba function cannot have arrays with multiple data types, so each dictionary has to store only one data type,
#     # just like inputs to the function.
#     out_dict_float32 = {}

#     agc_dens_curr_block = in_dict_float32[agc_2000].astype('float32')
#     bgc_dens_curr_block = in_dict_float32[bgc_2000].astype('float32')
#     deadwood_c_dens_curr_block = in_dict_float32[deadwood_c_2000].astype('float32')
#     litter_c_dens_curr_block = in_dict_float32[litter_c_2000].astype('float32')
    
#     processed_arr = np.zeros_like(agc_dens_curr_block).astype('float32')

#     for row in range(agc_dens_curr_block.shape[0]):
#         for col in range(agc_dens_curr_block.shape[1]):

#             agc_cell = agc_dens_curr_block[row, col]
#             bgc_cell = bgc_dens_curr_block[row, col]
#             deadwood_cell = deadwood_c_dens_curr_block[row, col]
#             litter_cell = litter_c_dens_curr_block[row, col]

#             total_c = agc_cell + bgc_cell + deadwood_cell + litter_cell
#             processed_arr[row, col] = total_c

#     out_dict_float32["total_C"] = processed_arr.copy()
    
#     return out_dict_float32

# Function to upload processed data to S3
def upload_to_s3(local_path, s3_uri, bounds, tile_id):
    s3_client = boto3.client('s3')  # Ensure client is created in the worker process
    
    # Extract carbon pool from the URI
    carbon_pool = s3_uri.split('/')[-1].split('__')[1]
    
    # Get today's date
    today = datetime.today().strftime('%Y%m%d')
    
    # Construct the output URI
    west, south, east, north = bounds
    output_uri = f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/test_outputs/total_C/2000/40000_pixels/{today}/{tile_id}_{west}_{south}_{east}_{north}__{carbon_pool}"
    bucket, key = output_uri.replace("s3://", "").split("/", 1)
    
    # Upload the file
    s3_client.upload_file(local_path, bucket, key)
    os.remove(local_path)  # Clean up the local file

# Returns the encompassing tile_id string in the form YYN/S_XXXE/W based on a coordinate
def xy_to_tile_id(top_left_x, top_left_y):
    lat_ceil = math.ceil(top_left_y / 10.0) * 10
    lng_floor = math.floor(top_left_x / 10.0) * 10
    
    lng = f"{str(abs(lng_floor)).zfill(3)}E" if lng_floor >= 0 else f"{str(abs(lng_floor)).zfill(3)}W"
    lat = f"{str(abs(lat_ceil)).zfill(2)}N" if lat_ceil >= 0 else f"{str(abs(lat_ceil)).zfill(2)}S"

    return f"{lat}_{lng}"

# Wrapper function to download, process, and upload a chunk
def download_process_upload_chunk(bounds):
    west, south, east, north = bounds

    bounds_str = boundstr(bounds)    # String form of chunk bounds
    tile_id = xy_to_tile_id(west, north)    # tile_id in YYN/S_XXXE/W
    chunk_length_pixels = calc_chunk_length_pixels(bounds)   # Chunk length in pixels (as opposed to decimal degrees)
    
    s3_uris = {
        f"{land_cover}_2000": f"{LC_uri}/composite/2000/raw/{tile_id}.tif",
        f"{land_cover}_2005": f"{LC_uri}/composite/2005/raw/{tile_id}.tif",
        f"{land_cover}_2010": f"{LC_uri}/composite/2010/raw/{tile_id}.tif",
        f"{land_cover}_2015": f"{LC_uri}/composite/2015/raw/{tile_id}.tif",
        f"{land_cover}_2020": f"{LC_uri}/composite/2020/raw/{tile_id}.tif",  

        f"{vegetation_height}_2000": f"{LC_uri}/vegetation_height/2000/{tile_id}_vegetation_height_2000.tif",
        f"{vegetation_height}_2005": f"{LC_uri}/vegetation_height/2005/{tile_id}_vegetation_height_2005.tif",
        f"{vegetation_height}_2010": f"{LC_uri}/vegetation_height/2010/{tile_id}_vegetation_height_2010.tif",
        f"{vegetation_height}_2015": f"{LC_uri}/vegetation_height/2015/{tile_id}_vegetation_height_2015.tif",
        f"{vegetation_height}_2020": f"{LC_uri}/vegetation_height/2020/{tile_id}_vegetation_height_2020.tif",  

        agc_2000: f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/AGC_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__AGC_density_MgC_ha_2000.tif",
        bgc_2000: f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/BGC_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__BGC_density_MgC_ha_2000.tif",
        deadwood_c_2000: f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/deadwood_C_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__deadwood_C_density_MgC_ha_2000.tif",
        litter_c_2000: f"s3://gfw2-data/climate/AFOLU_flux_model/LULUCF/outputs/litter_C_density_MgC_ha/2000/40000_pixels/20240729/{tile_id}__litter_C_density_MgC_ha_2000.tif",
        soil_c_2000: f"s3://gfw2-data/climate/carbon_model/carbon_pools/soil_carbon/intermediate_full_extent/standard/20231108/{tile_id}_soil_C_full_extent_2000_Mg_C_ha.tif",

        r_s_ratio: f"{r_s_ratio_path}{tile_id}_{r_s_ratio_pattern}.tif",

        # "drivers": f"s3://gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_drivers/processed/drivers_2022/20230407/{tile_id}_tree_cover_loss_driver_processed.tif",
        planted_forest_type_layer: f"s3://gfw2-data/climate/carbon_model/other_emissions_inputs/plantation_type/SDPTv2/20230911/{tile_id}_plantation_type_oilpalm_woodfiber_other.tif", # Originally from gfw-data-lake, so it's in 400x400 windows
        planted_forest_tree_crop_layer: f"s3://gfw2-data/climate/carbon_model/other_emissions_inputs/plantation_simpleType__planted_forest_tree_crop/SDPTv2/20230911/{tile_id}.tif"  # Originally from gfw-data-lake, so it's in 400x400 windows
        # "peat": f"s3://gfw2-data/climate/carbon_model/other_emissions_inputs/peatlands/processed/20230315/{tile_id}_peat_mask_processed.tif",
        # "ecozone": f"s3://gfw2-data/fao_ecozones/v2000/raster/epsg-4326/10/40000/class/gdal-geotiff/{tile_id}.tif",   # Originally from gfw-data-lake, so it's in 400x400 windows 
        # "iso": f"s3://gfw2-data/gadm_administrative_boundaries/v3.6/raster/epsg-4326/10/40000/adm0/gdal-geotiff/{tile_id}.tif",  # Originally from gfw-data-lake, so it's in 400x400 windows
        # "ifl_primary": f"s3://gfw2-data/climate/carbon_model/ifl_primary_merged/processed/20200724/{tile_id}_ifl_2000_primary_2001_merged.tif"
    }

    # for year in range(first_year, last_year+1):     # Annual burned area maps start in 2000
    for year in range(2000, 2021):     # Annual burned area maps start in 2000
        s3_uris[f"{burned_area}_{year}"] = f"s3://gfw2-data/climate/carbon_model/other_emissions_inputs/burn_year/burn_year_10x10_clip/ba_{year}_{tile_id}.tif"    

    # for year in range(first_year+1, last_year+1):     # Annual forest disturbance maps start in 2001 and ends in 2020
    for year in range(2001, 2021):     # Annual forest disturbance maps start in 2001 and ends in 2020
        s3_uris[f"{forest_disturbance}_{year}"] = f"{LC_uri}/annual_forest_disturbance/raw/{year}_{tile_id}.tif"  

    print("Downloading data")
    
    def download_and_process(name, uri):
        data, metadata = download_chunk(uri, bounds)
        return name, np.squeeze(data), metadata

    # Download each chunk in parallel and store them
    with ThreadPoolExecutor(max_workers=4) as executor:
        results = list(executor.map(lambda item: download_and_process(*item), s3_uris.items()))

    print("Done downloading")

    data_dict = {name: data for name, data, metadata in results}
    metadata = results[0][2]

    typed_dict_uint8, typed_dict_int16, typed_dict_int32, typed_dict_float32 = create_typed_dicts(data_dict)

    print("Running decision tree")
    
    # Sum the arrays using LULUCF_fluxes
    out_dict_uint32, out_dict_float32 = LULUCF_fluxes_bag(
        typed_dict_uint8, typed_dict_int16, typed_dict_float32 
    )

    out_dict_all_dtypes = {}

    # Transfers the dictionaries of numpy arrays for each data type to a new, Pythonic array
    for key, value in out_dict_uint32.items():
        out_dict_all_dtypes[key] = value

    for key, value in out_dict_float32.items():
        out_dict_all_dtypes[key] = value

    # Clear memory of unneeded arrays
    del out_dict_uint32
    del out_dict_float32
  
    # Construct the local output file path with bounding box in the name
    local_output_path = f"/tmp/{tile_id}_{west}_{south}_{east}_{north}.tif"

    print("Saving and uploading data")
    
    # # Save to a local file
    # transform = rasterio.transform.from_bounds(*bounds, width=chunk_length_pixels, height=chunk_length_pixels)
    # with rasterio.open(local_output_path, "w", driver='GTiff', width=chunk_length_pixels, height=chunk_length_pixels, count=1, 
    #                            dtype='float32', crs='EPSG:4326', transform=transform, compress='lzw', blockxsize=400, blockysize=400) as dst:
    #     dst.write(out_dict_all_dtypes["total_C"][np.newaxis, :, :])
    
    # # Upload the processed data
    # upload_to_s3(local_output_path, s3_uris[agc_2000], bounds, tile_id)

    # return
    # # return bounds, out_dict_all_dtypes["total_C"].shape

# Function to generate bounding boxes within a specified bounding box with a specified chunk size
def generate_chunks_within_bounds(west, south, east, north, chunk_size):
    chunks = []
    lat = south
    while lat < north:
        lon = west
        while lon < east:
            chunk_west = lon
            chunk_south = lat
            chunk_east = min(lon + chunk_size, east)
            chunk_north = min(lat + chunk_size, north)
            chunks.append((chunk_west, chunk_south, chunk_east, chunk_north))
            lon += chunk_size
        lat += chunk_size
    return chunks

# Specify the bounding box and chunk size: west, south, east, north
bounding_box = [10, 40, 20, 50]  # 50N_010E
chunk_size = 2

# bounding_box = [10, 40, 20, 50]  # 50N_010E
# bounding_box = [10, 49, 11, 50]
# chunk_size = 1  # 1x1 degree chunks

# bounding_box = [10, 49.75, 10.25, 50]
# chunk_size = 0.25

# Generate chunks within the specified bounding box
chunks = generate_chunks_within_bounds(*bounding_box, chunk_size)

# Define the function to be used with Dask for each chunk
def process_chunk(bounds):
    return download_process_upload_chunk(bounds)

# Create a Dask bag from the list of chunks
s3_bag = db.from_sequence(chunks, npartitions=len(chunks))

# Apply the process_chunk function to each chunk in parallel
processed_chunks = s3_bag.map(process_chunk).compute()

# # `processed_chunks` now contains the bounds and the shape of the processed data for each raster
# for bounds, shape in processed_chunks:
#     print(f"Bounds: {bounds}, Processed data shape: {shape}")
