In [1]:
import os
os.environ['USE_PYGEOS'] = '0'   # Suppresses some warning about geopandas
import geopandas as gpd

# scipy basics
import numpy as np
import botocore
from osgeo import gdal      # Necessary to do this import to get rasterio to import
import rasterio as rio
import rasterio.features

import time

# dask/parallelization libraries
import coiled
import dask
import dask.array as dar
from dask.distributed import Client, LocalCluster
import rioxarray
import xarray as xr
import xrspatial.local

<font size="6">Making cloud and local clusters</font> 

In [34]:
coiled_cluster = coiled.Cluster(
    n_workers=5,
    use_best_zone=True, 
    compute_purchase_option="spot_with_fallback",
    idle_timeout="20 minutes",
    # name="DGibbs Europe height flux model", 
    account='jterry64'   # Necessary to use the AWS environment that Justin set up in Coiled
)

Output()

Output()

In [35]:
# Coiled cluster (cloud run)
coiled_client = coiled_cluster.get_client()
coiled_client

0,1
Connection method: Cluster object,Cluster type: coiled.Cluster
Dashboard: https://cluster-vmmpp.dask.host/dWhXJk2CoNkCwZbG/status,

0,1
Dashboard: https://cluster-vmmpp.dask.host/dWhXJk2CoNkCwZbG/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tls://10.1.32.121:8786,Workers: 0
Dashboard: http://10.1.32.121:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [None]:
# # Local cluster (local run). Doesn't work-- .compute() method kill workers for unknown reasons. Can't use for now.
# # local_cluster = LocalCluster(silence_logs=False)
# local_cluster = LocalCluster()
# local_client = Client(local_cluster)

In [None]:
# Local single-process cluster (local run). Will run .compute() on just one process, not a whole cluster.
local_client = Client(processes=False)
local_client

<font size="6">Shutting down cloud and local clusters</font> 

In [103]:
coiled_cluster.shutdown()

In [None]:
local_client.shutdown()

<font size="6">Analysis</font> 

<font size="4">Paths and functions</font>

In [30]:
# General paths and constants

general_uri = 's3://gfw2-data/forest_change/GLAD_Europe_height_data/'

random_data_uri = 's3://gfw2-data/forest_change/GLAD_Europe_height_data/dummy_random_data__20230901/'

local_out_dir = 'C:\\GIS\\Carbon_model_Europe\\outputs\\'

timestr = time.strftime("%Y%m%d")

chunk_length = 800

In [5]:
# Reads tile
# From https://notebooks-staging.wri.org/user/dagibbs22/lab/tree/msims/biodiversity_global_stats.ipynb
# Bounding box use comes from https://github.com/corteva/rioxarray/issues/115#issuecomment-1206673437 and https://corteva.github.io/rioxarray/html/examples/clip_box.html. 

# Tile size is from the top left of the tile
def get_tile_dataset(uri, name, template=None, tile_size=10):
    # If the input tile_size is too large, it reverts to 10 (standard tile size)
    if tile_size > 10:
        tile_size = 10
    try:
        raster = rioxarray.open_rasterio(uri, chunks=chunk_length, default_name=name)
        # raster = rioxarray.open_rasterio(uri, default_name=name)
        raster_extent = raster.rio.bounds()
        minx=raster_extent[0]
        miny=raster_extent[1]
        maxx=raster_extent[2]
        maxy=raster_extent[3]
        return raster.rio.clip_box(minx=minx, miny=maxy-tile_size, maxx=minx+tile_size, maxy=maxy)
    except rasterio.errors.RasterioIOError as e:
        if template is not None:
            return xr.zeros_like(template)
        else:
            raise e

<font size="4">Model steps</font>

In [36]:
# Input file locations

# Using 10x10 degree rasters of actual data
forest_height_previous_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2020.tif'
forest_height_current_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2021.tif'
forest_loss_detection_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_DFL_2021.tif'

# forest_height_previous_uri = f'{general_uri}202307_revision/FH_2020.tif'
# forest_height_current_uri = f'{general_uri}202307_revision/FH_2021.tif'
# forest_loss_detection_uri = f'{general_uri}202307_revision/DFL_2021.tif'

driver_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_drivers/processed/drivers_2022/20230407/50N_010E_tree_cover_loss_driver_processed.tif"
planted_forest_type_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/planted_forest_type/SDPT_v1/standard/20200730/50N_010E_plantation_type_oilpalm_woodfiber_other_unmasked.tif"
peat_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/peatlands/processed/20230315/50N_010E_peat_mask_processed.tif"
tclf_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_fires/20230315/processed/50N_010E_tree_cover_loss_fire_processed.tif"

In [43]:
# Reads input files

tile_size = 1      # Tile size in degrees is from the top left of the tile. 10 is a full tile. Anything smaller is a subset of that.

forest_height_previous = get_tile_dataset(forest_height_previous_uri, name="forest_height_previous", tile_size=tile_size).squeeze("band")
forest_height_current = get_tile_dataset(forest_height_current_uri, name="forest_height_current", tile_size=tile_size).squeeze("band")
forest_loss_detection = get_tile_dataset(forest_loss_detection_uri, name="forest_loss_detection", tile_size=tile_size).squeeze("band")

driver = get_tile_dataset(driver_uri, name="driver", tile_size=tile_size).squeeze("band")
planted_forest_type = get_tile_dataset(planted_forest_type_uri, name="planted_forest_type", tile_size=tile_size).squeeze("band")
peat = get_tile_dataset(peat_uri, name="peat", tile_size=tile_size).squeeze("band")
tclf = get_tile_dataset(tclf_uri, name="tclf", tile_size=tile_size).squeeze("band")

In [77]:
def classify_maintained(self):
    print(self)
    self.forest_state = dask.array.where(np.logical_and(forest_height_previous >= 5, forest_height_current >= 5), 1, forest_state).compute()

In [None]:
def classify_maintained():
    forest_state_array = dask.array.where(np.logical_and(forest_height_previous >= 5, forest_height_current >= 5), 1, forest_state).compute()
    return forest_state_array

In [None]:
%%time

# xarray dataarray of 0s that has the properties of forest_height_current
forest_state = xr.zeros_like(forest_height_current)
print(forest_state)

# Applies forest state rules to the forest_state array of 0s
# https://stackoverflow.com/questions/60720294/2-dimensional-boolean-indexing-in-dask
forest_state = classify_maintained(forest_state)
# print(forest_state_array)
# forest_state_array = dask.array.where(
#     np.logical_or(np.logical_and(forest_height_previous >= 5, forest_height_current < 5), forest_loss_detection == 1), 2, forest_state_array).compute()
# forest_state_array = dask.array.where(
#     np.logical_and(forest_height_previous < 5, forest_height_current >= 5), 3, forest_state_array).compute()

In [94]:
maintained = np.logical_and(forest_height_previous >= 5, forest_height_current >= 5)
lost = np.logical_or(np.logical_and(forest_height_previous >= 5, forest_height_current < 5), forest_loss_detection == 1)
gained = np.logical_and(forest_height_previous < 5, forest_height_current >= 5)

In [100]:
%%time

# xarray dataarray of 0s that has the properties of forest_height_current
forest_state = xr.zeros_like(forest_height_current)

# Applies forest state rules to the forest_state array of 0s
# https://stackoverflow.com/questions/60720294/2-dimensional-boolean-indexing-in-dask
print("Calculating maintained pixels")
forest_state_array = dask.array.where(np.logical_and(maintained, maintained), 1, forest_state).compute()
print("Calculating lost pixels")
forest_state_array = dask.array.where(np.logical_and(lost, lost), 2, forest_state_array).compute()
print("Calculating gained pixels")
forest_state_array = dask.array.where(gained, 3, forest_state_array).compute()

Calculating maintained pixels
Calculating lost pixels
Calculating gained pixels
CPU times: total: 2.03 s
Wall time: 46.6 s


In [101]:
# Converts the forest_state numpy array to a xarray dataarray
state_da = xr.DataArray(forest_state_array, dims=('y', 'x'), coords={'x': forest_state['x'], 'y': forest_state['y']})

In [102]:
# Exports dataarray to raster
state_da.rio.set_crs("EPSG:4326")
state_da.rio.to_raster(f'{local_out_dir}forest_states_2021__{timestr}_{tile_size}_deg.tif', compress='DEFLATE', dtype='uint8')

In [38]:
%%time

# xarray dataarray of 0s that has the properties of forest_height_current
forest_state = xr.zeros_like(forest_height_current)

# Applies forest state rules to the forest_state array of 0s
# https://stackoverflow.com/questions/60720294/2-dimensional-boolean-indexing-in-dask
print("Calculating maintained pixels")
forest_state_array = dask.array.where(
    np.logical_and(forest_height_previous >= 5, forest_height_current >= 5), 1, forest_state).compute()
print("Calculating lost pixels")
forest_state_array = dask.array.where(
    np.logical_or(np.logical_and(forest_height_previous >= 5, forest_height_current < 5), forest_loss_detection == 1), 2, forest_state_array).compute()
print("Calculating gained pixels")
forest_state_array = dask.array.where(
    np.logical_and(forest_height_previous < 5, forest_height_current >= 5), 3, forest_state_array).compute()

Calculating maintained pixels
Calculating lost pxiels
Calculating gained pixels
CPU times: total: 8.75 s
Wall time: 3min
