In [7]:
import os
import ee
import pandas as pd
from google.cloud import storage

# Authenticate GEE
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/explore/nobackup/people/spotter5/cnn_mapping/gee-serdp-upload-7cd81da3dc69.json"

service_account = 'gee-serdp-upload@appspot.gserviceaccount.com'
credentials = ee.ServiceAccountCredentials(service_account, "/explore/nobackup/people/spotter5/cnn_mapping/gee-serdp-upload-7cd81da3dc69.json")
ee.Initialize(credentials)

# Load collections
sent_2A = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED")
s2Clouds = ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
lfdb = ee.FeatureCollection("users/spotter/fire_cnn/ann_w_id")

# Constants
MAX_CLOUD_PROBABILITY = 50
bucket_name = 'smp-scratch'
folder_name = 'monthly_intervals_ndsi_fast'

# Cloud masking for Sentinel-2
def sent_maskcloud(image):
    image = image.select(['B2', 'B3', 'B4', 'B8', 'B11', 'B12'], 
                         ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']).toShort()
    clouds = ee.Image(image.get('cloud_mask')).select('probability')
    isNotCloud = clouds.lt(MAX_CLOUD_PROBABILITY)
    return image.updateMask(isNotCloud).reproject(crs=image.projection().crs(), scale=30)

# Apply cloud masking
s2SrWithCloudMask = ee.Join.saveFirst('cloud_mask').apply(
    primary=sent_2A, 
    secondary=s2Clouds, 
    condition=ee.Filter.equals(leftField='system:index', rightField='system:index')
)
sent_2A = ee.ImageCollection(s2SrWithCloudMask).map(sent_maskcloud)

# Helper functions
def mask(image):
    qa = image.select('QA_PIXEL')                                       
    return image.updateMask(qa.bitwiseAnd(8).eq(0).And(qa.bitwiseAnd(10).eq(0)).And(qa.bitwiseAnd(32).eq(0)))

def land_scale(image):
    return image.multiply(0.0000275).add(-0.2)

def sent_scale(image):
    return image.multiply(0.0001)

def mask_ndsi(image):
    ndsi = image.normalizedDifference(['SR_B2', 'SR_B5'])
    return image.updateMask(ndsi.lte(-0.2))

def to_float(image):
    return image.select(['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']).toFloat()

def landsat_correct(sat, bands):
    coeffs = pd.read_csv("/explore/nobackup/people/spotter5/cnn_mapping/raw_files/boreal_xcal_regression_coefficients.csv").fillna(0)
    l5 = coeffs[(coeffs['satellite'] == sat) & (coeffs['band.or.si'].isin(bands))]
    l5['band.or.si'] = pd.Categorical(l5['band.or.si'], categories=bands)
    l5 = l5.sort_values('band.or.si')
    return l5['B0'].values.tolist(), l5['B1'].values.tolist(), l5['B2'].values.tolist(), l5['B3'].values.tolist()

l8_corr = landsat_correct('LANDSAT_8', ['blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'nbr', 'ndvi', 'ndii'])
l5_corr = landsat_correct('LANDSAT_5', ['blue', 'green', 'red', 'nir', 'swir1', 'swir2', 'nbr', 'ndvi', 'ndii'])

def apply_corrections(ic, pre_start, pre_end, post_start, post_end, geometry):
    pre_ic = ic.filterDate(pre_start, pre_end).filterBounds(geometry).map(mask).map(land_scale).select(
        ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']).map(to_float).map(mask_ndsi)
    post_ic = ic.filterDate(post_start, post_end).filterBounds(geometry).map(mask).map(land_scale).select(
        ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']).map(to_float).map(mask_ndsi)
    
    condition = pre_ic.size().gt(0).And(post_ic.size().gt(0))
    
    def calculate_difference(pre_ic, post_ic):
        pre_img = pre_ic.median().clip(geometry)
        post_img = post_ic.median().clip(geometry)
        return pre_img.subtract(post_img).multiply(1000).select(['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7'])
    
    return ee.Algorithms.If(condition, calculate_difference(pre_ic, post_ic), ee.Image().select())

def get_pre_post(pre_start, pre_end, post_start, post_end, geometry):
    lt5 = ee.ImageCollection('LANDSAT/LT05/C02/T1_L2')
    le7 = ee.ImageCollection('LANDSAT/LE07/C02/T1_L2')
    lc8 = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
    sent = sent_2A
    
    lt5_img = apply_corrections(lt5, pre_start, pre_end, post_start, post_end, geometry)
    le7_img = apply_corrections(le7, pre_start, pre_end, post_start, post_end, geometry)
    lc8_img = apply_corrections(lc8, pre_start, pre_end, post_start, post_end, geometry)
    sent_img = apply_corrections(sent, pre_start, pre_end, post_start, post_end, geometry)

    return ee.ImageCollection([lt5_img, le7_img, lc8_img, sent_img]).mosaic()

# Main processing function for a fire polygon
def process_fire_polygon(i):
    sub_shape = lfdb.filter(ee.Filter.eq("ID", i))
    bbox = sub_shape.geometry().bounds()
    
    proj = ee.Projection("EPSG:4326").translate(0.00, 0.00)
    final_buffer = ee.Geometry.Polygon(bbox.coordinates(), proj).transform(proj)
    final_buffer2 = final_buffer.buffer(5000).bounds()
    final_buffer = final_buffer.buffer(40000)

    this_year = ee.Number(sub_shape.aggregate_array('Year').get(0))
    all_months = ee.List([
        ee.List(['-03-01', '-04-01']),
        ee.List(['-04-01', '-05-01']),
        ee.List(['-05-01', '-06-01']),
        ee.List(['-06-01', '-07-01']),
        ee.List(['-07-01', '-08-01']),
        ee.List(['-08-01', '-09-01']),
        ee.List(['-09-01', '-10-01']),
        ee.List(['-10-01', '-11-01']),
        ee.List(['-06-01', '-08-31'])
    ])
    
    def process_month(month):
        m1 = ee.String(ee.List(month).get(0))
        m2 = ee.String(ee.List(month).get(1))

        start_year = ee.Number(this_year).subtract(1)
        end_year = ee.Algorithms.If(m1.compareTo('-06-01').eq(0).And(m2.compareTo('-08-31').eq(0)), this_year.add(1), this_year)

        pre_start = ee.Date(ee.String(start_year).cat(m1))
        pre_end = ee.Date(ee.String(start_year).cat(m2))
        post_start = ee.Date(ee.String(end_year).cat(m1))
        post_end = ee.Date(ee.String(end_year).cat(m2))

        return get_pre_post(pre_start, pre_end, post_start, post_end, final_buffer)

    all_months_images = ee.ImageCollection(all_months.map(process_month)).max()

    raw_bands = all_months_images.clip(final_buffer)

    # Fire raster
    lfdb_filtered = lfdb.filterBounds(final_buffer).filter(ee.Filter.eq("Year", this_year))
    fire_rast = lfdb_filtered.reduceToImage(properties=['ID'], reducer=ee.Reducer.first())
    fire_rast = fire_rast.where(fire_rast.gt(0), 1)

    bad_filtered = lfdb.filterBounds(final_buffer).filter(ee.Filter.Or(
        ee.Filter.eq("Year", ee.Number(this_year).subtract(1)),
        ee.Filter.eq("Year", this_year),
        ee.Filter.eq("Year", ee.Number(this_year).add(1)),
        ee.Filter.eq("Year", ee.Number(this_year).add(2))
    ))
    
    bad_fire_rast = bad_filtered.reduceToImage(properties=['ID'], reducer=ee.Reducer.first())
    bad_fire_rast = bad_fire_rast.where(bad_fire_rast.gt(0), 1)
    bad_fire_rast = bad_fire_rast.where(bad_fire_rast.eq(1).And(fire_rast.eq(1)), 2).unmask(-999)

    raw_bands = raw_bands.updateMask(bad_fire_rast.neq(1))

    # Target variable
    y = raw_bands.select(['NBR'], ['y']).where(fire_rast.eq(1), 1).where(fire_rast.neq(1), 0)
    y = y.cast({'y': 'short'})
    raw_bands = raw_bands.addBands(y)

    # Start download
    fname = f"{folder_name}/final_{i}"
    task = ee.batch.Export.image.toCloudStorage(
        image=raw_bands.toShort(),
        region=final_buffer2,
        description=fname,
        scale=30,
        crs='EPSG:3413',
        maxPixels=1e13,
        bucket=bucket_name,
        fileNamePrefix=fname
    )
    task.start()
    print(f"Downloading {fname}")

    return raw_bands

# Convert all_ids to an ee.List and map over it
all_ids = ee.List([16086, 3381, 15894, 2464, 8737, 12494, 9763, 3266])
all_ids = ee.List([3266])

all_results = all_ids.map(process_fire_polygon)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  l5['band.or.si'] = pd.Categorical(l5['band.or.si'], categories=bands)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  l5['band.or.si'] = pd.Categorical(l5['band.or.si'], categories=bands)


EEException: A mapped function's arguments cannot be used in client-side operations