This script will take a model and predict across a 'fishnet' grid of cells.  I made the fishnet grid in ArcGIS and uploaded to earthengine.  I did this because I can only predict over relatively small areas. 
I then save these predictions as assets in an Image collection on earth engine

In [1]:
import os
import ee
import numpy as np
from geeml.extract import extractor
import pandas as pd

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "my_json" #replace this with pathway to json credentials. 

service_account = 'my_service_account'
credentials = ee.ServiceAccountCredentials(service_account, "my_json")

ee.Initialize(credentials)#h high-volume end-point
ee.Initialize(opt_url='https://earthengine-highvolume.googleapis.com')

In [3]:
import geemap
import os
from google.cloud import storage
from google.cloud import client

In [4]:
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "my_json"
storage_client = storage.Client.from_service_account_json("my_json")

os.environ["GCLOUD_PROJECT"] = "gee-serdp-upload"
storage_client = storage.Client()

bucket_name = 'smp-scratch'

bucket = storage_client.bucket(bucket_name)

In [5]:
dem = ee.Image("UMN/PGC/ArcticDEM/V3/2m_mosaic") #arctic dem
sent_2A = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
s2Clouds = ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')

grd =  ee.FeatureCollection("users/spotter/fire_cnn/raw/ak_ca_grid")
territory22 = ee.FeatureCollection("users/spotter/fire_cnn/raw/ak_ca_grid") #grid for alaska and canada to predict on


In [6]:
def mask(image):
    qa = image.select('QA_PIXEL')                                       
    mask = qa.bitwiseAnd(8).eq(0).And(qa.bitwiseAnd(10).eq(0)).And(qa.bitwiseAnd(32).eq(0))  
    return(image.updateMask(mask))

def land_scale(image):

    return(image.multiply(0.0000275).add(-0.2))

def sent_scale(image):
    return(image.multiply(0.0001))

In [7]:

#remove clouds from sentinel 2, cloud probability of less than 15%
def sent_maskcloud(image):

    QA60 = image.select(['QA60'])
    clouds = QA60.bitwiseAnd(1<<10).Or(QA60.bitwiseAnd(1<<11))# this gives us cloudy pixels
    image = image.select(['B2', 'B3', 'B4', 'B8', 'B11', 'B12'], ['B1', 'B2', 'B3', 'B4', 'B5', 'B7'])# rename bands
    image =  image.updateMask(clouds.Not()) # remove the clouds from image

    #different resolutions so change to 30 seperately
    image1 = image.select(['B1', 'B2', 'B3', 'B4'], ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4'])
    image2 = image.select(['SR_B5', 'SR_B7'])

    #reproject 30m
    image1 = image1.reproject(
    crs = image1.projection().crs(),
    scale = 30)

    image2 = image2.reproject(
    crs = image2.projection().crs(),
    scale = 30)

    image = image1.addBands(image2)

    return image.toShort()

MAX_CLOUD_PROBABILITY = 50

def sent_maskcloud(image):
    
    clouds = ee.Image(image.get('cloud_mask')).select('probability')
    
    isNotCloud = clouds.lt(MAX_CLOUD_PROBABILITY)
    
    return image.updateMask(isNotCloud)

#Join S2 SR with cloud probability dataset to add cloud mask.
s2SrWithCloudMask = ee.Join.saveFirst('cloud_mask').apply(
    
  primary=sent_2A,
  secondary=s2Clouds,
  condition=ee.Filter.equals(leftField='system:index', rightField='system:index'))

sent_2A = ee.ImageCollection(s2SrWithCloudMask).map(sent_maskcloud)


Functions to apply logan corrections, first pull the necessary intercepts and coeffiterritory22s from logans csv file


In [8]:
import pandas as pd
coeffs = pd.read_csv("/explore/nobackup/people/spotter5/cnn_mapping/raw_files/boreal_xcal_regression_coefficients.csv").fillna(0)
#l5
def landsat_correct(sat, bands):

    """argument 1 is which sattelite, LANDASAT_5 or LANDSAT_8
    argument 2 is bands of interest.  Bands must be in same order as EE,
    
    regression is of form,
    L7 = B0 + (B1 * L5/8) + (B2 * L^2) + (B3 * L^3)
    """

    #bands of interest in order of interest
    l5 = coeffs[(coeffs['satellite'] == sat) & (coeffs['band.or.si'] .isin (bands))] 

    #arrange the band or si column
    l5['band.or.si']=pd.Categorical(l5['band.or.si'],categories=bands)
    l5=l5.sort_values('band.or.si')

    b0 = l5['B0'].values.tolist()
    b1 = l5['B1'].values.tolist()
    b2 = l5['B2'].values.tolist()
    b3 = l5['B3'].values.tolist()

    return (b0, b1, b2, b3)

#get the corrections, each output is a list at one of the four locations
l8_corr = landsat_correct(sat = 'LANDSAT_8', bands = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'nbr', 'ndvi', 'ndii'])
l5_corr = landsat_correct(sat = 'LANDSAT_5', bands = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'nbr', 'ndvi', 'ndii'])




A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  l5['band.or.si']=pd.Categorical(l5['band.or.si'],categories=bands)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  l5['band.or.si']=pd.Categorical(l5['band.or.si'],categories=bands)


Read in the file which will normalize to the global values

In [9]:
glob_norm = pd.read_csv('/explore/nobackup/people/spotter5/cnn_mapping/nbac_training/l8_sent_collection2_global_min_max_cutoff_proj.csv').reset_index(drop = True)

glob_norm  = glob_norm [['0', '1', '2', '3', '4', '5', '6', '7', '8']]
glob_norm = glob_norm.reset_index()
glob_norm

Unnamed: 0,index,0,1,2,3,4,5,6,7,8
0,0,-54.0,-49.0,-63.0,-49.0,-129.0,-155.0,21.0,-68.0,-9.0
1,1,44.0,55.0,51.0,175.0,73.0,18.0,974.0,522.0,686.0


Read in the model

In [10]:

MODEL_NAME = 'l8_sent_collection2_crop_VI_0_two_256_eeify'
MODEL_NAME = 'l8_sent_collection2_crop_dnbr_0_two_long_128_crop_mtn_eeify'
MODEL_NAME = 'l8_collection2_dnbr_one_128_2d_ds_long_neg_eeify'
MODEL_NAME = 'l8_collection2_dnbr_one_128_2d_ds_proj_pos'
MODEL_NAME = 'l8_collection2_dnbr_one_128_2d_ds_proj_eeify'
MODEL_NAME = 'l8_collection2_VI_one_128_2d_ds_proj_final_eeify'
MODEL_NAME = 'l8_collection2_VI_one_128_2d_ds_proj_final_mtbs_gids_eeify'

# MODEL_NAME = 'l8_sent_collection2_crop_VI_0_two_128_eeify'
# MODEL_NAME = 'l8_sent_collection2_crop_VI_0_two_128_eeify'
VERSION_NAME = 'v0'
PROJECT = 'gee-serdp-upload'

overlap=42

model = ee.Model.fromAiPlatformPredictor(
    #projectName:PROJECT,
    projectId = 'gee-serdp-upload',
    modelName = MODEL_NAME,
    version = VERSION_NAME,
    region = 'us-east4',
    #Can be anything, but don't make it too big.
    inputTileSize = [128, 128],  
    inputOverlapSize = [32, 32],
    # inputTileSize=[128-overlap*2, 128-overlap*2],
    # inputOverlapSize= [overlap, overlap],
    #Keep this the same as your training data.
    proj=ee.Projection('EPSG:3413').atScale(30),
    # proj=ee.Projection('EPSG:4326').atScale(30),

    fixInputProj=True,
    inputShapes= {
      'array': [3] #this matches with the input channel number
    },
    #Note the names here need to match what you specified in the
    #output dictionary you passed to the EEifier.
    outputBands={'prediction': {
        'type': ee.PixelType.float(),
        'dimensions': 1
      }
    },
)


In [11]:
def to_float(image):

    b1 = image.select('SR_B1').cast({'SR_B1':'float'}) #0
    b2 = image.select('SR_B2').cast({'SR_B2':'float'}) #1
    b3 = image.select('SR_B3').cast({'SR_B3':'float'}) #2
    b4 = image.select('SR_B4').cast({'SR_B4':'float'}) #3
    b5 = image.select('SR_B5').cast({'SR_B5':'float'}) #4
    b6 = image.select('SR_B7').cast({'SR_B7':'float'}) #5

    image = b1.addBands(b2).addBands(b3).addBands(b4).addBands(b5).addBands(b6)

    return image





In [12]:
all_territories = ee.List(grd.distinct(["Id"]).aggregate_array("Id")).getInfo()

# all_territories = [40]




In [13]:
print(len(all_territories))

246


In [14]:
#years to predict on

years = [2004, 2005, 2010, 2014, 2015]

for year in years:
    
    #grid cell loop
    for territory in all_territories:
        
    
        territory22 = grd.filter(ee.Filter.eq('Id',  territory))
        
        
        #create output image collection
        base = 'projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_' + str(year) + '_preds_VI'
        !earthengine create collection $base
        
        #full pathway to images within the collection
        full_out =os.path.join(base, str(year) + '_' + str(territory))
        
        #check if the asset exists, if it does skip it
        try:
            exists = ee.data.getAsset(full_out)
  
            print(f"Image {full_out} exists")

        except:
            print(f"Image {full_out} does not exist")
            
            #get pre dates
            pre_start = str(year - 1) + '-06-01'
            pre_end = str(year - 1) + '-08-31'

            #get post dates


            post_start = str(year + 1) + '-06-01'
            post_end = str(year + 1) + '-08-31'


            startYear = str(year - 1)
            endYear = str(year + 1)
            startDay  = '01-01' # what is the beginning of date filter | month-day
            endDay     = '12-30' # what is the end of date filter | month-day

            #########################################################################################################
            ###### ANNUAL SR TIME SERIES COLLECTION BUILDING FUNCTIONS ##### 
            #########################################################################################################

            #------ RETRIEVE A SENSOR SR COLLECTION FUNCTION -----
            def getSRcollection(start_date, end_date, sensor):
                # get a landsat collection for given year, day range, and sensor
                srCollection = ee.ImageCollection('LANDSAT/'+ sensor + '/C02/T1_L2').filterDate(start_date, end_date)

                return srCollection

            #get all collection5
            # lt4 = getSRcollection(startYear+'-'+startDay, endYear+'-'+endDay, 'LT04')  
            lt5 = getSRcollection(startYear+'-'+startDay, endYear+'-'+endDay, 'LT05').filterBounds(territory22)    
            le7 = getSRcollection(startYear+'-'+startDay, endYear+'-'+endDay, 'LE07').filterBounds(territory22)     
            lc8 = getSRcollection(startYear+'-'+startDay, endYear+'-'+endDay, 'LC08').filterBounds(territory22)        
            sent= sent_2A.filterDate(startYear+'-'+startDay, endYear+'-'+endDay).filterBounds(territory22.geometry())

            #         #------------------------------------------Landsat 5 corrections

            #select bands
            pre_lt5 = lt5.filterDate(pre_start, pre_end).map(mask).map(land_scale).select(['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']).map(to_float)

            #       #ensure we have imagery for the sensor
            if pre_lt5.size().getInfo() > 0 :


                #take the median
                pre_lt5 = pre_lt5.median().clip(territory22)

                #calculate nbr, ndvi and ndii
                pre_lt5_nbr = pre_lt5.normalizedDifference(['SR_B4', 'SR_B7']).select([0], ['NBR']).cast({'NBR': 'float'})
                pre_lt5_ndvi = pre_lt5.normalizedDifference(['SR_B4', 'SR_B3']).select([0], ['NDVI']).cast({'NDVI': 'float'})
                pre_lt5_ndii = pre_lt5.normalizedDifference(['SR_B4', 'SR_B5']).select([0], ['NDII']).cast({'NDII': 'float'})

                #add the bands back
                pre_lt5 = pre_lt5.addBands(pre_lt5_nbr).addBands(pre_lt5_ndvi).addBands(pre_lt5_ndii)

                #apply the corrections

                l5_pre_corr = pre_lt5.multiply(l5_corr[1]).add(pre_lt5.pow(2).multiply(l5_corr[2])).add(pre_lt5.pow(3).multiply(l5_corr[3])).add(l5_corr[0])

            #-------now do post-fire
            #select bands
            post_lt5 = lt5.filterDate(post_start, post_end).map(mask).map(land_scale).select(['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']).map(to_float)

            #       #ensure we have imagery for the sensor
            if post_lt5.size().getInfo() > 0 :



                #take the median
                post_lt5 = post_lt5.median().clip(territory22)

                #calculate nbr, ndvi and ndii
                post_lt5_nbr = post_lt5.normalizedDifference(['SR_B4', 'SR_B7']).select([0], ['NBR']).cast({'NBR': 'float'})
                post_lt5_ndvi = post_lt5.normalizedDifference(['SR_B4', 'SR_B3']).select([0], ['NDVI']).cast({'NDVI': 'float'})
                post_lt5_ndii = post_lt5.normalizedDifference(['SR_B4', 'SR_B5']).select([0], ['NDII']).cast({'NDII': 'float'})

                #add the bands back
                post_lt5 = post_lt5.addBands(post_lt5_nbr).addBands(post_lt5_ndvi).addBands(post_lt5_ndii)

                #apply the corrections

                l5_post_corr = post_lt5.multiply(l5_corr[1]).add(post_lt5.pow(2).multiply(l5_corr[2])).add(post_lt5.pow(3).multiply(l5_corr[3])).add(l5_corr[0])


                #         #------------------------------------------Landsat 7, no corrections but get things clipped and do pre fire/post_fire stuff


            #select bands
            pre_le7 = le7.filterDate(pre_start, pre_end).map(mask).map(land_scale).select(['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']).map(to_float)

            #       #ensure we have imagery for the sensor
            if pre_le7.size().getInfo() > 0 :



                #take the median
                pre_le7 = pre_le7.median().clip(territory22)

                #calculate nbr, ndvi and ndii
                pre_le7_nbr = pre_le7.normalizedDifference(['SR_B4', 'SR_B7']).select([0], ['NBR']).cast({'NBR': 'float'})
                pre_le7_ndvi = pre_le7.normalizedDifference(['SR_B4', 'SR_B3']).select([0], ['NDVI']).cast({'NDVI': 'float'})
                pre_le7_ndii = pre_le7.normalizedDifference(['SR_B4', 'SR_B5']).select([0], ['NDII']).cast({'NDII': 'float'})

                #add the bands back
                pre_le72 = pre_le7.addBands(pre_le7_nbr).addBands(pre_le7_ndvi).addBands(pre_le7_ndii)

            #-------now do post-fire
            #select bands
            post_le7 = le7.filterDate(post_start, post_end).map(mask).map(land_scale).select(['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']).map(to_float)
            #       #ensure we have imagery for the sensor
            if post_le7.size().getInfo() > 0 :


                #take the median
                post_le7 = post_le7.median().clip(territory22)

                #calculate nbr, ndvi and ndii
                post_le7_nbr = post_le7.normalizedDifference(['SR_B4', 'SR_B7']).select([0], ['NBR']).cast({'NBR': 'float'})
                post_le7_ndvi = post_le7.normalizedDifference(['SR_B4', 'SR_B3']).select([0], ['NDVI']).cast({'NDVI': 'float'})
                post_le7_ndii = post_le7.normalizedDifference(['SR_B4', 'SR_B5']).select([0], ['NDII']).cast({'NDII': 'float'})

                #add the bands back
                post_le72 = post_le7.addBands(post_le7_nbr).addBands(post_le7_ndvi).addBands(post_le7_ndii)

            #------------------------------------------Landsat 8 corrections


            #-------first do pre-fire

            #select bands
            pre_lc8 = lc8.filterDate(pre_start, pre_end).map(mask).map(land_scale).select(['SR_B2','SR_B3','SR_B4','SR_B5','SR_B6','SR_B7'],['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']) .map(to_float)

            #       #ensure we have imagery for the sensor
            if pre_lc8.size().getInfo() > 0 :



                #take the median
                pre_lc8 = pre_lc8.median().clip(territory22)

                #calculate nbr, ndvi and ndii
                pre_lc8_nbr = pre_lc8.normalizedDifference(['SR_B4', 'SR_B7']).select([0], ['NBR']).cast({'NBR': 'float'})
                pre_lc8_ndvi = pre_lc8.normalizedDifference(['SR_B4', 'SR_B3']).select([0], ['NDVI']).cast({'NDVI': 'float'})
                pre_lc8_ndii = pre_lc8.normalizedDifference(['SR_B4', 'SR_B5']).select([0], ['NDII']).cast({'NDII': 'float'})

                #add the bands back
                pre_lc8 = pre_lc8.addBands(pre_lc8_nbr).addBands(pre_lc8_ndvi).addBands(pre_lc8_ndii)

                #apply the corrections

                l8_pre_corr = pre_lc8.multiply(l8_corr[1]).add(pre_lc8.pow(2).multiply(l8_corr[2])).add(pre_lc8.pow(3).multiply(l8_corr[3])).add(l8_corr[0])

            #-------now do post-fire
              #select bands
            post_lc8 = lc8.filterDate(post_start, post_end).map(mask).map(land_scale).select(['SR_B2','SR_B3','SR_B4','SR_B5','SR_B6','SR_B7'],['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']) .map(to_float)

            #       #ensure we have imagery for the sensor
            if post_lc8.size().getInfo() > 0 :



            #take the median
                post_lc8 = post_lc8.median().clip(territory22)

                #calculate nbr, ndvi and ndii
                post_lc8_nbr = post_lc8.normalizedDifference(['SR_B4', 'SR_B7']).select([0], ['NBR']).cast({'NBR': 'float'})
                post_lc8_ndvi = post_lc8.normalizedDifference(['SR_B4', 'SR_B3']).select([0], ['NDVI']).cast({'NDVI': 'float'})
                post_lc8_ndii = post_lc8.normalizedDifference(['SR_B4', 'SR_B5']).select([0], ['NDII']).cast({'NDII': 'float'})

                #add the bands back
                post_lc8 = post_lc8.addBands(post_lc8_nbr).addBands(post_lc8_ndvi).addBands(post_lc8_ndii)

                #apply the corrections

                l8_post_corr = post_lc8.multiply(l8_corr[1]).add(post_lc8.pow(2).multiply(l8_corr[2])).add(post_lc8.pow(3).multiply(l8_corr[3])).add(l8_corr[0])

                # #          #------------------------------------------Sentinel 2 corrections, use landsat 8 coefficients


            ##-------first do pre-fire

            #select bands
            pre_sent = sent_2A.filterDate(pre_start, pre_end).map(sent_scale).select(['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']).map(to_float)

            #       #ensure we have imagery for the sensor
            if pre_sent.size().getInfo() > 0 :



                    #take the median
                pre_sent = pre_sent.median().clip(territory22)

                #calculate nbr, ndvi and ndii
                pre_sent_nbr = pre_sent.normalizedDifference(['SR_B4', 'SR_B7']).select([0], ['NBR']).cast({'NBR': 'float'})
                pre_sent_ndvi = pre_sent.normalizedDifference(['SR_B4', 'SR_B3']).select([0], ['NDVI']).cast({'NDVI': 'float'})
                pre_sent_ndii = pre_sent.normalizedDifference(['SR_B4', 'SR_B5']).select([0], ['NDII']).cast({'NDII': 'float'})

                #add the bands back
                pre_sent = pre_sent.addBands(pre_sent_nbr).addBands(pre_sent_ndvi).addBands(pre_sent_ndii)

                #apply the corrections

                sent_pre_corr = pre_sent.multiply(l8_corr[1]).add(pre_sent.pow(2).multiply(l8_corr[2])).add(pre_sent.pow(3).multiply(l8_corr[3])).add(l8_corr[0])

            #-------now do post-fire
            #select bands
            post_sent = sent_2A.filterDate(post_start, post_end).map(sent_scale).select(['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']).map(to_float)

            #       #ensure we have imagery for the sensor
            if post_sent.size().getInfo() > 0 :



                #take the median
                post_sent = post_sent.median().clip(territory22)

                #calculate nbr, ndvi and ndii
                post_sent_nbr = post_sent.normalizedDifference(['SR_B4', 'SR_B7']).select([0], ['NBR']).cast({'NBR': 'float'})
                post_sent_ndvi = post_sent.normalizedDifference(['SR_B4', 'SR_B3']).select([0], ['NDVI']).cast({'NDVI': 'float'})
                post_sent_ndii = post_sent.normalizedDifference(['SR_B4', 'SR_B5']).select([0], ['NDII']).cast({'NDII': 'float'})

                #add the bands back
                post_sent = post_sent.addBands(post_sent_nbr).addBands(post_sent_ndvi).addBands(post_sent_ndii)

                #apply the corrections

                sent_post_corr = post_sent.multiply(l8_corr[1]).add(post_sent.pow(2).multiply(l8_corr[2])).add(post_sent.pow(3).multiply(l8_corr[3])).add(l8_corr[0])


            #try to see if image exists, if so append

            #----------------------all prefire

            #       #empty list for pre-fire, use this to combine if we have land 5, 7, 8 or sentinel 
            pre_input = []

            try:
                l5_pre_corr.getInfo()
                pre_input.append(l5_pre_corr)

            except:
                pass

            try:
                pre_le72.getInfo()
                pre_input.append(pre_le72)

            except:
                pass

            try:
                l8_pre_corr.getInfo()
                pre_input.append(l8_pre_corr)

            except:
                pass

            try:
                sent_pre_corr.getInfo()
                pre_input.append(sent_pre_corr)

            except:
                pass


            #----------------------all postfire

            #         #       #empty list for post-fire, use this to combine if we have land 5, 7, 8 or sentinel 
            post_input = []

            try:
                l5_post_corr.getInfo()
                post_input.append(l5_post_corr)

            except:
                pass

            try:
                post_le72.getInfo()
                post_input.append(post_le72)

            except:
                pass

            try:
                l8_post_corr.getInfo()
                post_input.append(l8_post_corr)

            except:
                pass

            try:
                sent_post_corr.getInfo()
                post_input.append(sent_post_corr)

            except:
                pass

            if (len(pre_input) >0) and (len(post_input) > 0):

                #take the median of the image collections
                pre_input = ee.ImageCollection(pre_input)
                post_input = ee.ImageCollection(post_input)


                pre_input = pre_input.median()
                post_input= post_input.median()

                #difference the bands
                raw_bands = pre_input.subtract(post_input).multiply(1000)



                b1 = raw_bands.select('SR_B1').cast({'SR_B1':'short'}) #0
                b2 = raw_bands.select('SR_B2').cast({'SR_B2':'short'}) #1
                b3 = raw_bands.select('SR_B3').cast({'SR_B3':'short'}) #2
                b4 = raw_bands.select('SR_B4').cast({'SR_B4':'short'}) #3
                b5 = raw_bands.select('SR_B5').cast({'SR_B5':'short'}) #4
                b6 = raw_bands.select('SR_B7').cast({'SR_B7':'short'}) #5
                b7 = raw_bands.select('NBR').cast({'NBR':'short'}) #band 6 is dnbr is numpy
                b8 = raw_bands.select('NDVI').cast({'NDVI':'short'}) #7
                b9 = raw_bands.select('NDII').cast({'NDII':'short'}) #8


                #if using all bands
                raw_bands = b1.addBands(b2).addBands(b3).addBands(b4).addBands(b5).addBands(b6).addBands(b7).addBands(b8).addBands(b9)


                raw_bands2 = raw_bands.clip(territory22)

                
                #this is all to apply the global normalization values
                b1_min = glob_norm.reset_index().query('index == 0')['0'].values[0]
                b1_max = glob_norm.reset_index().query('index == 1')['0'].values[0]

                b2_min = glob_norm.reset_index().query('index == 0')['1'].values[0]
                b2_max = glob_norm.reset_index().query('index == 1')['1'].values[0]

                b3_min = glob_norm.reset_index().query('index == 0')['2'].values[0]
                b3_max = glob_norm.reset_index().query('index == 1')['2'].values[0]

                b4_min = glob_norm.reset_index().query('index == 0')['3'].values[0]
                b4_max = glob_norm.reset_index().query('index == 1')['3'].values[0]

                b5_min = glob_norm.reset_index().query('index == 0')['4'].values[0]
                b5_max = glob_norm.reset_index().query('index == 1')['4'].values[0]

                b6_min = glob_norm.reset_index().query('index == 0')['5'].values[0]
                b6_max = glob_norm.reset_index().query('index == 1')['5'].values[0]

                b7_min = glob_norm.reset_index().query('index == 0')['6'].values[0]
                b7_max = glob_norm.reset_index().query('index == 1')['6'].values[0]

                b8_min = glob_norm.reset_index().query('index == 0')['7'].values[0]
                b8_max = glob_norm.reset_index().query('index == 1')['7'].values[0]

                b9_min = glob_norm.reset_index().query('index == 0')['8'].values[0]
                b9_max = glob_norm.reset_index().query('index == 1')['8'].values[0]

                b7_min = glob_norm.reset_index().query('index == 0')['6'].values[0]
                b7_max = glob_norm.reset_index().query('index == 1')['6'].values[0]

                b8_min = glob_norm.reset_index().query('index == 0')['7'].values[0]
                b8_max = glob_norm.reset_index().query('index == 1')['7'].values[0]

                b9_min = glob_norm.reset_index().query('index == 0')['8'].values[0]
                b9_max = glob_norm.reset_index().query('index == 1')['8'].values[0]
                
                b1 = raw_bands2.select('SR_B1').toFloat().clamp(b1_min, b1_max).unitScale(b1_min, b1_max)
                b2 = raw_bands2.select('SR_B2').toFloat().clamp(b2_min, b2_max).unitScale(b2_min, b2_max)#.toFloat()
                b3 = raw_bands2.select('SR_B3').toFloat().clamp(b3_min, b3_max).unitScale(b3_min, b3_max)#.toFloat()
                b4 = raw_bands2.select('SR_B4').toFloat().clamp(b4_min, b4_max).unitScale(b4_min, b4_max)#.toFloat()
                b5 = raw_bands2.select('SR_B5').toFloat().clamp(b5_min, b5_max).unitScale(b5_min, b5_max)#.toFloat()
                b6 = raw_bands2.select('SR_B7').toFloat().clamp(b6_min, b6_max).unitScale(b6_min, b6_max)#.toFloat()
                b7 = raw_bands2.select('NBR').toFloat().clamp(b7_min, b7_max).unitScale(b7_min, b7_max)#.toFloat()
                b8 = raw_bands2.select('NDVI').toFloat().clamp(b8_min, b8_max).unitScale(b8_min, b8_max)#.toFloat()
                b9 = raw_bands2.select('NDII').toFloat().clamp(b9_min, b9_max).unitScale(b9_min, b9_max)#.toFloat()
                
                #need the bands to predict on in same order mdoel was trained on
                # for_predict = b1.addBands(b2).addBands(b3).addBands(b4).addBands(b5).addBands(b6).addBands(b7).addBands(b8).addBands(b9)
                for_predict = b7.addBands(b8).addBands(b9)
                # for_predict = b7

                for_predict = for_predict.reproject(crs = 'EPSG:3413', scale = 30)

                arrayImage = for_predict.float().toArray().rename('array')

                # #predict model
                predictions = model.predictImage(arrayImage).arrayGet([0])

                #scale the predictions to save disk space when saving
                final = predictions.multiply(1000).toShort()
                
                my_date = ee.Date(str(year) + '-01-01')
                final = final.set('system:time_start', my_date)


                print(f"Downloading {year} and {territory}")

    
                #export to asset
                task = ee.batch.Export.image.toAsset(
                                  image = final,
                                  region=territory22.geometry(), 
                                  description= 'predicted_' + str(year) + '_' + str(territory),
                                  scale=30,
                                  crs='EPSG:3413',
                                  # crs= 'EPSG:4326',
                                  maxPixels=1e13,
                                  assetId =  full_out)
                                  # assetId =  'users/spotter/fire_cnn/mtbs_predictions/' + str(year) + '_' + str(territory22))


                task.start()
   

Image projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_2015_preds_VI/2015_0 does not exist
Downloading 2015 and 0
Asset projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_2015_preds_VI already exists.
Image projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_2015_preds_VI/2015_1 does not exist
Downloading 2015 and 1
Asset projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_2015_preds_VI already exists.
Image projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_2015_preds_VI/2015_2 does not exist
Downloading 2015 and 2
Asset projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_2015_preds_VI already exists.
Image projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_2015_preds_VI/2015_3 does not exist
Downloading 2015 and 3
Asset projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_2015_preds_VI already exists.
Image projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_2015_preds_VI/2015_4 does not exist
Downloading 2015 and 4
Asset projects/gee-serdp-upload/assets/cnn_mapping/ak_ca_2015_preds_VI a

In [15]:
print(final.getInfo())

{'type': 'Image', 'bands': [{'id': 'prediction', 'data_type': {'type': 'PixelType', 'precision': 'int', 'min': -32768, 'max': 32767}, 'crs': 'EPSG:3413', 'crs_transform': [30, 0, 0, 0, 30, 0]}], 'properties': {'system:time_start': {'type': 'Date', 'value': 1420070400000}}}


In [16]:
print(final.get('system:time_start').getInfo())

{'type': 'Date', 'value': 1420070400000}


In [17]:
test = final.filterDate('2015-01-01', '2015-12-31')
print(test.getInfo())

AttributeError: 'Image' object has no attribute 'filterDate'

In [None]:
't'