In [1]:
import os
os.environ['USE_PYGEOS'] = '0'   # Suppresses some warning about geopandas
import geopandas as gpd

# scipy basics
import numpy as np
import botocore
from osgeo import gdal      # Necessary to do this import to get rasterio to import
import rasterio as rio
import rasterio.features

from typing import NamedTuple

import time

from functools import reduce
from itertools import chain

# dask/parallelization libraries
import coiled
import dask
import dask.array as dar
from dask.distributed import Client, LocalCluster
import rioxarray
import xarray as xr
import xrspatial.local

<font size="6">Making cloud and local clusters</font> 

In [2]:
coiled_cluster = coiled.Cluster(
    n_workers=5,
    use_best_zone=True, 
    compute_purchase_option="spot_with_fallback",
    idle_timeout="20 minutes",
    # name="DGibbs Europe height flux model", 
    account='jterry64'   # Necessary to use the AWS environment that Justin set up in Coiled
)

Output()

Output()

In [3]:
# Coiled cluster (cloud run)
coiled_client = coiled_cluster.get_client()
coiled_client

0,1
Connection method: Cluster object,Cluster type: coiled.Cluster
Dashboard: https://cluster-lpzhi.dask.host/3Mx1gj4iAicKq-EF/status,

0,1
Dashboard: https://cluster-lpzhi.dask.host/3Mx1gj4iAicKq-EF/status,Workers: 2
Total threads: 8,Total memory: 29.68 GiB

0,1
Comm: tls://10.1.23.88:8786,Workers: 2
Dashboard: http://10.1.23.88:8787/status,Total threads: 8
Started: Just now,Total memory: 29.68 GiB

0,1
Comm: tls://10.1.27.153:41977,Total threads: 4
Dashboard: http://10.1.27.153:8787/status,Memory: 14.84 GiB
Nanny: tls://10.1.27.153:44839,
Local directory: /scratch/dask-scratch-space/worker-h73a5m9m,Local directory: /scratch/dask-scratch-space/worker-h73a5m9m

0,1
Comm: tls://10.1.26.5:36555,Total threads: 4
Dashboard: http://10.1.26.5:8787/status,Memory: 14.84 GiB
Nanny: tls://10.1.26.5:36183,
Local directory: /scratch/dask-scratch-space/worker-2ad7xrts,Local directory: /scratch/dask-scratch-space/worker-2ad7xrts


In [None]:
# Local cluster (local run). Doesn't work-- .compute() method kill workers for unknown reasons. Can't use for now.
# local_cluster = LocalCluster(silence_logs=False)
local_cluster = LocalCluster()
local_client = Client(local_cluster)

In [None]:
local_client

In [None]:
# # Local single-process cluster (local run). Will run .compute() on just one process, not a whole cluster.
# local_client = Client(processes=False)
# local_client

<font size="6">Shutting down cloud and local clusters</font> 

In [None]:
coiled_cluster.shutdown()

In [None]:
local_client.shutdown()

<font size="6">Analysis</font> 

<font size="4">Paths and functions</font>

In [4]:
# General paths and constants

general_uri = 's3://gfw2-data/forest_change/GLAD_Europe_height_data/'

random_data_uri = 's3://gfw2-data/forest_change/GLAD_Europe_height_data/dummy_random_data__20230901/'

local_out_dir = 'C:\\GIS\\Carbon_model_Europe\\outputs\\'

timestr = time.strftime("%Y%m%d")

height_cutoff = 5    # meters

chunk_length = 1000   # Dimensions in pixels of dask chunks (width and height)

tile_size = 1      # Tile size in degrees is from the top left of the tile. 10 is a full tile. Anything smaller is a subset of that.

current_year = 2021

In [5]:
# Reads tile
# From https://notebooks-staging.wri.org/user/dagibbs22/lab/tree/msims/biodiversity_global_stats.ipynb
# Bounding box use comes from https://github.com/corteva/rioxarray/issues/115#issuecomment-1206673437 and https://corteva.github.io/rioxarray/html/examples/clip_box.html. 

# Tile size is from the top left of the tile
def get_tile_dataset(uri, name, template=None, tile_size=10):
    # If the input tile_size is too large, it reverts to 10 (standard tile size)
    if tile_size > 10:
        tile_size = 10
    try:
        raster = rioxarray.open_rasterio(uri, chunks=chunk_length, default_name=name)
        # raster = rioxarray.open_rasterio(uri, default_name=name)
        raster_extent = raster.rio.bounds()
        minx=raster_extent[0]
        miny=raster_extent[1]
        maxx=raster_extent[2]
        maxy=raster_extent[3]
        return raster.rio.clip_box(minx=minx, miny=maxy-tile_size, maxx=minx+tile_size, maxy=maxy).squeeze("band")
    except rasterio.errors.RasterioIOError as e:
        if template is not None:
            return xr.zeros_like(template)
        else:
            raise e

<font size="4">Reading in data</font>

In [22]:
# Input file locations

# forest_height_previous_uri = f'{general_uri}202307_revision/FH_2020.tif'
# forest_height_current_uri = f'{general_uri}202307_revision/FH_2021.tif'
# forest_loss_detection_uri = f'{general_uri}202307_revision/DFL_2021.tif'

forest_height_2016_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2016.tif'
forest_height_2017_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2017.tif'
forest_height_2018_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2018.tif'
forest_height_2019_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2019.tif'
forest_height_2020_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2020.tif'
forest_height_2021_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_2021.tif'

# Using 10x10 degree rasters of actual data
forest_height_previous_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_{current_year-1}.tif'
forest_height_current_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_FH_{current_year}.tif'
forest_loss_detection_uri = f'{general_uri}202307_revision/test_10x10_deg/50N_010E_DFL_{current_year}.tif'

driver_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_drivers/processed/drivers_2022/20230407/50N_010E_tree_cover_loss_driver_processed.tif"
planted_forest_type_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/planted_forest_type/SDPT_v1/standard/20200730/50N_010E_plantation_type_oilpalm_woodfiber_other_unmasked.tif"
peat_uri = "s3://gfw2-data/climate/carbon_model/other_emissions_inputs/peatlands/processed/20230315/50N_010E_peat_mask_processed.tif"

landcover_composite_2000_uri = "s3://gfw2-data/landcover/composite/2000/50N_010E_composite_landcover_2000.tif"
landcover_composite_2005_uri = "s3://gfw2-data/landcover/composite/2005/50N_010E_composite_landcover_2005.tif"
landcover_composite_2010_uri = "s3://gfw2-data/landcover/composite/2010/50N_010E_composite_landcover_2010.tif"
landcover_composite_2015_uri = "s3://gfw2-data/landcover/composite/2015/50N_010E_composite_landcover_2015.tif"
landcover_composite_2020_uri = "s3://gfw2-data/landcover/composite/2020/50N_010E_composite_landcover_2020.tif"

cropland_NE_2003_uri = "s3://gfw2-data/landcover/cropland/2003/raw/Global_cropland_NE_2003.tif"
cropland_NE_2007_uri = "s3://gfw2-data/landcover/cropland/2007/raw/Global_cropland_NE_2007.tif"
cropland_NE_2011_uri = "s3://gfw2-data/landcover/cropland/2011/raw/Global_cropland_NE_2011.tif"
cropland_NE_2015_uri = "s3://gfw2-data/landcover/cropland/2015/raw/Global_cropland_NE_2015.tif"
cropland_NE_2019_uri = "s3://gfw2-data/landcover/cropland/2019/raw/Global_cropland_NE_2019.tif"

ba_two_before_uri = f's3://gfw2-data/climate/carbon_model/other_emissions_inputs/burn_year/burn_year_10x10_clip/ba_{current_year-2}_50N_010E.tif'
ba_one_before_uri = f's3://gfw2-data/climate/carbon_model/other_emissions_inputs/burn_year/burn_year_10x10_clip/ba_{current_year-1}_50N_010E.tif'

agb_2000_uri = "s3://gfw2-data/climate/WHRC_biomass/WHRC_V4/Processed/50N_010E_t_aboveground_biomass_ha_2000.tif"

# # Using random data
# forest_height_previous_uri = f'{random_data_uri}50N_010E_FH_2020_random_data.tif'
# forest_height_current_uri = f'{random_data_uri}50N_010E_FH_2021_random_data.tif'
# forest_loss_detection_uri = f'{random_data_uri}50N_010E_DFL_2021_random_data.tif'

# driver_uri = f'{random_data_uri}50N_010E_tree_cover_loss_driver_processed_random_data.tif'
# planted_forest_type_uri = f'{random_data_uri}50N_010E_plantation_type_oilpalm_woodfiber_other_unmasked_random_data.tif'
# peat_uri = f'{random_data_uri}50N_010E_peat_mask_processed_random_data.tif'

# ba_2017_uri = f'{random_data_uri}50N_010E_burned_area_2017_random_data.tif'
# ba_2018_uri = f'{random_data_uri}50N_010E_burned_area_2018_random_data.tif'
# ba_2019_uri = f'{random_data_uri}50N_010E_burned_area_2019_random_data.tif'
# ba_2020_uri = f'{random_data_uri}50N_010E_burned_area_2020_random_data.tif'
# ba_2021_uri = f'{random_data_uri}50N_010E_burned_area_2021_random_data.tif'

print(forest_height_previous_uri)
print(ba_two_before_uri)
print(ba_one_before_uri)

s3://gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/test_10x10_deg/50N_010E_FH_2020.tif
s3://gfw2-data/climate/carbon_model/other_emissions_inputs/burn_year/burn_year_10x10_clip/ba_2019_50N_010E.tif
s3://gfw2-data/climate/carbon_model/other_emissions_inputs/burn_year/burn_year_10x10_clip/ba_2020_50N_010E.tif


In [23]:
# Reads input files

forest_height_2016 = get_tile_dataset(forest_height_2016_uri, name="forest_height_2016", tile_size=tile_size)
forest_height_2017 = get_tile_dataset(forest_height_2017_uri, name="forest_height_2017", tile_size=tile_size)
forest_height_2018 = get_tile_dataset(forest_height_2018_uri, name="forest_height_2018", tile_size=tile_size)
forest_height_2019 = get_tile_dataset(forest_height_2019_uri, name="forest_height_2019", tile_size=tile_size)
forest_height_2020 = get_tile_dataset(forest_height_2020_uri, name="forest_height_2020", tile_size=tile_size)
forest_height_2021 = get_tile_dataset(forest_height_2021_uri, name="forest_height_2021", tile_size=tile_size)

forest_height_previous = get_tile_dataset(forest_height_previous_uri, name="forest_height_previous", tile_size=tile_size)
forest_height_current = get_tile_dataset(forest_height_current_uri, name="forest_height_current", tile_size=tile_size)
forest_loss_detection = get_tile_dataset(forest_loss_detection_uri, name="forest_loss_detection", tile_size=tile_size)

driver = get_tile_dataset(driver_uri, name="driver", tile_size=tile_size)
planted_forest = get_tile_dataset(planted_forest_type_uri, name="planted_forest", tile_size=tile_size)
peat = get_tile_dataset(peat_uri, name="peat", tile_size=tile_size)

LC_previous = get_tile_dataset(landcover_composite_2015_uri, name="LC_previous", tile_size=tile_size)
LC_next = get_tile_dataset(landcover_composite_2020_uri, name="LC_next", tile_size=tile_size)

cropland_previous = get_tile_dataset(cropland_NE_2019_uri, name="cropland_previous", tile_size=tile_size)

ba_two_before = get_tile_dataset(ba_two_before_uri, name="ba_two_before", tile_size=tile_size)
ba_one_before = get_tile_dataset(ba_two_before_uri, name="ba_one_before", tile_size=tile_size)

agb_2000 = get_tile_dataset(agb_2000_uri, name="agb_2000", tile_size=tile_size)

<font size="4">Recent burned area</font>

In [8]:
%%time

# Maps raster of burned area in two preceding years

# xarray dataarray of 0s that has the properties of forest_height_current
burned_area_recent_blank = xr.zeros_like(forest_height_current)

print("Mapping burned area in two preceding years")
burned_area_recent_array = dask.array.where(np.logical_or(ba_two_before != 0, ba_one_before != 0), 1, burned_area_recent_blank).compute()

# Converts the burned_area_recent numpy array to a xarray dataarray
burned_area_recent = xr.DataArray(burned_area_recent_array, dims=('y', 'x'), 
                                  coords={'x': burned_area_recent_blank['x'], 'y': burned_area_recent_blank['y']})

Mapping burned area in two preceding years
CPU times: total: 1.05 s
Wall time: 8.12 s


In [None]:
# Exports dataarray to raster
burned_area_recent.rio.set_crs("EPSG:4326")
burned_area_recent.rio.to_raster(f'{local_out_dir}burned_area_recent_{current_year}__{timestr}_{tile_size}_deg.tif', compress='DEFLATE', dtype='uint8')

<font size="4">Final year of forest</font>

In [9]:
# Combines all the dataarrays into a dataset

forest_height_ds = xr.Dataset({
    "forest_height_2016": forest_height_2016, 
    "forest_height_2017": forest_height_2017, 
    "forest_height_2018": forest_height_2018, 
    "forest_height_2019": forest_height_2019,
    "forest_height_2020": forest_height_2020, 
    "forest_height_2021": forest_height_2021, 
})

# forest_height_ds

In [10]:
%%time

# Maps forest presence (>=5 m) for each year (each dataarray)

def map_forests(data_array, dataset):
    
    year = data_array.name[-4:]
    print(f"Mapping forest presence for {year}") 
 
    # xarray dataarray of 0s that has the properties of input dataarray
    forest_presence_zeros = xr.zeros_like(data_array)

    # Masks pixels with height >= 5 m
    forest_presence = dask.array.where((data_array >= height_cutoff), int(year), forest_presence_zeros).compute()

    # Converts numpy array to xarray dataarray
    forest_presence_da = xr.DataArray(forest_presence, dims=('y', 'x'), coords={'x': data_array['x'], 'y': data_array['y']})
    
    # Exports dataarray to raster
    forest_presence_da.rio.set_crs("EPSG:4326")
    forest_presence_da.rio.to_raster(f'{local_out_dir}forest_presence_{year}__{timestr}_{tile_size}_deg.tif', compress='DEFLATE', dtype='uint16')

    return forest_presence_da

# Applies the map_forests function to every dataarray in the dataset
forest_presence_ds = forest_height_ds.map(map_forests, dataset=forest_height_ds)

Mapping forest presence for 2016
Mapping forest presence for 2017
Mapping forest presence for 2018
Mapping forest presence for 2019
Mapping forest presence for 2020
Mapping forest presence for 2021
CPU times: total: 3.06 s
Wall time: 23.1 s


In [11]:
# Renames the forest presence dataarrays in the forest presence dataset

for variable in forest_presence_ds.data_vars:
    year = variable[-4:]
    forest_presence_ds = forest_presence_ds.rename({variable: f'forest_presence_{year}'})
# forest_presence_ds

In [12]:
#Getting the max value in multiple stacked data arrays within a dataset

# https://stackoverflow.com/questions/65149355/is-there-a-faster-way-to-sum-xarray-dataset-variables
vars = list(forest_presence_ds.keys())

final_forest_year = forest_presence_ds[vars].to_array().max("variable").compute()
# final_forest_year

In [None]:
final_forest_year.rio.set_crs("EPSG:4326")
final_forest_year.rio.to_raster(f'{local_out_dir}final_forest_year__{timestr}_{tile_size}_deg.tif', compress='DEFLATE', dtype='uint16')

<font size="4">Annual classification</font>

In [25]:
# Converts all the dataarrays into a dataset

forest_data = xr.Dataset({
    "forest_height_previous": forest_height_previous, 
    "forest_height_current": forest_height_current,
    "forest_loss_detection": forest_loss_detection,
    "driver": driver,
    "planted_forest": planted_forest,
    "peat": peat,
    "LC_previous": LC_previous,
    "LC_next": LC_next,
    "burned_area_recent": burned_area_recent,
    "final_forest_year": final_forest_year,
    "agb_2000": agb_2000
})

# forest_data

In [None]:
# Classification definitions

maintained = (forest_height_previous >= height_cutoff) & (forest_height_current >= height_cutoff)
gained = (forest_height_previous < height_cutoff) & (forest_height_current >= height_cutoff)
lost = (((forest_height_previous >= height_cutoff) & (forest_height_current < height_cutoff)) | (forest_loss_detection == 1))
no_forest = (forest_height_previous < height_cutoff) & (forest_height_current < height_cutoff)

grassland_next = (((LC_next >= 2) & (LC_next <= 26)) | ((LC_next >= 102) & (LC_next <= 126)))

forest_next = (((LC_next >= 27) & (LC_next <= 48)) | ((LC_next >= 127) & (LC_next <= 148)))

other_next = ((grassland_next == 0) & (forest_next == 0))

cropland_previous = (LC_previous == 244)
cropland_next = (LC_next == 244)

builtup_previous = (LC_previous == 250)
builtup_next = (LC_next == 250)

grassland_forest_previous = (((LC_previous >= 2) & (LC_previous <= 48)) | ((LC_previous >= 102) & (LC_previous <= 148)))
grassland_forest_next = (((LC_next >= 2) & (LC_next <= 48)) | ((LC_next >= 102) & (LC_next <= 148)))  

forestry = (driver == 3)

non_sdpt_forestry = (forestry & (grassland_forest_previous | grassland_forest_next) & (cropland_previous == 0) & (cropland_next == 0))

In [42]:
class Context():
    @property
    def FOREST_MAINTAINED(self):
        return 'FOREST MAINTAINED'

    @property
    def FOREST_GAIN(self):
        return 'FOREST GAIN'

    @property
    def FOREST_LOSS(self):
        return 'FOREST LOSS'

    @property
    def NON_PLANTED(self):
        return 'NON PLANTED'

    @property
    def PLANTED(self):
        return 'PLANTED'

    @property
    def NO_RECENT_FIRE(self):
        return 'NO RECENT FIRE'
        
    @property
    def RECENT_FIRE(self):
        return 'RECENT FIRE'

    @property
    def NO_PEAT(self):
        return 'NO PEAT'

    @property
    def PEAT(self):
        return 'PEAT'
    

def flatten(t):
    """Recursively flatten a nested tuple."""
    return [item for subtuple in t for item in (flatten(subtuple) if isinstance(subtuple, tuple) else (subtuple,))]
  
ctx = Context()


class CarbonBudgetClassifier():
  def __init__(self, forest_maintained=None, forest_gain=None, forest_loss=None):
    self.forest_maintained = forest_maintained
    self.forest_gain = forest_gain
    self.forest_loss = forest_loss
    
  def classify(self, forest_data):
    if (not self.forest_maintained and not self.forest_gain and not self.forest_loss):
      return ()
    return tuple(flatten(
      (
        tuple(Classification(predicate=((forest_data['forest_height_previous'] >= height_cutoff) & 
                                (forest_data['forest_height_current'] >= height_cutoff)), 
                     context=ctx.FOREST_MAINTAINED, val='1') + x for x in self.forest_maintained.classify(forest_data)),
        tuple(Classification(predicate=((forest_data['forest_height_previous'] < height_cutoff) & 
                                (forest_data['forest_height_current'] >= height_cutoff)), 
                             context=ctx.FOREST_GAIN, val='2') + x for x in self.forest_gain.classify(forest_data)),
        tuple(Classification(predicate=((((forest_data['forest_height_previous'] >= height_cutoff) & 
                                 (forest_data['forest_height_current'] < height_cutoff)) 
                                | (forest_data['forest_loss_detection'] == 1))),
                             context=ctx.FOREST_LOSS, val='3') + x for x in self.forest_loss.classify(forest_data)),
      )
    ))


class PlantedForestNode():
  def __init__(self, non_planted=None, planted=None):
    self.non_planted = non_planted
    self.planted = planted

  def classify(self, forest_data=None):
    if (not self.non_planted and not self.planted):
      return ()
    return tuple(flatten(
      (
        tuple(Classification(predicate=(forest_data['planted_forest'] == 0), context=ctx.NON_PLANTED, val='1') + x for x in self.non_planted.classify(forest_data)),
        tuple(Classification(predicate=(forest_data['planted_forest'] > 0), context=ctx.PLANTED, val='2') + x for x in self.planted.classify(forest_data)),
      )
    ))

class RecentFireNode():
  def __init__(self, no_recent_fire=None, recent_fire=None):
    self.no_recent_fire = no_recent_fire
    self.recent_fire = recent_fire

  def classify(self, forest_data=None):
    if (not self.no_recent_fire and not self.recent_fire):
      return ()
    return tuple(flatten(
      (
        tuple(Classification(predicate=(forest_data['burned_area_recent'] == 0), context=ctx.NO_RECENT_FIRE, val='1') + x for x in self.no_recent_fire.classify(forest_data)),
        tuple(Classification(predicate=(forest_data['burned_area_recent'] > 0), context=ctx.RECENT_FIRE, val='2') + x for x in self.recent_fire.classify(forest_data)),
      )
    ))

class PeatNode():
  def __init__(self, no_peat=None, peat=None):
    self.no_peat = no_peat
    self.peat = peat

  def classify(self, forest_data=None):
    if (not self.no_peat and not self.no_peat):
      return ()
    return tuple(flatten(
      (
        tuple(Classification(predicate=(forest_data['peat'] == 0), context=ctx.NO_PEAT, val='1') + x for x in self.no_peat.classify(forest_data)),
        tuple(Classification(predicate=(forest_data['peat'] > 0), context=ctx.PEAT, val='2') + x for x in self.peat.classify(forest_data)),
      )
    ))


class Leaf():
  def classify(self, forest_data=None):
    return (None,)


class Classification():
  def __init__(self, predicate=1, context=(), val=''):
    self.predicate = predicate
    self.context = context
    self.val = val

  def __add__(self, other):
    if (other is None):
      return self
    if (isinstance(other, Classification)):
      return Classification(
        predicate=(self.predicate & other.predicate),
        context=(self.context,) + (other.context,),
        val=self.val + other.val
      )
    
  def __eq__(self, other):
    if isinstance(other, Classification):
      return (self.predicate == other.predicate and 
              self.context == other.context and 
              self.val == other.val)
    return False

  def __str__(self):
    return f"""
    Classification:
    predicate: {self.predicate}
    context: {self.context}
    val: {self.val}
    """
    
  def __repr__(self):
    return f"Classification(predicate={self.predicate}, context={self.context}, val={self.val})"

In [32]:
def apply_where(forest_state_array, data):   #data=tuple(where clause, resulting digit)

    return dask.array.where(data.predicate, data.val, forest_state_array)

    

In [46]:
# Applies forest state rules to the forest_state_zeros array of 0s
def classify(forest_data):

    print("Building classification tree")
    classifier = CarbonBudgetClassifier(
        forest_maintained=Leaf(),
        forest_gain=Leaf(),
        forest_loss=PlantedForestNode(
          non_planted=RecentFireNode(
              no_recent_fire=PeatNode(
                  no_peat=Leaf(),
                  peat=Leaf()),
              recent_fire=PeatNode(
                  no_peat=Leaf(),
                  peat=Leaf()
              )
          ),
          planted=RecentFireNode(
              no_recent_fire=PeatNode(
                  no_peat=Leaf(),
                  peat=Leaf()),
              recent_fire=PeatNode(
                  no_peat=Leaf(),
                  peat=Leaf()
              )
          )
        )
    )

    print("Creating classifier")
    result = classifier.classify(forest_data)
    
    print("Reducing")
    forest_state_array = reduce(apply_where, result, forest_state_zeros)
    
    # print("Maintained branch")
    # forest_state_array = dask.array.where(
    #     maintained, 1, forest_state_zeros)     #maintained

    # print("Gain branch")
    # forest_state_array = dask.array.where(gained, 
    #     2, forest_state_array)    #gained

    # print("Loss branch")

    # ###Land use change (haven't built out peat and fire branches yet)
    # #Loss:no SDPT:no non-SDPT forestry:no later forest:cropland next
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year < int(year)) & (forest_next == 0))
    #     & cropland_next, 31111, forest_state_array)
    # #Loss:no SDPT:no non-SDPT forestry:no later forest:settlement next
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year < int(year)) & (forest_next == 0))
    #     & builtup_next, 31112, forest_state_array)
    # #Loss:no SDPT:no non-SDPT forestry:no later forest:grassland next
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year < int(year)) & (forest_next == 0))
    #     & grassland_next, 31113, forest_state_array)
    # #Loss:no SDPT:no non-SDPT forestry:no later forest: other land cover next
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year < int(year)) & (forest_next == 0))
    #     & ((cropland_next == 0) & (builtup_next == 0) & (grassland_next == 0)), 31114, forest_state_array)

    # ### Land cover change
    # #Loss:no SDPT:no non-SDPT forestry:later forest:forest next:no recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next) 
    #     & forest_next & (burned_area_recent == 0) & (peat == 0), 3112111, forest_state_array)
    # #Loss:no SDPT:no non-SDPT forestry:later forest:forest next:no recent fire:peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next)
    #     & forest_next & (burned_area_recent == 0) & (peat == 1), 3112112, forest_state_array)
    # #Loss:no SDPT:no non-SDPT forestry:later forest:forest next:recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next)
    #     & forest_next & (burned_area_recent == 1) & (peat == 0), 3112121, forest_state_array)
    # #Loss:no SDPT:no non-SDPT forestry:later forest:forest next:recent fire:peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next)
    #     & forest_next & (burned_area_recent == 1) & (peat == 1), 3112122, forest_state_array)

    # #Loss:no SDPT:no non-SDPT forestry:later forest:grassland next:no recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next)
    #     & grassland_next & (burned_area_recent == 0) & (peat == 0), 3112211, forest_state_array)   
    # #Loss:no SDPT:no non-SDPT forestry:later forest:grassland next:no recent fire:peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next)
    #     & grassland_next & (burned_area_recent == 0) & (peat == 1), 3112212, forest_state_array)
    # #Loss:no SDPT:no non-SDPT forestry:later forest:grassland next:recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next)
    #     & grassland_next & (burned_area_recent == 1) & (peat == 0), 3112221, forest_state_array)   
    # #Loss:no SDPT:no non-SDPT forestry:later forest:grassland next:recent fire:peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next) 
    #     & grassland_next & (burned_area_recent == 1) & (peat == 1), 3112222, forest_state_array)   
    
    
    # #Loss:no SDPT:no non-SDPT forestry:later forest:other LC next:no recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next)
    #     & other_next & (burned_area_recent == 0) & (peat == 0), 3112311, forest_state_array)
    # #Loss:no SDPT:no non-SDPT forestry:later forest:other LC next:no recent fire:peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next) 
    #     & other_next & (burned_area_recent == 0) & (peat == 1), 3112312, forest_state_array)   
    # #Loss:no SDPT:no non-SDPT forestry:later forest:other LC next:recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next)
    #     & other_next & (burned_area_recent == 1) & (peat == 0), 3112321, forest_state_array)   
    # #Loss:no SDPT:no non-SDPT forestry:later forest:other LC next:no recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & (non_sdpt_forestry == 0) & ((final_forest_year >= int(year)) | forest_next)
    #     & other_next & (burned_area_recent == 1) & (peat == 1), 3112322, forest_state_array)   

    # ### Forestry
    # # Loss:no SDPT:non-SDPT forestry:no recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & non_sdpt_forestry & (burned_area_recent == 0) & (peat == 0), 31211, forest_state_array)   
    # #Loss:no SDPT:non-SDPT forestry:no recent fire:peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & non_sdpt_forestry & (burned_area_recent == 0) & (peat == 1), 31212, forest_state_array)   
    # #Loss:no SDPT:non-SDPT forestry:recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & non_sdpt_forestry & (burned_area_recent == 1) & (peat == 0), 31221, forest_state_array)   
    # #Loss:no SDPT:non-SDPT forestry:recent fire:peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest == 0) & non_sdpt_forestry & (burned_area_recent == 1) & (peat == 1), 31222, forest_state_array)   
    
    # ### Forestry
    # #Loss:SDPT:no recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest > 0) & (burned_area_recent == 0) & (peat == 0), 3211, forest_state_array)   
    # #Loss:SDPT:no recent fire:peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest > 0) & (burned_area_recent == 0) & (peat == 1), 3212, forest_state_array)   
    # #Loss:SDPT:recent fire:no peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest > 0) & (burned_area_recent == 1) & (peat == 0), 3221, forest_state_array)  
    # #Loss:SDPT:recent fire:peat
    # forest_state_array = dask.array.where(
    #     lost & (planted_forest > 0) & (burned_area_recent == 1) & (peat == 1), 3222, forest_state_array)  

    # print("No forest")
    # forest_state_array = dask.array.where(
    #     no_forest, 4, forest_state_array)     #non-forest remaining non-forest 
    

    print("At return statement")
    return forest_state_array

In [47]:
%%time

# xarray dataarray of 0s that has the properties of forest_height_current
forest_state_zeros = xr.zeros_like(forest_height_current)

print("Assigning forest states")
# One compute() command for the entire function
# per https://docs.dask.org/en/stable/best-practices.html#avoid-calling-compute-repeatedly
forest_states = dask.compute(classify(forest_data))   
# forest_states[0]

Assigning forest states
Building classification tree
Creating classifier
Reducing
At return statement
CPU times: total: 3.95 s
Wall time: 2min 5s


In [44]:
%%time

# Converts the forest_state_zeros numpy array to a xarray dataarray
state_da = xr.DataArray(forest_states[0], dims=('y', 'x'), coords={'x': forest_state_zeros['x'], 'y': forest_state_zeros['y']})

# Exports dataarray to raster
state_da.rio.set_crs("EPSG:4326")
state_da.rio.to_raster(f'{local_out_dir}forest_states_{current_year}__{timestr}_{tile_size}_deg.tif', compress='DEFLATE', dtype='uint32')

CPU times: total: 5.5 s
Wall time: 5.54 s


In [None]:
#Summing values in multiple stacked data arrays within a dataset

# https://stackoverflow.com/questions/65149355/is-there-a-faster-way-to-sum-xarray-dataset-variables
vars_to_sum = ["forest_height_2016", "forest_height_2017", "forest_height_2018", "forest_height_2019", "forest_height_2020", "forest_height_2021"]

summed_variables = forest_height_ds[vars_to_sum].to_array().sum("variable").compute()
summed_variables
# summed_dataset = forest_height_ds.compute()
# summed_dataset

summed_variables.rio.to_raster(f'{local_out_dir}summed_heights_{current_year}__{timestr}_{tile_size}_deg.tif', compress='DEFLATE', dtype='uint16')

In [None]:
def classify_simple(dfs):
    print("Assigning forest states")    
    forest_state_array = dask.array.where(
        (dfs.forest_height_previous >= 5) & (dfs.forest_height_current >= 5), 
        1, forest_state_zeros)
    forest_state_array = dask.array.where((
        ((dfs.forest_height_previous >= 5) & (dfs.forest_height_current < 5)) 
         | (dfs.forest_loss_detection == 1)), 
        2, forest_state_array)
    forest_state_array = dask.array.where(
        (dfs.forest_height_previous < 5) & (dfs.forest_height_current >= 5), 
        3, forest_state_array)

    return forest_state_array