In [1]:
import os
os.environ['USE_PYGEOS'] = '0'   # Suppresses some warning about geopandas
import geopandas as gpd

# scipy basics
import numpy as np
import botocore
from osgeo import gdal      # Necessary to do this import to get rasterio to import
import rasterio as rio
import rasterio.features

import time

# dask/parallelization libraries
import coiled
import dask
import dask.array as dar
from dask.distributed import Client, LocalCluster
import rioxarray
import xarray as xr
import xrspatial.local

<font size="6">Making cloud and local clusters</font> 

In [None]:
coiled_cluster = coiled.Cluster(
    n_workers=10,
    use_best_zone=True, 
    compute_purchase_option="spot_with_fallback",
    idle_timeout="20 minutes",
    # worker_vm_types=["t3.medium"],
    # name="DGibbs Europe height flux model", 
    account='jterry64'   # Necessary to use the AWS environment that Justin set up in Coiled
)

In [None]:
# Coiled cluster (cloud run)
coiled_client = coiled_cluster.get_client()
coiled_client

In [None]:
# # Local cluster (local run). Doesn't work-- .compute() method kill workers for unknown reasons. Can't use for now.
# # local_cluster = LocalCluster(silence_logs=False)
# local_cluster = LocalCluster()
# local_client = Client(local_cluster)

In [2]:
# Local single-process cluster (local run). Will run .compute() on just one process, not a whole cluster.
local_client = Client(processes=False)
local_client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://192.168.0.144:8787/status,

0,1
Dashboard: http://192.168.0.144:8787/status,Workers: 1
Total threads: 8,Total memory: 31.81 GiB
Status: running,Using processes: False

0,1
Comm: inproc://192.168.0.144/4932/1,Workers: 1
Dashboard: http://192.168.0.144:8787/status,Total threads: 8
Started: Just now,Total memory: 31.81 GiB

0,1
Comm: inproc://192.168.0.144/4932/4,Total threads: 8
Dashboard: http://192.168.0.144:60367/status,Memory: 31.81 GiB
Nanny: None,
Local directory: C:\Users\DAVID~1.GIB\AppData\Local\Temp\dask-scratch-space\worker-i1egw6dm,Local directory: C:\Users\DAVID~1.GIB\AppData\Local\Temp\dask-scratch-space\worker-i1egw6dm


<font size="6">Shutting down cloud and local clusters</font> 

In [None]:
coiled_cluster.shutdown()

In [None]:
local_client.shutdown()

<font size="6">Analysis</font> 

<font size="4">General paths and functions</font>

In [3]:
# General paths and constants

general_uri = 's3://gfw2-data/forest_change/GLAD_Europe_height_data/'

local_out_dir = 'C:\\GIS\\Carbon_model_Europe\\outputs\\'

timestr = time.strftime("%Y%m%d")

In [None]:
# To make 10x10 tiles:
# gdalwarp from subprocess.check_call(cmd) isn't working
# cmd = ['gdalwarp', '-tr', '0.00025', '0.00025', '-co', 'COMPRESS=DEFLATE', '-tap', '-te', str(10), str(49), str(11), str(50), '-dstnodata', '0', '-t_srs', 'EPSG:4326', 
#        '-overwrite', '-progress', '/vsis3/gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/FH_2021.tif', 'C:\\GIS\\Carbon_model_Europe\\outputs\\50N_010E_FH_2021.tif']
# check_call(cmd)
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/FH_2021.tif 50N_010E_FH_2021.tif
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/FH_2020.tif 50N_010E_FH_2020.tif
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/FH_2019.tif 50N_010E_FH_2019.tif
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/FH_2018.tif 50N_010E_1FH_2018.tif
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/DFL_2021.tif 50N_010E_DFL_2021.tif
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/DFL_2020.tif 50N_010E_DFL_2020.tif
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/DFL_2019.tif 50N_010E_DFL_2019.tif
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/forest_change/GLAD_Europe_height_data/202307_revision/DFL_2018.tif 50N_010E_DFL_2018.tif
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_drivers/processed/drivers_2022/20230407/50N_010E_tree_cover_loss_driver_processed.tif 50N_010E_1deg_tree_cover_loss_driver_processed.tif
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/climate/carbon_model/other_emissions_inputs/tree_cover_loss_fires/20230315/processed/50N_010E_tree_cover_loss_fire_processed.tif 50N_010E_1deg_tree_cover_loss_fire_processed.tif
# gdalwarp -tr 0.00025 0.00025 -co COMPRESS=DEFLATE -tap -te 10 40 20 50 -dstnodata 0 -t_srs EPSG:4326 -overwrite /vsis3/gfw2-data/climate/carbon_model/other_emissions_inputs/peatlands/processed/20230315/50N_010E_peat_mask_processed.tif 50N_010E_1deg_peat_mask_processed.tif


In [4]:
# Reads tile
# From https://notebooks-staging.wri.org/user/dagibbs22/lab/tree/msims/biodiversity_global_stats.ipynb
# Bounding box use comes from https://github.com/corteva/rioxarray/issues/115#issuecomment-1206673437 and https://corteva.github.io/rioxarray/html/examples/clip_box.html. 

def get_tile_dataset(uri, name, template=None):
    try:
        raster = rioxarray.open_rasterio(uri, chunks=1000, default_name=name)
        raster_extent = raster.rio.bounds()
        # return raster.rio.clip_box(minx=raster_extent[0], miny=raster_extent[1], maxx=raster_extent[2], maxy=raster_extent[3])
        return raster.rio.clip_box(minx=10, miny=49, maxx=11, maxy=50)
    except rasterio.errors.RasterioIOError as e:
        if template is not None:
            return xr.zeros_like(template)
        else:
            raise e

In [33]:
# Number of values for each input layer in the order in which they will be stacked (so, must match order in consolidate_to_one_dimensional_index).
# +1 turns the max values into the number of options for that input layer

def index_dim_fx(forest_height_previous, forest_height_current, forest_loss_detection):

    forest_height_previous_max = forest_height_previous.max().compute().values.item(0)
    forest_height_current_max = forest_height_current.max().compute().values.item(0)
    forest_loss_detection_max = forest_loss_detection.max().compute().values.item(0)

    return (forest_height_previous_max+1, forest_height_current_max+1, forest_loss_detection_max+1)

In [34]:
# Combines the values from all the inputs into a one-dimensional index of unique values for every combination of input layer pixel values.
# In other words, turns the multidimensional array into a 1-D array in which each combination of input values results in a unique index

index_dims = index_dim_fx(forest_height_previous, forest_height_current, forest_loss_detection)

def consolidate_to_one_dimensional_index(chunk):
    return xr.DataArray(np.ravel_multi_index(
            [chunk.forest_height_previous.data, 
            chunk.forest_height_current.data,
            chunk.forest_loss_detection.data
            ], 
        index_dims), 
    coords=chunk.coords, dims=chunk.dims)

In [32]:
decision_tree_index = decision_tree_ds.map_blocks(consolidate_to_one_dimensional_index, template=forest_height_previous.copy())
decision_tree_index.compute()
# decision_tree_index.mean().compute()

In [41]:
# Makes a class that encapsulates all the pixel data passed into the decision tree for classification. 
# This makes it much easier to read how the decision tree uses the pixel data. 
# Also a place to do any modifications to the input data or add additional fields based off of the input data.

class ForestStateDecisionFactors:
    def __init__(self, height_prev_year, height_this_year, forest_loss_detection):
        self.height_prev_year = height_prev_year
        self.height_this_year = height_this_year
        self.forest_loss_detection = forest_loss_detection

In [42]:
# Decision tree for assigning forest classes. 
# Takes the above decision factors class, which includes all the data in input layers at one pixel, and classifes it using the decision tree logic.

class ForestStateDecisionTree:
    # define in a way similar to how we went over with Gary
    def classify(forestStateDecisionFactors):
        if forestStateDecisionFactors.height_prev_year >= 5 and forestStateDecisionFactors.height_this_year >= 5:   # maintained
            return 1
        elif forestStateDecisionFactors.height_prev_year >= 5 and forestStateDecisionFactors.height_this_year < 5:  # loss
            return 2
        elif forestStateDecisionFactors.height_prev_year < 5 and forestStateDecisionFactors.height_this_year >= 5:  # gain
            return 3
        else:                                                                                                       # no forest
            return 0

In [9]:
# Converts 1-D array of unique indexes back to your layer values and classifies them to a forest state
# three versions

def classify_index(index):
    decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
    return ForestStateDecisionTree.classify(decisionFactors)


# @dask.delayed
# def classify_index(index):
#     decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
#     return ForestStateDecisionTree.classify(decisionFactors)


# # Does not work because ndarray does not accept .apply(). This was suggested by Justin
# def classify_index(index):
#     decisionFactors = ForestStateDecisionFactors(*np.unravel_index(index, index_dims))
#     return ForestStateDecisionTree.classify(decisionFactors)

<font size="4">Running without for loop</font>

In [10]:
# Input file locations
# forest_height_previous_uri = f'{general_uri}202307_revision/FH_2020.tif'
# forest_height_current_uri = f'{general_uri}202307_revision/FH_2021.tif'
# forest_loss_detection_uri = f'{general_uri}202307_revision/DFL_2021.tif'

forest_height_previous_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_FH_2020.tif'
forest_height_current_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_FH_2021.tif'
forest_loss_detection_uri = f'{general_uri}202307_revision/test_1x1_deg/50N_010E_1deg_DFL_2020.tif'

driver_uri = "s3://gfw2-data/forest_change/GLAD_Europe_resubmission__202306/test_1x1_deg/50N_010E_1deg_tree_cover_loss_driver_processed.tif"
peat_uri = "s3://gfw2-data/forest_change/GLAD_Europe_resubmission__202306/test_1x1_deg/50N_010E_1deg_peat_mask_processed.tif"
tclf_uri = "s3://gfw2-data/forest_change/GLAD_Europe_resubmission__202306/test_1x1_deg/50N_010E_1deg_tree_cover_loss_fire_processed.tif"
# land_cover_uri = ...
# planted_forest_uri = ...

In [11]:
# Reads input files
forest_height_previous = get_tile_dataset(forest_height_previous_uri, name="forest_height_previous")
forest_height_current = get_tile_dataset(forest_height_current_uri, name="forest_height_current")
forest_loss_detection = get_tile_dataset(forest_loss_detection_uri, name="forest_loss_detection")

# driver = get_tile_dataset(driver_uri, name="driver")
# peat = get_tile_dataset(peat_uri, name="peat")
# TCLF = get_tile_dataset(tclf_uri, name="tclf")

In [12]:
%%time
# Some sample dask operations

print(forest_height_previous.mean().compute().values.item(0))
print(forest_height_current.mean().compute().values.item(0))
print(forest_loss_detection.mean().compute().values.item(0))
forest_height_previous
# forest_height_current.mean().compute()
# print(forest_height_current.mean().compute())
# print(driver)        
# print(driver.mean().compute())
# print(forest_height_previous.where(driver > 1).mean().compute())

6.1926073125
6.1837449375
0.001600875
CPU times: total: 4.12 s
Wall time: 9.93 s


Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 15.26 MiB 0.95 MiB Shape (1, 4000, 4000) (1, 1000, 1000) Dask graph 16 chunks in 2 graph layers Data type uint8 numpy.ndarray",4000  4000  1,

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


In [13]:
# Combines all inputs into one xarray dataset, 
# which is just a way to group arrays that are aligned on the same coordinates

decision_tree_ds = xr.Dataset({
    "forest_height_previous": forest_height_previous, 
    "forest_height_current": forest_height_current,
    "forest_loss_detection": forest_loss_detection
    # "drivers": driver
})

decision_tree_ds

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 15.26 MiB 0.95 MiB Shape (1, 4000, 4000) (1, 1000, 1000) Dask graph 16 chunks in 2 graph layers Data type uint8 numpy.ndarray",4000  4000  1,

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 15.26 MiB 0.95 MiB Shape (1, 4000, 4000) (1, 1000, 1000) Dask graph 16 chunks in 2 graph layers Data type uint8 numpy.ndarray",4000  4000  1,

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 15.26 MiB 0.95 MiB Shape (1, 4000, 4000) (1, 1000, 1000) Dask graph 16 chunks in 2 graph layers Data type uint8 numpy.ndarray",4000  4000  1,

Unnamed: 0,Array,Chunk
Bytes,15.26 MiB,0.95 MiB
Shape,"(1, 4000, 4000)","(1, 1000, 1000)"
Dask graph,16 chunks in 2 graph layers,16 chunks in 2 graph layers
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


In [35]:
# Uses the function above to map the dataset to 1D index

decision_tree_index = decision_tree_ds.map_blocks(consolidate_to_one_dimensional_index, template=forest_height_previous.copy())
decision_tree_index.compute()
# decision_tree_index.mean().compute()

In [36]:
# Saves the decision tree index array to raster

decision_tree_index.rio.to_raster(f'{local_out_dir}decision_tree_index_2021__{timestr}_not_for_loop.tif', compress='DEFLATE', dtype='uint16')

In [40]:
# Finds all unique index values across the tile

unique_indices = dar.unique(decision_tree_index.data)
unique_indices.compute_chunk_sizes()   #Added to resolve unknown chunk size error: https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks
unique_indices = unique_indices.compute()
# print(unique_indices)

In [43]:
%%time
# Converts the array of unique index values back to forest states

forest_states_per_index = [classify_index(index) for index in unique_indices]
# forest_states_per_index

CPU times: total: 0 ns
Wall time: 3 ms


In [44]:
%%time
# Swaps the unique indices for their corresponding forest states

forest_states = xrspatial.classify.reclassify(
    decision_tree_index.squeeze("band"),
    bins=unique_indices, 
    new_values=forest_states_per_index,
    name="forest_states"
)
forest_states_final = forest_states.compute()
print(forest_states_final)
print(forest_states_final.mean().compute())

<xarray.DataArray 'forest_states' (y: 4000, x: 4000)>
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)
Coordinates:
    band         int32 1
  * x            (x) float64 10.0 10.0 10.0 10.0 10.0 ... 11.0 11.0 11.0 11.0
  * y            (y) float64 50.0 50.0 50.0 50.0 50.0 ... 49.0 49.0 49.0 49.0
    spatial_ref  int32 0
Attributes:
    AREA_OR_POINT:  Area
    LAYER_TYPE:     athematic
    scale_factor:   1.0
    add_offset:     0.0
    long_name:      Layer_1
    _FillValue:     0
<xarray.DataArray 'forest_states' ()>
array(0.32106832, dtype=float32)
Coordinates:
    band         int32 1
    spatial_ref  int32 0
CPU times: total: 6.06 s
Wall time: 4.26 s


In [45]:
# Saves the forest state raster

forest_states_final.rio.to_raster(f'{local_out_dir}forest_states_2021__{timestr}_not_for_loop.tif', compress='DEFLATE', dtype='uint8')

<font size="4">Running with for loop</font>

In [None]:
%%time

tile_id_list = ["50N_010E"]

years = [2019, 2020, 2021]
# years = [2019]

for tile_id in tile_id_list:
    
#     driver_uri ='{0}{1}_1deg_tree_cover_loss_driver_processed.tif'.format(general_uri, tile_id)
#     peat_uri = '{0}{1}_1deg_peat_mask_processed.tif'.format(general_uri, tile_id)
#     tclf_uri = '{0}{1}_1deg_tree_cover_loss_fire_processed.tif'.format(general_uri, tile_id)
#     land_cover_uri = ...
#     planted_forest_uri = ...
    
#     driver = get_tile_dataset(driver_uri, name="driver")
#     peat = get_tile_dataset(peat_uri, name="peat")
#     tclf = get_tile_dataset(tclf_uri, name="tclf")
#     land_cover = get_tile_dataset(tclf_uri, name="land_cover_uri")
#     planted_forest = get_tile_dataset(tclf_uri, name="planted_forest_uri")

    for year in years:
        
        start = time.time()
              
        current_year = year
        previous_year = current_year - 1
        print("Previous year: " + str(previous_year))
        print("Current year: "+ str(current_year))

        # forest_height_previous_uri = f'{general_uri}202307_revision/FH_{previous_year}.tif'
        # forest_height_current_uri = f'{general_uri}202307_revision/FH_{current_year}.tif'
        # forest_loss_detection_uri = f'{general_uri}202307_revision/DFL_{current_year}.tif'
               
        forest_height_previous_uri = f'{general_uri}202307_revision/test_1x1_deg/{tile_id}_1deg_FH_{previous_year}.tif'
        forest_height_current_uri = f'{general_uri}202307_revision/test_1x1_deg/{tile_id}_1deg_FH_{current_year}.tif'
        forest_loss_detection_uri = f'{general_uri}202307_revision/test_1x1_deg/{tile_id}_1deg_DFL_{previous_year}.tif'
        
        print(forest_height_previous_uri)
        print(forest_height_current_uri)
        print(forest_loss_detection_uri)
               
        forest_height_previous = get_tile_dataset(forest_height_previous_uri, name="forest_height_previous")
        forest_height_current = get_tile_dataset(forest_height_current_uri, name="forest_height_current_year")
        forest_loss_detection_previous = get_tile_dataset(forest_loss_detection_uri, name="forest_height_current_year")
        
        # print(forest_height_previous.mean().compute())
        # print(forest_height_current.mean().compute())

          
        # you'll want to combine them all into one xarray dataset, which is just a way to group arrays that
        # are aligned on the same coordinates
        decision_tree_ds = xr.Dataset({
            "forest_height_previous": forest_height_previous, 
            "forest_height_current": forest_height_current
        })
        
        # use the function above to map your dataset to that 1D index
        print("Making index matrix...")
        decision_tree_index = decision_tree_ds.map_blocks(consolidate_to_one_dimensional_index, template=forest_height_previous.copy())
        
        # # Save the decision tree index array to raster
        # decision_tree_index_matrix = decision_tree_index.compute()
        # print("decision_tree_index")
        # print(decision_tree_index_matrix)
        # decision_tree_index_matrix.rio.to_raster(f'{local_out_dir}{tile_id}_1deg_decision_tree_index_{current_year}__{timestr}.tif', compress='DEFLATE', dtype='uint16')

        # find all unique index values across the tile
        print("Finding unique indices...")
        unique_indices = dar.unique(decision_tree_index.data)
        unique_indices.compute_chunk_sizes()   #Added to resolve unknown chunk size error: https://docs.dask.org/en/latest/array-chunks.html#unknown-chunks
        unique_indices=unique_indices.compute()
        # print("Unique indices:")
        # print(unique_indices)
        
        print("Converting unique index list to forest state list...")
        forest_states_per_index = [classify_index(index) for index in unique_indices]

        # print("forest_states_per_index")
        # print(forest_states_per_index)

        print("Creating forest state matrix...")
        forest_states = xrspatial.classify.reclassify(
            decision_tree_index.squeeze("band"),
            bins=unique_indices, 
            new_values=forest_states_per_index,
            name="forest_states"
        )
        forest_states_final = forest_states.compute()
        forest_states_final

        print("Saving forest state raster...")
        forest_states_final.rio.to_raster(f'{local_out_dir}{tile_id}_1deg_FH_forest_states_{current_year}__{timestr}.tif', compress='DEFLATE', dtype='uint8')
        
        end = time.time()
        print("Elapsed time to generate forest states in tile_id {0} for {1} : ".format(tile_id, current_year), round(end-start))

#         # # now you can also swap your forest states for emissions/removals factors
#         # emissions = (forest_states[forest_states == 1] * (full_biomass + non_co2_peat + non_co2_fire)) + 
#         #                 (forest_states[forest_states == 2] * partial_biomass + non_co2_peat + non_co2_fire) + ...