In [1]:
import os
os.environ['USE_PYGEOS'] = '0'   # Suppresses some warning about geopandas
import geopandas as gpd

# scipy basics
import numpy as np
import botocore
from osgeo import gdal      # Necessary to do this import to get rasterio to import
import rasterio as rio
import rasterio.features

import time

# dask/parallelization libraries
import coiled
import dask
import dask.array as dar
from dask.distributed import Client, LocalCluster
import rioxarray
import xarray as xr
import xrspatial.local

# Suppresses the BokehUserWarnings
# https://discourse.bokeh.org/t/how-to-silence-bokeh-warnings/2491/7
from bokeh.util.warnings import BokehUserWarning, warnings 
warnings.simplefilter(action='ignore', category=BokehUserWarning)

<font size="6">Making cloud and local clusters</font> 

In [None]:
coiled_cluster = coiled.Cluster(
    n_workers=10,
    use_best_zone=True, 
    compute_purchase_option="spot_with_fallback",
    idle_timeout="20 minutes",
    # name="DGibbs Europe height flux model", 
    account='jterry64'   # Necessary to use the AWS environment that Justin set up in Coiled
)

In [None]:
# Coiled cluster (cloud run)
coiled_client = coiled_cluster.get_client()
coiled_client

In [None]:
# # Local cluster (local run). Doesn't work-- .compute() method kill workers for unknown reasons. Can't use for now.
# # local_cluster = LocalCluster(silence_logs=False)
# local_cluster = LocalCluster()
# local_client = Client(local_cluster)

In [None]:
# Local single-process cluster (local run). Will run .compute() on just one process, not a whole cluster.
local_client = Client(processes=False)
local_client

<font size="6">Shutting down cloud and local clusters</font> 

In [None]:
coiled_cluster.shutdown()

In [None]:
local_client.shutdown()

<font size="6">Analysis</font> 

<font size="4">General paths and functions</font>

In [2]:
# General paths and constants

general_uri = 's3://gfw2-data/forest_change/GLAD_Europe_height_data/'

local_out_dir = 'C:\\GIS\\Carbon_model_Europe\\outputs\\'

timestr = time.strftime("%Y%m%d")

tile_size = 10      # Tile size is from the top left of the tile

In [None]:
# Reads tile
# From https://notebooks-staging.wri.org/user/dagibbs22/lab/tree/msims/biodiversity_global_stats.ipynb
# Bounding box use comes from https://github.com/corteva/rioxarray/issues/115#issuecomment-1206673437 and https://corteva.github.io/rioxarray/html/examples/clip_box.html. 

# Tile size is from the top left of the tile
def get_tile_dataset(uri, name, template=None, tile_size=10):
    # If the input tile_size is too large, it reverts to 10 (standard tile size)
    if tile_size > 10:
        tile_size = 10
    try:
        raster = rioxarray.open_rasterio(uri, chunks=1000, default_name=name)
        raster_extent = raster.rio.bounds()
        minx=raster_extent[0]
        miny=raster_extent[1]
        maxx=raster_extent[2]
        maxy=raster_extent[3]
        return raster.rio.clip_box(minx=minx, miny=maxy-tile_size, maxx=minx+tile_size, maxy=maxy)
        # return raster.rio.clip_box(minx=10, miny=45, maxx=15, maxy=50)       # 5x5 deg
        # return raster.rio.clip_box(minx=10, miny=48, maxx=12, maxy=50)       # 2x2 deg
        # return raster.rio.clip_box(minx=10, miny=49, maxx=11, maxy=50)       # 1x1 deg
    except rasterio.errors.RasterioIOError as e:
        if template is not None:
            return xr.zeros_like(template)
        else:
            raise e

In [None]:
def index_dims(forest_height_previous, forest_height_current, forest_loss_detection, driver, planted_forest_type, peat, tclf):
#
    forest_height_previous_max = forest_height_previous.max().compute().values.item(0)+1
    forest_height_current_max = forest_height_current.max().compute().values.item(0)+1
    forest_loss_detection_max = forest_loss_detection.max().compute().values.item(0)+1
    driver_max = driver.max().compute().values.item(0)+1
    planted_forest_type_max = planted_forest_type.max().compute().values.item(0)+1
    peat_max = peat.max().compute().values.item(0)+1
    tclf_max = tclf.max().compute().values.item(0)+1

    return (forest_height_previous_max, forest_height_current_max, forest_loss_detection_max, driver_max, planted_forest_type_max, peat_max, tclf_max)

In [None]:
# Combines the values from all the inputs into a one-dimensional index of unique values for every combination of input layer pixel values

# # Maximum value for each input layer in the order in which they will be stacked (so, must match order in consolidate_to_one_dimensional_index)
# index_dims = (40, 40, 2, 6, 5, 2, 25)

# Turns the multidimensional array into a 1-D array in which each combination of input values results in a unique index
def consolidate_to_one_dimensional_index(chunk):
    return xr.DataArray(np.ravel_multi_index(
            [chunk.forest_height_previous.data, 
             chunk.forest_height_current.data,
             chunk.forest_loss_detection.data,
             chunk.driver.data,
             chunk.planted_forest_type.data,
             chunk.peat.data,
             chunk.tclf.data
            ], 
        index_dims(forest_height_previous, forest_height_current, forest_loss_detection, driver, planted_forest_type, peat, tclf)), 
    coords=chunk.coords, dims=chunk.dims)

In [None]:
# Combines the values from all the inputs into a one-dimensional index of unique values for every combination of input layer pixel values

# Maximum value for each input layer in the order in which they will be stacked (so, must match order in consolidate_to_one_dimensional_index)
index_dims = (40, 40, 2, 6, 5, 2, 25)

# Turns the multidimensional array into a 1-D array in which each combination of input values results in a unique index
def consolidate_to_one_dimensional_index(chunk):
    return xr.DataArray(np.ravel_multi_index(
            [chunk.forest_height_previous.data, 
             chunk.forest_height_current.data,
             chunk.forest_loss_detection.data,
             chunk.driver.data,
             chunk.planted_forest_type.data,
             chunk.peat.data,
             chunk.tclf.data
            ], 
        index_dims), coords=chunk.coords, dims=chunk.dims)

In [None]:
# Makes a class that encapsulates all the pixel data passed into the decision tree for classification. 
# This makes it much easier to read how the decision tree uses the pixel data. 
# Also a place to do any modifications to the input data or add additional fields based off of the input data.

class ForestStateDecisionFactors:
    def __init__(self, height_prev_year, height_this_year, forest_loss_detection, driver, planted_forest_type, peat, tclf):
        self.height_prev_year = height_prev_year
        self.height_this_year = height_this_year
        self.forest_loss_detection = forest_loss_detection
        self.driver = driver 
        self.planted_forest_type = planted_forest_type
        self.peat = peat
        self.tclf = tclf

In [None]:
# Decision tree for assigning forest classes. 
# Takes the above decision factors class, which includes all the data in input layers at one pixel, and classifes it using the decision tree logic.

class ForestStateDecisionTree:
    # define in a way similar to how we went over with Gary
    def classify(forestStateDecisionFactors):
        if forestStateDecisionFactors.height_prev_year >= 5 and forestStateDecisionFactors.height_this_year >= 5:   # maintained
            return 1
        elif forestStateDecisionFactors.height_prev_year >= 5 and forestStateDecisionFactors.height_this_year < 5:  # loss
            return 2
        elif forestStateDecisionFactors.height_prev_year < 5 and forestStateDecisionFactors.height_this_year >= 5:  # gain
            return 3
        else:                                                                                                       # no forest
            return 0

In [None]:
# Decision tree for assigning forest classes. 
# Takes the above decision factors class, which includes all the data in input layers at one pixel, and classifes it using the decision tree logic.

class ForestStateDecisionTree:
    # define in a way similar to how we went over with Gary
    def classify(forestStateDecisionFactors):
        if forestStateDecisionFactors.height_prev_year >= 5 and forestStateDecisionFactors.height_this_year >= 5:   # maintained
            return 1
        elif forestStateDecisionFactors.height_prev_year < 5 and forestStateDecisionFactors.height_this_year >= 5:  # gain
            return 2
        elif ((forestStateDecisionFactors.height_prev_year >= 5 and forestStateDecisionFactors.height_this_year < 5) or (forest_loss_detection == 1)).all():  # loss
            if forestStateDecisionFactors.planted_forest_type > 0:
                if forestStateDecisionFactors.peat == 1:
                    return 311
                else:
                    return 31
            else:
                return 3
        else:                                                                                                       # no forest
            return 0

In [None]:
# Converts 1-D array of unique indexes back to your layer values and classifies them to a forest state
# three versions

def classify_index(index):
    decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
    return ForestStateDecisionTree.classify(decisionFactors)


# @dask.delayed
# def classify_index(index):
#     decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
#     return ForestStateDecisionTree.classify(decisionFactors)


# # Does not work because ndarray does not accept .apply(). This was suggested by Justin
# def classify_index(index):
#     decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
#     return ForestStateDecisionTree.classify(decisionFactors)

<font size="4">Running without for loop</font>

In [None]:
# Input file locations
# forest_height_previous_uri = f'{general_uri}202307_revision/FH_2020.tif'
# forest_height_current_uri = f'{general_uri}202307_revision/FH_2021.tif'
# forest_loss_detection_uri = f'{general_uri}202307_revision/DFL_2021.tif'

# forest_height_previous_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_FH_2020.tif'
# forest_height_current_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_FH_2021.tif'
# forest_loss_detection_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_DFL_2020.tif'

forest_height_previous_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2020.tif'
forest_height_current_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2021.tif'
forest_loss_detection_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_DFL_2020.tif'

driver_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_drivers/processed/drivers_2022/20230407/50N_010E_tree_cover_loss_driver_processed.tif"
planted_forest_type_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/planted_forest_type/SDPT_v1/standard/20200730/50N_010E_plantation_type_oilpalm_woodfiber_other_unmasked.tif"
peat_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/peatlands/processed/20230315/50N_010E_peat_mask_processed.tif"
tclf_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_fires/20230315/processed/50N_010E_tree_cover_loss_fire_processed.tif"

# land_cover_uri = ...

In [None]:
# Reads input files
forest_height_previous = get_tile_dataset(forest_height_previous_uri, name="forest_height_previous", tile_size=tile_size)
forest_height_current = get_tile_dataset(forest_height_current_uri, name="forest_height_current", tile_size=tile_size)
forest_loss_detection = get_tile_dataset(forest_loss_detection_uri, name="forest_loss_detection", tile_size=tile_size)

driver = get_tile_dataset(driver_uri, name="driver", tile_size=tile_size)
planted_forest_type = get_tile_dataset(planted_forest_type_uri, name="planted_forest_type", tile_size=tile_size)
peat = get_tile_dataset(peat_uri, name="peat", tile_size=tile_size)
tclf = get_tile_dataset(tclf_uri, name="tclf", tile_size=tile_size)

In [None]:
%%time
# Some sample dask operations

print(forest_height_previous.mean().compute().values.item(0))
print(forest_height_current.mean().compute().values.item(0))
# print(forest_loss_detection.mean().compute().values.item(0))
# print(driver.mean().compute().values.item(0))
# print(planted_forest_type.mean().compute().values.item(0))
# print(peat.mean().compute().values.item(0))
# print(tclf.mean().compute().values.item(0))
forest_height_previous
# print(forest_height_previous.where(driver > 1).mean().compute())

In [None]:
# Combines all inputs into one xarray dataset, 
# which is just a way to group arrays that are aligned on the same coordinates

decision_tree_ds = xr.Dataset({
    "forest_height_previous": forest_height_previous, 
    "forest_height_current": forest_height_current,
    "forest_loss_detection": forest_loss_detection,
    "driver": driver,
    "planted_forest_type": planted_forest_type,
    "peat": peat,
    "tclf": tclf
})

decision_tree_ds

In [None]:
%%time

# Uses the function above to map the dataset to 1D index

decision_tree_index = decision_tree_ds.map_blocks(consolidate_to_one_dimensional_index, template=forest_height_previous.copy())
decision_tree_index.compute()
# decision_tree_index.mean().compute()

In [None]:
# Saves the decision tree index array to raster

decision_tree_index.rio.to_raster(f'{local_out_dir}decision_tree_index_2021__{timestr}_not_for_loop.tif', compress='DEFLATE', dtype='uint16')

In [None]:
%%time

# Finds all unique index values across the tile

unique_indices = dar.unique(decision_tree_index.data)
unique_indices.compute_chunk_sizes()   #Added to resolve unknown chunk size error: https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks
unique_indices = unique_indices.compute()
# print(unique_indices)

In [None]:
%%time
# Converts the array of unique index values back to forest states

forest_states_per_index = [classify_index(index) for index in unique_indices]
# forest_states_per_index

In [None]:
%%time
# Swaps the unique indices for their corresponding forest states

forest_states = xrspatial.classify.reclassify(
    decision_tree_index.squeeze("band"),
    bins=unique_indices, 
    new_values=forest_states_per_index,
    name="forest_states"
)
forest_states_final = forest_states.compute()
print(forest_states_final)
print(forest_states_final.mean().compute())

In [None]:
# Saves the forest state raster

forest_states_final.rio.to_raster(f'{local_out_dir}forest_states_2021__{timestr}_{tile_size}_deg__not_for_loop.tif', compress='DEFLATE', dtype='uint8')

<font size="4">Running with for loop</font>

In [None]:
%%time

tile_id_list = ["50N_010E"]

tile_size = 1

# years = [2019, 2020, 2021]
years = [2021]

for tile_id in tile_id_list:

    driver_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_drivers/processed/drivers_2022/20230407/50N_010E_tree_cover_loss_driver_processed.tif"
    planted_forest_type_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/planted_forest_type/SDPT_v1/standard/20200730/50N_010E_plantation_type_oilpalm_woodfiber_other_unmasked.tif"
    peat_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/peatlands/processed/20230315/50N_010E_peat_mask_processed.tif"
    tclf_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_fires/20230315/processed/50N_010E_tree_cover_loss_fire_processed.tif"

    driver = get_tile_dataset(driver_uri, name="driver", tile_size)
    planted_forest_type = get_tile_dataset(planted_forest_type_uri, name="planted_forest_type", tile_size)
    peat = get_tile_dataset(peat_uri, name="peat", tile_size)
    tclf = get_tile_dataset(tclf_uri, name="tclf", tile_size)

    for year in years:
        
        start = time.time()
              
        current_year = year
        previous_year = current_year - 1
        print("Previous year: " + str(previous_year))
        print("Current year: "+ str(current_year))

        # forest_height_previous_uri = f'{general_uri}202307_revision/FH_{previous_year}.tif'
        # forest_height_current_uri = f'{general_uri}202307_revision/FH_{current_year}.tif'
        # forest_loss_detection_uri = f'{general_uri}202307_revision/DFL_{current_year}.tif'

        forest_height_previous_uri = f'{general_uri}202307_revision/test_10x10_deg/{tile_id}_FH_{previous_year}.tif'
        forest_height_current_uri = f'{general_uri}202307_revision/test_10x10_deg/{tile_id}_FH_{current_year}.tif'
        forest_loss_detection_uri = f'{general_uri}202307_revision/test_10x10_deg/{tile_id}_DFL_{previous_year}.tif'
                       
        forest_height_previous = get_tile_dataset(forest_height_previous_uri, name="forest_height_previous", tile_size)
        forest_height_current = get_tile_dataset(forest_height_current_uri, name="forest_height_current_year", tile_size)
        forest_loss_detection_previous = get_tile_dataset(forest_loss_detection_uri, name="forest_height_current_year", tile_size)
        
        print("Previous height mean =", forest_height_previous.mean().compute().values.item(0))
        print("Current height mean =",forest_height_current.mean().compute().values.item(0))
        # print(forest_loss_detection.mean().compute().values.item(0))
        # print(driver.mean().compute().values.item(0))
        # print(planted_forest_type.mean().compute().values.item(0))
        # print(peat.mean().compute().values.item(0))
        # print(tclf.mean().compute().values.item(0))

          
        # you'll want to combine them all into one xarray dataset, which is just a way to group arrays that
        # are aligned on the same coordinates
        decision_tree_ds = xr.Dataset({
            "forest_height_previous": forest_height_previous, 
            "forest_height_current": forest_height_current,
            "forest_loss_detection": forest_loss_detection,
            "driver": driver,
            "planted_forest_type": planted_forest_type,
            "peat": peat,
            "tclf": tclf
        })
        
        # use the function above to map your dataset to that 1D index
        print("Making index matrix...")
        decision_tree_index = decision_tree_ds.map_blocks(consolidate_to_one_dimensional_index, template=forest_height_previous.copy())
        
        # # Save the decision tree index array to raster
        decision_tree_index_matrix = decision_tree_index.compute()
        # print("decision_tree_index")
        print(decision_tree_index_matrix)
        # decision_tree_index_matrix.rio.to_raster(f'{local_out_dir}{tile_id}_1deg_decision_tree_index_{current_year}__{timestr}.tif', compress='DEFLATE', dtype='uint16')

        # find all unique index values across the tile
        print("Finding unique indices...")
        unique_indices = dar.unique(decision_tree_index.data)
        unique_indices.compute_chunk_sizes()   #Added to resolve unknown chunk size error: https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks
        unique_indices=unique_indices.compute()
        # print("Unique indices:")
        # print(unique_indices)
        
        print("Converting unique index list to forest state list...")
        forest_states_per_index = [classify_index(index) for index in unique_indices]

        # print("forest_states_per_index")
        # print(forest_states_per_index)

        print("Creating forest state matrix...")
        forest_states = xrspatial.classify.reclassify(
            decision_tree_index.squeeze("band"),
            bins=unique_indices, 
            new_values=forest_states_per_index,
            name="forest_states"
        )
        forest_states_final = forest_states.compute()
        forest_states_final

        print("Saving forest state raster...")
        forest_states_final.rio.to_raster(f'{local_out_dir}{tile_id}_5deg_FH_forest_states_{current_year}__{timestr}.tif', compress='DEFLATE', dtype='uint8')
        
        end = time.time()
        print("Elapsed time to generate forest states in tile_id {0} for {1} : ".format(tile_id, current_year), round(end-start))

#         # # now you can also swap your forest states for emissions/removals factors
#         # emissions = (forest_states[forest_states == 1] * (full_biomass + non_co2_peat + non_co2_fire)) + 
#         #                 (forest_states[forest_states == 2] * partial_biomass + non_co2_peat + non_co2_fire) + ...