In [None]:
import os
os.environ['USE_PYGEOS'] = '0'   # Suppresses some warning about geopandas
import geopandas as gpd

# scipy basics
import pandas as pd
import numpy as np
from shapely.wkb import loads
import botocore
from osgeo import gdal      # Necessary to do this import to get rasterio to import
import rasterio
import rasterio.features

import time

# dask/parallelization libraries
import coiled
import dask
import dask.array as dar
from dask.distributed import Client, LocalCluster
import dask_geopandas
import pystac
import rioxarray
import xarray as xr
import xrspatial.local

<font size="6">Making cloud and local clusters</font> 

In [None]:
coiled_cluster = coiled.Cluster(
    n_workers=20,
    use_best_zone=True, 
    compute_purchase_option="spot_with_fallback",
    idle_timeout="20 minutes",
    # worker_vm_types=["t3.medium"],
    # name="DGibbs Europe height flux model", 
    account='jterry64'   # Necessary to use the AWS environment that Justin set up in Coiled
)

In [None]:
# Coiled cluster (cloud run)
coiled_client = coiled_cluster.get_client()
coiled_client

In [None]:
# # Local cluster (local run). Doesn't work-- .compute() method kill workers for unknown reasons. Can't use for now.
# # local_cluster = LocalCluster(silence_logs=False)
# local_cluster = LocalCluster()
# local_client = Client(local_cluster)

In [None]:
# Local single-process cluster (local run). Will run .compute() on just one process, not a whole cluster.
local_client = Client(processes=False)
local_client

<font size="6">Shutting down cloud and local clusters</font> 

In [None]:
coiled_cluster.shutdown()

In [None]:
local_client.shutdown()

<font size="6">Analysis</font> 

<font size="4">General paths and utilities</font>

In [42]:
# Run with and without for loop

general_uri = 's3://gfw2-data/forest_change/GLAD_Europe_height_data/'

local_out_dir = 'C:\\GIS\\Carbon_model_Europe\\outputs\\'

timestr = time.strftime("%Y%m%d")

In [None]:
# Run with and without for loop
# From https://notebooks-staging.wri.org/user/dagibbs22/lab/tree/msims/biodiversity_global_stats.ipynb

def get_tile_dataset(uri, name, template=None):
    try:
        return rioxarray.open_rasterio(uri, chunks=1000, default_name=name)
    except rasterio.errors.RasterioIOError as e:
        if template is not None:
            return xr.zeros_like(template)
        else:
            raise e

In [None]:
# Run with and without for loop

class ForestStateDecisionFactors:
    def __init__(self, height_prev_year, height_this_year):
        self.height_prev_year = height_prev_year
        self.height_this_year = height_this_year

In [None]:
# Run with and without for loop

class ForestStateDecisionTree:
    # define in a way similar to how we went over with Gary
    def classify(forestStateDecisionFactors):
        if forestStateDecisionFactors.height_prev_year >= 5 and forestStateDecisionFactors.height_this_year >= 5:   # maintained
            return 1
        elif forestStateDecisionFactors.height_prev_year >= 5 and forestStateDecisionFactors.height_this_year < 5:  # loss
            return 2
        elif forestStateDecisionFactors.height_prev_year < 5 and forestStateDecisionFactors.height_this_year >= 5:  # gain
            return 3
        else:                                                                                                       # no forest
            return 0

In [61]:
# Run without for loop

index_dims = (40, 40)

def consolidate_to_one_dimensional_index(chunk):
    return xr.DataArray(np.ravel_multi_index([chunk.forest_height_previous.data, chunk.forest_height_current.data], index_dims), coords=chunk.coords, dims=chunk.dims)

<font size="4">Running without for loop</font>

In [43]:
# Run without for loop
forest_height_previous_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_FH_2020.tif'
forest_height_current_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_FH_2021.tif'
forest_loss_detection_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_DFL_2020.tif'

driver_uri = "s3://gfw2-data/forest_change/GLAD_Europe_resubmission__202306/test_1x1_deg/50N_010E_1deg_tree_cover_loss_driver_processed.tif"
peat_uri = "s3://gfw2-data/forest_change/GLAD_Europe_resubmission__202306/test_1x1_deg/50N_010E_1deg_peat_mask_processed.tif"
tclf_uri = "s3://gfw2-data/forest_change/GLAD_Europe_resubmission__202306/test_1x1_deg/50N_010E_1deg_tree_cover_loss_fire_processed.tif"
# land_cover_uri = ...
# planted_forest_uri = ...

In [44]:
# Run without for loop
forest_height_previous = get_tile_dataset(forest_height_previous_uri, name="forest_height_2020")
forest_height_current = get_tile_dataset(forest_height_current_uri, name="forest_height_2021")
forest_loss_detection_previous = get_tile_dataset(forest_loss_detection_uri, name="forest_height_current_year")

# driver = get_tile_dataset(driver_uri, name="driver")
# peat = get_tile_dataset(peat_uri, name="peat")
# TCLF = get_tile_dataset(tclf_uri, name="tclf")

In [45]:
%%time

# Run without for loop
# Some sample dask operations
# print(forest_height_previous)
print(forest_height_previous.mean().compute())
print(forest_height_current.mean().compute())
print(forest_loss_detection_previous.mean().compute())
# forest_height_current.mean().compute()
# print(forest_height_current.mean().compute())
# print(driver)        
# print(driver.mean().compute())
# print(forest_height_previous.where(driver > 1).mean().compute())

<xarray.DataArray 'forest_height_2020' ()>
array(6.19260731)
Coordinates:
    spatial_ref  int32 0
<xarray.DataArray 'forest_height_2021' ()>
array(6.18374494)
Coordinates:
    spatial_ref  int32 0
<xarray.DataArray 'forest_height_current_year' ()>
array(0.00160088)
Coordinates:
    spatial_ref  int32 0
CPU times: total: 344 ms
Wall time: 288 ms


In [62]:
# Run without for loop

# you'll want to combine them all into one xarray dataset, which is just a way to group arrays that
# are aligned on the same coordinates
decision_tree_ds = xr.Dataset({
    "forest_height_previous": forest_height_previous, 
    "forest_height_current": forest_height_current,
    # "drivers": driver
})

decision_tree_ds

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 15.26 MiB 0.95 MiB Shape (1, 4000, 4000) (1, 1000, 1000) Dask graph 16 chunks in 2 graph layers Data type uint8 numpy.ndarray",4000  4000  1,

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 15.26 MiB 0.95 MiB Shape (1, 4000, 4000) (1, 1000, 1000) Dask graph 16 chunks in 2 graph layers Data type uint8 numpy.ndarray",4000  4000  1,

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


In [63]:
# Run without for loop

# use the function above to map your dataset to that 1D index
decision_tree_index = decision_tree_ds.map_blocks(consolidate_to_one_dimensional_index, template=forest_height_previous.copy())
# decision_tree_index.compute()
# decision_tree_index.mean().compute()

In [64]:
# Run without for loop

# Save the decision tree index array to raster
decision_tree_index.rio.to_raster(f'{local_out_dir}decision_tree_index_2021__{timestr}_not_for_loop.tif', compress='DEFLATE', dtype='uint16')

In [68]:
# Run without for loop

# find all unique index values across the tile
unique_indices = dar.unique(decision_tree_index.data)
unique_indices.compute_chunk_sizes()   #Added to resolve unknown chunk size error: https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks
unique_indices = unique_indices.compute()
# print(unique_indices)

In [67]:
%%time

# Run without for loop
# you'll call this function on each unique index to convert the index back to your layer values and classify them to a forest state

def classify_index(index):
    decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
    return ForestStateDecisionTree.classify(decisionFactors)

forest_states_per_index = [classify_index(index) for index in unique_indices]
# forest_states_per_index


# @dask.delayed
# def classify_index(index):
#     decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
#     return ForestStateDecisionTree.classify(decisionFactors)

# forest_states_per_index = dask.compute(*[classify_index(index) for index in unique_indices])
# forest_states_per_index


# # Does not work because ndarray does not accept .apply(). This was suggested by Justin
# def classify_index(index):
#     decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
#     return ForestStateDecisionTree.classify(decisionFactors)

# forest_states_per_index = unique_indices.apply(classify_index)
# forest_states_per_index

CPU times: total: 0 ns
Wall time: 9 ms


In [69]:
%%time

# Run without for loop

# swap the indices for their corresponding forest state

forest_states = xrspatial.classify.reclassify(
    decision_tree_index.squeeze("band"),
    bins=unique_indices, 
    new_values=forest_states_per_index,
    name="forest_states"
)
forest_states_final = forest_states.compute()
forest_states_final

CPU times: total: 2.89 s
Wall time: 792 ms


In [None]:
forest_states_final.mean().compute()

In [70]:
forest_states_final.rio.to_raster(f'{local_out_dir}forest_states_2021__{timestr}_not_for_loop.tif', compress='DEFLATE', dtype='uint8')

<font size="4">Running with for loop</font>

In [None]:
# Run with for loop

index_dims = (40, 40)

def consolidate_to_one_dimensional_index_for_loop(chunk):
    return xr.DataArray(np.ravel_multi_index([chunk.forest_height_previous.data, chunk.forest_height_current.data], index_dims), coords=chunk.coords, dims=chunk.dims)
    

In [None]:
%%time

# Run with for loop

def classify_index(index):
    decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
    return ForestStateDecisionTree.classify(decisionFactors)

# forest_states_per_index = [classify_index(index) for index in unique_indices]
# forest_states_per_index