Haven't tested this one yet

In [None]:
import os
import logging
import boto3
import rioxarray
import rasterio
from rasterio.merge import merge as merge_arrays
import dask
from dask.distributed import Client, LocalCluster
from dask.diagnostics import ProgressBar
import atexit
import xarray as xr

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')

# AWS S3 setup
s3_bucket = "gfw2-data"
local_temp_dir = "/tmp/aggregated"

# Global variables for Dask cluster and client
cluster = None
client = None

def s3_file_exists(bucket, key):
    s3 = boto3.client('s3')
    try:
        s3.head_object(Bucket=bucket, Key=key)
        logging.info(f"File exists: s3://{bucket}/{key}")
        return True
    except:
        logging.info(f"File does not exist: s3://{bucket}/{key}")
        return False

def list_s3_files(bucket, prefix):
    s3 = boto3.client('s3')
    keys = []
    try:
        paginator = s3.get_paginator('list_objects_v2')
        for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
            for obj in page.get('Contents', []):
                keys.append(obj['Key'])
    except Exception as e:
        logging.error(f"Error listing files in s3://{bucket}/{prefix}: {e}")
    return keys

def aggregate_tile(tile_path, output_path):
    with rasterio.open(tile_path) as src:
        data = src.read(1, masked=True)
        transform = src.transform
        profile = src.profile

        # Resample to 4km resolution
        scale_factor = 30 / 4000  # 30m to 4km
        new_height = int(data.shape[0] * scale_factor)
        new_width = int(data.shape[1] * scale_factor)
        resampled_data = src.read(
            out_shape=(1, new_height, new_width),
            resampling=rasterio.enums.Resampling.average
        )

        # Update profile
        profile.update({
            "height": new_height,
            "width": new_width,
            "transform": rasterio.Affine(
                transform.a / scale_factor,
                transform.b,
                transform.c,
                transform.d,
                transform.e / scale_factor,
                transform.f
            )
        })

        with rasterio.open(output_path, 'w', **profile) as dst:
            dst.write(resampled_data, 1)

def merge_global_tiles(tiles, output_path):
    rasters = [rasterio.open(tile) for tile in tiles]
    merged, out_transform = merge_arrays(rasters)

    out_meta = rasters[0].meta.copy()
    out_meta.update({
        "height": merged.shape[1],
        "width": merged.shape[2],
        "transform": out_transform,
        "compress": "lzw"
    })

    with rasterio.open(output_path, 'w', **out_meta) as dst:
        dst.write(merged[0], 1)  # Write the first band

def cleanup():
    global client, cluster
    if client:
        client.close()
    if cluster:
        cluster.close()

def process_and_aggregate(tile_id, input_prefix, temp_prefix, output_prefix):
    s3_client = boto3.client('s3')
    input_file = f'{input_prefix}/{tile_id}.tif'
    local_input_path = os.path.join(local_temp_dir, f'{tile_id}.tif')
    temp_output_file = f'{temp_prefix}/{tile_id}_4km.tif'
    local_temp_output_path = os.path.join(local_temp_dir, f'{tile_id}_4km.tif')

    if not os.path.exists(local_temp_dir):
        os.makedirs(local_temp_dir)

    s3_client.download_file(s3_bucket, input_file, local_input_path)
    aggregate_tile(local_input_path, local_temp_output_path)
    s3_client.upload_file(local_temp_output_path, s3_bucket, temp_output_file)

    os.remove(local_input_path)
    os.remove(local_temp_output_path)

def main(input_prefix, temp_prefix, output_prefix, tile_ids=None):
    global cluster, client
    cluster = LocalCluster()
    client = Client(cluster)
    atexit.register(cleanup)

    try:
        available_tile_ids = list_s3_files(s3_bucket, input_prefix)
        available_tile_ids = [os.path.basename(path).replace('.tif', '') for path in available_tile_ids]

        if tile_ids:
            # Filter the available tile IDs to only include those in the provided list
            tile_ids_to_process = [tile_id for tile_id in tile_ids if tile_id in available_tile_ids]
        else:
            tile_ids_to_process = available_tile_ids

        for tile_id in tile_ids_to_process:
            dask_tile = dask.delayed(process_and_aggregate)(tile_id, input_prefix, temp_prefix, output_prefix)
            with ProgressBar():
                dask.compute(dask_tile)

        # Merge all the 4km tiles into one global raster
        aggregated_tiles = list_s3_files(s3_bucket, temp_prefix)
        aggregated_tiles = [f's3://{s3_bucket}/{tile}' for tile in aggregated_tiles]
        
        if aggregated_tiles:
            local_global_path = os.path.join(local_temp_dir, 'global_4km.tif')
            merge_global_tiles(aggregated_tiles, local_global_path)
            s3_global_output = f'{output_prefix}/global_4km.tif'
            s3_client.upload_file(local_global_path, s3_bucket, s3_global_output)
            os.remove(local_global_path)
            logging.info(f"Uploaded global raster to s3://{s3_bucket}/{s3_global_output}")
    finally:
        print("exit")

if __name__ == "__main__":
    input_prefix = 'climate/AFOLU_flux_model/organic_soils/outputs/soil/2020/10x10_degrees/'
    temp_prefix = 'climate/AFOLU_flux_model/organic_soils/outputs/soil/2020/4km_temp/'
    output_prefix = 'climate/AFOLU_flux_model/organic_soils/outputs/soil/2020/global/'

    tile_ids = ['50N_070E', '70N_040E', '00N_150E', '10N_130E']  # Replace with the list of tile IDs you want to process
    main(input_prefix, temp_prefix, output_prefix, tile_ids)
