In [1]:
import os
os.environ['USE_PYGEOS'] = '0'   # Suppresses some warning about geopandas
import geopandas as gpd

# scipy basics
import pandas as pd
import numpy as np
from shapely.wkb import loads
import botocore
from osgeo import gdal      # Necessary to do this import to get rasterio to import
import rasterio
import rasterio.features

import time

# dask/parallelization libraries
import coiled
import dask
import dask.array as dar
from dask.distributed import Client, LocalCluster
import dask_geopandas
import pystac
import rioxarray
import xarray as xr
import xrspatial.local

First thing to do is spin up a Dask cluster. Below is a utility to create a Dask cluster on our AWS infrastructure. A Dask cluster consistents of a client (this notebook), a scheduler, and any number of workers. The scheduler and workers will be created as tasks on ECS, which just means they'll all be individual, fairly small machines (2 cores, 8 GB RAM each). 

By default, the utility will create 30 workers, but you can pass an n_workers= param to get more. The output of the cell will be a link to the scheduler UI, which will has some good graphs and data on what's going on in the cluster. Once you set the client here, Dask knows to use it for any future operations.

Making cloud and local clusters

In [40]:
coiled_cluster = coiled.Cluster(
    n_workers=10,
    use_best_zone=True, 
    compute_purchase_option="spot_with_fallback",
    idle_timeout="20 minutes",
    # worker_vm_types=["t3.medium"],
    # name="DGibbs Europe height flux model", 
    account='jterry64'   # Necessary to use the AWS environment that Justin set up in Coiled
)

Output()

Output()

In [41]:
# Coiled cluster (cloud run)
coiled_client = coiled_cluster.get_client()
coiled_client

0,1
Connection method: Cluster object,Cluster type: coiled.Cluster
Dashboard: https://cluster-fzheh.dask.host/cIm7XZQouv-2uAmr/status,

0,1
Dashboard: https://cluster-fzheh.dask.host/cIm7XZQouv-2uAmr/status,Workers: 5
Total threads: 20,Total memory: 74.11 GiB

0,1
Comm: tls://10.1.37.178:8786,Workers: 5
Dashboard: http://10.1.37.178:8787/status,Total threads: 20
Started: Just now,Total memory: 74.11 GiB

0,1
Comm: tls://10.1.40.10:35889,Total threads: 4
Dashboard: http://10.1.40.10:8787/status,Memory: 14.82 GiB
Nanny: tls://10.1.40.10:41117,
Local directory: /scratch/dask-scratch-space/worker-ph6nd999,Local directory: /scratch/dask-scratch-space/worker-ph6nd999

0,1
Comm: tls://10.1.40.141:36851,Total threads: 4
Dashboard: http://10.1.40.141:8787/status,Memory: 14.83 GiB
Nanny: tls://10.1.40.141:40167,
Local directory: /scratch/dask-scratch-space/worker-5y9i2v2o,Local directory: /scratch/dask-scratch-space/worker-5y9i2v2o

0,1
Comm: tls://10.1.46.217:41305,Total threads: 4
Dashboard: http://10.1.46.217:8787/status,Memory: 14.83 GiB
Nanny: tls://10.1.46.217:36087,
Local directory: /scratch/dask-scratch-space/worker-u8cn4jh2,Local directory: /scratch/dask-scratch-space/worker-u8cn4jh2

0,1
Comm: tls://10.1.40.130:32939,Total threads: 4
Dashboard: http://10.1.40.130:8787/status,Memory: 14.82 GiB
Nanny: tls://10.1.40.130:43657,
Local directory: /scratch/dask-scratch-space/worker-m1n7xew6,Local directory: /scratch/dask-scratch-space/worker-m1n7xew6

0,1
Comm: tls://10.1.37.248:33847,Total threads: 4
Dashboard: http://10.1.37.248:8787/status,Memory: 14.82 GiB
Nanny: tls://10.1.37.248:44345,
Local directory: /scratch/dask-scratch-space/worker-kvx45qg5,Local directory: /scratch/dask-scratch-space/worker-kvx45qg5


In [None]:
# # Local cluster (local run). Doesn't work-- .compute() method kill workers for unknown reasons
# # local_cluster = LocalCluster(silence_logs=False)
# local_cluster = LocalCluster()
# local_client = Client(local_cluster)

In [32]:
# Local single-process cluster (local run). Will run .compute() on just one process, not a whole cluster.
local_client = Client(processes=False)
local_client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://192.168.0.121:8787/status,

0,1
Dashboard: http://192.168.0.121:8787/status,Workers: 1
Total threads: 8,Total memory: 31.82 GiB
Status: running,Using processes: False

0,1
Comm: inproc://192.168.0.121/8144/1,Workers: 1
Dashboard: http://192.168.0.121:8787/status,Total threads: 8
Started: Just now,Total memory: 31.82 GiB

0,1
Comm: inproc://192.168.0.121/8144/4,Total threads: 8
Dashboard: http://192.168.0.121:56218/status,Memory: 31.82 GiB
Nanny: None,
Local directory: C:\Users\DAVID~1.GIB\AppData\Local\Temp\dask-scratch-space\worker-mmv_f5xn,Local directory: C:\Users\DAVID~1.GIB\AppData\Local\Temp\dask-scratch-space\worker-mmv_f5xn


Shutting down cloud and local clusters

In [31]:
coiled_cluster.shutdown()

In [38]:
local_client.shutdown()

Analysis

In [4]:
# Run with and without for loop

general_uri = 's3://gfw2-data/forest_change/GLAD_Europe_height_data/'

local_out_dir = 'C:\\GIS\\Carbon_model_Europe\\outputs\\'

In [42]:
# Run with and without for loop
# From https://notebooks-staging.wri.org/user/dagibbs22/lab/tree/msims/biodiversity_global_stats.ipynb

def get_tile_dataset(uri, name, template=None):
    try:
        return rioxarray.open_rasterio(uri, chunks=2000, default_name=name)
    except rasterio.errors.RasterioIOError as e:
        if template is not None:
            return xr.zeros_like(template)
        else:
            raise e

Some sample input files:

In [6]:
# Run without for loop
forest_height_2020_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_FH_2020.tif'
forest_height_2021_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_FH_2021.tif'

driver_uri = "s3://gfw2-data/forest_change/GLAD_Europe_resubmission__202306/test_1x1_deg/50N_010E_1deg_tree_cover_loss_driver_processed.tif"
peat_uri = "s3://gfw2-data/forest_change/GLAD_Europe_resubmission__202306/test_1x1_deg/50N_010E_1deg_peat_mask_processed.tif"
tclf_uri = "s3://gfw2-data/forest_change/GLAD_Europe_resubmission__202306/test_1x1_deg/50N_010E_1deg_tree_cover_loss_fire_processed.tif"
# land_cover_uri = ...
# planted_forest_uri = ...

In [7]:
# Run without for loop
forest_height_2020 = get_tile_dataset(forest_height_2020_uri, name="forest_height_2020")
forest_height_2021 = get_tile_dataset(forest_height_2021_uri, name="forest_height_2021")
# driver = get_tile_dataset(driver_uri, name="driver")
# peat = get_tile_dataset(peat_uri, name="peat")
# TCLF = get_tile_dataset(tclf_uri, name="tclf")

In [8]:
%%time

# Run without for loop
# Some sample dask operations
# print(forest_height_2020)
forest_height_2021.mean().compute()
# forest_height_2021.mean().compute()
# print(forest_height_2021.mean().compute())
# print(driver)        
# print(driver.mean().compute())
# print(forest_height_2020.where(driver > 1).mean().compute())

CPU times: total: 1.66 s
Wall time: 5.48 s


You'll likely want to make a class that encapsulates all the pixel data you'll pass into your decision tree for classification. This will make it much easier to read how your decision tree uses the pixel data, and also will give you a place to do any modifications to the input data or add additional fields based off of the input data you may want to re-use.

In [9]:
# Run with and without for loop

class ForestStateDecisionFactors:
    def __init__(self, height_prev_year, height_this_year):
        self.height_prev_year = height_prev_year
        self.height_this_year = height_this_year

I didn't fill this out for now, but this is just a stub for the class we discussed with Gary. It will take in your above decision factors class, which includes all the data in your layers at one pixel, and classify it using your decision tree logic.

In [10]:
# Run with and without for loop

class ForestStateDecisionTree:
    # define in a way similar to how we went over with Gary
    def classify(forestStateDecisionFactors):
        if forestStateDecisionFactors.height_prev_year >= 5 and forestStateDecisionFactors.height_this_year >= 5:   # maintained
            return 1
        elif forestStateDecisionFactors.height_prev_year >= 5 and forestStateDecisionFactors.height_this_year < 5:  # loss
            return 2
        elif forestStateDecisionFactors.height_prev_year < 5 and forestStateDecisionFactors.height_this_year >= 5:  # gain
            return 3
        else:                                                                                                       # no forest
            return 0

So, the xrspatial.local.combine method I mentioned doesn't seem to work very well on large arrays (it's not lazily computed), and I noticed they plan to deprecate it. Not sure if they're going to replace, but this method works just as well, but it's a little more confusing to understand.

What we'd like to do is to smush all your pixel across many layers into a single integer. We'd then like to be able to easily convert it back to the layer data later. This will make it much easier and faster to process your data in the decision tree.

NumPy has a nifty and very performant way of doing this called ravel_multi_index: https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html

For this kind of problem, we typically stack all our raster layers on top of each other to create a 3D array. Imagine instead we throw away the geospatial component and reshape thee data as an N-dimensional array, one dimension for each layer. The length of each dimension is the highest value of each layer (e.g. for boolean layer it's 1, for a categorical layer it'd be the number of categories).

You can map a pixel in your 3D geospatial array to this N-d array by putting it at the index of all its layer values. So if you have 2 boolean layers and a categorical layer, and have a pixel with values (true, false, 3), in the N-d array, its location would be arr[1][0][3].

Now, in NumPy, when we want to ravel an array, this means we view a multidimensional array and turn it into a 1D array by simply reading each row in order, wrapping around the end of row to the neext row. If you have a 2x8 array, you can ravel it by reading the first row of 8 and then the second row of 8. E.g. in the raveled view, arr[1][3] would actually be indexed as arr[11].

Ok, bearing with me? Putting these together, you can convert values in your N-d array into a single integer by converting them to the index in their raveled view. So in the above example, arr[1][0][3] can be converted to the integer 5 (if I have the orderr right), because it'd be the 5th index if you called ravel(arr)[5].

This provides a very convenient and fast way to convert multidimensional into an integer and back, as long as you know the max dimensions of each layer. So in index_dims below, you'll want to fill out the max dimension for each layer in the order you'll stack them. For boolean layers, it'll just be 1, for categorical layers, it'll be the max category number, and for temporal layers, it'll be the max time. If you don't know it, you can can also just use the max value for the data type, as long as its not int64 (so uint8 would 256). If you have float layers you use here, let me know - this may not work as well, but didn't notice any in your decision tree.

Despite all that explanation of what's going on, you can basically just call the function I made below, which will map stack of layers into that raveled index and return it as new array.

In [11]:
# Run without for loop

index_dims = (40, 40)

def consolidate_to_one_dimensional_index(chunk):
    return xr.DataArray(np.ravel_multi_index([chunk.forest_height_2020.data, chunk.forest_height_2021.data], index_dims), coords=chunk.coords, dims=chunk.dims)

In [12]:
# Run without for loop

# you'll want to combine them all into one xarray dataset, which is just a way to group arrays that
# are aligned on the same coordinates
decision_tree_ds = xr.Dataset({
    "forest_height_2020": forest_height_2020, 
    "forest_height_2021": forest_height_2021,
    # "drivers": driver
})

decision_tree_ds

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 15.26 MiB 0.95 MiB Shape (1, 4000, 4000) (1, 1000, 1000) Dask graph 16 chunks in 2 graph layers Data type uint8 numpy.ndarray",4000  4000  1,

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 15.26 MiB 0.95 MiB Shape (1, 4000, 4000) (1, 1000, 1000) Dask graph 16 chunks in 2 graph layers Data type uint8 numpy.ndarray",4000  4000  1,

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


In [13]:
# Run without for loop

# use the function above to map your dataset to that 1D index
decision_tree_index = decision_tree_ds.map_blocks(consolidate_to_one_dimensional_index, template=forest_height_2020.copy())
#decision_tree_index.compute()
# decision_tree_index.mean().compute()

In [26]:
# Run without for loop

# Save the decision tree index array to raster
decision_tree_index.rio.to_raster(f'{local_out_dir}decision_tree_index_2021__20230811_not_for_loop_new_v2.tif', compress='DEFLATE', dtype='uint16')

In [15]:
# Run without for loop

# find all unique index values across the tile
unique_indices = dar.unique(decision_tree_index.data)
unique_indices.compute_chunk_sizes()   #Added to resolve unknown chunk size error: https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks
unique_indices = unique_indices.compute()
print(unique_indices)

[   0    3    4    5  120  123  124  125  126  160  163  164  165  166
  167  200  203  204  205  206  207  208  240  243  244  245  246  247
  248  249  280  283  284  285  286  287  288  289  290  320  323  324
  325  326  327  328  329  330  331  360  363  364  365  366  367  368
  369  370  371  372  400  403  404  405  406  407  408  409  410  411
  412  413  440  443  444  445  446  447  448  449  450  451  452  453
  454  480  483  484  485  486  487  488  489  490  491  492  493  494
  495  520  523  524  525  526  527  528  529  530  531  532  533  534
  535  536  560  563  564  565  566  567  568  569  570  571  572  573
  574  575  576  577  600  603  604  605  606  607  608  609  610  611
  612  613  614  615  616  617  618  640  643  644  645  646  647  648
  649  650  651  652  653  654  655  656  657  658  659  680  683  684
  685  686  687  688  689  690  691  692  693  694  695  696  697  698
  699  700  720  723  724  725  726  727  728  729  730  731  732  733
  734 

In [18]:
%%time

# Run without for loop

def classify_index(index):
    decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
    return ForestStateDecisionTree.classify(decisionFactors)

forest_states_per_index = [classify_index(index) for index in unique_indices]
# forest_states_per_index

CPU times: total: 0 ns
Wall time: 1 ms


In [17]:
%%time

# you'll call this function on each unique index to convert the index back to your layer values and classify them to a forest state

@dask.delayed
def classify_index(index):
    decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
    return ForestStateDecisionTree.classify(decisionFactors)

forest_states_per_index = dask.compute(*[classify_index(index) for index in unique_indices])
# forest_states_per_index

CPU times: total: 156 ms
Wall time: 338 ms


In [None]:
%%time

# Does not work because ndarray does not accept .apply(). This was suggested by Justin

def classify_index(index):
    decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
    return ForestStateDecisionTree.classify(decisionFactors)

forest_states_per_index = unique_indices.apply(classify_index)
forest_states_per_index

In [20]:
%%time

# Run without for loop

# swap the indices for their corresponding forest state

forest_states = xrspatial.classify.reclassify(
    decision_tree_index.squeeze("band"),
    bins=unique_indices, 
    new_values=forest_states_per_index,
    name="forest_states"
)
forest_states_final = forest_states.compute()
forest_states_final

CPU times: total: 891 ms
Wall time: 17.4 s


In [21]:
forest_states_final.mean().compute()

In [25]:
forest_states_final.rio.to_raster(f'{local_out_dir}forest_states_2021__20230811_not_for_loop_new_v2.tif', compress='DEFLATE', dtype='uint8')

Code for for loop

In [27]:
# Run with for loop

index_dims = (40, 40)

def consolidate_to_one_dimensional_index_for_loop(chunk):
    return xr.DataArray(np.ravel_multi_index([chunk.forest_height_previous_year.data, chunk.forest_height_current_year.data], index_dims), coords=chunk.coords, dims=chunk.dims)
    

In [28]:
%%time

# Run with for loop

def classify_index(index):
    decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
    return ForestStateDecisionTree.classify(decisionFactors)

# forest_states_per_index = [classify_index(index) for index in unique_indices]
# forest_states_per_index

CPU times: total: 0 ns
Wall time: 0 ns


In [43]:
%%time

tile_id_list = ["50N_010E"]

years = [2019, 2020, 2021]
# years = [2019]

timestr = time.strftime("%Y%m%d")


for tile_id in tile_id_list:
    
#     driver_uri ='{0}{1}_1deg_tree_cover_loss_driver_processed.tif'.format(general_uri, tile_id)
#     peat_uri = '{0}{1}_1deg_peat_mask_processed.tif'.format(general_uri, tile_id)
#     tclf_uri = '{0}{1}_1deg_tree_cover_loss_fire_processed.tif'.format(general_uri, tile_id)
#     land_cover_uri = ...
#     planted_forest_uri = ...
    
#     driver = get_tile_dataset(driver_uri, name="driver")
#     peat = get_tile_dataset(peat_uri, name="peat")
#     tclf = get_tile_dataset(tclf_uri, name="tclf")
#     land_cover = get_tile_dataset(tclf_uri, name="land_cover_uri")
#     planted_forest = get_tile_dataset(tclf_uri, name="planted_forest_uri")

    for year in years:
        
        start = time.time()
              
        current_year = year
        previous_year = current_year - 1
        print("Previous year: " + str(previous_year))
        print("Current year: "+ str(current_year))
               
        forest_height_previous_year_uri = f'{general_uri}202307_revision/test_1x1_deg/{tile_id}_1deg_FH_{previous_year}.tif'
        forest_height_current_year_uri = f'{general_uri}202307_revision/test_1x1_deg/{tile_id}_1deg_FH_{current_year}.tif'
        forest_loss_detection_uri = f'{general_uri}202307_revision/test_1x1_deg/{tile_id}_1deg_DFL_{current_year}.tif'
        
        print(forest_height_previous_year_uri)
        print(forest_height_current_year_uri)
        print(forest_loss_detection_uri)
               
        forest_height_previous_year = get_tile_dataset(forest_height_previous_year_uri, name="forest_height_previous_year")
        forest_height_current_year = get_tile_dataset(forest_height_current_year_uri, name="forest_height_current_year")
        forest_loss_detection_current_year = get_tile_dataset(forest_loss_detection_uri, name="forest_height_current_year")
        
        print(forest_height_previous_year.mean().compute())
        print(forest_height_current_year.mean().compute())

          
        # you'll want to combine them all into one xarray dataset, which is just a way to group arrays that
        # are aligned on the same coordinates
        decision_tree_ds = xr.Dataset({
            "forest_height_previous_year": forest_height_previous_year, 
            "forest_height_current_year": forest_height_current_year
        })
        
        # use the function above to map your dataset to that 1D index
        decision_tree_index = decision_tree_ds.map_blocks(consolidate_to_one_dimensional_index_for_loop, template=forest_height_previous_year.copy())
        decision_tree_index_matrix = decision_tree_index.compute()
        # print("decision_tree_index")
        # print(decision_tree_index_matrix)
        
        # Save the decision tree index array to raster
        decision_tree_index_matrix.rio.to_raster(f'{local_out_dir}{tile_id}_1deg_decision_tree_index_{current_year}__{timestr}.tif', compress='DEFLATE', dtype='uint16')

        # find all unique index values across the tile
        unique_indices = dar.unique(decision_tree_index.data)
        unique_indices.compute_chunk_sizes()   #Added to resolve unknown chunk size error: https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks
        unique_indices=unique_indices.compute()
        # print("Unique indices:")
        # print(unique_indices)
        

        # you'll call this function on each unique index to convert the index back to your layer values and classify them to a forest state      
        def classify_index(index):
            decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
            return ForestStateDecisionTree.classify(decisionFactors)

        forest_states_per_index = [classify_index(index) for index in unique_indices]

        # print("forest_states_per_index")
        # print(forest_states_per_index)
        
        forest_states = xrspatial.classify.reclassify(
            decision_tree_index.squeeze("band"),
            bins=unique_indices, 
            new_values=forest_states_per_index,
            name="forest_states"
        )
        forest_states_final = forest_states.compute()
        forest_states_final

        
        forest_states_final.rio.to_raster(f'{local_out_dir}{tile_id}_1deg_FH_forest_states_{current_year}__{timestr}.tif', compress='DEFLATE', dtype='uint8')
        
        end = time.time()
        print("Elapsed time to generate forest states in tile_id {0} for {1} : ".format(tile_id, current_year), round(end-start))

#         # # now you can also swap your forest states for emissions/removals factors
#         # emissions = (forest_states[forest_states == 1] * (full_biomass + non_co2_peat + non_co2_fire)) + 
#         #                 (forest_states[forest_states == 2] * partial_biomass + non_co2_peat + non_co2_fire) + ...

Previous year: 2018
Current year: 2019
s3://gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/test_1x1_deg/50N_010E_1deg_FH_2018.tif
s3://gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/test_1x1_deg/50N_010E_1deg_FH_2019.tif
s3://gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/test_1x1_deg/50N_010E_1deg_DFL_2019.tif
<xarray.DataArray 'forest_height_previous_year' ()>
array(6.21758581)
Coordinates:
    spatial_ref  int32 0
<xarray.DataArray 'forest_height_current_year' ()>
array(6.19990044)
Coordinates:
    spatial_ref  int32 0
Elapsed time to generate forest states in tile_id 50N_010E for 2019 :  54
Previous year: 2019
Current year: 2020
s3://gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/test_1x1_deg/50N_010E_1deg_FH_2019.tif
s3://gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/test_1x1_deg/50N_010E_1deg_FH_2020.tif
s3://gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/test_1x1_deg/50N_010E_1