Read in all packages

In [6]:
import os
import ee
import geemap
from geeml.extract import extractor
from google.cloud import storage
from google.cloud import client
import random

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/explore/nobackup/people/spotter5/cnn_mapping/gee-serdp-upload-7cd81da3dc69.json"

service_account = 'gee-serdp-upload@appspot.gserviceaccount.com'
credentials = ee.ServiceAccountCredentials(service_account, "/explore/nobackup/people/spotter5/cnn_mapping/gee-serdp-upload-7cd81da3dc69.json")
ee.Initialize(credentials)
# Initialize GEE with high-volume end-point
# ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com')
ee.Initialize()

In [7]:
os.environ["GCLOUD_PROJECT"] = "gee-serdp-upload"

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/explore/nobackup/people/spotter5/cnn_mapping/gee-serdp-upload-7cd81da3dc69.json"
storage_client = storage.Client.from_service_account_json("/explore/nobackup/people/spotter5/cnn_mapping/gee-serdp-upload-7cd81da3dc69.json")

os.environ["GCLOUD_PROJECT"] = "gee-serdp-upload"
storage_client = storage.Client()
# bucket_name = 'smp-scratch/mtbs_1985'
bucket_name = 'smp-scratch'

bucket = storage_client.bucket(bucket_name)

Read in the feature collections

In [8]:

dem = ee.Image("UMN/PGC/ArcticDEM/V3/2m_mosaic") #arctic dem
sent_2A = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
ak = ee.FeatureCollection("users/spotter/alaska") #ak shapefile
# lfdb = ee.FeatureCollection("users/spotter/fire_cnn/raw/ak_lfdb_1985") #ak fire polygons
# lfdb = ee.FeatureCollection("users/spotter/fire_cnn/raw/ak_lfdb_2014") #ak fire polygons
lfdb = ee.FeatureCollection("users/spotter/fire_cnn/raw/nbac_1985") #ak fire polygons

# lfdb = ee.FeatureCollection("users/spotter/fire_cnn/raw/nbac_2013") #ak fire polygons
# lfdb = ee.FeatureCollection("users/spotter/fire_cnn/raw/ak_mtbs_2014")
lfdb = ee.FeatureCollection("users/spotter/fire_cnn/raw/ak_mtbs_1985")

water = ee.ImageCollection("JRC/GSW1_3/YearlyHistory") #water mask

#since we are using modis can only use fires from 2001 and onward
lfdb = lfdb.filter(ee.Filter.gte('Year', 2001))

MODIS collections

In [9]:
terra = ee.ImageCollection("MODIS/061/MOD09A1")
aqua = ee.ImageCollection("MODIS/061/MYD09A1")

modis = terra.merge(aqua)

#only select good quality pixels
# def filter(image): 
#     mask = image.select('QA').neq(0)
#     return ee.Image(image).updateMask(mask)


def maskClouds(image):
    # The StateQA band contains information about clouds, shadows, and more.
    qa = image.select('StateQA')
    
    # Bits 0-1 indicate the pixel is clear (00), cloudy (01), mixed (10), or not set (11)
    # Bits 2-3 indicate the pixel is a cloud shadow (01) or not
    # Create a mask to filter out cloudy and shadow pixels.
    cloudShadowBitMask = (1 << 2)
    cloudsBitMask = (1 << 0)
    
    # Use bitwise operations to mask out clouds and shadows.
    mask = qa.bitwiseAnd(cloudShadowBitMask).eq(0) \
            .And(qa.bitwiseAnd(cloudsBitMask).eq(0))
    
    # Apply the mask to the image, setting non-clear pixels to null.
    return image.updateMask(mask)

modis = modis.map(maskClouds);


Loop through and download

In [10]:
all_dates = ee.List(lfdb.distinct(["ID"]).aggregate_array("ID"))
all_dates = all_dates.getInfo()
# all_dates = [2173, 2174]
#all_dates = [2173]
# all_dates = [2]
for i in all_dates:

  # print(id)
  # if id not in [1909, 1066, 1716]:
  # if id in [3823]:

    # try:
    # print(raw_bands.toFloat().getInfo())
    fname = f"median_{id}.tif"

    # if os.path.isfile(os.path.join('gs://smp-scratch/test_cnn', fname)) == False:

    # name = os.path.join('gs://smp-scratch', fname)
    name = fname

    storage_client = storage.Client()
    # bucket_name = 'smp-scratch/mtbs_1985'
    bucket_name = 'smp-scratch'

    bucket = storage_client.bucket(bucket_name)
    stats = storage.Blob(bucket=bucket, name=name).exists(storage_client)

    if stats == False:

        #get the fire polygon of interest
        sub_shape = lfdb.filter(ee.Filter.eq("ID", i))

        #get all other fire ids that are not this one
        not_fires = lfdb.filter(ee.Filter.neq("ID", i))


        #first get the bounding box of the fire
        bbox = sub_shape.geometry().bounds()


        #offset the bounding box by a random number
        # all_rands = [0.00, 0.02, -0.02]
        all_rands = [0.00]


        rand1 = random.sample(all_rands, 1)[0]
        rand2 = random.sample(all_rands, 1)[0]

        #offset applied
        proj = ee.Projection("EPSG:4326").translate(rand1, rand2)

        #for the bounding box apply the randomly selected offset
        final_buffer = ee.Geometry.Polygon(bbox.coordinates(), proj).transform(proj)

        #this is a bit of a hack but we have two different bounding box sizes because when we export we need to use some additonal area to avoid cuttoffs
#         final_buffer2 = final_buffer.buffer(distance= 5000).bounds()

#         final_buffer = final_buffer.buffer(distance= 40000)#.bounds().transform(proj='EPSG:3413', maxError=1)

        final_buffer2 = final_buffer.buffer(distance= (5000)).bounds()

        final_buffer = final_buffer.buffer(distance= (40000))#.bounds().transform(proj='EPSG:3413', maxError=1)

     #get the year of this fire
        this_year = ee.Number(sub_shape.aggregate_array('Year').get(0))
        
        year = this_year.getInfo() 
        
        pre_start = ee.Date.fromYMD(this_year.subtract(1), 6, 1)
        pre_end = ee.Date.fromYMD(this_year.subtract(1), 8, 31)
        post_start = pre_start.advance(2, 'year')
        post_end = pre_end.advance(2, 'year')
     
        #just getting some date info here to ensure pre fire is one  year before and post fire is one year after the fire year of interest
        startYear = pre_start.get('year')

        #convert to client side
        startYear = startYear.getInfo()  # local string
        endYear = str(int(startYear) + 2)
        startYear = str(startYear)
        
        pre_input_collection = modis.filterDate(pre_start, pre_end).select(['sur_refl_b01', 'sur_refl_b02', 'sur_refl_b03', 'sur_refl_b04', 'sur_refl_b05', 'sur_refl_b06', 'sur_refl_b07'])

        pre_input = pre_input_collection.median().clip(final_buffer).multiply(1000)

        #post firre
        post_input_collection = modis.filterDate(post_start, post_end).select(['sur_refl_b01', 'sur_refl_b02', 'sur_refl_b03', 'sur_refl_b04', 'sur_refl_b05', 'sur_refl_b06', 'sur_refl_b07'])


        post_input = post_input_collection.median().clip(final_buffer).multiply(1000)

        # print(post_input.getInfo())
        #get pre and post NDVI, ndvi and ndii
        preNBR = pre_input.normalizedDifference(['sur_refl_b02', 'sur_refl_b07']).select([0], ['preNBR'])


        postNBR = post_input.normalizedDifference(['sur_refl_b02', 'sur_refl_b07']).select([0], ['postNBR'])

        #get dNDVI
        dNBR = preNBR.select('preNBR').subtract(postNBR.select('postNBR'))
        dNBR = dNBR.select('preNBR').rename('dNBR').multiply(1000) #don't need to scale for normalized difference, it goes -1 to 1 anyway


        #-----NDVI
        preNDVI = pre_input.normalizedDifference(['sur_refl_b02', 'sur_refl_b01']).select([0], ['preNDVI'])


        postNDVI = post_input.normalizedDifference(['sur_refl_b02', 'sur_refl_b01']).select([0], ['postNDVI'])

        #get dNDVI
        NDVI = preNDVI.select('preNDVI').subtract(postNDVI.select('postNDVI'))
        NDVI = NDVI.select('preNDVI').rename('NDVI').multiply(1000)


        #-----NDII
        preNDII = pre_input.normalizedDifference(['sur_refl_b02', 'sur_refl_b06']).select([0], ['preNDII'])


        postNDII = post_input.normalizedDifference(['sur_refl_b02', 'sur_refl_b06']).select([0], ['postNDII'])

        #get dNDVI
        NDII = preNDII.select('preNDII').subtract(postNDII.select('postNDII'))
        NDII = NDII.select('preNDII').rename('NDII').multiply(1000)

        #we need to see which image ids from the entire lfdb are already included in the buffer
        lfdb_filtered_orig = lfdb.filterBounds(final_buffer)

        #ensure all fires are within the actual year of interest (this_year) and two years prior, otherwise ignore

        first_year =  int(startYear) + 1
        second_year =  int(startYear)
        third_year =  int(startYear) - 1
        fourth_year = int(startYear) + 2

        lfdb_filtered = lfdb_filtered_orig.filter(ee.Filter.eq("Year", first_year))

        bad_filtered = lfdb_filtered_orig.filter(ee.Filter.Or(ee.Filter.eq("Year", second_year), ee.Filter.eq("Year", third_year), ee.Filter.eq("Year", fourth_year)))


        #get ids which are in image
        all_dates_new = ee.List(lfdb_filtered.distinct(["ID"]).aggregate_array("ID")).getInfo()


        #remove ids from all dates which we do not need anymore
        all_dates = [i for i in all_dates if i not in all_dates_new]


        #now where the fire is turn the values to 0 and where the fire isn't turn them to 0
        # fire_rast = sub_shape.filter(ee.Filter.notNull(['ID'])).reduceToImage(properties = ['ID'], reducer =  ee.Reducer.first())

        fire_rast = lfdb_filtered.reduceToImage(properties= ['ID'], reducer = ee.Reducer.first())

        #bad fire rast
        bad_fire_rast = bad_filtered.reduceToImage(properties= ['ID'], reducer = ee.Reducer.first())

        #change values
        fire_rast = fire_rast.where(fire_rast.gt(0), 1)

        #chabnge values for bad fire raster 
        bad_fire_rast = bad_fire_rast.where(bad_fire_rast.gt(0), 1)

        #if the fires overlap we want to keep those locations
        bad_fire_rast = bad_fire_rast.where(bad_fire_rast.eq(1).And(fire_rast.eq(1)), 2).unmask(-999)

        # #fire_locaation values to 1 by first copying ndvi
        fire_rast = fire_rast.rename(['y'])

        #copy the first values of raw_bands
        y = post_input.select(['sur_refl_b02'], ['y'])#.toShort()

        #turn all values of y to 0
        y  = y.where(y.gt(0), 0)

        #turn values to 1 where fire_rast is 1
        y  = y.where(fire_rast.eq(1), 1)

        #y  = y.where(fire_rast.eq(1).And(dNBR.gte(0)), 1)

        # y  = y.where(dNBR.gte(84), 1)


        #difference the raw bands
        diff = pre_input.subtract(post_input)

        #combine

        raw_bands = diff.addBands(dNBR).addBands(NDVI).addBands(NDII)

        b1 = raw_bands.select('sur_refl_b01').cast({'sur_refl_b01':'short'})
        b2 = raw_bands.select('sur_refl_b02').cast({'sur_refl_b02':'short'})
        b3 = raw_bands.select('sur_refl_b03').cast({'sur_refl_b03':'short'})
        b4 = raw_bands.select('sur_refl_b04').cast({'sur_refl_b04':'short'})
        b5 = raw_bands.select('sur_refl_b05').cast({'sur_refl_b05':'short'})
        b6 = raw_bands.select('sur_refl_b06').cast({'sur_refl_b06':'short'})
        b7 = raw_bands.select('sur_refl_b07').cast({'sur_refl_b07':'short'})

        b8 = raw_bands.select('dNBR').cast({'dNBR':'short'})
        b9 = raw_bands.select('NDVI').cast({'NDVI':'short'})
        b10 = raw_bands.select('NDII').cast({'NDII':'short'})
        # b10 = raw_bands.select('y').cast({'y':'float'})
        b11 = y.select('y').cast({'y':'short'})


        # b1 = raw_bands.select('sur_refl_b01').cast({'sur_refl_b01':'float'})
        # b2 = raw_bands.select('sur_refl_b02').cast({'sur_refl_b02':'float'})
        # b3 = raw_bands.select('sur_refl_b03').cast({'sur_refl_b03':'float'})
        # b4 = raw_bands.select('sur_refl_b04').cast({'sur_refl_b04':'float'})
        # b5 = raw_bands.select('sur_refl_b05').cast({'sur_refl_b05':'float'})
        # b6 = raw_bands.select('sur_refl_b06').cast({'sur_refl_b06':'float'})
        # b7 = raw_bands.select('sur_refl_b07').cast({'sur_refl_b07':'float'})

        # b8 = raw_bands.select('dNBR').cast({'dNBR':'float'})
        # b9 = raw_bands.select('NDVI').cast({'NDVI':'float'})
        # b10 = raw_bands.select('NDII').cast({'NDII':'float'})
        # #b10 = raw_bands.select('y').cast({'y':'float'})
        # b11 = y.select('y').cast({'y':'float'})




        raw_bands = b1.addBands(b2).addBands(b3).addBands(b4).addBands(b5).addBands(b6).addBands(b7).addBands(b8).addBands(b9).addBands(b10)

        #remove values in waterMask which are equal to 0 (0 is water 11 is not). 
        # raw_bands = raw_bands.updateMask(waterOccFilled.lt(50).And(bad_fire_rast.neq(1)))
        raw_bands = raw_bands.updateMask(bad_fire_rast.neq(1))

        raw_bands = raw_bands.addBands(b11)


        task = ee.batch.Export.image.toCloudStorage(
                              image = raw_bands.toShort(),
                              region=final_buffer2, 
                              description='median_' + str(i),
                              crs= 'SR-ORG:6974',
                              scale = 463.3127165275,
                              maxPixels=1e13,
                              bucket = 'smp-scratch')

        task.start()
        print(i)

    # except:
    #     pass


 

411
512
592
656
1029
1123
359
358
360
390
352
361
353
362
392
402
401
403
480
936
571
513
561
514
589
582
583
608
594
642
636
643
644
624
625
626
628
629
645
657
653
654
655
665
666
652
667
651
658
650
951
659
27
34
38
39
28
40
25
41
29
37
24
31
35
36
33
710
728
697
700
741
756
758
757
763
1077
1082
1134
346
347
348
349
350
363
364
398
351
365
395
397
366
367
354
399
404
405
406
407
408
410
425
426
427
428
941
429
469
481
423
430
431
432
470
471
421
422
434
424
437
435
436
412
419
433
472
482
463
452
473
439
515
516
517
518
519
520
580
521
522
567
568
569
572
511
529
530
531
562
573
532
533
534
535
536
537
538
539
540
541
542
543
560
570
574
523
524
525
526
527
528
544
545
548
549
483
550
546
575
547
551
581
584
585
590
593
607
612
613
602
611
595
619
621
631
630
627
634
646
647
632
622
637
660
661
662
648
663
664
668
669
673
649
670
689
671
690
691
672
43
5
9
22
1
12
14
21
2
10
13
18
19
20
8
23
17
15
16
7
11
693
721
733
695
698
699
704
705
707
709
729
736
745
708
714
727
694
715
749
7