<a href="https://githubtocolab.com/giswqs/geemap/blob/master/examples/notebooks/local_rf_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>

Uncomment the following line to install [geemap](https://geemap.org) if needed.

In [1]:
pip install geemap

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting geemap
  Downloading geemap-0.17.1-py2.py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 15.3 MB/s 
[?25hCollecting ipyevents
  Downloading ipyevents-2.0.1-py2.py3-none-any.whl (130 kB)
[K     |████████████████████████████████| 130 kB 75.5 MB/s 
[?25hCollecting pycrs
  Downloading PyCRS-1.0.2.tar.gz (36 kB)
Collecting pyshp>=2.1.3
  Downloading pyshp-2.3.1-py2.py3-none-any.whl (46 kB)
[K     |████████████████████████████████| 46 kB 4.8 MB/s 
[?25hCollecting ipyfilechooser>=0.6.0
  Downloading ipyfilechooser-0.6.0-py3-none-any.whl (11 kB)
Collecting bqplot
  Downloading bqplot-0.12.36-py2.py3-none-any.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 62.5 MB/s 
[?25hCollecting geocoder
  Downloading geocoder-1.38.1-py2.py3-none-any.whl (98 kB)
[K     |████████████████████████████████| 98 kB 8.9 MB/s 
[?25hCollecting scooby
  Downloa

In [300]:
def rf_to_strings(estimator, feature_names, processes=2, output_mode="INFER"):
    """Function to convert a ensemble of decision trees into a list of strings. Wraps `tree_to_string`
    args:
        estimator (sklearn.ensemble.estimator): A decision tree classifier or regressor object created using sklearn
        feature_names (list[str]): List of strings that define the name of features (i.e. bands) used to create the model
        processes (int): number of cpu processes to spawn. Increasing processes will improve speed for large models. default = 2
        output_mode (str): the output mode of the estimator. Options are "INFER", "CLASSIFIATION", or "REGRESSION" (capitalization does not matter). default = "INFER"
    returns:
        trees (list[str]): list of strings where each string represents a decision tree estimator and collectively represent an ensemble decision tree estimator (i.e. RandomForest)
    """

    # force output mode to be capital
    output_mode = output_mode.upper()

    available_modes = ["INFER", "CLASSIFICATION", "REGRESSION", "PROBABILITY"]

    if output_mode not in available_modes:
        raise ValueError(
            f"The provided output_mode is not available, please provide one from the following list: {available_modes}"
        )

    # extract out the estimator trees
    if ( len(estimator.estimators_)>1):
      estimators = np.squeeze(estimator.estimators_)
    else:
      estimators = estimator.estimators_

    if output_mode == "INFER":
        if estimator.criterion in ["gini", "entropy"]:
            class_labels = estimator.classes_
        elif estimator.criterion in ["mse", "mae"]:
            class_labels = None
        else:
            raise RuntimeError(
                "Could not infer the output type from the estimator, please explicitly provide the output_mode "
            )

    elif output_mode == "CLASSIFICATION":
        class_labels = estimator.classes_

    else:
        class_labels = None

    # check that number of processors set to use is not more than available
    if processes >= mp.cpu_count():
        # if so, force to use only cpu count - 1
        processes = mp.cpu_count() - 1

    # run the tree extraction process in parallel
    with mp.Pool(processes) as pool:
        proc = pool.map_async(
            partial(
                tree_to_string,
                feature_names=feature_names,
                labels=class_labels,
                output_mode=output_mode,
            ),
            estimators,
        )
        trees = list(proc.get())

    return trees

## local_rf_training

This notebook illustrates how to train a random forest (or any other ensemble tree estimator) locally using scikit-learn, convert the estimator into a string representation that Earth Engine can interpret, and how to apply the machine learning model with EE.

In [101]:
# import package
import ee
import geemap
from geemap import ml  # note new module within geemap
import numpy as np

import pandas as pd
import sklearn

In [2]:
geemap.ee_initialize()

To authorize access needed by Earth Engine, open the following URL in a web browser and follow the instructions. If the web browser does not start automatically, please manually browse the URL below.

    https://code.earthengine.google.com/client-auth?scopes=https%3A//www.googleapis.com/auth/earthengine%20https%3A//www.googleapis.com/auth/devstorage.full_control&request_id=djeCeqiBuWTu8Fp1mZUKFdNvePYAcDmryfOwoB7_KyQ&tc=sENTlvovlBe8gx7V3SYf__BLbMCl1lXsr7D4dq7N5Aw&cc=ZrFkU-nWZp0eDqPYzOeLDpjZFoOxukqlNuGaAD6I96U

The authorization workflow will generate a code, which you should paste in the box below.
Enter verification code: 4/1ARtbsJotWiRr7QRVvWwHTf82lFpb8k_EBImV2gnktePKKICciYTvpMlUCMc

Successfully saved authorization token.


### Training a model locally and using with EE

In [37]:
# creat fc to train our RandomForest model
i1 = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED") \
  .filterBounds(ee.Geometry.Point(-82.63164062500002,46.36565483490312)) \
  .filterDate('2020-06-01','2020-07-01') \
  .first()


In [342]:
# specify the names of the features (i.e. band names) and label
# feature names used to extract out features and define what bands

inputBandNames = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7']
f1 = i1.select(inputBandNames ).sample(numPixels=1e3)
samples = ee.List(inputBandNames).map(lambda bandName : f1.aggregate_array(bandName)).getInfo()
df = pd.DataFrame(np.array(samples).transpose(),columns=inputBandNames)


In [343]:
# get the features and labels into separate variables
feature_names =  ['B2', 'B3',  'B5', 'B6', 'B7']
#label = range(df.shape[0])
label = 'B4'
X = df[feature_names]
y = df[label]

In [344]:
# create a classifier and fit
n_trees = 1
rf = sklearn.ensemble.RandomForestRegressor(n_trees,min_samples_leaf=100,bootstrap=False,random_state=0,verbose=0,max_leaf_nodes=10).fit(X, y)

In [345]:
# create a RF regressor for each class
df['class'] = rf.predict(X)
classes = df['class'].unique()
rfdicts = []
for thisclass in classes:
  thisdf = df.loc[df['class'] == thisclass]
  X = thisdf[feature_names]
  y = thisdf[label]
  thisrf = sklearn.ensemble.RandomForestRegressor(n_trees,bootstrap=False,random_state=0,verbose=0,).fit(X, y)
  rfdicts.append( {'class':thisclass , 'rf':thisrf}) 
print(rfdicts)

[{'class': 534.6492537313433, 'rf': RandomForestRegressor(bootstrap=False, n_estimators=1, random_state=0)}, {'class': 345.4950495049505, 'rf': RandomForestRegressor(bootstrap=False, n_estimators=1, random_state=0)}, {'class': 223.27551020408163, 'rf': RandomForestRegressor(bootstrap=False, n_estimators=1, random_state=0)}, {'class': 82.20792079207921, 'rf': RandomForestRegressor(bootstrap=False, n_estimators=1, random_state=0)}, {'class': 296.6016260162602, 'rf': RandomForestRegressor(bootstrap=False, n_estimators=1, random_state=0)}, {'class': 266.9111111111111, 'rf': RandomForestRegressor(bootstrap=False, n_estimators=1, random_state=0)}, {'class': 398.4818181818182, 'rf': RandomForestRegressor(bootstrap=False, n_estimators=1, random_state=0)}, {'class': 1233.58, 'rf': RandomForestRegressor(bootstrap=False, n_estimators=1, random_state=0)}]


In [302]:
# convert the estimator into a list of strings
# this function also works with the ensemble.ExtraTrees estimator
treeRF = rf_to_strings2(rf, feature_names, output_mode="REGRESSION")


In [309]:
print(rf.get_params(deep=True))

{'bootstrap': False, 'ccp_alpha': 0.0, 'criterion': 'squared_error', 'max_depth': None, 'max_features': 'auto', 'max_leaf_nodes': None, 'max_samples': None, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'min_samples_split': 50, 'min_weight_fraction_leaf': 0.0, 'n_estimators': 1, 'n_jobs': None, 'oob_score': False, 'random_state': 0, 'verbose': 0, 'warm_start': False}


In [307]:
print(treeRF[0])

1) root 1000 9999 9999 (5179335.493648838)
  2) B3 <= 2495.000000 1000 267746.8895 398.309
    4) B2 <= 577.500000 993 61512.0627 363.042296
      8) B2 <= 281.500000 897 20160.6594 304.303233
        16) B3 <= 192.000000 555 7502.3211 224.468468
          32) B6 <= 77.500000 90 2188.6736 69.355556
            64) B5 <= 48.000000 43 107.7847 31.488372 *
            65) B5 > 48.000000 16 137.5586 63.9375 *
          33) B6 > 77.500000 31 1069.0572 124.677419 *
        17) B3 > 192.000000 555 7502.3211 224.468468
          34) B2 <= 215.500000 465 2972.6800 254.490323
            68) B2 <= 173.500000 242 1411.1801 227.958678
              136) B2 <= 98.500000 4 470.5000 130.0 *
              137) B2 > 98.500000 80 1467.2344 202.375
                274) B2 <= 108.000000 1 0.0000 377.0 *
                275) B2 > 108.000000 76 1229.4924 206.184211
                  550) B2 <= 136.500000 15 446.9956 172.933333 *
                  551) B2 > 136.500000 75 851.6580 203.906667
                 

At this point you can take the list of strings and save them locally to avoid training again. However, we want to use the model with EE so we need to create an ee.Classifier and persist the data on ee for best results.

In [303]:
# create a ee classifier to use with ee objects from the trees
ee_classifierRF = ml.strings_to_classifier(treeRF)


In [304]:
# classify the image using the classifier we created from the local training
# note: here we select the feature_names from the image that way the classifier knows which bands to use
classifiedRF = i1.select(feature_names).classify(ee_classifierRF)

In [305]:
# display results
Map = geemap.Map(center=(46.36565483490312,-82.63164062500002), zoom=11)

Map.addLayer(
    i1,
    {"bands": ['B7', 'B5', 'B3'], "min": 800, "max": 2000, "gamma": 1.5},
    'image',
)
Map.addLayer(
    classifiedRF,
    {"min": 500, "max": 3000, "palette": ['red', 'green', 'blue']},
    'classification',
)

Map

Map(center=[46.36565483490312, -82.63164062500002], controls=(WidgetControl(options=['position', 'transparent_…

Yay!! 🎉 Looks like our example works. Don't party too much because there is a catch...

This workflow has several limitations particularly due to how much data you can pass from the client to the server and how large of a model ee can actually handle. EE can only handle 40MB of data passed to the server, so if you have a lot of large decision tree strings then this will not work. Also, creating a classifier from strings has limitation (see this ee-forum discussion: https://groups.google.com/g/google-earth-engine-developers/c/lFFU1GBPzi8/m/6MewQk1FBwAJ), this is again limited by string lengths when ee creates a computation graph.

So, you can use this but know you will probably run into errors when training large models.

### Saving trees to ee.FeatureCollection

Now we have the strings in a format that ee can use, we want to save it for later use. There is a function to export a list of tree strings to a feature collection. The feature collection will have a pro

In [32]:
# specify asset id where to save trees
# be sure to change <user_name> to your ee user name
asset_id = "users/rfernand387/random_forest_strings_test"

In [20]:
# kick off an export process so it will be saved to the ee asset
ml.export_trees_to_fc(trees, asset_id)

# this will kick off an export task, so wait a few minutes before moving on

In [33]:
# read the exported tree feature collection
rf_fc = ee.FeatureCollection(asset_id)

# convert it to a classifier, very similar to the `ml.trees_to_classifier` function
another_classifier = ml.fc_to_classifier(rf_fc)

# classify the image again but with the classifier from the persisted trees
classified = image.select(feature_names).classify(another_classifier)

In [34]:
# display results
# we should get the exact same results as before
Map = geemap.Map(center=(37.75, -122.25), zoom=11)

Map.addLayer(
    image,
    {"bands": ['B7', 'B5', 'B3'], "min": 0.05, "max": 0.55, "gamma": 1.5},
    'image',
)
Map.addLayer(
    classified,
    {"min": 0, "max": 2, "palette": ['red', 'green', 'blue']},
    'classification',
)

Map

Map(center=[37.75, -122.25], controls=(WidgetControl(options=['position', 'transparent_bg'], widget=HBox(child…