In [51]:
import ee
import geemap
import folium
import time

ee.Initialize()
Map = geemap.Map(center=[-5.0, -65.0], zoom=5)  # Center near the Amazon region

print("Start")
# Define the region of interest of Amazon basin
# amazon_region_whole = ee.Geometry.Polygon([[[-80.0, 10.0],
#             [-44.0, 10.0],
#             [-44.0, -18.0],                           
#             [-80.0, -18.0],
#             [-80.0, 10.0]]]) 
amazon_region = ee.Geometry.Polygon([[[-70.0, 4.0],
            [-60.0, 4.0],
            [-60.0, -10.0],                           
            [-70.0, -10.0],
            [-70.0, 4.0]]]) 


# Define the bitmasks
cloud_bit_mask = ee.Number(1 << 5)  # Cloud bit is in the 6th bit position
cirrus_bit_mask = ee.Number(1 << 9)  # Cirrus bit is in the 10th bit position

# Define a masking function
def mask_edges(image):
    edge = image.lt(-30.0)  # Define an edge mask where values are less than -30
    masked_image = image.mask().And(edge.Not())  # Mask out edges
    return image.updateMask(masked_image)  # Apply the mask

# Apply the mask using bitwise AND to check that both cloud and cirrus bits are 0
def mask_clouds(image):
    qa = image.select('QA60')  # Select the QA60 band that holds cloud and cirrus bit information
    mask = qa.bitwiseAnd(cloud_bit_mask).eq(0).And(qa.bitwiseAnd(cirrus_bit_mask).eq(0))
    return image.updateMask(mask)

# Calculate EVI
def calculate_evi(image):
    return image.expression(
        '2.5 * ((B8 - B4) / (B8 + 6 * B4 - 7.5 * B2 + 1))',
        {
            'B8': image.select('B8'),
            'B4': image.select('B4'),
            'B2': image.select('B2')
        }).rename('EVI')
    
# Function to create a grid of tiles over the region
def create_tile_grid(region, tile_size):
    bounds = region.bounds().coordinates().get(0).getInfo()
    lon_min, lat_min = bounds[0]
    lon_max, lat_max = bounds[2]
    tiles = []
    lon = int(lon_min)
    while int(lon) < int(lon_max):
        lat = int(lat_min)
        while int(lat) < int(lat_max):
            print(lon, lat, lon + tile_size, lat + tile_size)
            tile = ee.Geometry.Rectangle([lon, lat, lon + tile_size, lat + tile_size])
            tiles.append(ee.Feature(tile))
            lat += tile_size
        lon += tile_size
    return ee.FeatureCollection(tiles)



# Tile size in degrees (approx. 300 km × 300 km, depending on latitude)
tile_size_degrees = 1  # Approximate conversion for 300 km
tiles = create_tile_grid(amazon_region, tile_size_degrees)

# Apply batch processing to all tiles
# agb_maps = tiles.map(lambda tile: calculate_agb(tile))

# Process each tile and print the result
tile_list = tiles.toList(tiles.size()).getInfo()  # Convert FeatureCollection to a list
print("Size of tiles", len(tile_list))

print("Generating Map")
# Visualize the tiles on the map
Map.centerObject(amazon_region, 5);
Map.addLayer(amazon_region, {}, 'amazon region');
Map.addLayer(tiles, {'color': 'red'}, 'Tile Grid');
Map

Start
-70 -10 -69 -9
-70 -9 -69 -8
-70 -8 -69 -7
-70 -7 -69 -6
-70 -6 -69 -5
-70 -5 -69 -4
-70 -4 -69 -3
-70 -3 -69 -2
-70 -2 -69 -1
-70 -1 -69 0
-70 0 -69 1
-70 1 -69 2
-70 2 -69 3
-70 3 -69 4
-69 -10 -68 -9
-69 -9 -68 -8
-69 -8 -68 -7
-69 -7 -68 -6
-69 -6 -68 -5
-69 -5 -68 -4
-69 -4 -68 -3
-69 -3 -68 -2
-69 -2 -68 -1
-69 -1 -68 0
-69 0 -68 1
-69 1 -68 2
-69 2 -68 3
-69 3 -68 4
-68 -10 -67 -9
-68 -9 -67 -8
-68 -8 -67 -7
-68 -7 -67 -6
-68 -6 -67 -5
-68 -5 -67 -4
-68 -4 -67 -3
-68 -3 -67 -2
-68 -2 -67 -1
-68 -1 -67 0
-68 0 -67 1
-68 1 -67 2
-68 2 -67 3
-68 3 -67 4
-67 -10 -66 -9
-67 -9 -66 -8
-67 -8 -66 -7
-67 -7 -66 -6
-67 -6 -66 -5
-67 -5 -66 -4
-67 -4 -66 -3
-67 -3 -66 -2
-67 -2 -66 -1
-67 -1 -66 0
-67 0 -66 1
-67 1 -66 2
-67 2 -66 3
-67 3 -66 4
-66 -10 -65 -9
-66 -9 -65 -8
-66 -8 -65 -7
-66 -7 -65 -6
-66 -6 -65 -5
-66 -5 -65 -4
-66 -4 -65 -3
-66 -3 -65 -2
-66 -2 -65 -1
-66 -1 -65 0
-66 0 -65 1
-66 1 -65 2
-66 2 -65 3
-66 3 -65 4
-65 -10 -64 -9
-65 -9 -64 -8
-65 -8 -64 -7
-65 -7 -64 

Map(center=[-2.9961093155609886, -65.00000000000001], controls=(WidgetControl(options=['position', 'transparen…

In [63]:
###### Function to calculate AGB for each tile
def calculate_agb(tile):
    tile_geometry = ee.Feature(tile).geometry()

    # Load GEDI Level 4A data
    gedi_all = ee.FeatureCollection('LARSE/GEDI/GEDI04_A_002_INDEX')\
            .filter('time_start > "2022-01-01" && time_end < "2022-03-31"')\
            .filterBounds(tile_geometry);

    # Get the list of table_id values
    table_ids = gedi_all.aggregate_array('table_id').getInfo()
    print("Lenght of table id", len(table_ids))
    
    # Initialize an empty FeatureCollection
    gedi = ee.FeatureCollection([])

    # Loop through each table ID and merge them
    for table_id in table_ids:
        # Load each table and merge
        table = ee.FeatureCollection(table_id).filterBounds(tile_geometry)
        # image = image.reproject('EPSG:4326', None, 100)
        gedi = gedi.merge(table)
        # print("Size of Gedi", gedi.size().getInfo())

    
    # print('Number of filtered GEDI points:', gedi.size().getInfo())
    # print('Number of GEDI points:', gedi.size().getInfo())
    
    # Filter to keep only points with non-null 'agbd' values
    gedi = gedi.filter(ee.Filter.notNull(['agbd']))
    gedi_size = gedi.size().getInfo()
    # print('Number of filtered GEDI points:', gedi_size)
    if gedi_size != 0:
        # Map.setCenter(-60, 5, 5);
        # Map.addLayer(merged_granules);
        # print(gedi.first().getInfo())
        # load sentinel-1 data 
        spring = ee.Filter.date('2022-03-01', '2022-04-20');
        lateSpring = ee.Filter.date('2022-04-21', '2022-06-10');
        summer = ee.Filter.date('2022-06-11', '2022-08-31');
    
        sentinel1 = ee.ImageCollection('COPERNICUS/S1_GRD')\
                    .filterBounds(tile_geometry)\
                    .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))\
                    .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))\
                    .filter(ee.Filter.eq('instrumentMode', 'IW'))\
                    .filter(ee.Filter.inList('orbitProperties_pass', ['ASCENDING', 'DESCENDING']))
        # Select the VV and VH bands
        sentinel1_vv = sentinel1.select('VV')
        sentinel1_vh = sentinel1.select('VH')
        # Apply the masking function to each image in the collection
        sentinel1_vv_masked = sentinel1_vv.map(mask_edges)
        sentinel1_vv_final = ee.Image.cat(
                sentinel1_vv_masked.filter(spring).mean(),
                sentinel1_vv_masked.filter(lateSpring).mean(),
                sentinel1_vv_masked.filter(summer).mean());
        # Apply the masking function to each image in the collection
        sentinel1_vh_masked = sentinel1_vh.map(mask_edges)
        sentinel1_vh_final = ee.Image.cat(
                sentinel1_vh_masked.filter(spring).mean(),
                sentinel1_vh_masked.filter(lateSpring).mean(),
                sentinel1_vh_masked.filter(summer).mean());
        # print('Number of sentinel1 points:', sentinel1_vh_masked.size().getInfo(), sentinel1_vv_masked.size().getInfo()) 
        
        
        # Load Sentinel-2 surface reflectance data
        # Sentinel-2 Level 2A multispectral imagery, ref:  https://www.sciencedirect.com/science/article/pii/S1569843222002965#s0010
        sentinel2 = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \
                        .filterBounds(tile_geometry) \
                        .filterDate('2022-01-01', '2022-12-31') \
                        .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 40))
        # Define the bitmasks
        cloud_bit_mask = ee.Number(1 << 5)  # Cloud bit is in the 6th bit position
        cirrus_bit_mask = ee.Number(1 << 9)  # Cirrus bit is in the 10th bit position
    
        sentinel2 = sentinel2.map(mask_clouds)
        # Calculate NDVI
        ndvi = sentinel2.map(lambda image: image.normalizedDifference(['B8', 'B4']).rename('NDVI')).median()
        # print('Number of sentinel2 points:', sentinel2.size().getInfo()) 
       
        evi = sentinel2.map(calculate_evi).median()
        # print('Number of evi points:', evi.getInfo())
        
        
        # Load Landsat 8 Surface Reflectance data
        landsat8 = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2') \
                      .filterBounds(tile_geometry) \
                      .filterDate('2022-01-01', '2022-12-31') \
                      .filter(ee.Filter.lt('CLOUD_COVER', 35))

        # Calculate NDVI for Landsat 8
        landsat_ndvi = landsat8.map(lambda image: image.normalizedDifference(['SR_B5', 'SR_B4']).rename('NDVI')).median()

        
        
        # Load the GLO-30 DEM data from the COPERNICUS collection
        dem = ee.ImageCollection('COPERNICUS/DEM/GLO30') \
                  .filterBounds(tile_geometry) \
                  .mosaic()
        
        # Calculate slope in degrees
        slope = ee.Terrain.slope(dem)
        # Calculate aspect in degrees
        aspect = ee.Terrain.aspect(dem)
        # print('Number of dem points:')

        # print('Number of ndvi points:', ndvi.getInfo())
        # print('Number of evi points:', evi.getInfo())
        # print('Number of landsat_ndvi points:', landsat_ndvi.getInfo())
        # print('Number of dem points:', dem.getInfo())
        # print('Number of slope points:', slope.getInfo())
        # print('Number of aspect points:', aspect.getInfo())
        # Stack all the features (Sentinel-1, Sentinel-2, Landsat, DEM)
        feature_stack = sentinel1_vh_final.addBands(sentinel1_vv_final) \
                                    .addBands(ndvi) \
                                    .addBands(evi) \
                                    .addBands(landsat_ndvi) \
                                    .addBands(dem) \
                                    .addBands(slope) \
                                    .addBands(aspect)
        
        # print('Size of feature stack :', feature_stack.getInfo())
        
        # Sample the remote sensing data at GEDI footprint locations
        training_data = feature_stack.sampleRegions(
            collection=gedi,
            properties=['agbd'],
            scale=100,
            tileScale=16,
            geometries=True
        )

        # # from sklearn.impute import KNNImputer

        # # Select numeric columns for imputation
        # features = training_data[['sentinel1_vh_final', 'sentinel1_vh_final', 'ndvi']]  # Add relevant predictors
        # imputer = KNNImputer(n_neighbors=5)
        # training_data['agbd'] = imputer.fit_transform(features)[:, -1]
                
        # print('Number of training_data points:', training_data.size().getInfo())
        # if training_data.size().getInfo() == 0:
        #     return None
        
        # Train a Random Forest model
        classifier = ee.Classifier.smileRandomForest(50).setOutputMode('REGRESSION')
        
        # print('Before trained_model points')
        # Train the model
        trained_model = classifier.train(
            features=training_data,
            classProperty='agbd',
            inputProperties=feature_stack.bandNames()
        )
        
        # Apply the trained model to predict AGB
        agb_prediction = feature_stack.classify(trained_model)
        # .reproject(crs='EPSG:3857', scale=100) 
        return agb_prediction
        

# Function to check task status periodically
def check_task_status(task):
    while True:
        status = task.status()
        state = status['state']
        print('Current task state:', state)
        
        if state == 'COMPLETED':
            print("Export task completed successfully.")
            break
        elif state == 'FAILED':
            print("Export task failed:", status)
            break
        
        # Wait before checking the status again
        time.sleep(30)


agb_list = []
c=0
for tile in tile_list:
    agb = calculate_agb(tile)
    # Assuming `agb_prediction` is an ee.Image and `region` is an ee.Geometry
    task = ee.batch.Export.image.toCloudStorage(
        image=agb,         # Pass image directly, not in a dictionary
        description='AGB_Prediction_GCS',  # Task description
        bucket='test-agb-bucket',    # GCS bucket name
        fileNamePrefix='agb_prediction',  # Prefix for the file name in the bucket
        scale=100,                    # Scale in meters
        maxPixels=1e8,               # Max number of pixels
        fileFormat='GeoTIFF'          # File format
    )
    if agb != None:
        agb_list.append(agb)
        # Start the task
        # task.start()
        # status = task.status()
        # print(status)
        # print("Export started")
        # # Start monitoring the task
        # check_task_status(task)
    c=c+1
    if c == 12:
        break

agb_map = ee.ImageCollection(agb_list)
    
print("Size of agbs", agb_map.size().getInfo())
Map.addLayer(agb_map, {'min': 0, 'max': 300, 'palette': ['5F9EA0', 'grey', 'yellow', '7CFC00','5F8575', '228B22','008000', 
        '355E3B', '4F7942']}, 'Predicted AGB')

Map.centerObject(amazon_region, 5);
Map

Lenght of table id 12
Lenght of table id 8
Lenght of table id 8
Lenght of table id 7
Lenght of table id 8
Lenght of table id 8
Lenght of table id 9
Lenght of table id 11
Lenght of table id 11
Lenght of table id 15
Lenght of table id 12
Lenght of table id 9
Size of agbs 12


Map(bottom=2348.0, center=[-2.9961093155609886, -65.00000000000001], controls=(WidgetControl(options=['positio…