In [None]:
import os
os.environ['USE_PYGEOS'] = '0'   # Suppresses some warning about geopandas
import geopandas as gpd

# scipy basics
import numpy as np
import botocore
from osgeo import gdal      # Necessary to do this import to get rasterio to import
import rasterio as rio
import rasterio.features

import time

# dask/parallelization libraries
import coiled
import dask
import dask.array as dar
from dask.distributed import Client, LocalCluster
import rioxarray
import xarray as xr
import xrspatial.local

<font size="6">Making cloud and local clusters</font> 

In [None]:
coiled_cluster = coiled.Cluster(
    n_workers=6,
    use_best_zone=True, 
    compute_purchase_option="spot_with_fallback",
    idle_timeout="20 minutes",
    # name="DGibbs Europe height flux model", 
    account='jterry64'   # Necessary to use the AWS environment that Justin set up in Coiled
)

In [None]:
# Coiled cluster (cloud run)
coiled_client = coiled_cluster.get_client()
coiled_client

In [None]:
# # Local cluster (local run). Doesn't work-- .compute() method kill workers for unknown reasons. Can't use for now.
# # local_cluster = LocalCluster(silence_logs=False)
# local_cluster = LocalCluster()
# local_client = Client(local_cluster)

In [None]:
# Local single-process cluster (local run). Will run .compute() on just one process, not a whole cluster.
local_client = Client(processes=False)
local_client

<font size="6">Shutting down cloud and local clusters</font> 

In [None]:
coiled_cluster.shutdown()

In [None]:
local_client.shutdown()

<font size="6">Analysis</font> 

<font size="4">Paths and functions</font>

In [None]:
# General paths and constants

general_uri = 's3://gfw2-data/forest_change/GLAD_Europe_height_data/'

random_data_uri = 's3://gfw2-data/forest_change/GLAD_Europe_height_data/dummy_random_data__20230901/'

local_out_dir = 'C:\\GIS\\Carbon_model_Europe\\outputs\\'

timestr = time.strftime("%Y%m%d")

tile_size = 10      # Tile size in degrees is from the top left of the tile. 10 is a full tile. Anything smaller is a subset of that.

In [None]:
# Reads tile
# From https://notebooks-staging.wri.org/user/dagibbs22/lab/tree/msims/biodiversity_global_stats.ipynb
# Bounding box use comes from https://github.com/corteva/rioxarray/issues/115#issuecomment-1206673437 and https://corteva.github.io/rioxarray/html/examples/clip_box.html. 

# Tile size is from the top left of the tile
def get_tile_dataset(uri, name, template=None, tile_size=10):
    # If the input tile_size is too large, it reverts to 10 (standard tile size)
    if tile_size > 10:
        tile_size = 10
    try:
        raster = rioxarray.open_rasterio(uri, chunks=4000, default_name=name)
        raster_extent = raster.rio.bounds()
        minx=raster_extent[0]
        miny=raster_extent[1]
        maxx=raster_extent[2]
        maxy=raster_extent[3]
        return raster.rio.clip_box(minx=minx, miny=maxy-tile_size, maxx=minx+tile_size, maxy=maxy)
    except rasterio.errors.RasterioIOError as e:
        if template is not None:
            return xr.zeros_like(template)
        else:
            raise e

In [None]:
# Make a class that encapsulates all the pixel data passed into the decision tree for classification. 
# This makes it much easier to read how the decision tree uses the pixel data. 
# Also a place to do any modifications to the input data or add additional fields based off of the input data.
# The WRI engineers I've worked with have advised this for its flexibility. 

class ForestStateDecisionFactors:
    def __init__(self, forest_height_previous, forest_height_current, forest_loss_detection, driver, planted_forest_type, peat, fire_recent):
        self.forest_height_previous = forest_height_previous
        self.forest_height_current = forest_height_current
        self.forest_loss_detection = forest_loss_detection
        self.driver = driver 
        self.planted_forest_type = planted_forest_type
        self.peat = peat
        self.fire_recent = fire_recent

In [None]:
# Decision tree for assigning forest classes (sample). 
# Takes the above decision factors class, which includes all the data in input layers at one pixel, and classifes it using the decision tree logic.

class ForestStateDecisionTree:
    def classify(self, forestStateDecisionFactors):
        if forestStateDecisionFactors.forest_height_previous >= 5 and forestStateDecisionFactors.forest_height_current >= 5:   # maintained
            return 1
        elif forestStateDecisionFactors.forest_height_previous >= 5 and forestStateDecisionFactors.forest_height_current < 5:  # loss
            return 2
        elif forestStateDecisionFactors.forest_height_previous < 5 and forestStateDecisionFactors.forest_height_current >= 5:  # gain
            return 3
        else:                                                                                                       # no forest
            return 0

In [None]:
# Logic for whether loss is in peat
class LossInPeat():
  def __init__(self, no_peat, peat):
    self.no_peat = no_peat
    self.peat = peat
  
  def classify(self, forest_state_decision_factors):
    if forest_state_decision_factors.peat == 0:
      return 'NoPeat'
    else:
      return 'Peat'

loss_in_peat = LossInPeat(
    no_peat=None,
    peat=None
)

In [None]:
# Logic for what forest changed occurred
class ForestChange():
  def __init__(self, maintained, lost, gained):
    self.maintained = maintained
    self.lost = lost
    self.gained = gained
  
  def classify(self, forest_state_decision_factors):
    if forest_state_decision_factors.forest_height_previous >= 5 and forest_state_decision_factors.forest_height_current >= 5:   # maintained
        return 1
    elif forest_state_decision_factors.forest_height_previous >= 5 and forest_state_decision_factors.forest_height_current < 5:  # loss
        return 2
    elif forest_state_decision_factors.forest_height_previous < 5 and forest_state_decision_factors.forest_height_current >= 5:  # gain
        return 3
    else:                                                                                                       # no forest
        return 0

forest_change = ForestChange(
    maintained=None,
    lost=None,
    gained=None
)

In [None]:
class RecentFire():
  def __init__(self, no_fire_recent, fire_recent):
    self.no_fire_recent = no_fire_recent
    self.fire_recent = fire_recent
  
  def classify(self, forest_state_decision_factors):
    if forest_state_decision_factors.fire_recent == 0:
      return 'NoFire:' + self.fire_recent.classify(forest_state_decision_factors)
    else:
      return 'Fire:' + self.fire_recent.classify(forest_state_decision_factors)

recent_fire = RecentFire(
  no_fire_recent=LossInPeat(
    no_peat=None,
    peat=None),
  fire_recent=LossInPeat(
    no_peat=None,
    peat=None)
)

In [None]:
class ForestStateDecisionTree_v2:
    def classify_v2(self, forestStateDecisionFactors):
        recent_fire = RecentFire(
          no_fire_recent=LossInPeat(
            no_peat=None,
            peat=None),
          fire_recent=LossInPeat(
            no_peat=None,
            peat=None)
    )

<font size="4">Model steps</font>

In [None]:
# Input file locations

# Using 10x10 degree rasters of actual data
forest_height_previous_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2020.tif'
forest_height_current_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2021.tif'
forest_loss_detection_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_DFL_2021.tif'

driver_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_drivers/processed/drivers_2022/20230407/50N_010E_tree_cover_loss_driver_processed.tif"
planted_forest_type_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/planted_forest_type/SDPT_v1/standard/20200730/50N_010E_plantation_type_oilpalm_woodfiber_other_unmasked.tif"
peat_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/peatlands/processed/20230315/50N_010E_peat_mask_processed.tif"
fire_recent_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_fires/20230315/processed/50N_010E_tree_cover_loss_fire_processed.tif"

In [None]:
# Reads input files

forest_height_previous = get_tile_dataset(forest_height_previous_uri, name="forest_height_previous", tile_size=tile_size)
forest_height_current = get_tile_dataset(forest_height_current_uri, name="forest_height_current", tile_size=tile_size)
forest_loss_detection = get_tile_dataset(forest_loss_detection_uri, name="forest_loss_detection", tile_size=tile_size)

driver = get_tile_dataset(driver_uri, name="driver", tile_size=tile_size)
planted_forest_type = get_tile_dataset(planted_forest_type_uri, name="planted_forest_type", tile_size=tile_size)
peat = get_tile_dataset(peat_uri, name="peat", tile_size=tile_size)
fire_recent = get_tile_dataset(fire_recent_uri, name="fire_recent", tile_size=tile_size)

In [None]:
# Combines all inputs into one xarray dataset

decision_tree_ds = xr.Dataset({
    "forest_height_previous": forest_height_previous, 
    "forest_height_current": forest_height_current,
    "forest_loss_detection": forest_loss_detection,
    "driver": driver,
    "planted_forest_type": planted_forest_type,
    "peat": peat,
    "fire_recent": fire_recent
})

decision_tree_ds

In [None]:
%%time

"""
Some code that applies the decision tree to decision_tree_ds to make an xarray of forest_states for the previous and current years
"""

def classify(row):
    if row.size == 0:
        return 0
    else:
        print(row)
        # return loss_in_peat.classify(ForestStateDecisionFactors(*row))
        # return forest_change.classify(ForestStateDecisionFactors(*row))
        # return recent_fire.classify(ForestStateDecisionFactors(*row))
        # return ForestStateDecisionTree_v2().classify_v2(ForestStateDecisionFactors(*row))
        return ForestStateDecisionTree().classify(ForestStateDecisionFactors(*row))           # Justin's original invocation


def map_blocks(block):
    df = block.to_dataframe(dim_order=["band", "x", "y"]).drop(columns="spatial_ref")
    classified_df = df.apply(classify, axis=1)
    xarr = classified_df.to_xarray()
    return xarr

forest_states = decision_tree_ds.map_blocks(map_blocks).compute()
forest_states

In [None]:
forest_states_corrected = forest_states.transpose('band', 'y', 'x')   # Had to add because exporting to raster showed error about columns being transposed
forest_states_corrected.rio.set_crs("EPSG:4326")   # Had to add to get coordinate system for output raster

In [None]:
# Exports forest state array to raster

forest_states_corrected.rio.to_raster(f'{local_out_dir}forest_states_2021__{timestr}_{tile_size}_deg.tif', compress='DEFLATE', dtype='uint8')