# Expected Output:

In [1]:
import sys

from shapely.geometry import box
import geopandas as gpd
import ee
import pickle
import requests
import os
from osgeo import gdal, ogr
from common_functions import *
import shutil
from time import sleep
from math import ceil
from googleapiclient.discovery import build

ee.Initialize()

# Constants

In [2]:
# this is about 10 meters
METERS_TO_DECIMAL_DEGREES_CONST = 1/30/3600

#the value we use to signify no data at a pixel
NO_DATA_VALUE = 65535

#this is the biggest region we allow to avoid data overflow errors and keep files manageable
MAX_REGION_SIZE = 0.5

#the base region of interest folder
BASE_ROI_FOLDER = 'regions'
if BASE_ROI_FOLDER not in os.listdir():
    os.mkdir(BASE_ROI_FOLDER)

# User Input Area

In [3]:
BASELINE_WETLAND_PATH = 'C://Users/ritvik/Desktop/JPLProject/cifor_wetlands_colombia.tif'

In [4]:
AREA_OF_INTEREST_FILE = 'area_of_interest/mangrove.shp'

In [5]:
GOOGLE_EARTH_ENGINE_GDRIVE_FOLDER_ID = '1KvlrUHs_rN7xPlw53qtd9pweeLwmrJSP'

### Google Drive Access Instructions

In [6]:
MAX_CONSIDERED_ELEVATION = 6

In [7]:
SELECTED_BANDS = ['B2', 'NDVI', 'HH']

# Functions to Download File from Google Drive

In [8]:
def download_file_from_google_drive(file_id, destination):
    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params = { 'id' : file_id }, stream = True)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : file_id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True)

    save_response_content(response, destination)    

In [9]:
def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None

In [10]:
def save_response_content(response, destination):
    CHUNK_SIZE = 32768

    with open(destination, "wb") as f:
        for chunk in response.iter_content(CHUNK_SIZE):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)

In [11]:
def get_file_ids_from_google_drive():
    creds = None
    if os.path.exists('token.pickle'):
        with open('token.pickle', 'rb') as token:
            creds = pickle.load(token)

    service = build('drive', 'v3', credentials=creds)

    result = service.files().list(q="parents in '%s'"%GOOGLE_EARTH_ENGINE_GDRIVE_FOLDER_ID).execute()

    folder_name_to_file_id = {info['name'].split('-')[0]: info['id'] for info in result['files'] if len(info['name'].split('-')) == 2}
    
    return folder_name_to_file_id

In [12]:
def delete_file_from_google_drive_by_file_id(fid):
    creds = None
    if os.path.exists('token.pickle'):
        with open('token.pickle', 'rb') as token:
            creds = pickle.load(token)

    service = build('drive', 'v3', credentials=creds)
    
    service.files().delete(fileId=fid).execute()

# Data Collection Functions

In [13]:
def clip_raster_by_shapefile(source_raster, shapefile_path, save_path):
    """
    Given some raster, this function clips the raster gien the shape of another raster,
    
    source_raster: a TIFF that you wish to crop
    shapefile_path: a SHP whose bounds you will use to crop the source_raster
    save_path: the eventual place to save the cropped TIFF
    """
    source_ds = gdal.Open(source_raster, gdal.GA_ReadOnly)
    
    options = gdal.WarpOptions(format='GTiff', cutlineDSName=shapefile_path, cropToCutline=True)
    ds = gdal.Warp(save_path, source_ds, options=options)
    
    ds = None
    source_ds = None

In [14]:
def generate_region_folders(minx, miny, maxx, maxy, path):
    """
    Using the given extent, create a folder with shapefile info at the given path
    
    minx, miny, maxx, maxy: the extent of the region we wish to analyze
    path: the directory where to store all the sub-reion subdirectories
    """
    
    diff_x = (maxx - minx)
    diff_y = (maxy - miny)
    
    num_x = int(diff_x // MAX_REGION_SIZE + 1)
    num_y = int(diff_y // MAX_REGION_SIZE + 1)
    
    size_x = diff_x / num_x
    size_y = diff_y / num_y
    
    print(num_x, num_y, size_x, size_y)
    
    for i in range(num_x):
        for j in range(num_y):
            geo_box = box(minx+i*size_x, miny+j*size_y, minx+(i+1)*size_x, miny+(j+1)*size_y)
            df = gpd.GeoDataFrame(geometry=[geo_box], crs={'init':'epsg:4326'})
            df.to_file('%s_%s_%s'%(path, i, j))

In [15]:
def get_bands_from_region(folders_to_process, features, gdrive_folder, date_range, primary_dataset, wetland_code, selected_bands):
    """
    This function accepts the below parameters and querys Google Earth Engine for data. The data is stored in 
    Google Drive.
    
    folders_to_process: the folders where to find the regions of interest
    features: a dictionary of features to include in the resulting data cubes
    gdrive_folder: the name of the folder on Google Drive to store the results
    date_range: the date range for this data
    primary_dataset: the dataset to use for the eventual image resolution
    wetland_code: the code for the sub-type of wetland we will analyze
    """
    
    #this will store all started tasks
    tasks = {}
    
    #work through each sub-region 
    for region_folder in folders_to_process:
        
        filtered_imgs = []
        region_name = region_folder.split('/')[-1]
        
        print('Working on region folder: %s...'%region_name)
        
        print('Created Baseline Wetlands Raster...')
        #clip the baseline map of wetlands and store in sub-directory
        print('%s/%s.shp'%(region_folder, region_name))
        baseline_file_name = '%s/baseline_%s.tiff'%(region_folder, region_name)
        clip_raster_by_shapefile(BASELINE_WETLAND_PATH, '%s/%s.shp'%(region_folder, region_name), baseline_file_name)
        ds = gdal.Open(baseline_file_name, gdal.GA_ReadOnly)
        arr = ds.ReadAsArray()
        ds = None
        pct_wetland = np.mean(arr == wetland_code)
        print(pct_wetland)
        if pct_wetland < 0.01:
            print('Deleting Region Folder')
            print('==================================')
            shutil.rmtree(region_folder)
            continue
        
    
        #read the area of interest
        df = gpd.read_file(region_folder)

        #get the coordinates of that area
        area_coords = df.geometry[0].exterior.coords[:]
        area_coords = [list(pair) for pair in area_coords]

        #get the minx, miny, maxx, maxy
        x1 = min([item[0] for item in area_coords])
        y1 = max([item[1] for item in area_coords])

        x2 = max([item[0] for item in area_coords])
        y2 = min([item[1] for item in area_coords])

        #store the reference coordinates
        ref_coords = (x1,y1)

        #create an area of interest from Earth Engine Geometry
        area_of_interest = ee.Geometry.Polygon(coords=area_coords)

        #iterate over each data source
        for data_type_source, bands in features.items():
            data_type = data_type_source[0]
            data_source = data_type_source[1]
                
            print('Working on data source: %s...'%data_source)
            
            if data_type == 'collection':
                #access the Earth Engine image collection with the specified bands
                data = ee.ImageCollection(data_source).select(bands)

                #filter on date range
                data_filtered = data.filterBounds(area_of_interest).filterDate(date_range[0], date_range[1])

                #ensure there is at least 1 image
                num_items = data_filtered.size().getInfo()
                if num_items == 0:
                    print('no items found, returning started tasks.')
                    return tasks

                band_info = data_filtered.first().getInfo()['bands'][0]

                #if crs is already EPSG 4326, get resolution directly, otherwise need to transform from meters
                if band_info['crs'] == 'EPSG:4326':
                    res = band_info['crs_transform'][0]
                else:
                    res = band_info['crs_transform'][0] * METERS_TO_DECIMAL_DEGREES_CONST

                #if this is the eventual primary dataset, store its resolution
                if data_source == primary_dataset:
                    eventual_res = res

                #get a mosaic as median of all returned images
                mosaic = ee.Image(data_filtered.median())

                if data_source == 'COPERNICUS/S2_SR':

                    #calculate NDVI
                    if 'B4' in features[('collection','COPERNICUS/S2_SR')] and 'B8' in features[('collection','COPERNICUS/S2_SR')]:
                        ndvi = mosaic.normalizedDifference(['B4', 'B8']).rename('NDVI')
                        mosaic = ee.Image.addBands(mosaic, ndvi)

                    #calculate NDWI
                    if 'B3' in features[('collection','COPERNICUS/S2_SR')] and 'B8' in features[('collection','COPERNICUS/S2_SR')]:
                        ndwi = mosaic.normalizedDifference(['B3', 'B8']).rename('NDWI')
                        mosaic = ee.Image.addBands(mosaic, ndwi)
                        
            elif data_type == 'image':
                mosaic = ee.Image(data_source).select(bands)

            #add this mosaic to the list
            filtered_imgs.append(mosaic)
            
        
        #generate file name
        features_str = '_'.join([item[1] for item in features.keys()]).replace('/','_')
        fname = '%s-%s'%(region_folder.split('/')[-1], features_str)
        print(fname)
        
        #add the various layers on top of each other to create a data cube with all features
        final_img = ee.Image()
        
        for img in filtered_imgs:
            final_img = ee.Image.addBands(final_img,img)
        
        #use the ALOS qa band to filter out invalid pixels
        if 'qa' in features[('collection','JAXA/ALOS/PALSAR/YEARLY/SAR')]:
            qa_band = final_img.select('qa')
            qa_mask = qa_band.lt(51)
            final_img = final_img.where(qa_mask, NO_DATA_VALUE)
        
        #use the Sentinel-2 SCL band to filter out invalid pixels
        if 'SCL' in features[('collection','COPERNICUS/S2_SR')]:
            scl_band = final_img.select('SCL')
            scl_nodata_vals = [0,3,6,8,9,10]
            scl_mask = scl_band.eq(0)
            for v in scl_nodata_vals:
                scl_mask = scl_mask.Or(scl_band.eq(v))
            final_img = final_img.where(scl_mask, NO_DATA_VALUE)
            
        #use the SRTM elevationband to filter out invaild pixels
        if 'elevation' in features['image','CGIAR/SRTM90_V4']:
            elevation_band = final_img.select('elevation')
            elevation_mask = elevation_band.gt(MAX_CONSIDERED_ELEVATION)
            final_img = final_img.where(elevation_mask, NO_DATA_VALUE)
            
        #if any of the selected bands has NO_DATA_VALUE, mark that whole pixel as NO_DATA_VALUE
        for b in selected_bands:
            b_values = final_img.select(b)
            b_mask = b_values.eq(NO_DATA_VALUE)
            final_img = final_img.where(b_mask, NO_DATA_VALUE)
         
        #store the result with just the needed bands
        selected_bands = sorted(selected_bands)
        result = final_img.select(*selected_bands).float()
          
        #define the task to gather the data
        task = ee.batch.Export.image.toDrive(image=result,
                                             region=area_of_interest.getInfo()['coordinates'],
                                             description=region_folder.split('/')[-1],
                                             folder=gdrive_folder,
                                             fileNamePrefix=fname,
                                             crs_transform=[eventual_res, 0.0, ref_coords[0], 0.0, -eventual_res, ref_coords[1]],
                                             crs='EPSG:4326')
        
        #store the task
        tasks[fname] = task
        
        print('==================================')
    
    return list(tasks.items())

# Driver Code

In [16]:
driver = ogr.GetDriverByName('ESRI Shapefile')

dataSource = driver.Open(AREA_OF_INTEREST_FILE, gdal.GA_ReadOnly)

layer = dataSource.GetLayer()

wetland_type = layer.GetDescription().replace('/','').replace(' ','').lower()
wetland_type_file = open('wetland_type.txt', 'w')
wetland_type_file.write(wetland_type)
wetland_type_file.close()

wetland_code = feature_to_code[wetland_type]

curr_feature = layer.GetNextFeature()
curr_subregion_idx = 0

while curr_feature != None:
    minx, maxx, miny, maxy = curr_feature.GetGeometryRef().GetEnvelope()
    generate_region_folders(minx, miny, maxx, maxy, '%s/region_%s'%(BASE_ROI_FOLDER, curr_subregion_idx))
    
    curr_feature = layer.GetNextFeature()
    curr_subregion_idx += 1

2 2 0.3342691523898793 0.32751624022038595
1 1 0.47347826836990237 0.339933115752757
1 3 0.3035117104935239 0.48157191398307003
2 2 0.45526756574030003 0.26102007102443725


In [17]:
search_area_name = 'region'
folders_to_process = ['%s/%s'%(BASE_ROI_FOLDER, item) for item in os.listdir(BASE_ROI_FOLDER) if search_area_name in item]
features = {('collection','JAXA/ALOS/PALSAR/YEARLY/SAR'): ['HH', 'HV', 'qa'], ('collection', 'COPERNICUS/S2_SR'): ['B2', 'B3', 'B4', 'B8', 'SCL'], ('image','CGIAR/SRTM90_V4'): ['elevation']}
date_range = ['2017-01-01', '2019-01-01']
gdrive_folder = 'GoogleEarthEngine'

In [18]:
tasks = get_bands_from_region(folders_to_process, features, gdrive_folder, date_range, 'COPERNICUS/S2_SR', wetland_code, SELECTED_BANDS)

Working on region folder: region_0_0_0...
Created Baseline Wetlands Raster...
regions/region_0_0_0/region_0_0_0.shp
0.10341269841269841
Working on data source: JAXA/ALOS/PALSAR/YEARLY/SAR...
Working on data source: COPERNICUS/S2_SR...
Working on data source: CGIAR/SRTM90_V4...
region_0_0_0-JAXA_ALOS_PALSAR_YEARLY_SAR_COPERNICUS_S2_SR_CGIAR_SRTM90_V4
Working on region folder: region_0_0_1...
Created Baseline Wetlands Raster...
regions/region_0_0_1/region_0_0_1.shp
0.1104421768707483
Working on data source: JAXA/ALOS/PALSAR/YEARLY/SAR...
Working on data source: COPERNICUS/S2_SR...
Working on data source: CGIAR/SRTM90_V4...
region_0_0_1-JAXA_ALOS_PALSAR_YEARLY_SAR_COPERNICUS_S2_SR_CGIAR_SRTM90_V4
Working on region folder: region_0_1_0...
Created Baseline Wetlands Raster...
regions/region_0_1_0/region_0_1_0.shp
0.05307256235827664
Working on data source: JAXA/ALOS/PALSAR/YEARLY/SAR...
Working on data source: COPERNICUS/S2_SR...
Working on data source: CGIAR/SRTM90_V4...
region_0_1_0-JAXA_A

In [19]:
batch_size = 12

#process the tasks in small batches to avoid memory running out
for batch_idx in range(ceil(len(tasks) / batch_size)):
    
    #get the current batch of tasks
    curr_tasks = tasks[batch_size*batch_idx:batch_size*(batch_idx+1)]
    print('Processing Batch %s'%(batch_idx+1))

    #start all tasks in that batch
    for name,task in curr_tasks:
        task.start()
        
    print('Started all tasks in batch')
     
    #wait until all tasks in that batch are done, wait 1 minute between checks
    curr_states = [task.status()['state'] for name,task in curr_tasks]
    while set(curr_states) != {'COMPLETED'}:
        print('Current states: %s'%curr_states)
        sleep(120)
        curr_states = [task.status()['state'] for name,task in curr_tasks]
      
    #once all tasks done, get their file ids on google drive
    folder_name_to_file_id = get_file_ids_from_google_drive()
    
    #for each file...
    for roi_folder, fid in folder_name_to_file_id.items():
    
        #get feature file name
        features_file_name = '%s/%s/features_%s.tiff'%(BASE_ROI_FOLDER, roi_folder, roi_folder)
        
        #check if data already downloaded
        if features_file_name.split('/')[-1] not in os.listdir('%s/%s'%(BASE_ROI_FOLDER, roi_folder)):
            print('Downloading %s from Drive'%roi_folder)
            download_file_from_google_drive(fid, features_file_name)
         
        print('Deleting %s from Drive'%roi_folder)
        delete_file_from_google_drive_by_file_id(folder_name_to_file_id[roi_folder])
    
    print('================================')

Processing Batch 1
Started all tasks in batch
Current states: ['RUNNING', 'RUNNING', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY']
Current states: ['RUNNING', 'RUNNING', 'RUNNING', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY']
Current states: ['RUNNING', 'RUNNING', 'RUNNING', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY']
Current states: ['COMPLETED', 'COMPLETED', 'COMPLETED', 'RUNNING', 'RUNNING', 'RUNNING', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY']
Current states: ['COMPLETED', 'COMPLETED', 'COMPLETED', 'RUNNING', 'RUNNING', 'RUNNING', 'READY', 'READY', 'READY', 'READY', 'READY', 'READY']
Current states: ['COMPLETED', 'COMPLETED', 'COMPLETED', 'COMPLETED', 'RUNNING', 'RUNNING', 'RUNNING', 'READY', 'READY', 'READY', 'READY', 'READY']
Current states: ['COMPLETED', 'COMPLETED', 'COMPLETED', 'COMPLETED', 'COMPLETED', 'COMPLETED', 'RUNNING', 'RUNNING', 'RUNNING', 'R