In [1]:
import ee
import geemap

# -------------------------------
# 0) INITIALIZE EE
# -------------------------------
ee.Authenticate()
ee.Initialize(project='raman-461708')

# -------------------------------
# 1) PARAMETERS
# -------------------------------
aez         = 1
start_year  = 2017      # first LULC year (2017-07-01_2018-06-30)
total_years = 7         # you now have 7 yearly maps
project     = 'raman-461708'

roi_boundary = ee.FeatureCollection("users/mtpictd/agro_eco_regions") \
    .filter(ee.Filter.eq("ae_regcode", aez)).geometry()

# -------------------------------
# 2) CLASS GROUPING
# -------------------------------
# Original → grouped mapping
ORIG  = [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13]
GROUP = [0, 1, 2, 2, 2, 6, 7, 8, 8,  8,  8,  12, 13]
BAND  = 'predicted_label'

def group_classes(img, band=BAND):
    """Map fine LULC classes to grouped classes (Water + Crop merged)."""
    return img.select([band]).remap(ORIG, GROUP).rename('predicted_label').toInt()

# -------------------------------
# 3) LOAD: y5 FROM CORRECTED, y6 & y7 FROM ORIGINAL
# -------------------------------
def load_y5_corrected_y6_y7_original(aez, start_year, total_years, roi_boundary, project='raman-461708'):
    """
    Assumes:
      - Years: 1..total_years starting at start_year
      - You want last 3 years: year5, year6, year7
      - Year5 is already temporally corrected and stored as:
          projects/{project}/assets/AEZ_{aez}_{Y5-07-01_Y6-06-30}_temporal_corrected
      - Year6 & Year7 are still original LULC_v4_PanIndia_* assets.
    """
    # Last 3 conceptual years in the 7-year sequence
    y5 = start_year + (total_years - 3)  # e.g., 2017 + 4 = 2021
    y6 = y5 + 1                           # 2022
    y7 = y6 + 1                           # 2023

    # Date tags
    date5 = f"{y5}-07-01_{y5+1}-06-30"
    date6 = f"{y6}-07-01_{y6+1}-06-30"
    date7 = f"{y7}-07-01_{y7+1}-06-30"

    # ---- Year 5: corrected AEZ asset ----
    corr5_id = f"projects/{project}/assets/AEZ_{aez}_{date5}_temporal_corrected"
    corr5 = ee.Image(corr5_id).clip(roi_boundary)

    # corrected asset likely has band name 'y5f' (from your export).
    # Make it look like original by renaming to 'predicted_label'
    corr5_band = ee.String(corr5.bandNames().get(0))
    corr5_pl   = corr5.select([corr5_band], ['predicted_label'])

    # ---- Year 6 & 7: original pan-India assets ----
    base = f"projects/{project}/assets/LULC_v4_PanIndia_"
    img6 = ee.Image(f"{base}{date6}").clip(roi_boundary)
    img7 = ee.Image(f"{base}{date7}").clip(roi_boundary)

    return corr5_pl, img6, img7, (y5, y6, y7)

corr5_pl, img6_orig, img7_orig, (y5, y6, y7) = load_y5_corrected_y6_y7_original(
    aez=aez,
    start_year=start_year,
    total_years=total_years,
    roi_boundary=roi_boundary,
    project=project
)

# -------------------------------
# 4) GROUPED (y5, y6, y7)
# -------------------------------
# corr5_pl is already 'predicted_label' in original codebook; group it.
y5_grp = group_classes(corr5_pl, band='predicted_label')

# For y6 & y7, bands are already 'predicted_label' in the original LULC assets.
y6_grp = group_classes(img6_orig, band='predicted_label')
y7_grp = group_classes(img7_orig, band='predicted_label')

# Stack as grouped y5, y6, y7
seq_last3 = ee.Image.cat([
    y5_grp.rename('y5'),
    y6_grp.rename('y6'),
    y7_grp.rename('y7')
])

print("Bands in seq_last3:", seq_last3.bandNames().getInfo())

# -------------------------------
# 5) ADMISSIBILITY MATRIX & allow_flip(A,B)
# -------------------------------
# Grouped classes (excluding 0)
# Order: [Built-up, Water, Tree, Barren, Crop, Scrub, Plantation]
MAT_CLASSES = [1, 2, 6, 7, 8, 12, 13]

# A \ B (rows = target A, cols = current B): 1=allowed, 0=not allowed
#            Bu Wa Tr Ba Cr Sc Pl
ALLOW_TABLE = [
    # A=Built-up
               0, 0, 0, 1, 1, 1, 1,
    # A=Water
               0, 0, 1, 0, 0, 0, 0,
    # A=Tree/Forest
               1, 1, 0, 1, 1, 1, 1,
    # A=Barren
               1, 0, 1, 0, 1, 1, 1,
    # A=Crop
               0, 0, 1, 1, 0, 1, 0,
    # A=Scrub
               0, 0, 1, 1, 1, 0, 1,
    # A=Plantation
               1, 0, 1, 1, 0, 1, 0
]

ALLOW_TABLE_ARR = ee.Array(ALLOW_TABLE).reshape([7, 7])
ALLOW_TABLE_IMG = ee.Image.constant(ALLOW_TABLE_ARR)

def _class_to_index(x):
    """Map grouped class image -> index 0..6 (for MAT_CLASSES), else -1."""
    return ee.Image(x).remap(MAT_CLASSES, list(range(7)), -1).toInt()

def allow_flip(A, B):
    """
    Returns an image (0/1) indicating where flipping B -> A is allowed by the matrix.
    A, B are grouped-class images.
    """
    a = _class_to_index(A)
    b = _class_to_index(B)
    valid = a.gte(0).And(b.gte(0))
    val = ALLOW_TABLE_IMG.arrayGet([a, b]).toInt()
    return val.updateMask(valid).unmask(0).rename('allowFlip')

# -------------------------------
# 6) ABA FLIPS CENTERED AT YEAR 6
# -------------------------------
def detect_ABA_flips_year6(seq3):
    """
    seq3: image with bands y5, y6, y7 (grouped labels).
    Detects ABA patterns centered at year 6: (A, B, A) with B != A
    and flipping B->A allowed by the admissibility matrix.

    Returns:
      - aba_mask: 1 where year 6 is an admissible ABA anomaly.
      - y6c: corrected year-6 band with B flipped to A where aba_mask==1.
    """
    y5 = seq3.select('y5')  # A (prev year, already temporally corrected)
    y6 = seq3.select('y6')  # B (current / 6th year)
    y7 = seq3.select('y7')  # A (next year / new 7th year)

    # Raw ABA at center year 6: y5 == y7, y6 != y5
    raw_ABA = y5.eq(y7).And(y6.neq(y5))

    # Matrix constraint: only keep where B->A is allowed (A = y5, B = y6)
    adm = allow_flip(y5, y6).eq(1)

    aba_mask = raw_ABA.And(adm).rename('ABA6_allowed')  # 1 = valid anomaly at year 6

    # Corrected y6: flip B -> A where allowed ABA
    y6c = y6.where(aba_mask, y5).rename('y6c')

    return aba_mask, y6c

aba_mask, y6c = detect_ABA_flips_year6(seq_last3)
seq_last3_corrected = seq_last3.addBands([aba_mask, y6c])

print("Bands in corrected stack:", seq_last3_corrected.bandNames().getInfo())

# -------------------------------
# 7) OPTIONAL VISUALIZATION
# -------------------------------
pallete_lulc = [
  '000000','ff0000','74ccf4','1ca3ec','0f5e9c',
  'f1c232','38761d','A9A9A9','BAD93E','f59d22',
  'FF9371','b3561d','a9a9a9','84994F'
]
vis_params_lulc = {'min': 0, 'max': 13, 'palette': pallete_lulc}

mask_vis = {
    'min': 0,
    'max': 1,
    'palette': ['000000', 'ff00ff']  # 0=black, 1=magenta
}

Map = geemap.Map()
url = 'https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}'
Map.layout.height = '800px'
Map.add_tile_layer(url, name="Google Map", attribution="Google")
Map.addLayer(seq_last3.select('y5'), vis_params_lulc, 'y5 grouped (before ABA correction)')
Map.addLayer(seq_last3.select('y6'), vis_params_lulc, 'y6 grouped (before ABA correction)')
Map.addLayer(seq_last3.select('y7'), vis_params_lulc, 'y7 grouped (before ABA correction)')
Map.addLayer(seq_last3_corrected.select('y6c'), vis_params_lulc, 'y6 grouped (after ABA correction)')
Map.addLayer(aba_mask, mask_vis, 'ABA flips on year 6 (allowed)')
Map


Bands in seq_last3: ['y5', 'y6', 'y7']
Bands in corrected stack: ['y5', 'y6', 'y7', 'ABA6_allowed', 'y6c']


Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='topright', transp…

In [2]:

from datetime import datetime, timedelta
import pandas as pd
from dateutil.relativedelta import relativedelta
chastainBandNames = ['BLUE', 'GREEN', 'RED', 'NIR', 'SWIR1', 'SWIR2']

# Regression model parameters from Table-4. MSI TOA reflectance as a function of OLI TOA reflectance.
msiOLISlopes = [1.0946,1.0043,1.0524,0.8954,1.0049,1.0002]
msiOLIIntercepts = [-0.0107,0.0026,-0.0015,0.0033,0.0065,0.0046]

# Regression model parameters from Table-5. MSI TOA reflectance as a function of ETM+ TOA reflectance.
msiETMSlopes = [1.10601,0.99091,1.05681,1.0045,1.03611,1.04011]
msiETMIntercepts = [-0.0139,0.00411,-0.0024,-0.0076,0.00411,0.00861]

# Regression model parameters from Table-6. OLI TOA reflectance as a function of ETM+ TOA reflectance.
oliETMSlopes =[1.03501,1.00921,1.01991,1.14061,1.04351,1.05271];
oliETMIntercepts = [-0.0055,-0.0008,-0.0021,-0.0163,-0.0045,0.00261]

# Construct dictionary to handle all pairwise combos
chastainCoeffDict = { 'MSI_OLI':[msiOLISlopes,msiOLIIntercepts,1], # check what the third item corresponds to
                    'MSI_ETM':[msiETMSlopes,msiETMIntercepts,1],
                    'OLI_ETM':[oliETMSlopes,oliETMIntercepts,1],
                    'OLI_MSI':[msiOLISlopes,msiOLIIntercepts,0],
                    'ETM_MSI':[msiETMSlopes,msiETMIntercepts,0],
                    'ETM_OLI':[oliETMSlopes,oliETMIntercepts,0]
                    }

'''
Function to mask cloudy pixels in Landsat-7
'''
def maskL7cloud(image):
    qa = image.select('QA_PIXEL')
    mask = qa.bitwiseAnd(1 << 4).eq(0)
    return image.updateMask(mask).select(['B1', 'B2', 'B3' , 'B4' , 'B5' , 'B7']).rename('BLUE', 'GREEN', 'RED' , 'NIR' , 'SWIR1' , 'SWIR2')


'''
Function to mask cloudy pixels in Landsat-8
'''
def maskL8cloud(image):
    qa = image.select('QA_PIXEL')
    mask = qa.bitwiseAnd(1 << 4).eq(0)
    return image.updateMask(mask).select(['B2', 'B3', 'B4' , 'B5' , 'B6' , 'B7']).rename('BLUE', 'GREEN', 'RED' , 'NIR' , 'SWIR1' , 'SWIR2')


'''
Function to mask clouds using the quality band of Sentinel-2 TOA
'''
def maskS2cloudTOA(image):
    qa = image.select('QA60')
    # Bits 10 and 11 are clouds and cirrus, respectively.
    cloudBitMask = 1 << 10
    cirrusBitMask = 1 << 11
    # Both flags should be set to zero, indicating clear conditions.
    mask = qa.bitwiseAnd(cloudBitMask).eq(0).And(qa.bitwiseAnd(cirrusBitMask).eq(0));
    return image.updateMask(mask).select(['B2', 'B3', 'B4', 'B8',  'B11', 'B12']).rename(['BLUE', 'GREEN', 'RED', 'NIR', 'SWIR1', 'SWIR2'])


'''
Get Landsat and Sentinel image collections
'''
def Get_L7_L8_S2_ImageCollections(inputStartDate, inputEndDate, roi_boundary):
    # ------ Landsat 7 TOA
    L7 = ee.ImageCollection('LANDSAT/LE07/C02/T1_TOA') \
            .filterDate(inputStartDate, inputEndDate) \
            .filterBounds(roi_boundary) \
            .map(maskL7cloud)
    # print('\n Original Landsat 7 TOA dataset: \n',L7.limit(1).getInfo())
    # print('Number of images in Landsat 7 TOA dataset: \t',L7.size().getInfo())

    # ------ Landsat 8 TOA
    L8 = ee.ImageCollection('LANDSAT/LC08/C02/T1_TOA') \
            .filterDate(inputStartDate, inputEndDate) \
            .filterBounds(roi_boundary) \
            .map(maskL8cloud)
    # print('\n Original Landsat 8 TOA dataset: \n', L8.limit(1).getInfo())
    # print('Number of images in Landsat 8 TOA dataset: \t',L8.size().getInfo())

    # ------ Sentinel-2 TOA
    S2 = ee.ImageCollection('COPERNICUS/S2_HARMONIZED') \
            .filterDate(inputStartDate, inputEndDate) \
            .filterBounds(roi_boundary)  \
            .map(maskS2cloudTOA)
    # print('\n Original Sentinel-2 TOA dataset: \n',S2.limit(1).getInfo())
    # print('Number of images in Sentinel 2 TOA dataset: \t',S2.size().getInfo())

    return L7, L8, S2


'''
Function to apply model in one direction
'''
def dir0Regression(img,slopes,intercepts):
    return img.select(chastainBandNames).multiply(slopes).add(intercepts)


'''
Applying the model in the opposite direction
'''
def dir1Regression(img,slopes,intercepts):
    return img.select(chastainBandNames).subtract(intercepts).divide(slopes)


'''
Function to correct one sensor to another
'''
def harmonizationChastain(img, fromSensor,toSensor):
    # Get the model for the given from and to sensor
    comboKey = fromSensor.upper() + '_' + toSensor.upper()
    coeffList = chastainCoeffDict[comboKey]
    slopes = coeffList[0]
    intercepts = coeffList[1]
    direction = ee.Number(coeffList[2])

    # Apply the model in the respective direction
    out = ee.Algorithms.If(direction.eq(0),dir0Regression(img,slopes,intercepts),dir1Regression(img,slopes,intercepts))
    return ee.Image(out).copyProperties(img).copyProperties(img,['system:time_start'])


'''
Calibrate Landsat-8 (OLI) and Sentinel-2 (MSI) to Landsat-7 (ETM+)
'''
def Harmonize_L7_L8_S2(L7, L8, S2):
    # harmonization
    harmonized_L8 = L8.map( lambda img: harmonizationChastain(img, 'OLI','ETM') )
    harmonized_S2 = S2.map( lambda img: harmonizationChastain(img, 'MSI','ETM') )

    # Merge harmonized landsat-8 and sentinel-2 to landsat-7 image collection
    harmonized_LandsatSentinel_ic = ee.ImageCollection(L7.merge(harmonized_L8).merge(harmonized_S2))
    # print(harmonized_LandsatSentinel_ic.size().getInfo())
    return harmonized_LandsatSentinel_ic


'''
Add NDVI band to harmonized image collection
'''
def addNDVI(image):
    return image.addBands(image.normalizedDifference(['NIR', 'RED']).rename('NDVI')).float()


'''
Function definitions to get NDVI values at each 16-day composites
'''
def Get_NDVI_image_datewise(harmonized_LS_ic, roi_boundary):
    def get_NDVI_datewise(date):
        empty_band_image = ee.Image(0).float().rename(['NDVI']).updateMask(ee.Image(0).clip(roi_boundary))
        return harmonized_LS_ic.select(['NDVI']) \
                                .filterDate(ee.Date(date), ee.Date(date).advance(16, 'day')) \
                                .merge(empty_band_image)\
                                .median() \
                                .set('system:time_start',ee.Date(date).millis())
    return get_NDVI_datewise

def Get_LS_16Day_NDVI_TimeSeries(inputStartDate, inputEndDate, harmonized_LS_ic, roi_boundary):
    startDate = datetime.strptime(inputStartDate,"%Y-%m-%d")
    endDate = datetime.strptime(inputEndDate,"%Y-%m-%d")

    date_list = pd.date_range(start=startDate, end=endDate, freq='16D').tolist()
    date_list = ee.List( [datetime.strftime(curr_date,"%Y-%m-%d") for curr_date in date_list] )

    LSC =  ee.ImageCollection.fromImages(date_list.map(Get_NDVI_image_datewise(harmonized_LS_ic, roi_boundary)))

    return LSC


'''
Pair available LSC and modis values for each time stamp.
'''
def pairLSModis(lsRenameBands):
    def pair(feature):
        date = ee.Date( feature.get('system:time_start') )
        startDateT = date.advance(-8,'day')
        endDateT = date.advance(8,'day')

        # ------ MODIS VI ( We can add EVI to the band list later )
        modis = ee.ImageCollection('MODIS/061/MOD13Q1') \
                .filterDate(startDateT, endDateT) \
                .select(['NDVI','SummaryQA']) \
                .filterBounds(roi_boundary) \
                .median() \
                .rename(['NDVI_modis', 'SummaryQA_modis'])

        return feature.rename(lsRenameBands).addBands(modis)
    return pair


'''
Function to get Pearson Correlation Coffecient to perform GapFilling
'''
def get_Pearson_Correlation_Coefficients(LSC_modis_paired_ic, roi_boundary, bandList):
    corr = LSC_modis_paired_ic.filterBounds(roi_boundary) \
                                .select(bandList).toArray() \
                                .arrayReduce( reducer = ee.Reducer.pearsonsCorrelation(), axes=[0], fieldAxis=1 ) \
                                .arrayProject([1]).arrayFlatten([['c', 'p']])
    return corr


'''Use print(...) to write to this console.
Fill gaps in LSC timeseries using modis data
'''
def gapfillLSM(LSC_modis_regression_model, LSC_bandName, modis_bandName):
    def peformGapfilling(image):
        offset = LSC_modis_regression_model.select('offset')
        scale = LSC_modis_regression_model.select('scale')
        nodata = -1

        lsc_image = image.select(LSC_bandName)
        modisfit = image.select(modis_bandName).multiply(scale).add(offset)

        mask = lsc_image.mask()#update mask needs an input (no default input from the API document)
        gapfill = lsc_image.unmask(nodata)
        gapfill = gapfill.where(mask.Not(), modisfit)

        '''
        in SummaryQA,
        0: Good data, use with confidence
        1: Marginal data, useful but look at detailed QA for more information
        2: Pixel covered with snow/ice
        3: Pixel is cloudy
        '''
        qc_m = image.select('SummaryQA_modis').unmask(3)  # missing value is grouped as cloud
        w_m  = modisfit.mask().rename('w_m').where(qc_m.eq(0), 0.8)  # default is 0.8
        w_m = w_m.where(qc_m.eq(1), 0.5)   # Marginal
        w_m = w_m.where(qc_m.gte(2), 0.2) # snow/ice or cloudy

        # make sure these modis values are read where there is missing data from LandSat, Sentinel
        w_l = gapfill.mask() # default is 1
        w_l = w_l.where(mask.Not(), w_m)

        return gapfill.addBands(w_l).rename(['gapfilled_'+LSC_bandName,'SummaryQA']) #have NDVI from modis and a summary of clarity for each

    return peformGapfilling


'''
Function to combine LSC with Modis data
'''
def Combine_LS_Modis(LSC):
    lsRenameBands = ee.Image(LSC.first()).bandNames().map( lambda band: ee.String(band).cat('_lsc') )
    LSC_modis_paired_ic = LSC.map( pairLSModis(lsRenameBands) )

    # Output contains scale, offset i.e. two bands
    LSC_modis_regression_model_NDVI = LSC_modis_paired_ic.select(['NDVI_modis', 'NDVI_lsc']) \
                                                            .reduce(ee.Reducer.linearFit())

    corr_NDVI = get_Pearson_Correlation_Coefficients(LSC_modis_paired_ic, roi_boundary, ['NDVI_modis', 'NDVI_lsc'])
    LSMC_NDVI = LSC_modis_paired_ic.map(gapfillLSM(LSC_modis_regression_model_NDVI, 'NDVI_lsc', 'NDVI_modis'))

    return LSMC_NDVI


'''
Mask out low quality pixels
'''
def mask_low_QA(lsmc_image):
    low_qa = lsmc_image.select('SummaryQA').neq(0.2)
    return lsmc_image.updateMask(low_qa).copyProperties(lsmc_image, ['system:time_start'])


'''
Add image timestamp to each image in time series
'''
def add_timestamp(image):
    timeImage = image.metadata('system:time_start').rename('timestamp')
    timeImageMasked = timeImage.updateMask(image.mask().select(0))
    return image.addBands(timeImageMasked)


'''
Perform linear interpolation on missing values
'''
def performInterpolation(image):
    image = ee.Image(image)
    beforeImages = ee.List(image.get('before'))
    beforeMosaic = ee.ImageCollection.fromImages(beforeImages).mosaic()
    afterImages = ee.List(image.get('after'))
    afterMosaic = ee.ImageCollection.fromImages(afterImages).mosaic()

    # Interpolation formula
    # y = y1 + (y2-y1)*((t – t1) / (t2 – t1))
    # y = interpolated image
    # y1 = before image
    # y2 = after image
    # t = interpolation timestamp
    # t1 = before image timestamp
    # t2 = after image timestamp

    t1 = beforeMosaic.select('timestamp').rename('t1')
    t2 = afterMosaic.select('timestamp').rename('t2')
    t = image.metadata('system:time_start').rename('t')
    timeImage = ee.Image.cat([t1, t2, t])
    timeRatio = timeImage.expression('(t - t1) / (t2 - t1)', {
                    't': timeImage.select('t'),
                    't1': timeImage.select('t1'),
                    't2': timeImage.select('t2'),
                })

    interpolated = beforeMosaic.add((afterMosaic.subtract(beforeMosaic).multiply(timeRatio)))
    result = image.unmask(interpolated)
    fill_value = ee.ImageCollection([beforeMosaic, afterMosaic]).mosaic()
    result = result.unmask(fill_value)

    return result.copyProperties(image, ['system:time_start'])


def interpolate_timeseries(S1_TS):
    lsmc_masked = S1_TS.map(mask_low_QA)
    filtered = lsmc_masked.map(add_timestamp)

    # Time window in which we are willing to look forward and backward for unmasked pixel in time series
    timeWindow = 120

    # Define a maxDifference filter to find all images within the specified days. Convert days to milliseconds.
    millis = ee.Number(timeWindow).multiply(1000*60*60*24)
    # Filter says that pick only those timestamps which lie between the 2 timestamps not more than millis difference apart
    maxDiffFilter = ee.Filter.maxDifference(
                                difference = millis,
                                leftField = 'system:time_start',
                                rightField = 'system:time_start',
                                )

    # Filter to find all images after a given image. Compare the image's timstamp against other images.
    # Images ahead of target image should have higher timestamp.
    lessEqFilter = ee.Filter.lessThanOrEquals(
                                leftField = 'system:time_start',
                                rightField = 'system:time_start'
                            )

    # Similarly define this filter to find all images before a given image
    greaterEqFilter = ee.Filter.greaterThanOrEquals(
                                leftField = 'system:time_start',
                                rightField = 'system:time_start'
                            )

    # Apply first join to find all images that are after the target image but within the timeWindow
    filter1 = ee.Filter.And( maxDiffFilter, lessEqFilter )
    join1 = ee.Join.saveAll(
                    matchesKey = 'after',
                    ordering = 'system:time_start',
                    ascending = False
            )
    join1Result = join1.apply(
                    primary = filtered,
                    secondary = filtered,
                    condition = filter1
                    )

    # Apply first join to find all images that are after the target image but within the timeWindow
    filter2 = ee.Filter.And( maxDiffFilter, greaterEqFilter )
    join2 = ee.Join.saveAll(
                    matchesKey = 'before',
                    ordering = 'system:time_start',
                    ascending = True
            )
    join2Result = join2.apply(
                    primary = join1Result,
                    secondary = join1Result,
                    condition = filter2
                    )

    interpolated_S1_TS = ee.ImageCollection(join2Result.map(performInterpolation))

    return interpolated_S1_TS


'''
Function Definition to get Padded NDVI LSMC timeseries image for a given ROI
'''
def Get_Padded_NDVI_TS_Image(startDate, endDate, roi_boundary):
    L7, L8, S2 = Get_L7_L8_S2_ImageCollections(startDate, endDate, roi_boundary)

    harmonized_LS_ic = Harmonize_L7_L8_S2(L7, L8, S2)
    harmonized_LS_ic = harmonized_LS_ic.map(addNDVI)
    LSC = Get_LS_16Day_NDVI_TimeSeries(startDate, endDate, harmonized_LS_ic, roi_boundary)
    LSMC_NDVI = Combine_LS_Modis(LSC)
    Interpolated_LSMC_NDVI = interpolate_timeseries(LSMC_NDVI)
    final_LSMC_NDVI_TS = Interpolated_LSMC_NDVI.select(['gapfilled_NDVI_lsc']).toBands()
    final_LSMC_NDVI_TS = final_LSMC_NDVI_TS.clip(roi_boundary)

    input_bands = final_LSMC_NDVI_TS.bandNames()
    return final_LSMC_NDVI_TS, input_bands


'''
Function definition to compute euclidean distance to each cluster centroid
features ---> clusters
flattened ---> time series image clipped to target area
input_bands ---> band names for time series image
studyarea ---> geometry of region of interest
'''
# Function to get distances as required from each pixel to each cluster centroid
def Get_Euclidean_Distance(cluster_centroids, roi_timeseries_img, input_bands, roi_boundary):

    def wrapper(curr_centroid):
        temp_img = ee.Image()
        curr_centroid = ee.Feature(curr_centroid).setGeometry(roi_boundary)
        temp_fc = ee.FeatureCollection( [curr_centroid] )
        class_img = temp_fc.select(['class']).reduceToImage(['class'], ee.Reducer.first()).rename(['class'])
        def create_img(band_name):
            return temp_fc.select([band_name]).reduceToImage([band_name], ee.Reducer.first()).rename([band_name])

        temp_img = input_bands.map(create_img)
        empty = ee.Image()
        temp_img = ee.Image( temp_img.iterate( lambda img, prev: ee.Image(prev).addBands(img) , empty))

        temp_img = temp_img.select(temp_img.bandNames().remove('constant'))
        distance = roi_timeseries_img.spectralDistance(temp_img, 'sed')
        confidence = ee.Image(1.0).divide(distance).rename(['confidence'])
        distance = distance.addBands(confidence)
        return distance.addBands(class_img.rename(['class'])).set('class', curr_centroid.get('class'))

    return cluster_centroids.map(wrapper)


'''
Function definition to get final prediction image from distance images
'''
def Get_final_prediction_image(distance_imgs_list):
    # Denominator is an image that is sum of all confidences to each cluster
    sum_of_distances = ee.ImageCollection( distance_imgs_list ).select(['confidence']).sum().unmask()
    distance_imgs_ic = ee.ImageCollection( distance_imgs_list ).select(['distance','class'])

    # array is an image where distance band is an array of distances to each cluster centroid and class band is an array of classes associated with each cluster
    array_img = ee.ImageCollection(distance_imgs_ic).toArray()

    axes = {'image': 0, 'band':1}
    sort = array_img.arraySlice(axes['band'], 0, 1)
    sorted = array_img.arraySort(sort)

    # take the first image only
    values = sorted.arraySlice(axes['image'], 0, 1)
    # convert back to an image
    min = values.arrayProject([axes['band']]).arrayFlatten([['distance', 'class']])
    # Extract the hard classification
    predicted_class_img = min.select(1)
    predicted_class_img = predicted_class_img.rename(['predicted_label'])

    return predicted_class_img

## My Helper Functions
def change_clusters(cluster_centroids):
    size = cluster_centroids.size().getInfo()
    features = []
    for i in range(size):
        features.append(ee.Feature(cluster_centroids.toList(size).get(i)).set("class", 13+i))
    return ee.FeatureCollection(features)


def get_cropping_frequency(roi_boundary, startDate, endDate):
    cluster_centroids = ee.FeatureCollection('projects/ee-indiasat/assets/L3_LULC_Clusters/Final_Level3_PanIndia_Clusters')
    ignore_clusters = [12] # remove invalid clusters
    cluster_centroids = cluster_centroids.filter(ee.Filter.Not( ee.Filter.inList('class', ignore_clusters)))
    
    final_LSMC_NDVI_TS, input_bands =  Get_Padded_NDVI_TS_Image(startDate, endDate, roi_boundary)
    distance_imgs_list = Get_Euclidean_Distance(cluster_centroids, final_LSMC_NDVI_TS, input_bands, roi_boundary)
    final_classified_img = Get_final_prediction_image(distance_imgs_list)
    ### adding Cluster values after 12
    #cluster_centroids = change_clusters(cluster_centroids)
    distance_imgs_list = Get_Euclidean_Distance(cluster_centroids, final_LSMC_NDVI_TS, input_bands, roi_boundary)
    final_cluster_classified_img = Get_final_prediction_image(distance_imgs_list)
    final_cluster_classified_img = final_cluster_classified_img.rename(['predicted_cluster'])
    final_classified_img = final_classified_img.addBands(final_cluster_classified_img)
    return final_classified_img, final_LSMC_NDVI_TS

def get_six_year_cropping_frequency_rasters(roi_boundary, start_year, num_years=6):
    """
    Uses your existing get_cropping_frequency(roi_boundary, startDate, endDate)
    and produces a list of 6 yearly cropping-frequency images.

    IMPORTANT:
    - Does NOT rename or modify your existing functions/variables.
    - Each output image is clipped to roi_boundary (as your pipeline already does inside Get_Padded_NDVI_TS_Image).
    - Output list order: [Y1, Y2, ..., Y6] where
        Y1 = start_year-07-01 to start_year+1-06-30
        ...
        Y6 = start_year+5-07-01 to start_year+6-06-30

    Returns:
      crop_freq_imgs: Python list of ee.Image (each has band 'predicted_label' and 'predicted_cluster')
      date_ranges:    Python list of (startDate, endDate) strings for bookkeeping
    """
    crop_freq_imgs = []
    date_ranges = []

    for i in range(num_years):
        y1 = start_year + i
        y2 = y1 + 1

        currStartDate = f"{y1}-07-01"
        currEndDate   = f"{y2}-06-30"

        # Your function returns:
        #   final_classified_img: bands ['predicted_label', 'predicted_cluster']
        #   final_LSMC_NDVI_TS:   NDVI time-series (not needed here)
        cropping_frequency_img, _ = get_cropping_frequency(roi_boundary, currStartDate, currEndDate)

        # Keep as-is (same variable names, same band names)
        crop_freq_imgs.append(cropping_frequency_img.select(['predicted_label',]))
        date_ranges.append((currStartDate, currEndDate))

    return crop_freq_imgs, date_ranges


In [3]:
# --------- MERGED CLASS CONSTANTS & HELPERS (if not already defined) ---------
ORIG  = [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13]
GROUP = [0, 1, 2, 2, 2, 6, 7, 8, 8,  8,  8,  12, 13]

WATER_SET = ee.List([2, 3, 4])          # original water subclasses
CROP_SET  = ee.List([8, 9, 10, 11])     # original crop subclasses

def _is_in_list(img, vals):
    # vals is ee.List of ints
    return ee.Image(img).remap(vals, ee.List.repeat(1, vals.size()), 0).eq(1)

def _group_from_orig(orig_img):
    # orig_img: ee.Image single band predicted_label
    return orig_img.remap(ORIG, GROUP).toInt()

def _safe_water_source(orig_img):
    """
    Ensure output is a valid water subclass. If orig_img isn't water (2/3/4), fall back to 2.
    """
    isw = _is_in_list(orig_img, WATER_SET)
    return ee.Image(orig_img).where(isw.Not(), 2).toInt()


def reinstate_merged_classes_year6(
    original6_img,          # ee.Image, original year-6 LULC (fine classes), band 'predicted_label'
    grouped_y6c,            # ee.Image, grouped corrected year-6 band (y6c, values in {0,1,2,6,7,8,12,13})
    original5_img,          # ee.Image, corrected year-5 LULC (fine classes), band 'predicted_label'
    crop_intensity6_img=None  # ee.Image for crop frequency/intensity, band 'predicted_label'
):
    """
    Reinstates water subclasses (2,3,4) and crop subclasses (8,9,10,11) for the
    corrected year-6 grouped labels.

    Logic:
      - Non-merged classes (1,6,7,12,13) are written directly from grouped_y6c.
      - If grouped_y6c == Water(2):
            * If original6 already water (2/3/4), keep it.
            * Else, copy water subclass from year-5 (original5_img).
              If year-5 isn't water, fall back to generic 2.
      - If grouped_y6c == Crop(8):
            * If original6 already crop (8/9/10/11), keep it.
            * Else, use crop_intensity6_img (recomputed cropping frequency for year 6).
    """

    # Fine labels for year 6 and year 5
    orig6 = ee.Image(original6_img).select('predicted_label').toInt()
    orig5 = ee.Image(original5_img).select('predicted_label').toInt()

    # Grouped originals and grouped corrected
    orig_grp6 = _group_from_orig(orig6)
    corr_grp6 = ee.Image(grouped_y6c).toInt()

    # Start with original year-6 labels
    final6 = orig6

    # Flags for corrected water/crop in grouped space
    isWaterCorr = corr_grp6.eq(2)
    isCropCorr  = corr_grp6.eq(8)

    # ------------- (A) Non-merged classes -------------
    # Non-merged: {Built-up(1), Tree(6), Barren(7), Scrub(12), Plantation(13)}
    nonMerged = isWaterCorr.Or(isCropCorr).Not()
    final6 = final6.where(
        nonMerged.And(corr_grp6.neq(orig_grp6)),
        corr_grp6
    )

    # ------------- (B) Crop re-introduction -------------
    # Here we use your recomputed cropping frequency for year 6
    if crop_intensity6_img is not None:
        crop_pred6 = ee.Image(crop_intensity6_img).select('predicted_label').toInt()
    else:
        crop_pred6 = ee.Image(8).toInt()  # fallback

    orig_is_crop = _is_in_list(orig6, CROP_SET)
    crop_fill = orig6.where(orig_is_crop.Not(), crop_pred6).toInt()
    final6 = final6.where(isCropCorr, crop_fill)

    # ------------- (C) Water re-introduction -------------
    orig_is_water = _is_in_list(orig6, WATER_SET)

    # For singlet ABA at year 6, use year 5 as the temporal reference
    prev1  = orig5
    prev1w = _safe_water_source(prev1)

    water_source = prev1w
    water_fill = orig6.where(orig_is_water.Not(), water_source).toInt()
    final6 = final6.where(isWaterCorr, water_fill)

    return final6.rename('y6f')


In [4]:
# ----- Year indices for your 6-year window -----
y1 = start_year          # e.g., 2017
y6 = start_year + 5      # sixth year
y7 = y6 + 1              # next year (for date window)

start6 = f"{y6}-07-01"
end6   = f"{y7}-06-30"

# 1. Recompute cropping frequency for year 6 (as you do for earlier years)
crop_freq6_img, _ = get_cropping_frequency(
    roi_boundary=roi_boundary,
    startDate=start6,
    endDate=end6
)
# crop_freq6_img has band 'predicted_label' (cluster-based cropping frequency class)

# 2. Reinstate merged classes for year 6, using:
#    - img6_orig: original fine LULC for year 6
#    - corr5_pl:  corrected fine LULC for year 5
#    - seq_last3_corrected: image with band 'y6c' (grouped corrected year 6)
final_y6 = reinstate_merged_classes_year6(
    original6_img=img6_orig,
    grouped_y6c=seq_last3_corrected.select('y6c'),
    original5_img=corr5_pl,
    crop_intensity6_img=crop_freq6_img   # this is your recomputed cropping intensity for year 6
)

# (Optional) visualize
pallete_lulc = [
  '000000','ff0000','74ccf4','1ca3ec','0f5e9c',
  'f1c232','38761d','A9A9A9','BAD93E','f59d22',
  'FF9371','b3561d','a9a9a9','84994F'
]
vis_params_lulc = {'min': 0, 'max': 13, 'palette': pallete_lulc}

Map.addLayer(final_y6, vis_params_lulc, 'year 6 final (un-grouped with new crop freq)')
Map


Map(bottom=912.0, center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='top…

In [5]:
# --- PARAMETERS ---
project    = 'raman-461708'
scale      = 10
maxPixels  = 1e13

# You already have:
#   aez
#   start_year
#   roi_boundary
#   final_y6  (image with band 'y6f')

# Compute the year-span for the 6th year
y6 = start_year + 5      # sixth year start
y7 = y6 + 1              # sixth year end

date_tag = f"{y6}-07-01_{y7}-06-30"

# EXACT naming convention you provided
asset_id = f"projects/{project}/assets/AEZ_{aez}_{date_tag}_temporal_corrected"

# Task description
desc = f"AEZ_{aez}_{y6}_{y7}_temporal_corrected"

# Start export
task = ee.batch.Export.image.toAsset(
    image     = final_y6.clip(roi_boundary),  # should contain band 'y6f'
    description = desc,
    assetId   = asset_id,
    region    = roi_boundary,
    scale     = scale,
    maxPixels = maxPixels
)

task.start()
print(f"Started export: {asset_id}")


Started export: projects/raman-461708/assets/AEZ_1_2022-07-01_2023-06-30_temporal_corrected
