## Create the biomass map in Mexico region

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import glob
import os
import pickle
import random

import dask
import dask.array as da
import geopandas
import learn2map.raster_tools as rt
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import rioxarray
import tensorflow as tf
import xarray as xr
from a2105_global100.ml_functions import (
    data_cnn_prediction,
    densenet_model,
    density_scatter_plot,
    load_tf,
    regression_tfrecord_from_df,
    residual2_box_plot,
    residual_box_plot,
)
from geocube.api.core import make_geocube
from joblib import Parallel, delayed
from sklearn.model_selection import train_test_split
from tqdm import tqdm

print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices("GPU")))

os.environ["GDAL_MAX_DATASET_POOL_SIZE"] = "999"


os.chdir("/Volumes/STVHD/Mexico_AGB_30m/data/gedi_25m/")

Num GPUs Available:  1


## define the function

In [2]:
def raster_adding_yxgrid(
    darr,
    y_lat,
    x_lon,
    block_info=None,
):
    id_y = block_info[0]["array-location"][1]
    id_x = block_info[0]["array-location"][2]
    # x0, _ = geotfm * [np.arange(id_x[0], id_x[1]) + 0.5, 0.5]
    # _, y0 = geotfm * [0.5, np.arange(id_y[0], id_y[1]) + 0.5]
    x0 = x_lon[id_x[0] : id_x[1]]
    y0 = y_lat[id_y[0] : id_y[1]]
    xv, yv = np.meshgrid(x0, y0)
    img_stack = np.concatenate([darr, yv[None, :, :], xv[None, :, :]], axis=0)
    return img_stack


def mosaic_adding_yxgrid(xarr0, chunk_size=3000):
    xarr0 = xarr0.chunk(chunks=(xarr0.shape[0], chunk_size, chunk_size))
    out_chunks = (xarr0.shape[0] + 2, xarr0.chunks[1], xarr0.chunks[2])
    darr_out = xarr0.data.map_blocks(
        raster_adding_yxgrid,
        xarr0.y.values,
        xarr0.x.values,
        chunks=out_chunks,
        meta=np.array(()),
    )
    new_coords = [
        ("band", xarr0.band.values.tolist() + ["y_grid", "x_grid"]),
        ("y", xarr0.y.values),
        ("x", xarr0.x.values),
    ]
    xarr_out = xr.DataArray(darr_out, coords=new_coords, dims=xarr0.dims)
    return xarr_out


def extract_training_lst(
    i,
    tiles0,
    tiles1,
    scale=1,  # scale is the size ratio of tiles1 vs. tiles0
    width=1,
    valid_range=(0, 9999),
    mask_band=0,
    target_bands=[0],
    out_path=None,
):
    dims0 = tiles0.shape
    dims1 = tiles1.shape
    # print(dims0)
    # print(dims1)
    n_arr = np.arange(dims0[1] * dims0[2])
    mask_arr = tiles0[mask_band, :, :]
    mask_arr = np.where((mask_arr > valid_range[0]) & (mask_arr < valid_range[1]), 1, 0)
    valid_idx = n_arr[mask_arr.ravel().astype(bool)]

    if len(valid_idx) > 0:
        # print(len(valid_idx))
        sample0_lst = []
        sample1_lst = []
        for i in valid_idx:
            row, col = np.unravel_index(i, (dims0[1], dims0[2]))
            # print([row, col])
            sample_i = tiles0[
                :,
                row : row + 1,
                col : col + 1,
            ]
            sample0_lst.append(
                np.transpose(sample_i, [1, 2, 0])
            )  # transpose to [y, x, band]
            sample_i = tiles1[
                :,
                row * scale + width // 2 : row * scale + width * 3 // 2,
                col * scale + width // 2 : col * scale + width * 3 // 2,
            ]
            # print(np.transpose(sample_i, [1, 2, 0]).shape)
            sample1_lst.append(
                np.transpose(sample_i, [1, 2, 0])
            )  # transpose to [y, x, band]
        valid_target = np.stack(sample0_lst, axis=0)
        valid_sample = np.stack(sample1_lst, axis=0)

        band_list = np.arange(dims0[0])
        y_bands = np.isin(band_list, target_bands)
        if y_bands.any():
            y_lst = valid_target[:, 0, 0, y_bands]

        X_lst = valid_sample[:, :, :, np.arange(dims1[0])]

        pickle.dump([X_lst, y_lst], open(f"{out_path}_{i}", "wb"))
        return f"{out_path}_{i}"
        # else:
        # X_lst = np.empty([0, width, width, len(X_bands)])
        # y_lst = np.empty([0, len(y_bands)])


def get_layers_xy(
    xarr0,
    xarr1,
    scale,
    width,
    mask_band,
    target_bands,
    valid_range=(0, 9999),
    out_path=None,
):
    # Make sure xarr0/xarr1 have the same chunks
    darr0 = da.overlap.overlap(xarr0.data, depth=(0, 0, 0), boundary="nearest")
    darr1 = da.overlap.overlap(xarr1.data, depth=(0, width, width), boundary="nearest")
    tiles0 = darr0.to_delayed().ravel()
    tiles1 = darr1.to_delayed().ravel()
    results = [
        dask.delayed(extract_training_lst)(
            i,
            tiles0[i],
            tiles1[i],
            scale=scale,
            width=width,
            valid_range=valid_range,
            mask_band=mask_band,
            target_bands=target_bands,
            out_path=out_path,
        )
        for i in range(len(tiles0))
    ]
    return results


def build_vrt_mosaic(in_file, out_file):
    in_fname = "merge_list.txt"
    with open(in_fname, "w") as f:
        f.writelines("%s\n" % i for i in in_file)
    command = f"gdalbuildvrt -overwrite -input_file_list {in_fname} {out_file}"
    print(command)
    output = os.system(command)
    print(output)

## open geotiff and save to zarr files

In [20]:


file_lst = [
    "HLS_2020_red_30m_mex.tif",
    "HLS_2020_nir_30m_mex.tif",
    "HLS_2020_swir1_30m_mex.tif",
    "HLS_2020_swir2_30m_mex.tif",
    "HLS_2020_vi_30m_mex.tif",
    "ALOS_HH_2019_2021_30m.tif",
    "ALOS_HV_2019_2021_30m.tif",
    "GEDI_25m_CH98.tif",
]
new_lst0 = [f"/Volumes/STVHD/Mexico_AGB_30m/data/hls_30m/{file}" for file in file_lst[0:5]]
new_lst1 = [f"/Volumes/STVHD/Mexico_AGB_30m/data/alos_30m/{file}" for file in file_lst[5:7]]
new_lst2 = [f"/Volumes/STVHD/Mexico_AGB_30m/data/gedi_25m/{file}" for file in file_lst[7:8]]
new_lst = [new_lst0, new_lst1, new_lst2]
new_lst

name_lst = [
    "hls_30m_mx",
    "alos_30m_mx",
    "GEDI_25m_mx",
]
for ii in range(1,2):
    path_vrt = f"/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_{name_lst[ii]}.vrt"
    in_fname = f"/Volumes/STVHD/Mexico_AGB_30m/test/merge_list_{name_lst[ii]}.txt"
    with open(in_fname, "w") as f:
        f.writelines("%s\n" % i for i in new_lst[ii])
    band_name = rt.build_stack_vrt(in_fname, path_vrt)

['ALOS_HH_2019_2021_30m_b1', 'ALOS_HV_2019_2021_30m_b1']


In [4]:
out_file = "/Volumes/STVHD/Mexico_AGB_30m/data/gedi_25m/GEDI_25m_CH98.tif"
path_dem = "/vsis3/ctrees-input-data/digital_elevation_model/copernicus_GLO30/vrt/Copernicus_DEM_30m.vrt"
rt.raster_clip(out_file, path_dem, "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.vrt")

gdalwarp -s_srs "+proj=longlat +datum=WGS84 +no_defs" -t_srs "+proj=longlat +datum=WGS84 +no_defs" -te -117.41923994490472 14.17339397401678 -86.61331390659396 32.72697327340035 -ts 137172 82615 -srcnodata nan -dstnodata nan -overwrite -multi -co COMPRESS=DEFLATE -co ZLEVEL=1 -co PREDICTOR=2 -co BIGTIFF=YES -r near -ot Float32 "/vsis3/ctrees-input-data/digital_elevation_model/copernicus_GLO30/vrt/Copernicus_DEM_30m.vrt" "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.vrt"


In [5]:
'''
rt.raster_clip(
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.vrt",
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.vrt",
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_rs.vrt",
)

da1 = rioxarray.open_rasterio(
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.vrt", chunck=(1, 1000, 1000), lock=False
).astype("float32")
da1 = da1.fillna(np.nan)
da1.rio.write_crs("EPSG:4326", inplace=True)
ds = xr.Dataset({"da": da1})
ds.to_zarr("/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.zarr", mode="w")

'''

# Open the raster file with appropriate chunking
da1 = rioxarray.open_rasterio(
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.vrt",
    chunks=(1, 1000, 1000),
    lock=False
).astype("float32")

# Fill NaNs with np.nan explicitly
da1 = da1.fillna(np.nan)

# Write the CRS to the dataset
da1.rio.write_crs("EPSG:4326", inplace=True)

# Create a Dataset from the DataArray
ds = xr.Dataset({"da": da1})

# Save the dataset to a Zarr file
ds.to_zarr("/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.zarr", mode="w")

print("Dataset saved successfully.")



ERROR 1: PROJ: internal_proj_create_from_database: /Users/stephaniegeorge/miniforge3/envs/ctrees-dev/share/proj/proj.db contains DATABASE.LAYOUT.VERSION.MINOR = 0 whereas a number >= 2 is expected. It comes from another PROJ installation.


Dataset saved successfully.


Process ALOS

Note: This section prepares Alos layers in zarr format. 
Alos native res is 25m. 
Step1: resample to 30m
Step2: save vrt with stacked alos HH and HV
Step3: save zarr dataset

In [6]:
import rioxarray

input_path = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.vrt"

try:
    da = rioxarray.open_rasterio(input_path, chunks=(1, 1000, 1000))
    print("Raster file read successfully.")
except Exception as e:
    print(f"Failed to read raster file: {e}")
da

Raster file read successfully.


Unnamed: 0,Array,Chunk
Bytes,59.64 GiB,3.81 MiB
Shape,"(2, 69334, 115449)","(1, 1000, 1000)"
Dask graph,16240 chunks in 2 graph layers,16240 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 59.64 GiB 3.81 MiB Shape (2, 69334, 115449) (1, 1000, 1000) Dask graph 16240 chunks in 2 graph layers Data type float32 numpy.ndarray",115449  69334  2,

Unnamed: 0,Array,Chunk
Bytes,59.64 GiB,3.81 MiB
Shape,"(2, 69334, 115449)","(1, 1000, 1000)"
Dask graph,16240 chunks in 2 graph layers,16240 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Resample ALOS rasters

Make directory

In [14]:
import os

output_dir = "/Volumes/STVHD/Mexico_AGB_30m/data/alos_30m"
os.makedirs(output_dir, exist_ok=True)




Resample with gdalwarp

In [17]:
import subprocess
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define file paths
input_paths = [
    "/Volumes/STVHD/Mexico_AGB_30m/data/alos_25m/ALOS_HH_2019_2021_25m.tif",
    "/Volumes/STVHD/Mexico_AGB_30m/data/alos_25m/ALOS_HV_2019_2021_25m.tif"
]
output_paths = [
    "/Volumes/STVHD/Mexico_AGB_30m/data/alos_30m/ALOS_HH_2019_2021_30m.tif",
    "/Volumes/STVHD/Mexico_AGB_30m/data/alos_30m/ALOS_HV_2019_2021_30m.tif"
]

for input_path, output_path in zip(input_paths, output_paths):
    try:
        # Resample the raster to 30m resolution using gdalwarp
        logger.info(f"Resampling raster to 30m resolution: {input_path}")
        subprocess.run([
            "gdalwarp",
            "-tr", "0.0002695", "0.0002695",  # Set target resolution to approximately 30m in degrees
            "-r", "near",  # Resampling method
            "-of", "GTiff",  # Output format
            input_path,
            output_path
        ], check=True)

        logger.info(f"Raster resampled and saved successfully: {output_path}")

    except subprocess.CalledProcessError as e:
        logger.error(f"Failed to resample or save raster file {input_path}: {e}")



INFO:__main__:Resampling raster to 30m resolution: /Volumes/STVHD/Mexico_AGB_30m/data/alos_25m/ALOS_HH_2019_2021_25m.tif


Creating output file that is 115449P x 69334L.
Processing /Volumes/STVHD/Mexico_AGB_30m/data/alos_25m/ALOS_HH_2019_2021_25m.tif [1/1] : 0Using internal nodata values (e.g. -3.4e+38) for image /Volumes/STVHD/Mexico_AGB_30m/data/alos_25m/ALOS_HH_2019_2021_25m.tif.
Copying nodata values from source /Volumes/STVHD/Mexico_AGB_30m/data/alos_25m/ALOS_HH_2019_2021_25m.tif to destination /Volumes/STVHD/Mexico_AGB_30m/data/alos_30m/ALOS_HH_2019_2021_30m.tif.
...10...20...30...40...50...60...70...80...90...

INFO:__main__:Raster resampled and saved successfully: /Volumes/STVHD/Mexico_AGB_30m/data/alos_30m/ALOS_HH_2019_2021_30m.tif
INFO:__main__:Resampling raster to 30m resolution: /Volumes/STVHD/Mexico_AGB_30m/data/alos_25m/ALOS_HV_2019_2021_25m.tif


100 - done.
Creating output file that is 115449P x 69334L.
Processing /Volumes/STVHD/Mexico_AGB_30m/data/alos_25m/ALOS_HV_2019_2021_25m.tif [1/1] : 0Using internal nodata values (e.g. -3.4e+38) for image /Volumes/STVHD/Mexico_AGB_30m/data/alos_25m/ALOS_HV_2019_2021_25m.tif.
Copying nodata values from source /Volumes/STVHD/Mexico_AGB_30m/data/alos_25m/ALOS_HV_2019_2021_25m.tif to destination /Volumes/STVHD/Mexico_AGB_30m/data/alos_30m/ALOS_HV_2019_2021_30m.tif.
...10...20...30...40...50...60...70...80...90...

INFO:__main__:Raster resampled and saved successfully: /Volumes/STVHD/Mexico_AGB_30m/data/alos_30m/ALOS_HV_2019_2021_30m.tif


100 - done.


Save alos vrt to zarr (if this step crashes the kernel, run in separate script)

In [3]:

# Test opening a raster
try:
    da = rioxarray.open_rasterio("/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.vrt", chunks={'x': 1000, 'y': 1000})
    print("Raster opened successfully.")
except Exception as e:
    print(f"Failed to open raster: {e}")


Raster opened successfully.


In [3]:
def process_raster(input_path, output_path):
    try:
        # Open the raster file with chunking
        print("Opening raster file.")
        da = rioxarray.open_rasterio(input_path, chunks={'x': 1000, 'y': 1000})

        # Ensure the CRS is correctly set
        if not da.rio.crs:
            da.rio.write_crs("EPSG:4326", inplace=True)

        # Create a Dataset from the DataArray
        ds = xr.Dataset({"da": da})

        # Save the dataset to a Zarr file
        print("Saving dataset to Zarr.")
        ds.to_zarr(output_path, mode="w")

        print("Dataset saved successfully.")
    except Exception as e:
        print(f"Failed to process raster file: {e}")

# Paths
input_path = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.vrt"
output_path = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.zarr"

# Process the raster
process_raster(input_path, output_path)


Opening raster file.
Saving dataset to Zarr.


: 

In [10]:
ds
da1

NameError: name 'ds' is not defined

Alternative to process ALOS

In [None]:
'''
import os
import rasterio
import xarray as xr
import rioxarray
import tempfile
import shutil
from dask.diagnostics import ProgressBar
'''

In [None]:
'''
def is_s3_path(path):
    return path.startswith('/vsis3/')

def download_to_tempfile(vrt_path, temp_dir):
    local_vrt_path = os.path.join(temp_dir, "local_copy.vrt")
    with rasterio.Env():
        with rasterio.open(vrt_path) as src:
            profile = src.profile
            with rasterio.open(local_vrt_path, 'w', **profile) as dst:
                dst.write(src.read())
    return local_vrt_path

def calculate_bounds_for_stratum(vrt_path, i, j):
    with rasterio.open(vrt_path) as src:
        bounds = src.bounds
        width = (bounds.right - bounds.left) / 3
        height = (bounds.top - bounds.bottom) / 3
        left = bounds.left + i * width
        right = left + width
        bottom = bounds.bottom + j * height
        top = bottom + height
        return (left, bottom, right, top)

def process_stratum(input_vrt_path, bounds, output_zarr_path):
    try:
        with rasterio.open(input_vrt_path) as src:
            window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
            data = src.read(window=window)
        
        stratum = xr.DataArray(
            data,
            dims=["band", "y", "x"],
            coords={
                "band": [1],
                "y": range(data.shape[1]),
                "x": range(data.shape[2])
            }
        )
        stratum.rio.write_crs("EPSG:4326", inplace=True)
        
        # Rechunking to ensure uniform chunk sizes
        stratum = stratum.chunk({'band': 1, 'x': 1000, 'y': 1000})

        ds = xr.Dataset({"da": stratum})
        
        # Using dask to_zarr with progress bar for better tracking
        with ProgressBar():
            ds.to_zarr(output_zarr_path, mode="w")
    except rasterio.errors.RasterioIOError as e:
        print(f"Error processing stratum with bounds {bounds}: {e}")

def process_and_save_chunks(input_vrt_path, output_zarr_base_path):
    temp_dir = None

    try:
        # Check if the VRT path is an S3 path
        if is_s3_path(input_vrt_path):
            # Create a temporary directory
            temp_dir = tempfile.mkdtemp()
            # Download the VRT to the temporary directory
            local_vrt_path = download_to_tempfile(input_vrt_path, temp_dir)
        else:
            local_vrt_path = input_vrt_path

        for i in range(3):
            for j in range(3):
                stratum_bounds = calculate_bounds_for_stratum(local_vrt_path, i, j)
                stratum_output_path = f"{output_zarr_base_path}{i}_{j}.zarr"
                print(f"Processing stratum {i}_{j} with bounds: {stratum_bounds}")
                
                process_stratum(local_vrt_path, stratum_bounds, stratum_output_path)
                print(f"Stratum {i}_{j} saved to {stratum_output_path}")
    
    finally:
        if temp_dir:
            # Clean up the temporary directory
            shutil.rmtree(temp_dir)
'''

In [None]:
'''
# Example usage:
process_and_save_chunks(input_vrt_dem_path, output_zarr_dem_base_path)
'''

Additional step: Run this step in case some tiles failed

1/3

In [None]:
'''
import os
import rasterio
import xarray as xr
import rioxarray
import tempfile
import shutil
from dask.diagnostics import ProgressBar
from time import sleep
'''

2/3

In [None]:
'''
def is_s3_path(path):
    return path.startswith('/vsis3/')

def download_to_tempfile(vrt_path, temp_dir):
    local_vrt_path = os.path.join(temp_dir, "local_copy.vrt")
    with rasterio.Env():
        with rasterio.open(vrt_path) as src:
            profile = src.profile
            with rasterio.open(local_vrt_path, 'w', **profile) as dst:
                dst.write(src.read())
    return local_vrt_path

def calculate_bounds_for_stratum(vrt_path, i, j):
    with rasterio.open(vrt_path) as src:
        bounds = src.bounds
        width = (bounds.right - bounds.left) / 3
        height = (bounds.top - bounds.bottom) / 3
        left = bounds.left + i * width
        right = left + width
        bottom = bounds.bottom + j * height
        top = bottom + height
        return (left, bottom, right, top)

def process_stratum(input_vrt_path, bounds, output_zarr_path, retries=3):
    attempt = 0
    while attempt < retries:
        try:
            with rasterio.open(input_vrt_path) as src:
                window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
                data = src.read(window=window)
            
            stratum = xr.DataArray(
                data,
                dims=["band", "y", "x"],
                coords={
                    "band": [1],
                    "y": range(data.shape[1]),
                    "x": range(data.shape[2])
                }
            )
            stratum.rio.write_crs("EPSG:4326", inplace=True)
            
            # Rechunking to ensure uniform chunk sizes
            stratum = stratum.chunk({'band': 1, 'x': 1000, 'y': 1000})

            ds = xr.Dataset({"da": stratum})
            
            # Using dask to_zarr with progress bar for better tracking
            with ProgressBar():
                ds.to_zarr(output_zarr_path, mode="w")
            
            # If successful, break out of the retry loop
            return
        except rasterio.errors.RasterioIOError as e:
            print(f"Error processing stratum with bounds {bounds}: {e}")
            attempt += 1
            sleep(2)  # Wait a bit before retrying
            if attempt >= retries:
                print(f"Failed to process stratum with bounds {bounds} after {retries} attempts")

def process_missing_stratum(input_vrt_path, output_zarr_base_path, missing_stratum):
    temp_dir = None

    try:
        # Check if the VRT path is an S3 path
        if is_s3_path(input_vrt_path):
            # Create a temporary directory
            temp_dir = tempfile.mkdtemp()
            # Download the VRT to the temporary directory
            local_vrt_path = download_to_tempfile(input_vrt_path, temp_dir)
        else:
            local_vrt_path = input_vrt_path

        i, j = missing_stratum
        stratum_bounds = calculate_bounds_for_stratum(local_vrt_path, i, j)
        stratum_output_path = f"{output_zarr_base_path}{i}_{j}.zarr"
        print(f"Processing stratum {i}_{j} with bounds: {stratum_bounds}")
        
        process_stratum(local_vrt_path, stratum_bounds, stratum_output_path)
        print(f"Stratum {i}_{j} saved to {stratum_output_path}")
    
    finally:
        if temp_dir:
            # Clean up the temporary directory
            shutil.rmtree(temp_dir)
'''

3/3: Define missing stratum here

In [None]:
'''
# Define the missing stratum to process (e.g., stratum 2_2)
missing_stratum = (2, 2)

# Example usage:
process_missing_stratum(input_vrt_dem_path, output_zarr_dem_base_path, missing_stratum)
'''

Save alos in zarr in strata

In [None]:
'''
import os
import rasterio
import xarray as xr
import rioxarray
from dask.diagnostics import ProgressBar
'''

In [None]:
'''
def calculate_bounds_for_stratum(vrt_path, i, j):
    with rasterio.open(vrt_path) as src:
        bounds = src.bounds
        width = (bounds.right - bounds.left) / 3
        height = (bounds.top - bounds.bottom) / 3
        left = bounds.left + i * width
        right = left + width
        bottom = bounds.bottom + j * height
        top = bottom + height
        return (left, bottom, right, top)

def process_stratum(input_vrt_path, bounds, output_zarr_path):
    try:
        with rasterio.open(input_vrt_path) as src:
            window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
            data = src.read(window=window)
            num_bands = src.count  # Get the number of bands
        
        stratum = xr.DataArray(
            data,
            dims=["band", "y", "x"],
            coords={
                "band": range(1, num_bands + 1),  # Set correct band coordinates
                "y": range(data.shape[1]),
                "x": range(data.shape[2])
            }
        )
        stratum.rio.write_crs("EPSG:4326", inplace=True)
        
        # Rechunking to ensure uniform chunk sizes
        stratum = stratum.chunk({'band': 1, 'x': 1000, 'y': 1000})

        ds = xr.Dataset({"da": stratum})
        
        # Using dask to_zarr with progress bar for better tracking
        with ProgressBar():
            ds.to_zarr(output_zarr_path, mode="w")
    except rasterio.errors.RasterioIOError as e:
        print(f"Error processing stratum with bounds {bounds}: {e}")

def process_and_save_chunks(input_vrt_path, output_zarr_base_path):
    for i in range(3):
        for j in range(3):
            stratum_bounds = calculate_bounds_for_stratum(input_vrt_path, i, j)
            stratum_output_path = f"{output_zarr_base_path}{i}_{j}.zarr"
            print(f"Processing stratum {i}_{j} with bounds: {stratum_bounds}")
            
            process_stratum(input_vrt_path, stratum_bounds, stratum_output_path)
            print(f"Stratum {i}_{j} saved to {stratum_output_path}")
'''

In [None]:
'''
# Example usage:
input_vrt_alos_path = '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.vrt'
output_zarr_alos_base_path = '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx_stratum_'
process_and_save_chunks(input_vrt_alos_path, output_zarr_alos_base_path)
'''

In [2]:

path_tif = '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.zarr'
dax = xr.open_zarr(path_tif)["da"]
dax

Unnamed: 0,Array,Chunk
Bytes,8 B,4 B
Shape,"(2, 1, 1)","(1, 1, 1)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 8 B 4 B Shape (2, 1, 1) (1, 1, 1) Dask graph 2 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1  2,

Unnamed: 0,Array,Chunk
Bytes,8 B,4 B
Shape,"(2, 1, 1)","(1, 1, 1)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Save hls zarr in strata

In [None]:
'''
import os
import tempfile
import shutil
import rasterio
import xarray as xr
import rioxarray
from dask.diagnostics import ProgressBar
'''

In [None]:
'''

def calculate_bounds_for_stratum(vrt_path, i, j):
    with rasterio.open(vrt_path) as src:
        bounds = src.bounds
        width = (bounds.right - bounds.left) / 3
        height = (bounds.top - bounds.bottom) / 3
        left = bounds.left + i * width
        right = left + width
        bottom = bounds.bottom + j * height
        top = bottom + height
        return (left, bottom, right, top)

def process_stratum(input_vrt_path, bounds, output_zarr_path):
    try:
        with rasterio.open(input_vrt_path) as src:
            window = rasterio.windows.from_bounds(*bounds, transform=src.transform)
            data = src.read(window=window)
            num_bands = src.count  # Get the number of bands
        
        stratum = xr.DataArray(
            data,
            dims=["band", "y", "x"],
            coords={
                "band": range(1, num_bands + 1),  # Set correct band coordinates
                "y": range(data.shape[1]),
                "x": range(data.shape[2])
            }
        )
        stratum.rio.write_crs("EPSG:4326", inplace=True)
        
        # Rechunking to ensure uniform chunk sizes
        stratum = stratum.chunk({'band': 1, 'x': 1000, 'y': 1000})

        ds = xr.Dataset({"da": stratum})
        
        # Using dask to_zarr with progress bar for better tracking
        with ProgressBar():
            ds.to_zarr(output_zarr_path, mode="w")
    except rasterio.errors.RasterioIOError as e:
        print(f"Error processing stratum with bounds {bounds}: {e}")

def process_and_save_chunks(input_vrt_path, output_zarr_base_path):
    for i in range(3):
        for j in range(3):
            stratum_bounds = calculate_bounds_for_stratum(input_vrt_path, i, j)
            stratum_output_path = f"{output_zarr_base_path}{i}_{j}.zarr"
            print(f"Processing stratum {i}_{j} with bounds: {stratum_bounds}")
            
            process_stratum(input_vrt_path, stratum_bounds, stratum_output_path)
            print(f"Stratum {i}_{j} saved to {stratum_output_path}")


'''


In [None]:
'''
# Example usage:
input_vrt_hls_path = '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.vrt'
output_zarr_hls_base_path = '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx_stratum_'
process_and_save_chunks(input_vrt_hls_path, output_zarr_hls_base_path)

'''


In [None]:
'''
rt.raster_clip(
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.vrt",
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_GEDI_25m_mx.vrt",
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_GEDI_30m_mx_rs.vrt",
)

da1 = rioxarray.open_rasterio(
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_GEDI_30m_mx_rs.vrt", chunck=(1, 1000, 1000), lock=False
).astype("float32")
da1 = da1.fillna(np.nan)
da1.rio.write_crs("EPSG:4326", inplace=True)
ds = xr.Dataset({"da": da1})
ds.to_zarr("/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_GEDI_30m_mx.zarr", mode="w")
'''

In [5]:
import rioxarray
import xarray as xr
import numpy as np

# Define file paths
input_path = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.vrt"
output_path = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.zarr"

# Open the raster file with appropriate chunking
da1 = rioxarray.open_rasterio(
    input_path,
    chunks=(1, 1000, 1000),
    lock=False
).astype("float32")

# Fill NaNs with np.nan explicitly
da1 = da1.fillna(np.nan)

# Write the CRS to the dataset
da1.rio.write_crs("EPSG:4326", inplace=True)

# Create a Dataset from the DataArray
ds = xr.Dataset({"da": da1})

# Save the dataset to a Zarr file
ds.to_zarr(output_path, mode="w")

print("Dataset saved successfully.")



ERROR 1: PROJ: internal_proj_create_from_database: /Users/stephaniegeorge/miniforge3/envs/ctrees-dev/share/proj/proj.db contains DATABASE.LAYOUT.VERSION.MINOR = 0 whereas a number >= 2 is expected. It comes from another PROJ installation.


KeyboardInterrupt: 

In [None]:
'''
ds1 = xr.open_zarr("/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.zarr")["da"].chunk(
    [1, 1000, 1000]
)
ds2 = xr.open_zarr("/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.zarr")["da"].chunk(
    [1, 1000, 1000]
)
ds3 = xr.open_zarr("/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.zarr")["da"].chunk(
    [1, 1000, 1000]
)
ds2["x"] = ds1.x.values
ds2["y"] = ds1.y.values
ds3["x"] = ds1.x.values
ds3["y"] = ds1.y.values

combined_ds = xr.concat([ds1, ds2, ds3], dim="band")
combined_ds = combined_ds.rio.write_crs("EPSG:4326")
combined_ds = combined_ds.chunk([-1, 1000, 10000])
combined_ds.attrs = {}
# Save the combined dataset back to a Zarr dataset
ds = xr.Dataset({"da": combined_ds})
del ds.da.encoding["chunks"]
ds.to_zarr("/Volumes/STVHD/Mexico_AGB_30m/test/xlayers_combine_30m.zarr", mode="w")
ds

ds

'''

In [16]:
import rioxarray
import xarray as xr
import numpy as np
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define file paths
input_paths = [
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.zarr",
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.zarr",
    "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.zarr"
]
output_path = "/Volumes/STVHD/Mexico_AGB_30m/test/xlayers_combine_30m.zarr"

# Define the target CRS
target_crs = "EPSG:4326"

def verify_and_set_crs(ds, target_crs):
    if not ds.rio.crs:
        logger.info("Setting CRS.")
        ds = ds.rio.write_crs(target_crs)
    return ds

def reproject_to_match(ds, reference_ds):
    if ds.x.size != reference_ds.x.size or ds.y.size != reference_ds.y.size:
        logger.info("Reprojecting to match dimensions.")
        ds = ds.rio.reproject_match(reference_ds)
    return ds

try:
    # Open the raster files with appropriate chunking
    logger.info("Opening raster files.")
    datasets = [xr.open_zarr(path)["da"].chunk([1, 1000, 1000]) for path in input_paths]

    # Step 1: Verify and set CRS for each dataset
    for i, ds in enumerate(datasets):
        datasets[i] = verify_and_set_crs(ds, target_crs)

    # Step 2: Reproject datasets if necessary
    reference_ds = datasets[0]
    for i, ds in enumerate(datasets):
        if i > 0:  # Skip the first dataset since it is the reference
            datasets[i] = reproject_to_match(ds, reference_ds)

    # Step 3: Align dimensions
    for ds in datasets:
        ds["x"] = reference_ds.x.values
        ds["y"] = reference_ds.y.values

    # Step 4: Concatenate along the band dimension
    logger.info("Concatenating datasets.")
    combined_ds = xr.concat(datasets, dim="band")

    # Write the CRS to the combined dataset
    combined_ds = combined_ds.rio.write_crs(target_crs)

    # Chunk the dataset
    combined_ds = combined_ds.chunk([-1, 1000, 10000])
    combined_ds.attrs = {}

    # Save the combined dataset back to a Zarr dataset
    ds = xr.Dataset({"da": combined_ds})
    del ds.da.encoding["chunks"]
    logger.info("Saving dataset to Zarr.")
    ds.to_zarr(output_path, mode="w")

    logger.info("Dataset saved successfully.")
except Exception as e:
    logger.error(f"Failed to process raster files: {e}")


INFO:__main__:Opening raster files.
INFO:__main__:Setting CRS for dataset 0.
ERROR 1: PROJ: internal_proj_create_from_database: /Users/stephaniegeorge/miniforge3/envs/ctrees-dev/share/proj/proj.db contains DATABASE.LAYOUT.VERSION.MINOR = 0 whereas a number >= 2 is expected. It comes from another PROJ installation.
INFO:__main__:Setting CRS for dataset 1.
ERROR 1: PROJ: internal_proj_create_from_database: /Users/stephaniegeorge/miniforge3/envs/ctrees-dev/share/proj/proj.db contains DATABASE.LAYOUT.VERSION.MINOR = 0 whereas a number >= 2 is expected. It comes from another PROJ installation.
INFO:__main__:Reprojecting dataset 1 to match dimensions.


: 

Step 1: Load and concatenate the data

In [11]:


def load_and_concat_strata(base_path, variable_name, num_strata):
    # Create a list to store the datasets
    datasets = []
    
    for i in range(num_strata):
        for j in range(num_strata):
            # Construct file path for each stratum
            file_path = f"{base_path}/{variable_name}_stratum_{i}_{j}.zarr"
            try:
                ds = xr.open_zarr(file_path)
                print(f"Loaded {file_path}")
                datasets.append(ds)
            except Exception as e:
                print(f"Error opening {file_path}: {e}")
    
    if not datasets:
        raise ValueError(f"No datasets found or loaded for {variable_name}.")
    
    # Concatenate along the first dimension (assuming 'concat_dim' is the concatenation dimension)
    try:
        combined_ds = xr.concat(datasets, dim='concat_dim')
    except Exception as e:
        print(f"Error during concatenation: {e}")
        raise
    
    return combined_ds

# Example usage
base_path = "/Volumes/STVHD/Mexico_AGB_30m/test"
num_strata = 3

try:
    ds_dem = load_and_concat_strata(base_path, "test_layers_dem_30m_mx", num_strata)
    ds_hls = load_and_concat_strata(base_path, "test_layers_hls_30m_mx", num_strata)
    ds_alos = load_and_concat_strata(base_path, "test_layers_alos_30m_mx", num_strata)
    print("Datasets loaded and concatenated successfully.")
except Exception as e:
    print(f"Error during loading and concatenating datasets: {e}")


Error opening /Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_0_0.zarr: No such file or directory: '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_0_0.zarr'
Error opening /Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_0_1.zarr: No such file or directory: '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_0_1.zarr'
Error opening /Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_0_2.zarr: No such file or directory: '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_0_2.zarr'
Error opening /Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_1_0.zarr: No such file or directory: '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_1_0.zarr'
Error opening /Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_1_1.zarr: No such file or directory: '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_1_1.zarr'
Error opening /Volumes/STVHD/M

Step2: Interpolation

In [None]:
def interpolate_datasets(ds_dem, ds_hls, ds_alos):
    try:
        # Rechunk datasets if necessary
        ds_dem = ds_dem.chunk({'x': -1, 'y': -1})
        ds_hls = ds_hls.chunk({'x': -1, 'y': -1})
        ds_alos = ds_alos.chunk({'x': -1, 'y': -1})

        # Interpolate datasets
        ds_hls_interp = ds_hls.interp(x=ds_dem.x, y=ds_dem.y, method="nearest")
        ds_alos_interp = ds_alos.interp(x=ds_dem.x, y=ds_dem.y, method="nearest")
        
        print("Interpolation successful.")
        return ds_hls_interp, ds_alos_interp
    except Exception as e:
        print(f"Error during interpolation: {e}")
        raise

# Example usage
try:
    ds_hls_interp, ds_alos_interp = interpolate_datasets(ds_dem, ds_hls, ds_alos)
except Exception as e:
    print(f"Error during preprocessing or interpolation: {e}")


Step3: Combine and save in chunks

In [None]:
import xarray as xr
import dask

def preprocess_dataset(ds, chunk_size_x=1000, chunk_size_y=1000):
    """Preprocess dataset to ensure manageable chunk sizes."""
    return ds.chunk({'x': chunk_size_x, 'y': chunk_size_y})

def save_combined_dataset(ds, output_path):
    """Save the combined dataset with optimized chunk management."""
    try:
        # Ensure the dataset is chunked appropriately
        ds = preprocess_dataset(ds, chunk_size_x=1000, chunk_size_y=1000)

        # Adjust Dask configuration
        with dask.config.set({'array.slicing.split_large_chunks': True}):
            ds.to_zarr(output_path, mode='w')
        
        print(f"Combined dataset saved successfully to {output_path}")
    except Exception as e:
        print(f"Error during saving the combined dataset: {e}")
        raise

def incremental_merge(datasets, output_path):
    """Incrementally merge datasets to manage chunk size and memory usage."""
    try:
        # Start with the first dataset
        combined_ds = preprocess_dataset(datasets[0], chunk_size_x=1000, chunk_size_y=1000)

        # Incrementally merge the remaining datasets
        for ds in datasets[1:]:
            # Ensure each dataset is appropriately chunked
            ds = preprocess_dataset(ds, chunk_size_x=1000, chunk_size_y=1000)
            
            # Merge the current dataset with the combined dataset
            combined_ds = xr.merge([combined_ds, ds], compat='override')
            print(f"Dataset merged successfully.")

        # Save the final combined dataset
        save_combined_dataset(combined_ds, output_path)
    
    except Exception as e:
        print(f"Error during incremental merging: {e}")
        raise

# Example usage
output_path = "/Volumes/STVHD/Mexico_AGB_30m/test/xlayers_combine_30m.zarr"
try:
    datasets = [ds_dem, ds_hls_interp, ds_alos_interp]  # List of datasets to merge
    incremental_merge(datasets, output_path)
except Exception as e:
    print(f"Error during processing: {e}")


load X-layers(in Zarr) and set up

In [None]:
path_tif = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_GEDI_30m_mx.zarr"
da0 = xr.open_zarr(path_tif)["da"]
da0


In [None]:
import xarray as xr
import dask

# Function to load datasets from multiple strata with smaller chunks
def load_datasets(file_pattern, strata):
    datasets = []
    for stratum in strata:
        # Generate file path for each stratum
        file_path = file_pattern.format(*stratum)
        try:
            ds = xr.open_zarr(file_path)["da"].chunk([1, 1000, 1000])  # Adjusted chunk size
            datasets.append(ds)
            print(f"Loaded {file_path}")
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
    return datasets

# Function to combine datasets
def combine_datasets(datasets):
    try:
        if not datasets:
            raise ValueError("No datasets to combine.")
        combined_ds = xr.concat(datasets, dim="band")
        combined_ds = combined_ds.rio.write_crs("EPSG:4326")
        combined_ds = combined_ds.chunk([-1, 1000, 1000])  # Adjusted chunk size

        # Convert combined DataArray to a Dataset
        combined_ds = xr.Dataset({"da": combined_ds})
        
        print("Datasets combined successfully.")
        return combined_ds
    except Exception as e:
        print(f"Error during combining datasets: {e}")
        raise

# Function to save the combined dataset with handling attributes
def save_combined_dataset(ds, output_path):
    try:
        # Remove conflicting attributes from variables
        if "da" in ds:
            da = ds["da"]
            da.attrs.pop('grid_mapping', None)
        
        # Remove encoding fields
        if "da" in ds:
            del ds["da"].encoding["chunks"]
        
        # Adjust Dask configuration to manage large chunks
        with dask.config.set({'array.slicing.split_large_chunks': True}):
            ds.to_zarr(output_path, mode='w', consolidated=True)  # Use consolidated=True for improved performance
        
        print(f"Combined dataset saved successfully to {output_path}")
    except Exception as e:
        print(f"Error during saving the combined dataset: {e}")
        raise


In [None]:

# File patterns and strata for different datasets
file_patterns = {
    'alos': '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx_stratum_{}_{}.zarr',
    'hls': '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx_stratum_{}_{}.zarr',
    'dem': '/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx_stratum_{}_{}.zarr'
}

# Strata as tuples (stratum1, stratum2)
strata = [('0', '0'), ('0', '1'), ('0', '2'), ('1', '0'), ('1', '1'), ('1', '2'), ('2', '0'), ('2', '1'), ('2', '2')]


In [None]:

# Load datasets for each type
datasets_alos = load_datasets(file_patterns['alos'], strata)
datasets_hls = load_datasets(file_patterns['hls'], strata)
datasets_dem = load_datasets(file_patterns['dem'], strata)


In [None]:

# Combine datasets
combined_ds_alos = combine_datasets(datasets_alos)
combined_ds_hls = combine_datasets(datasets_hls)
combined_ds_dem = combine_datasets(datasets_dem)


In [None]:

# Save the combined datasets separately
output_paths = {
    'alos': "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.zarr",
    'hls': "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.zarr",
    'dem': "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.zarr"
}


In [None]:

save_combined_dataset(combined_ds_alos, output_paths['alos'])


In [12]:
path_tif = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.zarr"
dax = xr.open_zarr(path_tif)["da"]
dax

Unnamed: 0,Array,Chunk
Bytes,8 B,4 B
Shape,"(2, 1, 1)","(1, 1, 1)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 8 B 4 B Shape (2, 1, 1) (1, 1, 1) Dask graph 2 chunks in 2 graph layers Data type float32 numpy.ndarray",1  1  2,

Unnamed: 0,Array,Chunk
Bytes,8 B,4 B
Shape,"(2, 1, 1)","(1, 1, 1)"
Dask graph,2 chunks in 2 graph layers,2 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
save_combined_dataset(combined_ds_hls, output_paths['hls'])


In [13]:
path_tif = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.zarr"
dax = xr.open_zarr(path_tif)["da"]
dax

Unnamed: 0,Array,Chunk
Bytes,170.34 GiB,3.81 MiB
Shape,"(5, 74214, 123224)","(1, 1000, 1000)"
Dask graph,46500 chunks in 2 graph layers,46500 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 170.34 GiB 3.81 MiB Shape (5, 74214, 123224) (1, 1000, 1000) Dask graph 46500 chunks in 2 graph layers Data type float32 numpy.ndarray",123224  74214  5,

Unnamed: 0,Array,Chunk
Bytes,170.34 GiB,3.81 MiB
Shape,"(5, 74214, 123224)","(1, 1000, 1000)"
Dask graph,46500 chunks in 2 graph layers,46500 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
save_combined_dataset(combined_ds_dem, output_paths['dem'])


In [14]:
path_tif = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.zarr"
dax = xr.open_zarr(path_tif)["da"]
dax

Unnamed: 0,Array,Chunk
Bytes,42.22 GiB,3.81 MiB
Shape,"(1, 82615, 137172)","(1, 1000, 1000)"
Dask graph,11454 chunks in 2 graph layers,11454 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 42.22 GiB 3.81 MiB Shape (1, 82615, 137172) (1, 1000, 1000) Dask graph 11454 chunks in 2 graph layers Data type float32 numpy.ndarray",137172  82615  1,

Unnamed: 0,Array,Chunk
Bytes,42.22 GiB,3.81 MiB
Shape,"(1, 82615, 137172)","(1, 1000, 1000)"
Dask graph,11454 chunks in 2 graph layers,11454 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:


# Define the paths to the individual datasets
alos_path = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_alos_30m_mx.zarr"
hls_path = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.zarr"
dem_path = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_dem_30m_mx.zarr"


In [None]:
# Load the datasets
def load_combined_dataset(path):
    try:
        ds = xr.open_zarr(path)
        print(f"Loaded dataset from {path}")
        return ds
    except Exception as e:
        print(f"Error loading dataset from {path}: {e}")
        raise

# Ensure unique 'band' dimension values
def ensure_unique_band(ds):
    if 'band' in ds.dims:
        _, index = pd.factorize(ds['band'].values)
        ds = ds.isel(band=sorted(index))
        print(f"Ensured unique 'band' values for dataset")
    return ds

In [None]:

# Load each dataset
ds_alos = load_combined_dataset(alos_path)
ds_hls = load_combined_dataset(hls_path)
ds_dem = load_combined_dataset(dem_path)


In [None]:

# Ensure unique 'band' dimension values for each dataset
ds_alos = ensure_unique_band(ds_alos)
ds_hls = ensure_unique_band(ds_hls)
ds_dem = ensure_unique_band(ds_dem)


In [None]:
# Combine the datasets along a new dimension
try:
    combined_ds = xr.concat([ds_alos, ds_hls, ds_dem], dim="variable")
    combined_ds = combined_ds.chunk({'variable': -1, 'y': 1000, 'x': 1000})  # Adjust chunk size if necessary
    print("Datasets combined successfully.")
except Exception as e:
    print(f"Error during combining datasets: {e}")
    raise

In [None]:
# Save the combined dataset to the specified output path
output_path = "/Volumes/STVHD/Mexico_AGB_30m/test/xlayers_combine_30m.zarr"


In [None]:

def save_combined_dataset(ds, output_path):
    try:
        # Remove conflicting attributes from variables
        for var in ds.data_vars.values():
            var.attrs.pop('grid_mapping', None)
        
        # Remove encoding fields
        for var in ds.variables:
            if 'chunks' in ds[var].encoding:
                del ds[var].encoding['chunks']
        
        # Adjust Dask configuration to manage large chunks
        with dask.config.set({'array.slicing.split_large_chunks': True}):
            ds.to_zarr(output_path, mode='w', consolidated=True)  # Use consolidated=True for improved performance
        
        print(f"Combined dataset saved successfully to {output_path}")
    except Exception as e:
        print(f"Error during saving the combined dataset: {e}")
        raise


In [None]:

# Save the combined dataset
save_combined_dataset(combined_ds, output_path)

In [None]:

path_tif = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_hls_30m_mx.zarr"
da1 = xr.open_zarr(path_tif)["da"]
da1
print(da0.rio.resolution())
print(da1.rio.resolution())

print(da1.shape[2] / da0.shape[2])

extraction training/eval/test

In [None]:
outname = "/Volumes/STVHD/Mexico_AGB_30m/test/mexico_agb/RH98_30m"
os.makedirs(outname, exist_ok=True)

pix_size = 0.00022457882102988513
offsets = 0.00005 + pix_size * 2.5
width = 3

In [None]:
path_tif = "/Volumes/STVHD/Mexico_AGB_30m/test/test_layers_GEDI_30m_mx.zarr"
da0 = xr.open_zarr(path_tif)["da"].chunk((-1, 1000, 1000))
da0

In [None]:
path_tif = "/Volumes/STVHD/Mexico_AGB_30m/test/xlayers_combine_30m.zarr"
da1 = xr.open_zarr(path_tif)["da"].chunk((-1, 1000, 1000))

da1

In [None]:


def get_layers_xy(xarr0, xarr1, scale, width, mask_band, target_bands, valid_range, out_path):
    darr0 = xarr0.data.rechunk({0: -1, 1: width, 2: width})
    darr1 = xarr1.data.rechunk({0: -1, 1: width, 2: width})
    
    tiles0 = darr0.to_delayed().ravel()
    tiles1 = darr1.to_delayed().ravel()
    
    # Ensure tiles0 and tiles1 have the same length
    min_length = min(len(tiles0), len(tiles1))
    tiles0 = tiles0[:min_length]
    tiles1 = tiles1[:min_length]

    results = [
        dask.delayed(extract_training_lst)(
            i,
            tiles0[i],
            tiles1[i],
            scale=scale,
            width=width,
            valid_range=valid_range,
            mask_band=mask_band,
            target_bands=target_bands,
            out_path=out_path,
        )
        for i in range(len(tiles0))
    ]
    return results


In [None]:

# Example usage
results = get_layers_xy(
    da0,
    da1,
    scale=1,
    width=width,
    mask_band=0,
    target_bands=[0],
    valid_range=(0, 700),
    out_path=f"{outname}/training_samples",
)

computed = dask.compute(*results)


In [None]:
results = get_layers_xy(
    da0,
    da1,
    scale=1,
    width=width,
    mask_band=0,
    target_bands=[0],
    valid_range=(0, 700),
    out_path=f"{outname}/training_samples",
)

computed = dask.compute(*results)

In [None]:
da0.plot(vmin=0, vmax=400, figsize=(6, 8))

In [None]:
#/Volumes/STVHD/Mexico_AGB_30m/test/training_samples_753564
# read data

files_15 = glob.glob(f"{outname}/training_samples*")

X_15_lst = []
y_15_lst = []

for ifile in files_15:
    X_lst, y_lst = pickle.load(open(ifile, "rb"))
    X_15_lst.append(X_lst)
    y_15_lst.append(y_lst)

X_15 = np.concatenate(X_15_lst, axis=0)
y_15 = np.concatenate(y_15_lst, axis=0)

print(f"Concatenated X shape: {X_15.shape}")
print(f"Concatenated y shape: {y_15.shape}")
'''
X_15_lst = []
y_15_lst = []
for ifile in files_15:
    X_lst, y_lst = pickle.load(open(ifile, "rb"))
    # os.remove(ifile)
    X_15_lst.append(X_lst)
    y_15_lst.append(y_lst)
X_15 = np.concatenate(X_15_lst, axis=0)
y_15 = np.concatenate(y_15_lst, axis=0)

valid_idx = ~np.isnan(X_15.reshape(X_15.shape[0], -1)).any(axis=1)
X_15 = X_15[valid_idx]
y_15 = y_15[valid_idx]

'''

In [None]:
X_15

## save as tfrecords

In [None]:
X_train, X_test, y_train, y_test = train_test_split(
    X_15, y_15[:, 0], test_size=0.2, random_state=44 
)
X_train, X_eval, y_train, y_eval = train_test_split(
    X_train, y_train, test_size=0.2, random_state=55
)

In [None]:
print(X_train.shape)
print(y_train.shape)
plt.hist(y_15[:, 0], bins=50)

In [None]:
X_train.shape

In [None]:
# Function to serialize a single example of features and label
def serialize_example(feature, label):
    feature = feature.flatten()  # Flatten the multi-dimensional feature to 1D
    feature_dict = {
        "features": tf.train.Feature(float_list=tf.train.FloatList(value=feature)),
        "label": tf.train.Feature(float_list=tf.train.FloatList(value=[label])),
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    return example_proto.SerializeToString()


with tf.io.TFRecordWriter(f"{outname}/tf_training.tfrecord") as writer:
    for feature, label in zip(X_train, y_train):
        serialized_example = serialize_example(feature, label)
        writer.write(serialized_example)
with tf.io.TFRecordWriter(f"{outname}/tf_eval.tfrecord") as writer:
    for feature, label in zip(X_eval, y_eval):
        serialized_example = serialize_example(feature, label)
        writer.write(serialized_example)
with tf.io.TFRecordWriter(f"{outname}/tf_test.tfrecord") as writer:
    for feature, label in zip(X_test, y_test):
        serialized_example = serialize_example(feature, label)
        writer.write(serialized_example)

## Load training and preprocessing

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 128
BATCH_SIZE_TEST = 32
width = 3
nbands = 8
n_train = y_train.shape[0]

In [None]:
# Read the TFRecord file
def parse_function(example_proto):
    feature_description = {
        "features": tf.io.FixedLenFeature(
            [width * width * nbands], tf.float32
        ),  # Must match the flattened shape
        "label": tf.io.FixedLenFeature([], tf.float32),
    }
    parsed_features = tf.io.parse_single_example(example_proto, feature_description)
    parsed_features["features"] = tf.reshape(
        parsed_features["features"], (width, width, nbands)
    )  # Reshape back to original shape
    return parsed_features["features"], parsed_features["label"]


train_dataset = tf.data.TFRecordDataset(f"{outname}/tf_training.tfrecord").map(
    parse_function
)
eval_dataset = tf.data.TFRecordDataset(f"{outname}/tf_eval.tfrecord").map(
    parse_function
)
test_dataset = tf.data.TFRecordDataset(f"{outname}/tf_test.tfrecord").map(
    parse_function
)

In [None]:
train_dataset

In [None]:
# Find mean/std of training data
SAMPLE_N = 50000
X_list = []
y_list = []
for X, Y in train_dataset.shuffle(buffer_size=200000).take(SAMPLE_N):
    X_list.append(X.numpy())
    y_list.append(Y.numpy())
X_all = np.stack(X_list, axis=0)
y_all = np.stack(y_list, axis=0)
X_mean = np.nanmean(X_all.reshape(-1, nbands), axis=0).tolist()
X_std = np.nanstd(X_all.reshape(-1, nbands), axis=0).tolist()
Y_mean = np.nanmean(y_all.reshape(-1, 1), axis=0).tolist()
Y_std = np.nanstd(y_all.reshape(-1, 1), axis=0).tolist()
print(f"X_mean: {X_mean}")
print(f"X_std: {X_std}")
print(f"y_mean: {Y_mean}")
print(f"y_std: {Y_std}")

In [None]:
def batch_normalize_xy(X, Y):
    sample_mean = tf.constant([[X_mean]])
    sample_std = tf.constant([[X_std]])
    X = (X - sample_mean) / sample_std
    X = tf.where(tf.math.is_nan(X), tf.zeros_like(X), X)

    y_mean = tf.constant(Y_mean)
    y_std = tf.constant(Y_std)
    Y = (Y - y_mean) / y_std
    # Y = tf.where(tf.math.is_nan(Y), tf.ones_like(Y) * 63, Y)
    # indices = tf.constant(list(range(0,7)) + list(range(11,18)))
    # Y = tf.gather(Y, indices)
    return X, Y


ds1 = (
    train_dataset.repeat()
    .shuffle(80000)
    .batch(
        BATCH_SIZE,
    )
    .map(batch_normalize_xy, num_parallel_calls=AUTOTUNE)
    .prefetch(AUTOTUNE)
)
ds2 = (
    eval_dataset.batch(
        BATCH_SIZE_TEST,
    )
    .map(batch_normalize_xy, num_parallel_calls=AUTOTUNE)
    .prefetch(AUTOTUNE)
)

fname = f"{outname}_densenet_l1_ckpt"
dn_structure = (12,)
epochs = 100

# tf.debugging.set_log_device_placement(True)
# strategy = tf.distribute.MirroredStrategy()
# print("Number of devices: {}".format(strategy.num_replicas_in_sync))

m = densenet_model(dn_structure, width, nbands, 1)
print(m.summary())

m.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    # loss=tf.keras.losses.get("MeanSquaredError"),
    loss=tf.keras.losses.get("MeanAbsoluteError"),
    metrics=[tf.keras.metrics.get(metric) for metric in ["RootMeanSquaredError"]],
)

# with strategy.scope():
if os.path.exists(fname):
    m.load_weights(fname)
    print("Pretrained model loaded...")
else:
    print("Building new model...")

earlystop_callback = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    verbose=1,
    patience=5,
    restore_best_weights=True,
)

history = m.fit(
    x=ds1,
    epochs=epochs,
    callbacks=[
        earlystop_callback,
    ],
    steps_per_epoch=int(n_train / BATCH_SIZE),
    validation_data=ds2,
)
m.save_weights(fname, save_format="h5")

plt.figure(figsize=(6, 4))
plt.plot(history.history["loss"], label="Training")
plt.plot(history.history["val_loss"], label="Validation")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(loc="upper right")
plt.savefig(f"{fname}_history.png", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
# define evaluation function
def eval_plot(
    ds_test,
    model,
    out_name,
    n_samples=None,
):
    ds2 = (
        ds_test.batch(BATCH_SIZE_TEST)
        .map(batch_normalize_xy, num_parallel_calls=AUTOTUNE)
        .prefetch(AUTOTUNE)
    )

    X_list = []
    y_list = []
    for X, Y in ds2:
        X_list.append(X.numpy())
        y_list.append(Y.numpy())
    X_test = np.concatenate(X_list)
    y_test = np.concatenate(y_list)

    if n_samples is not None:
        idx_test = np.random.choice(X_test.shape[0], size=n_samples, replace=False)
        X_test = X_test[idx_test, :]
        y_test = y_test[idx_test, None]
    y_test = y_test * np.array(Y_std)[None, :] + np.array(Y_mean)[None, :]
    predictions = (
        model.predict(X_test, verbose=1) * np.array(Y_std)[None, :]
        + np.array(Y_mean)[None, :]
    )

    Parallel(n_jobs=-1)(
        delayed(density_scatter_plot)(
            y_test[:, i_test][~np.isnan(y_test[:, i_test])],
            predictions[:, i_test][~np.isnan(y_test[:, i_test])],
            file_name=f"{out_name}_xyscatter.png",
        )
        for i_test in tqdm(range(y_test.shape[1]))
    )
    Parallel(n_jobs=-1)(
        delayed(residual_box_plot)(
            y_test[:, i_test][~np.isnan(y_test[:, i_test])],
            predictions[:, i_test][~np.isnan(y_test[:, i_test])],
            outname=f"{out_name}_residual.png",
            floating=True,
        )
        for i_test in tqdm(range(y_test.shape[1]))
    )
    Parallel(n_jobs=-1)(
        delayed(residual2_box_plot)(
            y_test[:, i_test][~np.isnan(y_test[:, i_test])],
            predictions[:, i_test][~np.isnan(y_test[:, i_test])],
            outname=f"{out_name}_residual2.png",
            floating=True,
        )
        for i_test in tqdm(range(y_test.shape[1]))
    )

In [None]:
eval_plot(
    train_dataset,
    m,
    f"{fname}_train",
    n_samples=1000,
)

In [None]:
eval_plot(
    eval_dataset,
    m,
    f"{fname}_eval",
    n_samples=3000,
)

In [None]:
eval_plot(
    test_dataset,
    m,
    f"{fname}_test",
    n_samples=3000,
)

## Prediction

In [None]:
width = 3
bands = list(range(8))
chunk_size = 800
mask_band = 4

dn_structure = (12,)
fname = f"{outname}_densenet_l1_ckpt"

m = densenet_model(dn_structure, width, nbands, 1)
m.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.get("MeanAbsoluteError"),
    metrics=[tf.keras.metrics.get(metric) for metric in ["RootMeanSquaredError"]],
)
m.load_weights(fname)

In [None]:
def build_vrt_mosaic(in_file, out_file):
    in_fname = "merge_list.txt"
    with open(in_fname, "w") as f:
        f.writelines("%s\n" % i for i in in_file)
    command = f"gdalbuildvrt -overwrite -input_file_list {in_fname} {out_file}"
    print(command)
    output = os.system(command)
    print(output)


def yprep_model(Y):
    y_mean = np.array(Y_mean)
    y_std = np.array(Y_std)
    Y = Y * y_std[None, :] + y_mean[None, :]
    return Y


def xprep_model(X):
    sample_mean = np.array(X_mean)
    sample_std = np.array(X_std)
    X = (X - sample_mean[None, :]) / sample_std[None, :]
    X = np.where(np.isnan(X), np.zeros_like(X), X)
    X = X[:, :]
    return X

In [None]:
print(outname)

In [None]:
os.makedirs(f"{outname}_densenet_l1", exist_ok=True)
out_prefix = f"{outname}_densenet_l1/pred"

ij_lst = [
    (i, j)
    for i in range(np.ceil(da1.shape[1] / chunk_size).astype(int))
    for j in range(np.ceil(da1.shape[2] / chunk_size).astype(int))
]
random.shuffle(ij_lst)

# Parallel(n_jobs=2)(
Parallel(n_jobs=1, prefer="threads")(
    delayed(data_cnn_prediction)(
        ij,
        chunk_size,
        path_tif,
        m,
        yprep_model,
        xprep_model,
        bands=bands,
        mask_band=mask_band,
        nbands_out=1,
        patchsize=(width, width),
        strides=(1, 1),
        scale=10,
        crs="EPSG:4326",
        withxy=False,
        out_name=out_prefix,
    )
    for ij in tqdm(ij_lst)
)

In [None]:
print(outname)

In [None]:
out_file = f"{outname}_densenet_l1_rh98_mosaic.vrt"
flist = f"{outname}_densenet_l1/pred*.tif"
in_file = glob.glob(flist)
build_vrt_mosaic(in_file, out_file)

cog_file = f"{outname}_densenet_l1_rh98_mosaic.tif"
cmmd = (
    f"gdal_translate {out_file} {cog_file} -of COG -co BIGTIFF=YES -co COMPRESS=DEFLATE "
    f"-co PREDICTOR=2 -co NUM_THREADS=16 -ot Int16"
)
os.system(cmmd)

'''
cmmd = (
    f"aws s3 cp {outname}_densenet_l1_rh98_mosaic.tif "
    f"{s3_path}{outname}_densenet_l1_rh98_mosaic.tif"
)
os.system(cmmd)
'''