In [282]:
import os
import logging
import boto3
import rioxarray
import rasterio
from rasterio.merge import merge as merge_arrays
import dask
from dask.distributed import Client, LocalCluster
from dask.diagnostics import ProgressBar

In [283]:
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')


In [284]:
# AWS S3 setup
s3_bucket = "gfw2-data"
local_temp_dir = "/tmp/merged"

In [285]:
def s3_file_exists(bucket, key):
    s3 = boto3.client('s3')
    try:
        s3.head_object(Bucket=bucket, Key=key)
        logging.info(f"File exists: s3://{bucket}/{key}")
        return True
    except:
        logging.info(f"File does not exist: s3://{bucket}/{key}")
        return False

In [286]:
def list_s3_files(bucket, prefix):
    s3 = boto3.client('s3')
    keys = []
    try:
        paginator = s3.get_paginator('list_objects_v2')
        for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
            for obj in page.get('Contents', []):
                keys.append(obj['Key'])
    except Exception as e:
        logging.error(f"Error listing files in s3://{bucket}/{prefix}: {e}")
    return keys


In [287]:
def merge_tiles(tile_id, input_prefix, output_prefix):
    small_raster_paths = list_s3_files(s3_bucket, input_prefix)

    small_raster_paths = [path for path in small_raster_paths if tile_id in path]

    small_raster_paths = [f's3://{s3_bucket}/{path}' for path in small_raster_paths]

    if not small_raster_paths:
        logging.info(f"No small rasters found for tile {tile_id}.")
        return

    # Open rasters using rasterio directly
    small_rasters = [rasterio.open(path) for path in small_raster_paths]

    merged, out_transform = merge_arrays(small_rasters)

    if not os.path.exists(local_temp_dir):
        os.makedirs(local_temp_dir)

    out_file = f'merged_{tile_id}.tif'
    local_output_path = os.path.join(local_temp_dir, out_file)

    # Copy the metadata from one of the source rasters
    out_meta = small_rasters[0].meta.copy()
    out_meta.update({
        "driver": "GTiff",
        "height": merged.shape[1],
        "width": merged.shape[2],
        "transform": out_transform,
        "compress": "lzw"
    })

    with rasterio.open(local_output_path, 'w', **out_meta) as dst:
        dst.write(merged, 1)

    s3_client = boto3.client('s3')
    s3_output_path = os.path.join(output_prefix, out_file)
    s3_client.upload_file(local_output_path, s3_bucket, s3_output_path)
    logging.info(f"Uploaded merged raster to s3://{s3_bucket}/{s3_output_path}")

    os.remove(local_output_path)

In [None]:
def cleanup():
    global client, cluster
    if client:
        client.close()
    if cluster:
        cluster.close()

In [None]:
def main(input_prefix, output_prefix):
    global cluster, client
    cluster = LocalCluster()
    client = Client(cluster)
    atexit.register(cleanup)  # Ensure the cluster is closed when the script exits

    try:
        tile_ids = list(set([os.path.basename(path).split('_')[0] for path in list_s3_files(s3_bucket, input_prefix)]))

        dask_tiles = [dask.delayed(merge_tiles)(tile_id, input_prefix, output_prefix) for tile_id in tile_ids]
        with ProgressBar():
            dask.compute(*dask_tiles)
    finally:
        cleanup()

In [None]:
if __name__ == "__main__":
    input_prefix = 'climate/AFOLU_flux_model/organic_soils/outputs/soil/2020/8000_pixels/20240603/'  # Replace with your input prefix
    output_prefix = 'climate/AFOLU_flux_model/organic_soils/outputs/soil/2020/10x10_degrees/'  # Replace with your desired output prefix
    main(input_prefix, output_prefix)