In [1]:
# The Google Earth Engine module
import ee

# The datetime module is used to specify the dates
# to search for imagery
import datetime

# Import the geemap (https://geemap.org/) module which
# has a visualisation tool
import geemap

# Geopandas allows us to read the shapefile used to
# define the region of interest (ROI)
import geopandas

# The colab module to access data from your google drive
from google.colab import drive

In [2]:
try:
  import pb_gee_tools
  import pb_gee_tools.datasets
  import pb_gee_tools.convert_types
except:
  !git clone https://github.com/remotesensinginfo/pb_gee_tools.git
  !pip install ./pb_gee_tools/.
  import pb_gee_tools
  import pb_gee_tools.datasets
  import pb_gee_tools.convert_types

Cloning into 'pb_gee_tools'...
remote: Enumerating objects: 375, done.[K
remote: Counting objects: 100% (104/104), done.[K
remote: Compressing objects: 100% (83/83), done.[K
remote: Total 375 (delta 49), reused 56 (delta 21), pack-reused 271 (from 1)[K
Receiving objects: 100% (375/375), 672.79 KiB | 4.05 MiB/s, done.
Resolving deltas: 100% (197/197), done.
Processing ./pb_gee_tools
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pb_gee_tools
  Building wheel for pb_gee_tools (setup.py) ... [?25l[?25hdone
  Created wheel for pb_gee_tools: filename=pb_gee_tools-0.3.0-py3-none-any.whl size=17966 sha256=259e6e58d817b9577e8507e0bc5d8d09b63767ac2d33dbbc18555748a5474ae5
  Stored in directory: /tmp/pip-ephem-wheel-cache-6aldr5f7/wheels/e8/da/17/69dc01c6cbd07adb00923d8c1dd9a22fb31e96a17b74c93812
Successfully built pb_gee_tools
Installing collected packages: pb_gee_tools
Successfully installed pb_gee_tools-0.3.0


In [3]:
ee_prj_name = "ee-pb-dev"  # <==== Replace this with your own EE project string
ee.Authenticate()
ee.Initialize(project=ee_prj_name)

In [4]:
drive.mount("/content/drive")

Mounted at /content/drive


In [5]:
# The file path on google drive for ROI output vector file.
vec_cls_roi_file = "/content/drive/MyDrive/mangrove_chng_cls/roi_poly.geojson"

# Start and End date
start_date = datetime.datetime(year=2020, month=1, day=1)
end_date = datetime.datetime(year=2020, month=12, day=31)

# Output no data value
no_data_val = 0.0

In [6]:
# Specify the paths to the vector files with the training points for each of
# classes. Note these files include samples for a larger areas than defined
# in the roi file.
vec_mng_smpls_file = '/content/drive/MyDrive/mangrove_chng_cls/mng_smpls.geojson'
vec_wtr_smpls_file = '/content/drive/MyDrive/mangrove_chng_cls/wtr_smpls.geojson'
vec_oth_smpls_file = '/content/drive/MyDrive/mangrove_chng_cls/oth_smpls.geojson'

In [7]:
# The following function calls convert the training sample points to GEE
# point geometries. In addition, the points a clipped to the ROI for the
# region being classified and subsampled to reduce the number of points
# used to train the classifier as this reduced the memory footprint of
# the notebook which is limited.
gee_mng_pts = pb_gee_tools.convert_types.get_gee_pts(vec_mng_smpls_file, rnd_smpl = 7500, rnd_seed = 42, vec_roi_file=vec_cls_roi_file)
gee_wtr_pts = pb_gee_tools.convert_types.get_gee_pts(vec_wtr_smpls_file, rnd_smpl = 7500, rnd_seed = 42, vec_roi_file=vec_cls_roi_file)
gee_oth_pts = pb_gee_tools.convert_types.get_gee_pts(vec_oth_smpls_file, rnd_smpl = 7500, rnd_seed = 42, vec_roi_file=vec_cls_roi_file)


In [8]:
# Merge the training samples into a single Feature Collection with a variable
# for the ID of each class.
train_smpls = ee.FeatureCollection([
    ee.Feature(gee_mng_pts, {'class': 1}),
    ee.Feature(gee_wtr_pts, {'class': 2}),
    ee.Feature(gee_oth_pts, {'class': 3}),
])

In [9]:
# Get the vector polygon from the vector file as a Google Earth Polygon
roi_gee_poly = pb_gee_tools.convert_types.convert_vector_to_gee_polygon(vec_cls_roi_file)

In [10]:
# Load the Sentinel-2 imagery.
s2_img_col = pb_gee_tools.datasets.get_sen2_sr_collection(
    aoi=roi_gee_poly,
    start_date=start_date,
    end_date=end_date,
    cloud_thres = 70,
)

In [12]:
# A function which will be used to calculate a number of image band indices
# for each input image.
def calc_band_indices(img):
    img = img.multiply(.0001).float()
    ndvi = img.normalizedDifference(["B8", "B4"]).rename("NDVI")
    ndwi = img.normalizedDifference(["B8", "B11"]).rename("NDWI")
    nbr = img.normalizedDifference(["B8", "B12"]).rename("NBR")
    evi = img.expression(
      '2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))',
      {
          'NIR': img.select('B8'),
          'RED': img.select('B4'),
          'BLUE': img.select('B2'),
      },
    ).rename(['EVI'])
    mvi = img.expression(
      '((NIR - GREEN) / (SWIR - GREEN))',
      {
          'NIR': img.select('B8'),
          'SWIR': img.select('B11'),
          'GREEN': img.select('B3'),
      },
    ).rename(['MVI'])
    remi = img.expression(
      '((REDEDGE - RED) / (SWIR + GREEN))',
      {
          'REDEDGE': img.select('B6'),
          'RED': img.select('B4'),
          'SWIR': img.select('B11'),
          'GREEN': img.select('B3'),
      },
    ).rename(['REMI'])
    return img.addBands([ndvi, ndwi, nbr, evi, mvi, remi])

In [13]:
# Map the input images to calculate the indices for
# each the input images.
s2_indices_img_col = s2_img_col.map(calc_band_indices)


In [14]:
# Get the list of image bands.
img_bands = s2_indices_img_col.first().bandNames()
img_bands

In [15]:
# Sample training data from each of the input images
def sample_img_training(img):
    train_smpl_data = img.sampleRegions(
        collection=train_smpls, properties=["class"], scale=10
    )
    return train_smpl_data

training_data = s2_indices_img_col.map(sample_img_training)
training_data = training_data.flatten()

In [16]:
# Train the classifier.
rf_cls_mdl = ee.Classifier.smileRandomForest(numberOfTrees=100).train(training_data, "class", img_bands)

In [None]:
# Save the classifier as an asset in your Google Earth Engine account
asset_id = f'projects/{ee_prj_name}/assets/mng_indices_rf_cls'
task = ee.batch.Export.classifier.toAsset(
  classifier=rf_cls_mdl,
  description=f'mng_test_rf_cls',
  assetId=asset_id
)
task.start()