<a href="https://colab.research.google.com/github/puzhao89/Adversarial_Autoencoder/blob/master/SAR4Wildfire_EO_Datacubes_Processing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


## **D3.2 Prototype for analysis ready Sentinel-1/2 data cubes**

This script is written to query, process and export time series earth observation data, including Sentinel-1 C-Band SAR, Sentinel-2 and Landsat-8 multispectral data, with Earth Engine Python API.

0. Setup Software libaraies
1. Define Study Areas
2. EO Data Processing
3. Export EO Data

# Step 0: Import libraries

Authenticate and import as necessary.

In [0]:
# Cloud authentication.
from google.colab import auth
auth.authenticate_user()

In [0]:
# Import, authenticate and initialize the Earth Engine library.
import ee
ee.Authenticate()
ee.Initialize()

In [0]:
# Folium setup.
import folium
print(folium.__version__)

# Step 1: Define Study Areas

### 1.1 Wildfire Event

In [0]:

"""////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""
""" ================================================================= Study Areas ==================================================================== """
"""////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""

""" ==============================> Chuckegg Wildfire (2019) <==============================="""
CA_Chuckegg_roi = (ee.Geometry.Rectangle(
    [-118.15516161677635, 58.80344955293542,
      -116.58411669490135, 57.78735191577551]))
Chuckegg = {
    'name': 'Chuckegg',
    'roi': CA_Chuckegg_roi,
    'crs': 'EPSG:32611',

    'startDate': '2019-05-15',  # 2019-05-18
    'endDate': '2019-08-20',  # 2019-10-01

    'msiExportDates': ['2019-05-21', '2019-05-24', '2019-05-26', '2019-05-28', 
                        '2019-06-10', '2019-06-17', '2019-06-25', '2019-07-11', '2019-07-15', 
                        '2019-07-17', '2019-07-20', '2019-08-06', '2019-08-12'],

    # MSI
    'S2_Master': '2019-05-16',  # good S2 prefire date
    'L8_Master': '2019-07-05',  # good L8 prefire date

    # SAR
    'DSC13': '2017-05-17',
    'DSC115': '2017-05-12',  # prefire date
    'ASC64': '2017-06-26',
    'ASC20': '2017-05-12',
    # SAR-orbit
    'orbNumList': [20],
    'bandList': ['VV', 'VH']
}
""" ============================= Elephant Wildfire (2017) =============================== """
elephant_roi = ee.Geometry.Rectangle([-121.7697, 50.6512, -120.7068, 51.5224])
elephant_refPoly = ee.FeatureCollection("users/omegazhangpzh/elephant_refPoly")

elephant = {
    'name': 'elephant',
    'roi': elephant_roi,
    'poly': elephant_refPoly,
    'crs': 'EPSG: 32610',

    'startDate': '2017-07-05',  # 2017-07-07
    'endDate': '2017-10-06',  # 2017-10-01

    # MSI
    'S2_Master': '2017-11-06',  # good S2 prefire date
    'L8_Master': '2017-07-05',  # good L8 prefire date

    # SAR
    'DSC13': '2017-05-17',
    'DSC115': '2017-05-12',  # prefire date
    'ASC64': '2017-06-26',

    # SAR-orbit
    'orbitList': ['ASC64'],  # ['DSC13', 'DSC115', 'ASC64']
    'bandList': ['VV', 'VH']
}

""" ============================= Sydney Wildfire (2019) ======================================= """
Sydney_roi = (ee.Geometry.Rectangle(
    [149.62892366273888, -34.46629366617598,
      151.60646272523888, -32.24350555003469]))

Sydney_msiExportDateList = ['2019-10-22',  # master
      '2019-10-27', '2019-10-28', '2019-11-01', '2019-11-06', '2019-11-11', '2019-11-13', '2019-11-21',
      '2019-12-16', '2019-12-21', '2019-12-26', '2019-12-31', '2020-01-05', '2020-01-10',
      '2019-12-11',  # a bit cloudy, but still good to use
      # '2019-12-15', # not good
      ]

Sydney_sarExportDateList = ['2019-10-28', '2019-11-06', '2019-11-09', '2019-11-18', '2019-11-21', '2019-11-27',
                      '2019-11-30', '2019-12-12', '2019-12-15', '2019-12-24', '2019-12-27', '2020-01-05',
                      '2020-01-08']

Sydney_AU = {
    'name': 'Sydney_small',
    'roi': Sydney_roi,  # Sydney_roi_1111,
    'poly': Sydney_roi,
    'crs': 'EPSG:3577',  # 'EPSG:3577', # EPSG:3577
    'pntsRect': (ee.Geometry.Rectangle([150.0941018581052, -34.12241279581038,
                                        151.1927346706052, -32.894960055302946])),  # AU

    'startDate': '2019-10-22',  # 2018-11-08
    'endDate': '2020-01-11',  # 2018-12-10
    'msiExportDates': Sydney_msiExportDateList,
    'sarExportDates': Sydney_sarExportDateList,

    'S2_Master': '2019-10-22',
    # 'L8_Master': '2018-11-06',

    'ASC9': '2019-10-16',
    'DSC147': '2019-10-25',  # prefire date
    # 'DSC42': '2018-11-04',
    # 'ASC137': '2018-11-05',
    #
    # 'orbitList': ['ASC35', 'DSC115', 'DSC42', 'ASC137'],
    'orbNumList': [9, 147],
    'bandList': ['VV', 'VH']
}

### Karbole Wildfire 2018, Sweden
Karbole_roi = ee.Geometry.Rectangle([
    15.137434283688016, 61.86566784664094,
    15.604353229000516, 62.06961520427164])

Karbole_SE = {
    'name': 'Karbole_SE',
    'roi': Karbole_roi,
    'crs': 'EPSG:3006',

    'startDate': '2018-07-01',  # 2019-05-18
    'endDate': '2018-08-10',  # 2019-10-01
    'msExportDates': ['2018-07-16', '2018-07-17','2018-07-19', '2018-07-24','2018-07-26',
                      '2018-07-27','2018-07-31','2018-08-03','2018-08-08'],

    # MSI Master Dates
    'S2_Master': '2018-07-04',  # good S2 prefire date
    # 'L8_Master': '2017-07-05',  # good L8 prefire date

    # SAR Master Dates
    # 'ASC102': '',
    'DSC66': '2018-07-09',
    'DSC168': '2018-07-10',
    'DSC95': '2018-07-11',  # '2018-07-05',

    # SAR-orbit
    'orbNumList': [66, 168, 95], 
    'bandList': ['VV', 'VH']
}


### 1.2 Set Global Variable 

In [0]:
# """////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""
# """ =============================================================== Configuration ==================================================================== """
# """////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""

fireEvent = Sydney_AU
CRS = "EPSG:4326" # fireEvent['crs'] #"EPSG:4326"  # fireEvent['epsg']#Amazon #"EPSG:3995"# Russia  # ""EPSG:32610"

""" Configuration """
scale = 20
cloudLevel = 100 # used for setting cloud level
groupLevel = 10  # 2017-04-28T12:12:35(len: 19), 2017-04-28T12(len: 13)
labelShowLevel = 10  # IMG_LABEL: if 10 then 2017-04-28 ==> 20170428, if 13 then 20170428T12

sarQueryFlag = True # query
msiQueryFlag = True



# Step 2:  EO Datacubes Processing



### **2.1** Datacube Processing Functions

In [0]:
""" ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// """
""" /////////////////////////////////////////     Functions for EE Preprocessing     //////////////////////////////////////// """
""" ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// """

import ee, time
ee.Initialize()

def set_timeEnd_newdays(img):
    group_days = img.date().format().slice(0, 10)
    return img.set('system:time_end', group_days)


""" SAR Group Days """
def set_group_index_4_S1(img):
    orbitKey = (ee.String(img.get("orbitProperties_pass")).replace('DESCENDING', 'DSC').replace('ASCENDING', 'ASC')
                .cat(ee.Number(img.get("relativeOrbitNumber_start")).int().format()))
    Date = (img.date().format().slice(0, labelShowLevel).replace('-', '').replace('-', ''))
    Name = (Date).cat('_').cat(orbitKey)

    groupIndex = img.date().format().slice(0, groupLevel)  # 2017 - 07 - 23T14:11:22(len: 19)
    return img.setMulti({
        'GROUP_INDEX': groupIndex,
        'IMG_LABEL': Name,
        'Orbit_Key': orbitKey
    })


# "group by" date
def group_S1_by_date_orbit(imgcollection):
    imgCol_sort = imgcollection.sort("system:time_start")
    imgCol = imgCol_sort.map(set_group_index_4_S1)
    d = imgCol.distinct(['GROUP_INDEX'])
    di = ee.ImageCollection(d)
    date_eq_filter = (ee.Filter.equals(leftField='system:time_end',
                                       rightField='system:time_end'))

    date_eq_filter = (ee.Filter.And(
        ee.Filter.equals(leftField='GROUP_INDEX', rightField='GROUP_INDEX')
        , ee.Filter.equals(leftField='Orbit_Key', rightField='Orbit_Key')
        , ee.Filter.equals(leftField='transmitterReceiverPolarisation', rightField='transmitterReceiverPolarisation')
    ))

    saveall = ee.Join.saveAll("to_mosaic")
    j = saveall.apply(di, imgCol, date_eq_filter)
    ji = ee.ImageCollection(j)

    def mosaicImageBydate(img):
        ## Old version
        # mosaiced = ee.ImageCollection.fromImages(img.get('to_mosaic')).mosaic().updateMask(1)
        # return ee.Image(mosaiced).copyProperties(img, img.propertyNames())

        imgCol2mosaic = ee.ImageCollection.fromImages(img.get('to_mosaic'))
        firstImgGeom = imgCol2mosaic.first().geometry()
        mosaicGeom = ee.Geometry(imgCol2mosaic.iterate(unionGeomFun, firstImgGeom))
        mosaiced = imgCol2mosaic.mosaic().copyProperties(img, img.propertyNames())
        return ee.Image(mosaiced).set("system:footprint", mosaicGeom)  # lost

    imgcollection_grouped = ji.map(mosaicImageBydate)
    return ee.ImageCollection(imgcollection_grouped.copyProperties(imgCol, imgCol.propertyNames()))


def add_RBR(img):
    RBR = img.expression('b("VH")-b("VV")').rename("RBR")  # (b("VV")-b("VH"))/(b("VV")+b("VH"))
    return ee.Image(img.addBands(RBR).copyProperties(img, img.propertyNames()))


def printList(inList, markString=None):
    print("---------{}: {}----------".format(markString, len(inList)))
    for ele in inList:
        print(ele)
    print("---------------------\n")


""" Preprocessing Funs for MS Images """


def S2_bandRename(img):
    toBandNameList = ['B', 'G', 'R', 'NIR', 'SWIR1', 'SWIR2', 'cloud']
    return (img.select(['B2', 'B3', 'B4', 'B8', 'B11', 'B12', 'QA60'])
            .rename(toBandNameList).copyProperties(img, img.propertyNames()))


def L8_bandRename(img):
    toBandNameList = ['B', 'G', 'R', 'NIR', 'SWIR1', 'SWIR2', 'cloud']
    # toBandNameList = ['B2', 'B3', 'B4', 'NIR', 'B11', 'B12', 'cloud']
    return (img.select(['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'BQA'])
            .rename(toBandNameList).copyProperties(img, img.propertyNames()))


def add_NDVI(img):
    NDVI = img.normalizedDifference(['NIR', 'R']).select('nd').rename('NDVI')
    return img.addBands(NDVI).copyProperties(img, img.propertyNames())


def add_NBR(img):
    NBR = img.normalizedDifference(['NIR', 'SWIR2']).select('nd').rename('NBR')
    NBR1 = img.normalizedDifference(['SWIR1', 'SWIR2']).select('nd').rename('NBR1')
    return img.addBands(NBR).addBands(NBR1).copyProperties(img, img.propertyNames())


def updateCloudMaskL8(img):
    qa = img.select('cloud')  # BQA
    mask = qa.bitwiseAnd(1 << 4).eq(0)
    return img.addBands(mask, overwrite=True) # 0 for cloud, 1 for clear

def updateCloudMaskS2(img):
    qa = img.select('cloud')  # QA60
    cloudBitMask = 1 << 10
    cirrusBitMask = 1 << 11
    mask = qa.bitwiseAnd(cloudBitMask).eq(0).And(
        qa.bitwiseAnd(cirrusBitMask).eq(0))
    return img.addBands(mask, overwrite=True)

def mask_L8_clouds(img):
    return img.updateMask(updateCloudMaskL8(img).select('cloud'))

def mask_S2_clouds(img):
    return img.updateMask(updateCloudMaskS2(img).select('cloud'))


# def maskClouds(image):
#     score = image.select('cloud')  # S2: QA60
#     mask = score.lt(1)
#     return image.updateMask(mask).copyProperties(image, image.propertyNames())
#
#
# def L8_cloud_div10(img):
#     return img.addBands(srcImg=ee.Image(img).select('cloud').divide(10),
#                         overwrite=True)


""" ====================== Sentinel-2 ======================"""


def set_S2_group_index(img):
    imgDateStr = img.date().format()

    groupIndex = imgDateStr.slice(0, groupLevel)  # 2017 - 07 - 23 T14:11:22(len: 19)

    Date = imgDateStr.slice(0, labelShowLevel).replace('-', '').replace('-', '')

    imgLabel = (Date).cat("_S2")

    return img.setMulti({
        'GROUP_INDEX': groupIndex,
        'SAT_NAME': 'S2',
        'IMG_LABEL': imgLabel
    })


def unionGeomFun(img, first):
    rightGeo = ee.Geometry(img.geometry())
    return ee.Geometry(first).union(rightGeo)


# "group by date
def group_S2_ImgCol(imgcollection, multiSensorGroupFlag=False):
    imgCol_sort = imgcollection.sort("system:time_start")
    imgCol = imgCol_sort.map(set_S2_group_index)

    d = imgCol.distinct(['GROUP_INDEX'])
    di = ee.ImageCollection(d)

    # Join collection to itself grouped by date
    date_eq_filter = ee.Filter.And(
        ee.Filter.equals(leftField='GROUP_INDEX', rightField='GROUP_INDEX')
        , ee.Filter.equals(leftField='SAT_NAME', rightField='SAT_NAME'))

    if (multiSensorGroupFlag):  # if it is allowed to group data from multiple sensor.
        date_eq_filter = ee.Filter.equals(leftField='GROUP_INDEX', rightField='GROUP_INDEX')

    saveall = ee.Join.saveAll("to_mosaic")
    j = saveall.apply(di, imgCol, date_eq_filter)
    ji = ee.ImageCollection(j)

    original_proj = ee.Image(ji.first()).select(0).projection()

    def mosaicImageBydate(img):
        ## Old version
        # mosaiced = ee.ImageCollection.fromImages(img.get('to_mosaic')).mosaic().updateMask(1)
        # return ee.Image(mosaiced).copyProperties(img, img.propertyNames())

        imgCol2mosaic = ee.ImageCollection.fromImages(img.get('to_mosaic'))
        firstImgGeom = imgCol2mosaic.first().geometry()
        mosaicGeom = ee.Geometry(imgCol2mosaic.iterate(unionGeomFun, firstImgGeom))
        mosaiced = imgCol2mosaic.mosaic().copyProperties(img, img.propertyNames())
        return ee.Image(mosaiced).set("system:footprint", mosaicGeom)  # lost

    imgcollection_grouped = ji.map(mosaicImageBydate)
    return ee.ImageCollection(imgcollection_grouped.copyProperties(imgCol, imgCol.propertyNames()))


"""// // / == == == == == == == == Grouping Landsat-8 ImageCollections == == == == == == == == =="""


def set_L8_group_index(img):
    imgDateStr = img.date().format()
    groupIndex = imgDateStr.slice(0, groupLevel)  # 2017 - 07 - 23T14:11:22(len: 19)

    Date = imgDateStr.slice(0, labelShowLevel).replace('-', '').replace('-', '')

    imgLabel = (Date).cat("_L8")

    return img.setMulti({
        'GROUP_INDEX': groupIndex,
        'SAT_NAME': 'L8',
        'IMG_LABEL': imgLabel
    })


# "group by date
def group_L8_ImgCol(imgcollection, multiSensorGroupFlag=False):
    imgCol_sort = imgcollection.sort("system:time_start")
    imgCol = imgCol_sort.map(set_L8_group_index)

    d = imgCol.distinct(['GROUP_INDEX'])
    di = ee.ImageCollection(d)

    # Join collection to itself grouped by date
    date_eq_filter = ee.Filter.And(
        ee.Filter.equals(leftField='GROUP_INDEX', rightField='GROUP_INDEX')
        , ee.Filter.equals(leftField='SAT_NAME', rightField='SAT_NAME'))

    if (multiSensorGroupFlag):  # if it is allowed to group data from multiple sensor.
        date_eq_filter = ee.Filter.equals(leftField='GROUP_INDEX', rightField='GROUP_INDEX')

    saveall = ee.Join.saveAll("to_mosaic")
    j = saveall.apply(di, imgCol, date_eq_filter)
    ji = ee.ImageCollection(j)

    original_proj = ee.Image(ji.first()).select(0).projection()

    def mosaicImageBydate(img):
        ## Old version
        # mosaiced = ee.ImageCollection.fromImages(img.get('to_mosaic')).mosaic().updateMask(1)
        # return ee.Image(mosaiced).copyProperties(img, img.propertyNames())

        imgCol2mosaic = ee.ImageCollection.fromImages(img.get('to_mosaic'))
        firstImgGeom = imgCol2mosaic.first().geometry()
        mosaicGeom = ee.Geometry(imgCol2mosaic.iterate(unionGeomFun, firstImgGeom))
        mosaiced = imgCol2mosaic.mosaic().copyProperties(img, img.propertyNames())
        return ee.Image(mosaiced).set("system:footprint", mosaicGeom)  # lost

    imgcollection_grouped = ji.map(mosaicImageBydate)
    return ee.ImageCollection(imgcollection_grouped.copyProperties(imgCol, imgCol.propertyNames()))

""" ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// """
""" /////////////////////////////////////////     Functions for Exporting Data from EE     ////////////////////////////////// """
""" ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// """

def check_export_task(task, imgName):
    # Block until the task completes.
    print('Running export from GEE to drive or cloudStorage...')
    while task.active():
        time.sleep(30)

    # Error condition
    if task.status()['state'] != 'COMPLETED':
        print("Error with export: {}".format(imgName))
    else:
        print("Export completed: {}".format(imgName))

def batch_imgCol_export_to_drive(imgCol, N=100, KERNEL_SIZE=256, bands4download=[''], toDriveFolder="Sydney_tfRecord"):
    print("------------------------- Download Start ... ----------------------------------")
    downLoadList = (imgCol).aggregate_array('IMG_LABEL').getInfo()
    printList(downLoadList, 'downloadList: ')

    # KERNEL_SIZE = 256R
    n = 10
    # N = 100
    list = ee.List.repeat(1, KERNEL_SIZE)
    lists = ee.List.repeat(list, KERNEL_SIZE)
    kernel = ee.Kernel.fixed(KERNEL_SIZE, KERNEL_SIZE, lists)

    num = imgCol.size().getInfo()
    imgColList = imgCol.toList(num)

    for i in range(0, num):
        img = ee.Image(imgColList.get(i))
        imgName = img.get('IMG_LABEL').getInfo().replace(":", '')

        arrays = img.neighborhoodToArray(kernel)
        featCol = ee.FeatureCollection([])
        for SEED in range(n):
            tmpfeatCol = arrays.sample(
                region=roi,
                scale=20,
                numPixels=N / n,
                seed=SEED,
                tileScale=8).select(bands4download)
            featCol = featCol.merge(tmpfeatCol)

        task = ee.batch.Export.table.toDrive(
            collection=featCol,
            description="{}_ks{}_N{}".format(imgName, KERNEL_SIZE, featCol.size().getInfo()),
            folder=toDriveFolder,
            fileNamePrefix="{}_ks{}_N{}".format(imgName, KERNEL_SIZE, featCol.size().getInfo()),
            fileFormat="TFRecord")

        task.start()
        check_export_task(task, imgName)

def batch_imgCol_export_to_cloud(imgCol, N=100, KERNEL_SIZE=256, toDriveFolder="Sydney_NRT_roi"):
    print("------------------------- Download Start ... ----------------------------------")
    downLoadList = (imgCol).aggregate_array('IMG_LABEL').getInfo()
    printList(downLoadList, 'downloadList: ')

    # KERNEL_SIZE = 256R
    n = 1
    # N = 100
    list = ee.List.repeat(1, KERNEL_SIZE)
    lists = ee.List.repeat(list, KERNEL_SIZE)
    kernel = ee.Kernel.fixed(KERNEL_SIZE, KERNEL_SIZE, lists)

    num = imgCol.size().getInfo()
    imgColList = imgCol.toList(num)

    for i in range(0, num):
        img = ee.Image(imgColList.get(i))
        imgName = img.get('IMG_LABEL').getInfo().replace(":", '')

        arrays = img.neighborhoodToArray(kernel)
        featCol = ee.FeatureCollection([])
        for SEED in range(n):
            tmpfeatCol = arrays.sample(
                region=roi,
                scale=20,
                numPixels=N / n,
                seed=SEED,
                tileScale=8)

            featCol = featCol.merge(tmpfeatCol)

        desc = "{}_ks{}_N{}".format(imgName, KERNEL_SIZE, featCol.size().getInfo()),
        task = ee.batch.Export.table.toCloudStorage(
            collection=featCol,
            description=desc,
            bucket=BUCKET,
            folder=toDriveFolder + "/" + desc,
            fileNamePrefix=desc,
            fileFormat="TFRecord",
            selectors=BANDS
        )

        task.start()
        check_export_task(task, imgName)

    
def buildGeometryPoint(pnt):
    return pnt.setGeometry(ee.Geometry.Point([pnt.get("longitude"), pnt.get("latitude")]))

def batch_imgCol_export_to_cloud_classBalanced(imgCol, classBand, N=100, KERNEL_SIZE=256, toDriveFolder="Sydney_NRT"):
    print("------------------------- Download Start ... ----------------------------------")
    downLoadList = (imgCol).aggregate_array('IMG_LABEL').getInfo()
    printList(downLoadList, 'to_cloud_classBalanced')

    # KERNEL_SIZE = 256R
    n = 30
    # N = 100
    list = ee.List.repeat(1, KERNEL_SIZE)
    lists = ee.List.repeat(list, KERNEL_SIZE)
    kernel = ee.Kernel.fixed(KERNEL_SIZE, KERNEL_SIZE, lists)

    num = imgCol.size().getInfo()
    imgColList = imgCol.toList(num)

    for i in range(0, num):
        img = ee.Image(imgColList.get(i))
        imgName = img.get('IMG_LABEL').getInfo().replace(":", '')
  
        arrays = img.neighborhoodToArray(kernel)
        featCol = ee.FeatureCollection([])
        for SEED in range(n):
            # print("SEED: {}".format(SEED))
            pntCol = (img.select(classBand).rename('class').toInt().addBands(ee.Image.pixelLonLat())
                .stratifiedSample(
                    region=roi,
                    numPoints=int(N/n),
                    classBand='class',
                    scale=20,
                    seed=SEED,
                    dropNulls=True,
                    tileScale=8
                ).map(buildGeometryPoint)
            )  

            tmpfeatCol = arrays.sampleRegions(
              collection=pntCol,
              scale=20,
              # properties=FEATURES,
              tileScale=8
            )

            print("SEED {}: {}".format(SEED, tmpfeatCol.size().getInfo()))
            featCol = featCol.merge(tmpfeatCol)

        desc = "trainPatches_{}_ks{}_N{}".format(imgName, KERNEL_SIZE, featCol.size().getInfo())
        print("desc: {}".format(desc))
        task = ee.batch.Export.table.toDrive(
            collection=featCol,
            description=desc,
            # bucket=BUCKET,
            folder=toDriveFolder,
            fileNamePrefix=desc,
            fileFormat="TFRecord",
            selectors=FEATURES
        )

        task.start()
        check_export_task(task, desc)

def batch_imgCol_export_to_drive_by_pntCol(imgCol, pntCol, KERNEL_SIZE=256, bands4download=[''], toDriveFolder=''):
    print("------------------------- Download Start ... ----------------------------------")
    downLoadList = (imgCol).aggregate_array('IMG_LABEL').getInfo()
    printList(downLoadList, 'downloadList')

    # KERNEL_SIZE = 256R
    # n = 10
    # N = 1000
    list = ee.List.repeat(1, KERNEL_SIZE)
    lists = ee.List.repeat(list, KERNEL_SIZE)
    kernel = ee.Kernel.fixed(KERNEL_SIZE, KERNEL_SIZE, lists)

    num = imgCol.size().getInfo()
    imgColList = imgCol.toList(num)

    for i in range(0, num):
        img = ee.Image(imgColList.get(i))
        imgName = img.get('IMG_LABEL').getInfo().replace(":", '')

        arrays = img.select(bands4download).neighborhoodToArray(kernel)
        featCol = arrays.sampleRegions(
            collection=pntCol,
            scale=20,
            # properties=bands4download,
            tileScale=8
        )

        task = ee.batch.Export.table.toDrive(
            collection=featCol,
            description="{}_ks{}_N{}_by_pntCol".format(imgName, KERNEL_SIZE, featCol.size().getInfo()),
            folder="{}_by_pntCol".format(toDriveFolder),
            fileNamePrefix="{}_ks{}_N{}_by_pntCol".format(imgName, KERNEL_SIZE, featCol.size().getInfo()),
            fileFormat="TFRecord")

        task.start()
        check_export_task(task, imgName)


def export_img_time_series_by_AOI(imgCol, pntCol, KERNEL_SIZE=256, BANDS=[], toFolder='', fileFormat='TFRecord'):
    print("------------------------- Download Start ... ----------------------------------")
    downLoadList = (imgCol).aggregate_array('IMG_LABEL').getInfo()
    printList(downLoadList, ' aoi_time_series ')

    list = ee.List.repeat(1, KERNEL_SIZE)
    lists = ee.List.repeat(list, KERNEL_SIZE)
    kernel = ee.Kernel.fixed(KERNEL_SIZE, KERNEL_SIZE, lists)

    def set_aoi_id(pnt):
        return pnt.set('aoi', pnt.id())

    def sampleOverImage(img):
        def setImgLabel(f):
            return f.set('IMG_LABEL', img.get("IMG_LABEL"))

        return (img.select(BANDS).float()
                .neighborhoodToArray(kernel)
                .sampleRegions(
            collection=pntCol,
            scale=20,
            # properties=bands4download,
            tileScale=8
        ).map(setImgLabel))

    def formatData(table, rowId, colId):
        rows = table.distinct(rowId)
        joined = ee.Join.saveAll('matches').apply(
            primary=rows,
            secondary=table,
            condition=ee.Filter.equals(
                leftField=rowId,
                rightField=rowId
            ))

        def prepareData(row):
            # def bandConCat(BANDS):

            def data2array(feat):
                feat = ee.Feature(feat)
                array = ee.Array.cat(
                    [ee.Array(feat.get('NIR')), ee.Array(feat.get('SWIR1')), ee.Array(feat.get('SWIR2')),
                     ee.Array(feat.get('dNBR1'))])
                return [feat.get(colId), array]

            values = ee.List(row.get('matches')).map(data2array)
            return row.select([rowId]).set(ee.Dictionary(values.flatten()))

        return joined.map(prepareData)

    pntCol = pntCol.map(set_aoi_id)
    triplets = imgCol.map(sampleOverImage).flatten()
    featCol = formatData(triplets, 'aoi', 'IMG_LABEL')

    desc = "{}_ks{}_N{}_by_pntCol_TFRecord".format('aoi_time_series', KERNEL_SIZE, featCol.size().getInfo())
    task = ee.batch.Export.table.toDrive(
        collection=featCol,
        description=desc,
        folder="{}_by_pntCol".format(toFolder),
        fileNamePrefix=desc,
        # fileFormat=fileFormat
        fileFormat="TFRecord"
    )

    task.start()
    check_export_task(task, desc)



def export_imgCol_to_Cloud_by_patch(imgCol, KERNEL_SIZE=256, BANDS=[], fileFormat='GeoTIFF'):
    print("------------------------- Download Start ... ----------------------------------")
    downLoadList = (imgCol).aggregate_array('IMG_LABEL').getInfo()
    printList(downLoadList, ' aoi_time_series ')

    num = imgCol.size().getInfo()
    imgColList = imgCol.toList(num)

    for i in range(0, num):
        img = ee.Image(imgColList.get(i))
        imgName = img.get('IMG_LABEL').getInfo().replace(":", '')

        task = ee.batch.Export.image.toCloudStorage(
            image=img.select(BANDS),
            description="{}_{}".format(fireEvent['name'], imgName),
            bucket=BUCKET,
            # folder="{}/{}".format(fireEvent['name'], imgName),
            fileNamePrefix='{}/{}'.format("{}/{}".format(fireEvent['name'], imgName), imgName),
            scale=20,
            region=roi,
            maxPixels=1e10,
            fileFormat='TFRecord',
            # shardSize=KERNEL_SIZE,
            # fileDimensions=KERNEL_SIZE,
            formatOptions={
                'patchDimensions': [KERNEL_SIZE, KERNEL_SIZE],
                'kernelSize': [128, 128],
                'compressed': True,
                'maxFileSize': 104857600 # 1024*1024
            }
        )
        task.start()
        check_export_task(task, imgName)

""" Transform image to poly """
def img2poly(img0):
    img = img0.reduce(ee.Reducer.anyNonZero())
    poly = img.reduceToVectors(
        geometry=img.geometry(),
        scale=50,
        maxPixels=1361828260)
    return ee.FeatureCollection(poly)


def exportPolyToAsset(burnRef):
    imgLabel = burnRef.get('IMG_LABEL').getInfo()
    unburnRef = burnRef.mask(burnRef.neq(1).And(waterMask0.neq(0))).eq(0).clip(roi)

    polys = {}
    polys['polyBurnt'] = img2poly(burnRef.mask(burnRef.eq(1)))
    polys['polyUnburn'] = img2poly(unburnRef.mask(unburnRef.eq(1)))

    for polyKey in ['polyUnburn']:
        assetId = "users/omegazhangpzh/Sydney_polys/{}".format(imgLabel)
        task = ee.batch.Export.table.toAsset(
            collection=polys[polyKey],
            description="{}_{}".format(imgLabel, polyKey),
            assetId="{}_{}".format(assetId, polyKey))

        task.start()
        check_export_task(task, "{}_poly".format(imgLabel))

def exportImageToAsset(burnRef):
    imgLabel = burnRef.get('IMG_LABEL').getInfo()

    assetId = "users/omegazhangpzh/Sydney_polys/{}".format(imgLabel)
    task = ee.batch.Export.image.toAsset(
        image=burnRef,
        description="{}".format(imgLabel),
        assetId="{}".format(assetId),
        scale=20,
        # maxPixels=1784780954588439758
    )

    task.start()
    check_export_task(task, "{}_toAsset".format(imgLabel))

def filt_morph(img):
    kernel_slope = ee.Kernel.gaussian(radius=2)
    kernel_slope2 = ee.Kernel.gaussian(radius=1)
    return (img.focal_median(kernel=kernel_slope, iterations=1)
            .focal_max(kernel=kernel_slope2, iterations=1)
            .focal_min(kernel=kernel_slope2, iterations=1))

### 2.2 Start to Query and Process Data

In [0]:

fireName = fireEvent['name']
roi = fireEvent['roi']

fireStartDate = fireEvent['startDate']
fireEndDate = fireEvent['endDate']

''' Point Filter '''
pntsFilterFlag = False
if 'pntsRect' in fireEvent.keys():
    pntsFilterFlag = True
    pntsRect = fireEvent['pntsRect']
    coordList = ee.List(roi.coordinates().get(0))
    p0 = ee.Geometry.Point(coordList.get(0))  # Bottom-Left
    p1 = ee.Geometry.Point(coordList.get(1))  # Bottom-Right
    p2 = ee.Geometry.Point(coordList.get(2))  # Top-Right
    p3 = ee.Geometry.Point(coordList.get(3))  # Top-Left
    pntsFilter = ee.Filter.And(
        ee.Filter.geometry(p0)
        # , ee.Filter.geometry(p1)
        # , ee.Filter.geometry(p2)
        # , ee.Filter.geometry(p3)
    )

quaryROI = roi
G_kernel = ee.Kernel.gaussian(21)

''' SAR '''
sarKmapFlag = True & sarQueryFlag  # kmap flag
sarLogRtExportFlag = True & sarQueryFlag


''' MSI '''
msiExportFlag = True & msiQueryFlag

checkStartDate = ee.Date(fireStartDate)
checkEndDate = ee.Date(fireEndDate)

print("==================> {}: [ {}, {} ] <=================".format(fireEvent['name'], fireStartDate, fireEndDate))
print(" Checking: [ {}, {} ]".format(checkStartDate.format().getInfo(), checkEndDate.format().getInfo()))
print("==================================================================")

""" Dates to Export """
if 'msiExportDates' in fireEvent.keys():
    msiExportDates = fireEvent['msiExportDates']

if 'sarExportDates' in fireEvent.keys():
    sarExportDates = fireEvent['sarExportDates']

"""======================== DEM ============================"""
hansenImage = ee.Image('UMD/hansen/global_forest_change_2015')
datamask = hansenImage.select('datamask')
waterMask0 = datamask.eq(1).rename('water')

""" =============== Froest Land Cover 2015 ================= """
CGLC_2015 = ee.Image("COPERNICUS/Landcover/100m/Proba-V/Global/2015")
froestMask = CGLC_2015.select("discrete_classification").eq(112)

dem_30m = ee.Image("USGS/SRTMGL1_003")
dem = ee.Terrain.products(dem_30m)

alos_dem = ee.Image("JAXA/ALOS/AW3D30_V1_1").select('AVE')
terrain = ee.Terrain.products(alos_dem)
slope = terrain.select("slope")
aspect = terrain.select("aspect")
hillshade = terrain.select("hillshade")

ascMask = filt_morph(ee.Image(1).subtract(ee.Image(slope.gt(20)).multiply(hillshade.gt(180))).rename("ASC"))
dscMask = filt_morph(ee.Image(1).subtract(ee.Image(slope.gt(20)).multiply(hillshade.lt(180))).rename("DSC"))
# ASC.addBands(DSC)
waterMask = ee.Image((waterMask0.addBands(ascMask).addBands(dscMask)).setMulti({'IMG_LABEL': 'waterMask'}))

maskDict = {}
maskDict['ASC'] = ee.Image(waterMask0.multiply(ascMask))
maskDict['DSC'] = ee.Image(waterMask0.multiply(dscMask))

# dem

# """////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""
# """ ====================================================== Sentinel-1 C-Band Data ==================================================================== """
# """////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""

if sarQueryFlag:
    print("\n----------------- Sentinel-1 ------------------------")
    S1_filtered = ee.ImageCollection(ee.ImageCollection("COPERNICUS/S1_GRD")
                                      .filterBounds(quaryROI)
                                      .filterMetadata('instrumentMode', "equals", 'IW')
                                    #  .filterMetadata('relativeOrbitNumber_start', 'not_equals', 102)
                                    #  .filterMetadata('relativeOrbitNumber_start', 'not_equals', 137) # elephant
                                    #  .filterMetadata('relativeOrbitNumber_start', 'not_equals', 166) # elephant
                                      .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VV'))
                                      .filter(ee.Filter.listContains('transmitterReceiverPolarisation', 'VH'))
                                      )

    if 'orbNumList' in fireEvent.keys():
        orbNumList = fireEvent['orbNumList']
        S1_filtered = S1_filtered.filter(ee.Filter.inList(opt_leftField='relativeOrbitNumber_start', opt_rightValue=orbNumList))

    sarImgCol = S1_filtered.filterDate(checkStartDate, checkEndDate)
    sarImgCol_grouped = (group_S1_by_date_orbit(sarImgCol)
                          .map(add_RBR)
                          .select([0, 1, 3])
                          )


    def sarRescaleToOne(img):
        return (ee.Image(img.select(['VV', 'VH', 'RBR'])
                          .clamp(-25.0, 5.0).unitScale(-25.0, 5.0).float())
                          .copyProperties(img, img.propertyNames())
                )


    sarImgCol_grouped = sarImgCol_grouped.map(sarRescaleToOne)

    print("SAR dateRange: [" + checkStartDate.format().slice(0, 10).getInfo() + ", "
          + checkEndDate.format().slice(0, 10).getInfo() + "]")

    print('sarImgCol_grouped: ')
    printList(sarImgCol_grouped.aggregate_array('IMG_LABEL').getInfo())

    # """#= == == == == == == == == == == == == == == == == == OrbitList == == == == == == == == == == == == == == == == == == == == == ="""
    # ------- Method - 2: aggregate_array for obtaining distinct orbitKey -------------------------------------
    orbitKeyList = ee.List(sarImgCol_grouped.aggregate_array("Orbit_Key")).distinct().sort()
    print("orbitKeyList: ", orbitKeyList.getInfo())

    # orbitKeyList = ['ASC20']
    if sarKmapFlag:
        historyImgCol = (S1_filtered.filterDate(ee.Date(fireStartDate).advance(-3, 'month'), fireStartDate)
                          .map(set_group_index_4_S1)
                          .map(add_RBR)
                          .map(sarRescaleToOne)
                          )

        # printList(historyImgCol.aggregate_array('IMG_LABEL').getInfo(), historyImgCol)

# """////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""
# """ ========================================================== Multispectral Data ==================================================================== """
# """////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""

if msiQueryFlag:
    print("\n---------------- Sentinel-2/Landsat-8 ------------------------")

    # cloudLevel = 50
    S2_filters = ee.Filter.And(
        ee.Filter.geometry(quaryROI)
    )

    """ == == == == == == == == == == == = S2 Data == == == == == == == == == == == =="""
    def S2_bandScale(img):
        return (ee.Image(img).select('B.*').divide(10000)
                .addBands(img.select('QA60'))  # keep BQA and cloud bands
                .copyProperties(img, img.propertyNames()))


    S2_filtered = ee.ImageCollection(ee.ImageCollection("COPERNICUS/S2")
                                      .filter(S2_filters)
                                      .filterDate(checkStartDate, checkEndDate)
                                      .filter(ee.Filter.lt('CLOUD_COVERAGE_ASSESSMENT', cloudLevel))
                                      )

    S2_grouped = (group_S2_ImgCol(S2_filtered)
                  .map(S2_bandScale)
                  .map(S2_bandRename)
                  .map(updateCloudMaskS2)
                  # .map(mask_S2_clouds)
                  )

    # """ == == == == == == == == == == == L8 Data == == == == == == == == == == == == == == ="""

    L8_filtered = (ee.ImageCollection("LANDSAT/LC08/C01/T1_TOA")  # TR or TOA
                    .filter(S2_filters)
                    .filterDate(checkStartDate, checkEndDate)
                    .filter(ee.Filter.lt('CLOUD_COVER_LAND', cloudLevel))
                    )

    L8_grouped = ee.ImageCollection(group_L8_ImgCol(L8_filtered)
                                    .map(L8_bandRename)
                                    .map(updateCloudMaskL8)
                                    # .map(mask_L8_clouds)
                                    # .map(add_NBR)
                                )

    # """ ======= Merge Multispectral Data ============="""
    msiImgCol = (S2_grouped.merge(L8_grouped)
                  # .map(maskClouds)
                  )

    # printList(msiImgCol.aggregate_array('IMG_LABEL').getInfo(), "msiImgCol before pntsFilter")

    if pntsFilterFlag:  # apply point filter
        msiImgCol = msiImgCol.filter(pntsFilter)

    msiMasterDate = fireEvent['S2_Master']
    if 'elephant' in fireEvent['name']:
        msiMasterDate = fireEvent['L8_Master']

    if 'msiExportDates' in fireEvent.keys():
        msiImgCol = msiImgCol.filter(ee.Filter.inList(opt_leftField='GROUP_INDEX', opt_rightValue=msiExportDates+[msiMasterDate]))

    # printList(msiImgCol.aggregate_array('IMG_LABEL').getInfo(), "msiImgCol after pntsFilter")

    def msiRescaleToOne(img):
        return (img.select('cloud').addBands(
                        img.select(['R', 'G', 'B', 'NIR', 'SWIR1', 'SWIR2'])
                            .clamp(0, 0.5).unitScale(0, 0.5).float()
                        )    #.toUint8()
                              # .copyProperties(img, img.propertyNames())
                )

    msiImgCol = (msiImgCol
                    .map(msiRescaleToOne)
                    .map(add_NBR)
                    .sort('IMG_LABEL', False)
                  )
    # print("bands: {}".format(msiImgCol.first().bandNames().getInfo()))
    printList(msiImgCol.aggregate_array('IMG_LABEL').getInfo(), "msiImgCol after rescale")
    # printList(L8_grouped.aggregate_array('IMG_LABEL').getInfo(), "L8_grouped")
    # printList(msiImgCol.aggregate_array('IMG_LABEL').getInfo(), "msiImgCol")

    print("----------------------------------------------------------------------")

# """////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""
# """ ================================================================= Export MS Data ================================================================= """
# """////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""

if msiExportFlag:
    def add_dNBR(slaveImg):
        TH = (ee.Image(0.25).rename('dNBR')
              .addBands(ee.Image(0.1).rename('dNBR1'))
              )

        dNBR = ee.Image(msiMasterImg.subtract(slaveImg).select(['NBR', 'NBR1']).rename(['dNBR', 'dNBR1']))
        dNBR_bin = (dNBR.select(['dNBR', 'dNBR1'])
                        .gt(TH)
                        .multiply(waterMask0)
                        .multiply(slaveImg.select('cloud'))
                        .multiply(froestMask)
                        .rename((['dNBR_bin', 'dNBR1_bin']))
                    )

        cloud_bin = slaveImg.select('cloud').gt(1)
        return (slaveImg
                  .addBands(msiMasterImg)
                  .addBands(dNBR)
                  .addBands(filt_morph(dNBR_bin))
                  # .multiply(slaveImg.select('cloud'))
                  .copyProperties(slaveImg, slaveImg.propertyNames())
                )

    ### MS Master Image
    msiMasterImg = ee.Image(msiImgCol.filterDate(ee.Date(msiMasterDate), ee.Date(msiMasterDate).advance(1, 'day')).first())

    msiImgCol_toExport = (msiImgCol.filterDate("2019-10-27", "2020-01-11")
                          .map(add_dNBR)
                          .sort('IMG_LABEL', False)
                  )
    print("===> MS BANDS: {}".format(ee.Image(msiImgCol_toExport.first()).bandNames().getInfo()))
    printList(msiImgCol_toExport.aggregate_array('IMG_LABEL').getInfo(), "msiImgCol_toExport")


# """////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""
# """ ================================================================= Export SAR Data ================================================================ """
# """////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////"""

if sarLogRtExportFlag:

    if 'sarExportDates' in fireEvent.keys():
        sarImgCol_grouped = sarImgCol_grouped.filter(ee.Filter.inList(opt_leftField='GROUP_INDEX', opt_rightValue=sarExportDates))

    sarMeanDict = {}
    sarStdDict = {}
    sarImgCol_toExport = ee.ImageCollection([])
    for orbKey in orbitKeyList.getInfo():
        # orbKey = 'DSC42'
        print("==> orbKey checking: ", orbKey)

        orbImgCol = ee.ImageCollection(sarImgCol_grouped.filter(ee.Filter.equals('Orbit_Key', orbKey)))

        print('orbImgCol: ')
        printList(orbImgCol.aggregate_array('IMG_LABEL').getInfo())

        print("-----------------------------------------------------\n\n")

        print('===> orbImgCol: {} <==='.format(orbKey))
        printList(orbImgCol.aggregate_array('IMG_LABEL').getInfo())

        num = orbImgCol.size().getInfo()
        orbImgList = orbImgCol.toList(num)

        orbMasterDate = ee.Date(fireEvent[orbKey])
        orbMasterImg = ee.Image(
            sarImgCol_grouped.filterDate(orbMasterDate, orbMasterDate.advance(2, 'day')).first())

        # printList(ee.ImageCollection(orbMasterImg).aggregate_array('IMG_LABEL').getInfo())

        # masterImg = ee.Image(orbImgCol.first())

        def add_logRt(slaveImg):
            meanImg = sarMeanDict[orbKey]
            stdDevImg = sarStdDict[orbKey]
            logRt = (
                ee.Image(meanImg.select(['VV_mean', 'VH_mean', 'RBR_mean']).rename(['VV', 'VH', 'RBR']))
                    .subtract(slaveImg)
                    .multiply(ee.Image(-1).rename('VV').addBands(ee.Image(-1).rename('VH')).addBands(ee.Image(1).rename('RBR')))
                    # ee.Image(orbMasterImg).subtract(slaveImg)
                    # .multiply(waterMask.select('water'))
                    # .multiply(waterMask.select('ASC'))
                    .select(['VV', 'VH', 'RBR'])
                    .rename(['VV_logRt', 'VH_logRt', 'RBR_logRt'])
                    # .copyProperties(slaveImg, slaveImg.propertyNames())
            )

            return (slaveImg.addBands(logRt)
                    .addBands(meanImg)
                    .addBands(stdDevImg))


        def add_kmap(logRtImg):
            stdDevImg = sarStdDict[orbKey].select(['VV_std', 'VH_std', 'RBR_std']).rename(['VV', 'VH', 'RBR'])
            kmap = ee.Image(logRtImg.select(['VV_logRt', 'VH_logRt', 'RBR_logRt']).rename(['VV', 'VH', 'RBR'])
                        .divide(stdDevImg)
                        .convolve(G_kernel)#.divide(10).float()
                        .clamp(-5.0, 5.0).divide(5).float()
                        .rename(['kVV', 'kVH', 'kRBR'])
                    )

            # k = kmap.expression(
            #     "sqrt(b('kVV')*b('kVV')+b('kVH')*b('kVH'))*b('kVV')/sqrt(b('kVV')*b('kVV'))")\
            #         .convolve(G_kernel).rename('kmap')
            return logRtImg.addBands(kmap)


        """ Compute mean and stdDev with the stdDev of historical time series for one given orbit"""
        orbHistoryImgCol = historyImgCol.filter(ee.Filter.equals('Orbit_Key', orbKey))

        printList(orbHistoryImgCol.aggregate_array('IMG_LABEL').getInfo(), 'orbHistoryImgCol')

        sarMeanDict[orbKey] = ee.Image(orbHistoryImgCol.reduce(ee.Reducer.mean())
                .setMulti({'IMG_LABEL': 'SAR_MEAN_{}'.format(orbKey)}))

        sarStdDict[orbKey] = ee.Image(orbHistoryImgCol.reduce(ee.Reducer.stdDev())
                .select(['VV_stdDev', 'VH_stdDev', 'RBR_stdDev'])
                .rename(['VV_std', 'VH_std', 'RBR_std'])
                .setMulti({'IMG_LABEL': 'SAR_STD_{}'.format(orbKey)}))

        # print("Std Bands: {}".format(sarStdDict[orbKey].bandNames().getInfo()))

        logRtImgCol = (orbImgCol.map(add_logRt)
                .map(add_kmap)
            )
        sarImgCol_toExport = sarImgCol_toExport.merge(logRtImgCol)


    """ Export polys """
    # polyDate = ee.Date("2019-12-27")
    # burnRef = ee.Image(sarImgCol_toExport.filterDate(polyDate, polyDate.advance(1, 'day')).first())
    # print("burnRef bandnames: {}".format(burnRef.bandNames().getInfo()))
    # # exportPolyToAsset(burnRef.select('kRBR_bin1').clip(roi))
    # exportImageToAsset(burnRef.clip(roi))

    binThreshold = 1
    BAND2BIN = ['kVH', 'kVV', 'kRBR']
    binBandNameList = []
    for band in BAND2BIN:
        binBandNameList.append("{}_bin".format(band))

    def imgBinarize(img):
        # binThreshold = 1
        imgBin = (img.select(BAND2BIN).gt(binThreshold / 5.0)
                  .multiply(waterMask.select('water'))
                  .multiply(waterMask.select(orbKey[:3]))
                  .multiply(froestMask)
                  .rename(binBandNameList)
                  )
        return img.addBands(filt_morph(imgBin))

    # def changeImgLabel(img):
    #     return img.setMulti({'IMG_LABEL': ee.String("SAR_{}".format(img.get('IMG_LABEL').getInfo()))})

    print("SAR BANDS: {}".format(sarImgCol_toExport.first().bandNames().getInfo()))
    sarImgCol_toExport = sarImgCol_toExport.map(imgBinarize).sort('IMG_LABEL', False)
    printList(ee.ImageCollection(sarImgCol_toExport).aggregate_array('IMG_LABEL').getInfo(), 'sarImgCol_toExport')

    # Multispectral Data
    msiImgCol_toExport = (msiImgCol.filterDate("2019-10-27", "2020-01-11")
                          .map(add_dNBR)
                          .sort('IMG_LABEL', False)
                  )
    print("===> MS BANDS: {}".format(ee.Image(msiImgCol_toExport.first()).bandNames().getInfo()))
    printList(msiImgCol_toExport.aggregate_array('IMG_LABEL').getInfo(), "msiImgCol_toExport")

print("------------------------ Earth Engine Preprocessing Finished -----------------------------------")



### 2.3 Visualize Images 

In [0]:
MSI = msiImgCol_toExport.filterDate("2019-12-26", "2019-12-28").first()
SAR = sarImgCol_toExport.filterDate("2019-12-27", "2019-12-28").first()

mapid = MSI.select('dNBR1_bin').getMapId({'min': 0, 'max': 1})
map = folium.Map()
folium.TileLayer(
    tiles=mapid['tile_fetcher'].url_format,
    attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
    overlay=True,
    name='dNBR1',
  ).add_to(map)

mapid = SAR.select('kRBR_bin').getMapId({'min': 0, 'max': 1})
folium.TileLayer(
    tiles=mapid['tile_fetcher'].url_format,
    attr='Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>',
    overlay=True,
    name='kRBR',
  ).add_to(map)
map.add_child(folium.LayerControl())
map

# Step 3: Export EO Data

##### 3.1 Define Export Method

In [0]:
def batch_imgCol_export_to_cloud_classBalanced(imgCol, classBand, N=100, KERNEL_SIZE=256, toDriveFolder="Sydney_NRT"):
    print("------------------------- Download Start ... ----------------------------------")
    downLoadList = (imgCol).aggregate_array('IMG_LABEL').getInfo()
    printList(downLoadList, 'to_cloud_classBalanced')

    # KERNEL_SIZE = 256R
    n = 30
    SCALE = 100
    print("Total Number of SEED : {} in {}m resolution".format(n, SCALE))
    # N = 100
    list = ee.List.repeat(1, KERNEL_SIZE)
    lists = ee.List.repeat(list, KERNEL_SIZE)
    kernel = ee.Kernel.fixed(KERNEL_SIZE, KERNEL_SIZE, lists)

    num = imgCol.size().getInfo()
    imgColList = imgCol.toList(num)

    for i in range(0, num):
        img = ee.Image(imgColList.get(i))
        imgName = img.get('IMG_LABEL').getInfo().replace(":", '')
  
        arrays = img.neighborhoodToArray(kernel)
        featCol = ee.FeatureCollection([])
        for SEED in range(n):
            # print("SEED: {}".format(SEED))
            pntCol = (img.select(classBand).rename('class').toInt().addBands(ee.Image.pixelLonLat())
                .stratifiedSample(
                    region=roi,
                    numPoints=int(N/n),
                    classBand='class',
                    scale=SCALE,
                    seed=SEED,
                    dropNulls=True,
                    tileScale=8
                ).map(buildGeometryPoint)
            )  

            tmpfeatCol = arrays.sampleRegions(
              collection=pntCol,
              scale=SCALE,
              # properties=FEATURES,
              tileScale=8
            )

            print("SEED {}: {}".format(SEED, tmpfeatCol.size().getInfo()))
            featCol = featCol.merge(tmpfeatCol)

        desc = "trainPatches_{}_{}m_ks{}_N{}".format(imgName, SCALE, KERNEL_SIZE, featCol.size().getInfo())
        print("desc: {}".format(desc))
        task = ee.batch.Export.table.toDrive(
            collection=featCol,
            description=desc,
            # bucket=BUCKET,
            folder=toDriveFolder,
            fileNamePrefix=desc,
            fileFormat="TFRecord",
            selectors=FEATURES
        )

        task.start()
        check_export_task(task, desc)
        return featCol

# printList(ee.ImageCollection(sarImgCol_toExport).aggregate_array('IMG_LABEL').getInfo(), 'sarImgCol_toExport')


### 3.2 Export Data to Drive/CloudStorge 

In [0]:
FEATURES = [
          'VV', 'VH', 'RBR',
          'VV_mean', 'VH_mean', 'RBR_mean',
          'VV_std', 'VH_std', 'RBR_std',
          # 'VV_logRt', 'VH_logRt', 'RBR_logRt',
          # 'kVV', 'kVH', 'kRBR', 'kmap',
          'kRBR_bin', 'ref'

          # 'NIR', 'SWIR1', 'SWIR2',
          # 'NIR_1', 'SWIR1_1', 'SWIR2_1',
          # 'dNBR1_bin'
          ]


""" to cloudStorage """
BUCKET = "wildfire-unet"
FOLDER = 'Sydney_NRT'

msiImg =  msiImgCol_toExport.filterDate("2019-12-26", "2019-12-27").first()
sarImg = sarImgCol_toExport.filterDate("2019-12-27", "2019-12-28").first()
ref = ee.Image("users/omegazhangpzh/Sydney_Ref/20191227_ASC9").unmask().rename('ref')
mrgImg_toExport = (
          # msiImg
          sarImg
          # .addBands(msiImg)
            .addBands(ref)
        ).select(FEATURES)

def printBandNames(img):
  bandList = img.bandNames().getInfo()
  for i in range(0, len(bandList), 3):
    print("{}\n".format(bandList[i:i+3]))

printBandNames(mrgImg_toExport)
# print("mrgImg bandnames: {}".format(mrgImg_toExport.bandNames().getInfo()))


sarImgCol_toExport_Test = ee.ImageCollection([mrgImg_toExport])

# batch_imgCol_export_to_cloud(sarImgCol_toExport_Test, N=1000, KERNEL_SIZE=256, toDriveFolder=FOLDER)
featCol = batch_imgCol_export_to_cloud_classBalanced(imgCol=sarImgCol_toExport_Test, classBand='ref', N=1000, KERNEL_SIZE=128, toDriveFolder="Sydney_NRT")
print("Exported PropertyNames: {}".format(featCol.propertyNames().getInfo())

