In [1]:
import ee

ee.Authenticate()
ee.Initialize(project='raman-461708')


# -------------------------------
# 1. LOADING
# -------------------------------

def load_aez_temporal_images(aez, start_year, num_years,roi_boundary, project='raman-461708'):
    base = f"projects/{project}/assets/LULC_v4_PanIndia_"
    images = []
    for i in range(num_years):
        y1 = start_year + i
        y2 = y1 + 1
        asset = f"{base}{y1}-07-01_{y2}-06-30"
        img = ee.Image(asset)
        images.append(img.clip(roi_boundary))
    return images

aez = 19
start_year = 2017
num_years = 6
project = 'raman-461708'

roi_boundary = ee.FeatureCollection("users/mtpictd/agro_eco_regions") \
    .filter(ee.Filter.eq("ae_regcode", aez)).geometry()
#roi = ee.FeatureCollection("users/mtpictd/agro_eco_regions")

# 1. Load
images = load_aez_temporal_images(aez, start_year, num_years, roi_boundary, project)


# -----------------------------
# 0) Grouping (as you requested)
# -----------------------------
ORIG =  [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13]
GROUP = [0, 1, 2, 2, 2, 6, 7, 8, 8,  8,  8,  12, 13]

BAND = 'predicted_label'

def group_classes(img, band=BAND):
    g = img.select([band]).remap(ORIG, GROUP).rename('predicted_label')
    return g


# ============================================================
# 0.5) ADMISSIBILITY MATRIX (apply to detection)
# ============================================================
# grouped classes used in the matrix (exclude background 0)
# Order: [Built-up, Water, Tree, Barren, Crop, Scrub, Plantation]
MAT_CLASSES = [1, 2, 6, 7, 8, 12, 13]

# A \ B (rows = target A, cols = current B): 1=allowed, 0=not allowed
#            Bu Wa Tr Ba Cr Sc Pl
ALLOW_TABLE = [
    # A=Built-up
               0, 0, 0, 1, 1, 1, 1,
    # A=Water
               0, 0, 1, 0, 0, 0, 0,
    # A=Tree/Forest
               1, 1, 0, 1, 1, 1, 1,
    # A=Barren
               1, 0, 1, 0, 1, 1, 1,
    # A=Crop
               0, 0, 1, 1, 0, 1, 0,
    # A=Scrub
               0, 0, 1, 1, 1, 0, 1,
    # A=Plantation
               1, 0, 1, 1, 0, 1, 0
]

# Make a 7x7 ee.Array first, then a constant image from it
ALLOW_TABLE_ARR = ee.Array(ALLOW_TABLE).reshape([7, 7])
ALLOW_TABLE_IMG = ee.Image.constant(ALLOW_TABLE_ARR)


def _class_to_index(x):
    # map grouped class -> 0..6 index, else -1
    return ee.Image(x).remap(MAT_CLASSES, list(range(7)), -1).toInt()

def allow_flip(A, B):
    """
    Returns 1 where flipping B -> A is allowed by matrix, else 0.
    A and B are grouped class images.
    """
    a = _class_to_index(A)
    b = _class_to_index(B)
    valid = a.gte(0).And(b.gte(0))
    val = ALLOW_TABLE_IMG.arrayGet([a, b]).toInt()
    return val.updateMask(valid).unmask(0).rename('allowFlip')


# ----------------------------------------------------------
# 1) Stack 6 yearly layers into one 6-band image: y1..y6
# ----------------------------------------------------------
def stack_6_years(year_imgs):
    if len(year_imgs) != 6:
        raise ValueError("Provide exactly 6 yearly images.")

    bands = []
    for i, im in enumerate(year_imgs, start=1):
        bands.append(im.select(['predicted_label']).rename(f'y{i}'))
    return ee.Image.cat(bands)


# ----------------------------------------------------------
# 2) Helpers
# ----------------------------------------------------------
def _count_ones_nbit(mask, nbits):
    bits = [mask.rightShift(k).bitwiseAnd(1) for k in range(nbits)]
    return ee.Image.cat(bits).reduce(ee.Reducer.sum()).rename('singletCount').toInt()

def _has_adjacent_ones(mask):
    return mask.bitwiseAnd(mask.leftShift(1)).neq(0).rename('hasAdj11')

def _has_run_k(mask, nbits, k):
    base = (1 << k) - 1
    cond = ee.Image(0)
    for start in range(0, nbits - k + 1):
        pat = ee.Image(base << start).toInt()
        cond = cond.Or(mask.bitwiseAnd(pat).eq(pat))
    return cond

def _max_run_length(mask, nbits):
    any1 = mask.neq(0)
    maxRun = ee.Image(0).where(any1, 1)
    for k in range(2, nbits + 1):
        maxRun = maxRun.where(_has_run_k(mask, nbits, k), k)
    return maxRun.rename('singletMaxRun').toInt()


# ----------------------------------------------------------
# 3) Singlet categories WITH MATRIX constraint
#    Only count an ABA center if flipping mid->left is allowed.
# ----------------------------------------------------------
def singlet_category_6yr(seq6):
    y = [seq6.select(f'y{i}') for i in range(1, 7)]  # y1..y6

    centers = []
    for idx in range(1, 5):  # centers y2..y5
        left  = y[idx - 1]
        mid   = y[idx]
        right = y[idx + 1]

        raw = left.eq(right).And(mid.neq(left))  # ABA local condition

        # matrix constraint: only keep this center if flip mid -> left is allowed
        adm = allow_flip(left, mid).eq(1)

        s = raw.And(adm).rename(f's{idx+1}').toUint8()  # s2..s5 (admissible)
        centers.append(s)

    # Pack s2..s5 into 4-bit mask (s2 MSB bit3, s5 LSB bit0)
    mask = ee.Image(0).toInt()
    for j, s in enumerate(centers):  # j=0..3 => s2..s5
        mask = mask.bitwiseOr(s.toInt().leftShift(3 - j))
    singletSeqMask = mask.rename('singletSeqMask')

    singletCount = _count_ones_nbit(singletSeqMask, 4).rename('singletCount')
    anySinglet = singletSeqMask.neq(0)

    adj11 = _has_adjacent_ones(singletSeqMask)
    maxRun = _max_run_length(singletSeqMask, 4).rename('singletMaxRun')

    bits_in_adj = singletSeqMask.bitwiseAnd(singletSeqMask.leftShift(1))
    bits_in_adj_covered = bits_in_adj.bitwiseOr(bits_in_adj.rightShift(1))
    isolatedMask = singletSeqMask.bitwiseAnd(bits_in_adj_covered.bitwiseNot()).rename('isolatedSingletMask')

    hasIsolated = isolatedMask.neq(0)
    isMax2 = maxRun.eq(2)

    composite = anySinglet.And(isMax2).And(adj11).And(hasIsolated)
    disjointSingles = anySinglet.And(adj11.Not())

    cat = ee.Image(0).toInt() \
        .where(disjointSingles, 1) \
        .where(isMax2.And(disjointSingles.Not()).And(composite.Not()), 2) \
        .where(maxRun.eq(3), 3) \
        .where(maxRun.eq(4), 4) \
        .where(composite, 6) \
        .where(anySinglet.Not(), 0) \
        .rename('singletCategory')

    return ee.Image.cat([
        cat,
        singletSeqMask,
        singletCount,
        adj11,
        maxRun,
        isolatedMask.toInt(),
        ee.Image.cat(centers)  # s2..s5 (admissible)
    ])


# ----------------------------------------------------------
# 4) Doublet patterns WITH MATRIX constraint
#    Keep doublet only if BOTH X->A and Y->A are allowed.
# ----------------------------------------------------------
def detect_doublet_6yr(seq6):
    y1 = seq6.select('y1')
    y2 = seq6.select('y2')
    y3 = seq6.select('y3')
    y4 = seq6.select('y4')
    y5 = seq6.select('y5')
    y6 = seq6.select('y6')

    def xy_constraints(A, X, Y):
        return X.neq(Y).And(X.neq(A)).And(Y.neq(A))

    # (1) A X Y A A *
    A = y1
    X = y2
    Y = y3
    p1_raw = y4.eq(A).And(y5.eq(A)).And(xy_constraints(A, X, Y))
    p1 = p1_raw.And(allow_flip(A, X).eq(1)).And(allow_flip(A, Y).eq(1))

    # (2) A A X Y A A
    A = y1
    X = y3
    Y = y4
    p2_raw = y2.eq(A).And(y5.eq(A)).And(y6.eq(A)).And(xy_constraints(A, X, Y))
    p2 = p2_raw.And(allow_flip(A, X).eq(1)).And(allow_flip(A, Y).eq(1))

    # (3) * A A X Y A
    A = y2
    X = y4
    Y = y5
    p3_raw = y3.eq(A).And(y6.eq(A)).And(xy_constraints(A, X, Y))
    p3 = p3_raw.And(allow_flip(A, X).eq(1)).And(allow_flip(A, Y).eq(1))

    pat = ee.Image(0).toInt() \
        .where(p1, 1) \
        .where(p2, 2) \
        .where(p3, 3) \
        .rename('doubletPattern')

    anyDoublet = p1.Or(p2).Or(p3).rename('hasDoublet')
    return pat.addBands(anyDoublet)


# ----------------------------------------------------------
# 5) Main wrapper
# ----------------------------------------------------------
def classify_temporal_flip_categories_6yr(year_imgs, roi):
    grouped = [group_classes(im, band=BAND) for im in year_imgs]
    seq6 = stack_6_years(grouped)

    roi_mask = ee.Image.constant(1).clip(roi).selfMask()

    sing = singlet_category_6yr(seq6).updateMask(roi_mask)
    dbl  = detect_doublet_6yr(seq6).updateMask(roi_mask)

    return seq6.addBands(sing).addBands(dbl)


# -----------------------------
# Visualization
# -----------------------------
singlet_vis = {
    'min': 0,
    'max': 6,
    'palette': [
        '000000',  # 0 - No flips
        '2ecc71',  # 1 - Disjoint singles
        'f1c40f',  # 2 - Run=2
        'e67e22',  # 3 - Run=3
        'e74c3c',  # 4 - Run=4
        '8e44ad',  # 5 - Run=5 (won't happen in 6yr)
        '34495e'   # 6 - Composite
    ]
}

doublet_vis_6yr = {
    'min': 0,
    'max': 3,
    'palette': [
        '000000',  # 0 none
        '00bcd4',  # 1 pattern: A X Y A A *
        'ff9800',  # 2 pattern: A A X Y A A
        'e91e63'   # 3 pattern: * A A X Y A
    ]
}

pallete_lulc = [
  '000000','ff0000','74ccf4','1ca3ec','0f5e9c',
  'f1c232','38761d','A9A9A9','BAD93E','f59d22',
  'FF9371','b3561d','a9a9a9','84994F'
]

vis_params_lulc = {'min': 0, 'max': 13, 'palette': pallete_lulc}


# -----------------------------
# Run
# -----------------------------
out = classify_temporal_flip_categories_6yr(images, roi_boundary)

import geemap
Map = geemap.Map()
url = 'https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}'
Map.layout.height = '1000px'
Map.add_tile_layer(url, name="Google Map", attribution="Google")

Map.addLayer(out.select('singletCategory'), singlet_vis, 'singletCategory (matrix-filtered)')
Map.addLayer(out.select('doubletPattern'), doublet_vis_6yr, 'doubletPattern (matrix-filtered)')
Map.addLayer(out.select('y1'), vis_params_lulc, 'y1')
Map.addLayer(out.select('y2'), vis_params_lulc, 'y2')
Map.addLayer(out.select('y3'), vis_params_lulc, 'y3')
Map.addLayer(out.select('y4'), vis_params_lulc, 'y4')
Map.addLayer(out.select('y5'), vis_params_lulc, 'y5')
Map.addLayer(out.select('y6'), vis_params_lulc, 'y6')

Map


Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='topright', transp…

In [2]:
# ============================================================
# 6) APPLY TEMPORAL CORRECTIONS (6-year) USING YOUR RULES
# ============================================================

def _eq(img, val):
    return img.eq(val)

def _count_eq_6(y_list, val_img):
    # count over y1..y6 how many equal val_img (pixel-wise)
    c = ee.Image(0)
    for yi in y_list:
        c = c.add(yi.eq(val_img))
    return c.toInt()

def _apply_where_band(seq6, band_name, cond, new_val):
    """Return seq6 with one band updated where cond is true."""
    b = seq6.select(band_name)
    b2 = b.where(cond, new_val).rename(band_name)
    others = seq6.select(seq6.bandNames().remove(band_name))
    return ee.Image.cat([others, b2]).select(seq6.bandNames())

def _apply_flip_to_A(seq6, cond, A):
    """Flip pixels under cond to class A for all 6 years where values != A."""
    out = seq6
    for i in range(1, 7):
        bi = f'y{i}'
        yi = out.select(bi)
        out = _apply_where_band(out, bi, cond.And(yi.neq(A)), A)
    return out

def _apply_flip_B_to_A_in_mask(seq6, cond, A, B, years):
    """
    In given list of years (e.g., [1,2,3,4]), flip B -> A where cond holds.
    Does NOT touch A values; does NOT touch other classes outside {A,B}.
    """
    out = seq6
    for t in years:
        bi = f'y{t}'
        yi = out.select(bi)
        out = _apply_where_band(out, bi, cond.And(yi.eq(B)), A)
    return out


def _unpack_centers_from_mask_4bit(mask4):
    """
    mask4 corresponds to [s2 s3 s4 s5] in bits [3 2 1 0]
    Return images s2..s5 (0/1).
    """
    s2 = mask4.rightShift(3).bitwiseAnd(1).rename('s2')
    s3 = mask4.rightShift(2).bitwiseAnd(1).rename('s3')
    s4 = mask4.rightShift(1).bitwiseAnd(1).rename('s4')
    s5 = mask4.rightShift(0).bitwiseAnd(1).rename('s5')
    return s2, s3, s4, s5


def correct_singlets_6yr(seq6, sing):
    """
    Apply your singlet correction rules by singletCategory, using global evidence.
    Expects 'sing' output from singlet_category_6yr(seq6) (already matrix filtered).
    """
    y = [seq6.select(f'y{i}') for i in range(1, 7)]
    cat = sing.select('singletCategory')
    mask4 = sing.select('singletSeqMask').toInt()           # 4-bit s2..s5
    iso4  = sing.select('isolatedSingletMask').toInt()      # subset of mask4 (same bit layout)
    s2, s3, s4, s5 = _unpack_centers_from_mask_4bit(mask4)
    i2, i3, i4, i5 = _unpack_centers_from_mask_4bit(iso4)

    out = seq6

    # -------------------------
    # CASE 1: disjoint singles
    # Flip each isolated center year to its surrounding A (left==right)
    # centers are y2..y5 (s2..s5)
    # -------------------------
    case1 = cat.eq(1)

    # For each center t, set y[t] = y[t-1] (which equals y[t+1] by ABA)
    # Only apply where the center is active AND case1
    out = _apply_where_band(out, 'y2', case1.And(s2.eq(1)), out.select('y1'))
    out = _apply_where_band(out, 'y3', case1.And(s3.eq(1)), out.select('y2'))  # note: uses updated y2 in out
    out = _apply_where_band(out, 'y4', case1.And(s4.eq(1)), out.select('y3'))
    out = _apply_where_band(out, 'y5', case1.And(s5.eq(1)), out.select('y4'))

    # -------------------------
    # CASE 3: run length 3 (ABABA*)
    # Winner is A in that segment; flip Bs -> A inside the 5-year segment.
    # Possible placements of 3 consecutive centers:
    #  - (s2,s3,s4)=1  => segment years 1..5 (A=y1, B=y2)
    #  - (s3,s4,s5)=1  => segment years 2..6 (A=y2, B=y3)
    # -------------------------
    case3 = cat.eq(3)
    run3_123 = case3.And(s2.eq(1)).And(s3.eq(1)).And(s4.eq(1))   # centers 2,3,4
    run3_234 = case3.And(s3.eq(1)).And(s4.eq(1)).And(s5.eq(1))   # centers 3,4,5

    # placement 1: years 1..5, A=y1, B=y2
    A = out.select('y1')
    B = out.select('y2')
    out = _apply_flip_B_to_A_in_mask(out, run3_123, A, B, years=[1,2,3,4,5])

    # placement 2: years 2..6, A=y2, B=y3
    A = out.select('y2')
    B = out.select('y3')
    out = _apply_flip_B_to_A_in_mask(out, run3_234, A, B, years=[2,3,4,5,6])

    # -------------------------
    # CASE 4: run length 4 (ABABAB)
    # Tie -> choose A and flip all B -> A across years 1..6 (A=y1, B=y2)
    # This only happens when all s2..s5 are 1.
    # -------------------------
    case4 = cat.eq(4)
    run4 = case4.And(s2.eq(1)).And(s3.eq(1)).And(s4.eq(1)).And(s5.eq(1))
    A = out.select('y1')
    B = out.select('y2')
    out = _apply_flip_B_to_A_in_mask(out, run4, A, B, years=[1,2,3,4,5,6])

    # -------------------------
    # CASE 2: run length 2 (ABAB**)
    # Use full-sequence majority between A and B; tie -> choose A.
    #
    # Possible placements:
    #  - (s2,s3)=1 => ABAB in years 1..4, A=y1, B=y2
    #  - (s3,s4)=1 => ABAB in years 2..5, A=y2, B=y3
    #  - (s4,s5)=1 => ABAB in years 3..6, A=y3, B=y4
    # -------------------------
    case2 = cat.eq(2)
    run2_12 = case2.And(s2.eq(1)).And(s3.eq(1)).And(s4.eq(0)).And(s5.eq(0))
    run2_23 = case2.And(s3.eq(1)).And(s4.eq(1)).And(s2.eq(0)).And(s5.eq(0))
    run2_34 = case2.And(s4.eq(1)).And(s5.eq(1)).And(s2.eq(0)).And(s3.eq(0))

    def _resolve_run2(out_img, cond, A_band, B_band):
        A = out_img.select(A_band)
        B = out_img.select(B_band)

        # global counts across 6 years
        ylist = [out_img.select(f'y{i}') for i in range(1, 7)]
        cA = _count_eq_6(ylist, A)
        cB = _count_eq_6(ylist, B)

        A_wins = cA.gt(cB)
        B_wins = cB.gt(cA)
        tie    = cA.eq(cB)

        # if strict majority: flip the minority to majority (across all 6 years)
        out2 = out_img
        out2 = _apply_flip_B_to_A_in_mask(out2, cond.And(A_wins), A, B, years=[1,2,3,4,5,6])
        out2 = _apply_flip_B_to_A_in_mask(out2, cond.And(B_wins), B, A, years=[1,2,3,4,5,6])

        # if tie: choose A and flip B->A
        out2 = _apply_flip_B_to_A_in_mask(out2, cond.And(tie), A, B, years=[1,2,3,4,5,6])
        return out2

    out = _resolve_run2(out, run2_12, 'y1', 'y2')
    out = _resolve_run2(out, run2_23, 'y2', 'y3')
    out = _resolve_run2(out, run2_34, 'y3', 'y4')

    # -------------------------
    # CASE 6: composite
    # Strategy: fix isolated (odd/disjoint) first, then recompute singlets and fix even (run2) next.
    # -------------------------
    case6 = cat.eq(6)

    # Pass-1: apply isolated centers only (from isolatedSingletMask bits i2..i5)
    tmp = out
    tmp = _apply_where_band(tmp, 'y2', case6.And(i2.eq(1)), tmp.select('y1'))
    tmp = _apply_where_band(tmp, 'y3', case6.And(i3.eq(1)), tmp.select('y2'))
    tmp = _apply_where_band(tmp, 'y4', case6.And(i4.eq(1)), tmp.select('y3'))
    tmp = _apply_where_band(tmp, 'y5', case6.And(i5.eq(1)), tmp.select('y4'))

    # Recompute singlet info on updated sequence, then apply run2 logic again but only where case6
    sing2 = singlet_category_6yr(tmp)
    cat2  = sing2.select('singletCategory')
    mask4b = sing2.select('singletSeqMask').toInt()
    t2, t3, t4, t5 = _unpack_centers_from_mask_4bit(mask4b)

    # Only handle the run2 part in pass-2 (category 2 equivalent) under original case6 pixels
    case6_run2 = case6.And(cat2.eq(2))

    run2b_12 = case6_run2.And(t2.eq(1)).And(t3.eq(1)).And(t4.eq(0)).And(t5.eq(0))
    run2b_23 = case6_run2.And(t3.eq(1)).And(t4.eq(1)).And(t2.eq(0)).And(t5.eq(0))
    run2b_34 = case6_run2.And(t4.eq(1)).And(t5.eq(1)).And(t2.eq(0)).And(t3.eq(0))

    tmp = _resolve_run2(tmp, run2b_12, 'y1', 'y2')
    tmp = _resolve_run2(tmp, run2b_23, 'y2', 'y3')
    tmp = _resolve_run2(tmp, run2b_34, 'y3', 'y4')

    # Merge composite-corrected pixels back into out
    # (only replace where case6; otherwise keep out)
    out = out.where(case6, tmp)

    return out


def correct_doublets_6yr(seq6, dbl):
    """
    Apply doublet corrections:
      pattern 1: A X Y A A *  (A=y1) => y2,y3 -> A
      pattern 2: A A X Y A A  (A=y1) => y3,y4 -> A
      pattern 3: * A A X Y A  (A=y2) => y4,y5 -> A
    Expects 'dbl' from detect_doublet_6yr(seq6) (already matrix filtered).
    """
    pat = dbl.select('doubletPattern').toInt()
    has = dbl.select('hasDoublet').eq(1)

    out = seq6

    # pat==1
    cond1 = has.And(pat.eq(1))
    A1 = out.select('y1')
    out = _apply_where_band(out, 'y2', cond1, A1)
    out = _apply_where_band(out, 'y3', cond1, A1)

    # pat==2
    cond2 = has.And(pat.eq(2))
    A2 = out.select('y1')
    out = _apply_where_band(out, 'y3', cond2, A2)
    out = _apply_where_band(out, 'y4', cond2, A2)

    # pat==3
    cond3 = has.And(pat.eq(3))
    A3 = out.select('y2')
    out = _apply_where_band(out, 'y4', cond3, A3)
    out = _apply_where_band(out, 'y5', cond3, A3)

    return out


def apply_temporal_corrections_6yr(out_img):
    """
    out_img is output of classify_temporal_flip_categories_6yr(images)
    containing y1..y6 + singlet/doublet diagnostic bands.
    Returns corrected sequence y1..y6 plus diagnostics.
    """
    seq6 = out_img.select(['y1','y2','y3','y4','y5','y6'])
    sing = out_img.select([
        'singletCategory','singletSeqMask','isolatedSingletMask'
    ])
    dbl  = out_img.select(['doubletPattern','hasDoublet'])

    # Apply singlets first (your pipeline), then doublets
    seq_after_sing = correct_singlets_6yr(seq6, sing)
    seq_after_both = correct_doublets_6yr(seq_after_sing, dbl)

    return out_img.addBands(seq_after_both.rename(['y1c','y2c','y3c','y4c','y5c','y6c']), overwrite=True)


In [3]:
corrected = apply_temporal_corrections_6yr(out)

# visualize corrected sequence
Map.addLayer(corrected.select('y1c'), vis_params_lulc, 'y1 corrected')
Map.addLayer(corrected.select('y2c'), vis_params_lulc, 'y2 corrected')
Map.addLayer(corrected.select('y3c'), vis_params_lulc, 'y3 corrected')
Map.addLayer(corrected.select('y4c'), vis_params_lulc, 'y4 corrected')
Map.addLayer(corrected.select('y5c'), vis_params_lulc, 'y5 corrected')
Map.addLayer(corrected.select('y6c'), vis_params_lulc, 'y6 corrected')
Map

Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='topright', transp…

In [4]:

from datetime import datetime, timedelta
import pandas as pd
from dateutil.relativedelta import relativedelta
chastainBandNames = ['BLUE', 'GREEN', 'RED', 'NIR', 'SWIR1', 'SWIR2']

# Regression model parameters from Table-4. MSI TOA reflectance as a function of OLI TOA reflectance.
msiOLISlopes = [1.0946,1.0043,1.0524,0.8954,1.0049,1.0002]
msiOLIIntercepts = [-0.0107,0.0026,-0.0015,0.0033,0.0065,0.0046]

# Regression model parameters from Table-5. MSI TOA reflectance as a function of ETM+ TOA reflectance.
msiETMSlopes = [1.10601,0.99091,1.05681,1.0045,1.03611,1.04011]
msiETMIntercepts = [-0.0139,0.00411,-0.0024,-0.0076,0.00411,0.00861]

# Regression model parameters from Table-6. OLI TOA reflectance as a function of ETM+ TOA reflectance.
oliETMSlopes =[1.03501,1.00921,1.01991,1.14061,1.04351,1.05271];
oliETMIntercepts = [-0.0055,-0.0008,-0.0021,-0.0163,-0.0045,0.00261]

# Construct dictionary to handle all pairwise combos
chastainCoeffDict = { 'MSI_OLI':[msiOLISlopes,msiOLIIntercepts,1], # check what the third item corresponds to
                    'MSI_ETM':[msiETMSlopes,msiETMIntercepts,1],
                    'OLI_ETM':[oliETMSlopes,oliETMIntercepts,1],
                    'OLI_MSI':[msiOLISlopes,msiOLIIntercepts,0],
                    'ETM_MSI':[msiETMSlopes,msiETMIntercepts,0],
                    'ETM_OLI':[oliETMSlopes,oliETMIntercepts,0]
                    }

'''
Function to mask cloudy pixels in Landsat-7
'''
def maskL7cloud(image):
    qa = image.select('QA_PIXEL')
    mask = qa.bitwiseAnd(1 << 4).eq(0)
    return image.updateMask(mask).select(['B1', 'B2', 'B3' , 'B4' , 'B5' , 'B7']).rename('BLUE', 'GREEN', 'RED' , 'NIR' , 'SWIR1' , 'SWIR2')


'''
Function to mask cloudy pixels in Landsat-8
'''
def maskL8cloud(image):
    qa = image.select('QA_PIXEL')
    mask = qa.bitwiseAnd(1 << 4).eq(0)
    return image.updateMask(mask).select(['B2', 'B3', 'B4' , 'B5' , 'B6' , 'B7']).rename('BLUE', 'GREEN', 'RED' , 'NIR' , 'SWIR1' , 'SWIR2')


'''
Function to mask clouds using the quality band of Sentinel-2 TOA
'''
def maskS2cloudTOA(image):
    qa = image.select('QA60')
    # Bits 10 and 11 are clouds and cirrus, respectively.
    cloudBitMask = 1 << 10
    cirrusBitMask = 1 << 11
    # Both flags should be set to zero, indicating clear conditions.
    mask = qa.bitwiseAnd(cloudBitMask).eq(0).And(qa.bitwiseAnd(cirrusBitMask).eq(0));
    return image.updateMask(mask).select(['B2', 'B3', 'B4', 'B8',  'B11', 'B12']).rename(['BLUE', 'GREEN', 'RED', 'NIR', 'SWIR1', 'SWIR2'])


'''
Get Landsat and Sentinel image collections
'''
def Get_L7_L8_S2_ImageCollections(inputStartDate, inputEndDate, roi_boundary):
    # ------ Landsat 7 TOA
    L7 = ee.ImageCollection('LANDSAT/LE07/C02/T1_TOA') \
            .filterDate(inputStartDate, inputEndDate) \
            .filterBounds(roi_boundary) \
            .map(maskL7cloud)
    # print('\n Original Landsat 7 TOA dataset: \n',L7.limit(1).getInfo())
    # print('Number of images in Landsat 7 TOA dataset: \t',L7.size().getInfo())

    # ------ Landsat 8 TOA
    L8 = ee.ImageCollection('LANDSAT/LC08/C02/T1_TOA') \
            .filterDate(inputStartDate, inputEndDate) \
            .filterBounds(roi_boundary) \
            .map(maskL8cloud)
    # print('\n Original Landsat 8 TOA dataset: \n', L8.limit(1).getInfo())
    # print('Number of images in Landsat 8 TOA dataset: \t',L8.size().getInfo())

    # ------ Sentinel-2 TOA
    S2 = ee.ImageCollection('COPERNICUS/S2_HARMONIZED') \
            .filterDate(inputStartDate, inputEndDate) \
            .filterBounds(roi_boundary)  \
            .map(maskS2cloudTOA)
    # print('\n Original Sentinel-2 TOA dataset: \n',S2.limit(1).getInfo())
    # print('Number of images in Sentinel 2 TOA dataset: \t',S2.size().getInfo())

    return L7, L8, S2


'''
Function to apply model in one direction
'''
def dir0Regression(img,slopes,intercepts):
    return img.select(chastainBandNames).multiply(slopes).add(intercepts)


'''
Applying the model in the opposite direction
'''
def dir1Regression(img,slopes,intercepts):
    return img.select(chastainBandNames).subtract(intercepts).divide(slopes)


'''
Function to correct one sensor to another
'''
def harmonizationChastain(img, fromSensor,toSensor):
    # Get the model for the given from and to sensor
    comboKey = fromSensor.upper() + '_' + toSensor.upper()
    coeffList = chastainCoeffDict[comboKey]
    slopes = coeffList[0]
    intercepts = coeffList[1]
    direction = ee.Number(coeffList[2])

    # Apply the model in the respective direction
    out = ee.Algorithms.If(direction.eq(0),dir0Regression(img,slopes,intercepts),dir1Regression(img,slopes,intercepts))
    return ee.Image(out).copyProperties(img).copyProperties(img,['system:time_start'])


'''
Calibrate Landsat-8 (OLI) and Sentinel-2 (MSI) to Landsat-7 (ETM+)
'''
def Harmonize_L7_L8_S2(L7, L8, S2):
    # harmonization
    harmonized_L8 = L8.map( lambda img: harmonizationChastain(img, 'OLI','ETM') )
    harmonized_S2 = S2.map( lambda img: harmonizationChastain(img, 'MSI','ETM') )

    # Merge harmonized landsat-8 and sentinel-2 to landsat-7 image collection
    harmonized_LandsatSentinel_ic = ee.ImageCollection(L7.merge(harmonized_L8).merge(harmonized_S2))
    # print(harmonized_LandsatSentinel_ic.size().getInfo())
    return harmonized_LandsatSentinel_ic


'''
Add NDVI band to harmonized image collection
'''
def addNDVI(image):
    return image.addBands(image.normalizedDifference(['NIR', 'RED']).rename('NDVI')).float()


'''
Function definitions to get NDVI values at each 16-day composites
'''
def Get_NDVI_image_datewise(harmonized_LS_ic, roi_boundary):
    def get_NDVI_datewise(date):
        empty_band_image = ee.Image(0).float().rename(['NDVI']).updateMask(ee.Image(0).clip(roi_boundary))
        return harmonized_LS_ic.select(['NDVI']) \
                                .filterDate(ee.Date(date), ee.Date(date).advance(16, 'day')) \
                                .merge(empty_band_image)\
                                .median() \
                                .set('system:time_start',ee.Date(date).millis())
    return get_NDVI_datewise

def Get_LS_16Day_NDVI_TimeSeries(inputStartDate, inputEndDate, harmonized_LS_ic, roi_boundary):
    startDate = datetime.strptime(inputStartDate,"%Y-%m-%d")
    endDate = datetime.strptime(inputEndDate,"%Y-%m-%d")

    date_list = pd.date_range(start=startDate, end=endDate, freq='16D').tolist()
    date_list = ee.List( [datetime.strftime(curr_date,"%Y-%m-%d") for curr_date in date_list] )

    LSC =  ee.ImageCollection.fromImages(date_list.map(Get_NDVI_image_datewise(harmonized_LS_ic, roi_boundary)))

    return LSC


'''
Pair available LSC and modis values for each time stamp.
'''
def pairLSModis(lsRenameBands):
    def pair(feature):
        date = ee.Date( feature.get('system:time_start') )
        startDateT = date.advance(-8,'day')
        endDateT = date.advance(8,'day')

        # ------ MODIS VI ( We can add EVI to the band list later )
        modis = ee.ImageCollection('MODIS/061/MOD13Q1') \
                .filterDate(startDateT, endDateT) \
                .select(['NDVI','SummaryQA']) \
                .filterBounds(roi_boundary) \
                .median() \
                .rename(['NDVI_modis', 'SummaryQA_modis'])

        return feature.rename(lsRenameBands).addBands(modis)
    return pair


'''
Function to get Pearson Correlation Coffecient to perform GapFilling
'''
def get_Pearson_Correlation_Coefficients(LSC_modis_paired_ic, roi_boundary, bandList):
    corr = LSC_modis_paired_ic.filterBounds(roi_boundary) \
                                .select(bandList).toArray() \
                                .arrayReduce( reducer = ee.Reducer.pearsonsCorrelation(), axes=[0], fieldAxis=1 ) \
                                .arrayProject([1]).arrayFlatten([['c', 'p']])
    return corr


'''Use print(...) to write to this console.
Fill gaps in LSC timeseries using modis data
'''
def gapfillLSM(LSC_modis_regression_model, LSC_bandName, modis_bandName):
    def peformGapfilling(image):
        offset = LSC_modis_regression_model.select('offset')
        scale = LSC_modis_regression_model.select('scale')
        nodata = -1

        lsc_image = image.select(LSC_bandName)
        modisfit = image.select(modis_bandName).multiply(scale).add(offset)

        mask = lsc_image.mask()#update mask needs an input (no default input from the API document)
        gapfill = lsc_image.unmask(nodata)
        gapfill = gapfill.where(mask.Not(), modisfit)

        '''
        in SummaryQA,
        0: Good data, use with confidence
        1: Marginal data, useful but look at detailed QA for more information
        2: Pixel covered with snow/ice
        3: Pixel is cloudy
        '''
        qc_m = image.select('SummaryQA_modis').unmask(3)  # missing value is grouped as cloud
        w_m  = modisfit.mask().rename('w_m').where(qc_m.eq(0), 0.8)  # default is 0.8
        w_m = w_m.where(qc_m.eq(1), 0.5)   # Marginal
        w_m = w_m.where(qc_m.gte(2), 0.2) # snow/ice or cloudy

        # make sure these modis values are read where there is missing data from LandSat, Sentinel
        w_l = gapfill.mask() # default is 1
        w_l = w_l.where(mask.Not(), w_m)

        return gapfill.addBands(w_l).rename(['gapfilled_'+LSC_bandName,'SummaryQA']) #have NDVI from modis and a summary of clarity for each

    return peformGapfilling


'''
Function to combine LSC with Modis data
'''
def Combine_LS_Modis(LSC):
    lsRenameBands = ee.Image(LSC.first()).bandNames().map( lambda band: ee.String(band).cat('_lsc') )
    LSC_modis_paired_ic = LSC.map( pairLSModis(lsRenameBands) )

    # Output contains scale, offset i.e. two bands
    LSC_modis_regression_model_NDVI = LSC_modis_paired_ic.select(['NDVI_modis', 'NDVI_lsc']) \
                                                            .reduce(ee.Reducer.linearFit())

    corr_NDVI = get_Pearson_Correlation_Coefficients(LSC_modis_paired_ic, roi_boundary, ['NDVI_modis', 'NDVI_lsc'])
    LSMC_NDVI = LSC_modis_paired_ic.map(gapfillLSM(LSC_modis_regression_model_NDVI, 'NDVI_lsc', 'NDVI_modis'))

    return LSMC_NDVI


'''
Mask out low quality pixels
'''
def mask_low_QA(lsmc_image):
    low_qa = lsmc_image.select('SummaryQA').neq(0.2)
    return lsmc_image.updateMask(low_qa).copyProperties(lsmc_image, ['system:time_start'])


'''
Add image timestamp to each image in time series
'''
def add_timestamp(image):
    timeImage = image.metadata('system:time_start').rename('timestamp')
    timeImageMasked = timeImage.updateMask(image.mask().select(0))
    return image.addBands(timeImageMasked)


'''
Perform linear interpolation on missing values
'''
def performInterpolation(image):
    image = ee.Image(image)
    beforeImages = ee.List(image.get('before'))
    beforeMosaic = ee.ImageCollection.fromImages(beforeImages).mosaic()
    afterImages = ee.List(image.get('after'))
    afterMosaic = ee.ImageCollection.fromImages(afterImages).mosaic()

    # Interpolation formula
    # y = y1 + (y2-y1)*((t – t1) / (t2 – t1))
    # y = interpolated image
    # y1 = before image
    # y2 = after image
    # t = interpolation timestamp
    # t1 = before image timestamp
    # t2 = after image timestamp

    t1 = beforeMosaic.select('timestamp').rename('t1')
    t2 = afterMosaic.select('timestamp').rename('t2')
    t = image.metadata('system:time_start').rename('t')
    timeImage = ee.Image.cat([t1, t2, t])
    timeRatio = timeImage.expression('(t - t1) / (t2 - t1)', {
                    't': timeImage.select('t'),
                    't1': timeImage.select('t1'),
                    't2': timeImage.select('t2'),
                })

    interpolated = beforeMosaic.add((afterMosaic.subtract(beforeMosaic).multiply(timeRatio)))
    result = image.unmask(interpolated)
    fill_value = ee.ImageCollection([beforeMosaic, afterMosaic]).mosaic()
    result = result.unmask(fill_value)

    return result.copyProperties(image, ['system:time_start'])


def interpolate_timeseries(S1_TS):
    lsmc_masked = S1_TS.map(mask_low_QA)
    filtered = lsmc_masked.map(add_timestamp)

    # Time window in which we are willing to look forward and backward for unmasked pixel in time series
    timeWindow = 120

    # Define a maxDifference filter to find all images within the specified days. Convert days to milliseconds.
    millis = ee.Number(timeWindow).multiply(1000*60*60*24)
    # Filter says that pick only those timestamps which lie between the 2 timestamps not more than millis difference apart
    maxDiffFilter = ee.Filter.maxDifference(
                                difference = millis,
                                leftField = 'system:time_start',
                                rightField = 'system:time_start',
                                )

    # Filter to find all images after a given image. Compare the image's timstamp against other images.
    # Images ahead of target image should have higher timestamp.
    lessEqFilter = ee.Filter.lessThanOrEquals(
                                leftField = 'system:time_start',
                                rightField = 'system:time_start'
                            )

    # Similarly define this filter to find all images before a given image
    greaterEqFilter = ee.Filter.greaterThanOrEquals(
                                leftField = 'system:time_start',
                                rightField = 'system:time_start'
                            )

    # Apply first join to find all images that are after the target image but within the timeWindow
    filter1 = ee.Filter.And( maxDiffFilter, lessEqFilter )
    join1 = ee.Join.saveAll(
                    matchesKey = 'after',
                    ordering = 'system:time_start',
                    ascending = False
            )
    join1Result = join1.apply(
                    primary = filtered,
                    secondary = filtered,
                    condition = filter1
                    )

    # Apply first join to find all images that are after the target image but within the timeWindow
    filter2 = ee.Filter.And( maxDiffFilter, greaterEqFilter )
    join2 = ee.Join.saveAll(
                    matchesKey = 'before',
                    ordering = 'system:time_start',
                    ascending = True
            )
    join2Result = join2.apply(
                    primary = join1Result,
                    secondary = join1Result,
                    condition = filter2
                    )

    interpolated_S1_TS = ee.ImageCollection(join2Result.map(performInterpolation))

    return interpolated_S1_TS


'''
Function Definition to get Padded NDVI LSMC timeseries image for a given ROI
'''
def Get_Padded_NDVI_TS_Image(startDate, endDate, roi_boundary):
    L7, L8, S2 = Get_L7_L8_S2_ImageCollections(startDate, endDate, roi_boundary)

    harmonized_LS_ic = Harmonize_L7_L8_S2(L7, L8, S2)
    harmonized_LS_ic = harmonized_LS_ic.map(addNDVI)
    LSC = Get_LS_16Day_NDVI_TimeSeries(startDate, endDate, harmonized_LS_ic, roi_boundary)
    LSMC_NDVI = Combine_LS_Modis(LSC)
    Interpolated_LSMC_NDVI = interpolate_timeseries(LSMC_NDVI)
    final_LSMC_NDVI_TS = Interpolated_LSMC_NDVI.select(['gapfilled_NDVI_lsc']).toBands()
    final_LSMC_NDVI_TS = final_LSMC_NDVI_TS.clip(roi_boundary)

    input_bands = final_LSMC_NDVI_TS.bandNames()
    return final_LSMC_NDVI_TS, input_bands


'''
Function definition to compute euclidean distance to each cluster centroid
features ---> clusters
flattened ---> time series image clipped to target area
input_bands ---> band names for time series image
studyarea ---> geometry of region of interest
'''
# Function to get distances as required from each pixel to each cluster centroid
def Get_Euclidean_Distance(cluster_centroids, roi_timeseries_img, input_bands, roi_boundary):

    def wrapper(curr_centroid):
        temp_img = ee.Image()
        curr_centroid = ee.Feature(curr_centroid).setGeometry(roi_boundary)
        temp_fc = ee.FeatureCollection( [curr_centroid] )
        class_img = temp_fc.select(['class']).reduceToImage(['class'], ee.Reducer.first()).rename(['class'])
        def create_img(band_name):
            return temp_fc.select([band_name]).reduceToImage([band_name], ee.Reducer.first()).rename([band_name])

        temp_img = input_bands.map(create_img)
        empty = ee.Image()
        temp_img = ee.Image( temp_img.iterate( lambda img, prev: ee.Image(prev).addBands(img) , empty))

        temp_img = temp_img.select(temp_img.bandNames().remove('constant'))
        distance = roi_timeseries_img.spectralDistance(temp_img, 'sed')
        confidence = ee.Image(1.0).divide(distance).rename(['confidence'])
        distance = distance.addBands(confidence)
        return distance.addBands(class_img.rename(['class'])).set('class', curr_centroid.get('class'))

    return cluster_centroids.map(wrapper)


'''
Function definition to get final prediction image from distance images
'''
def Get_final_prediction_image(distance_imgs_list):
    # Denominator is an image that is sum of all confidences to each cluster
    sum_of_distances = ee.ImageCollection( distance_imgs_list ).select(['confidence']).sum().unmask()
    distance_imgs_ic = ee.ImageCollection( distance_imgs_list ).select(['distance','class'])

    # array is an image where distance band is an array of distances to each cluster centroid and class band is an array of classes associated with each cluster
    array_img = ee.ImageCollection(distance_imgs_ic).toArray()

    axes = {'image': 0, 'band':1}
    sort = array_img.arraySlice(axes['band'], 0, 1)
    sorted = array_img.arraySort(sort)

    # take the first image only
    values = sorted.arraySlice(axes['image'], 0, 1)
    # convert back to an image
    min = values.arrayProject([axes['band']]).arrayFlatten([['distance', 'class']])
    # Extract the hard classification
    predicted_class_img = min.select(1)
    predicted_class_img = predicted_class_img.rename(['predicted_label'])

    return predicted_class_img

## My Helper Functions
def change_clusters(cluster_centroids):
    size = cluster_centroids.size().getInfo()
    features = []
    for i in range(size):
        features.append(ee.Feature(cluster_centroids.toList(size).get(i)).set("class", 13+i))
    return ee.FeatureCollection(features)


def get_cropping_frequency(roi_boundary, startDate, endDate):
    cluster_centroids = ee.FeatureCollection('projects/ee-indiasat/assets/L3_LULC_Clusters/Final_Level3_PanIndia_Clusters')
    ignore_clusters = [12] # remove invalid clusters
    cluster_centroids = cluster_centroids.filter(ee.Filter.Not( ee.Filter.inList('class', ignore_clusters)))
    
    final_LSMC_NDVI_TS, input_bands =  Get_Padded_NDVI_TS_Image(startDate, endDate, roi_boundary)
    distance_imgs_list = Get_Euclidean_Distance(cluster_centroids, final_LSMC_NDVI_TS, input_bands, roi_boundary)
    final_classified_img = Get_final_prediction_image(distance_imgs_list)
    ### adding Cluster values after 12
    #cluster_centroids = change_clusters(cluster_centroids)
    distance_imgs_list = Get_Euclidean_Distance(cluster_centroids, final_LSMC_NDVI_TS, input_bands, roi_boundary)
    final_cluster_classified_img = Get_final_prediction_image(distance_imgs_list)
    final_cluster_classified_img = final_cluster_classified_img.rename(['predicted_cluster'])
    final_classified_img = final_classified_img.addBands(final_cluster_classified_img)
    return final_classified_img, final_LSMC_NDVI_TS

def get_six_year_cropping_frequency_rasters(roi_boundary, start_year, num_years=6):
    """
    Uses your existing get_cropping_frequency(roi_boundary, startDate, endDate)
    and produces a list of 6 yearly cropping-frequency images.

    IMPORTANT:
    - Does NOT rename or modify your existing functions/variables.
    - Each output image is clipped to roi_boundary (as your pipeline already does inside Get_Padded_NDVI_TS_Image).
    - Output list order: [Y1, Y2, ..., Y6] where
        Y1 = start_year-07-01 to start_year+1-06-30
        ...
        Y6 = start_year+5-07-01 to start_year+6-06-30

    Returns:
      crop_freq_imgs: Python list of ee.Image (each has band 'predicted_label' and 'predicted_cluster')
      date_ranges:    Python list of (startDate, endDate) strings for bookkeeping
    """
    crop_freq_imgs = []
    date_ranges = []

    for i in range(num_years):
        y1 = start_year + i
        y2 = y1 + 1

        currStartDate = f"{y1}-07-01"
        currEndDate   = f"{y2}-06-30"

        # Your function returns:
        #   final_classified_img: bands ['predicted_label', 'predicted_cluster']
        #   final_LSMC_NDVI_TS:   NDVI time-series (not needed here)
        cropping_frequency_img, _ = get_cropping_frequency(roi_boundary, currStartDate, currEndDate)

        # Keep as-is (same variable names, same band names)
        crop_freq_imgs.append(cropping_frequency_img.select(['predicted_label',]))
        date_ranges.append((currStartDate, currEndDate))

    return crop_freq_imgs, date_ranges


In [5]:
# ============================================================
# 7) RE-INTRODUCE MERGED CLASSES (WATER + CROP) INTO FINAL MAPS
# ============================================================

# Original classes and grouped mapping (same as you used)
ORIG =  [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13]
GROUP = [0, 1, 2, 2, 2, 6, 7, 8, 8,  8,  8,  12, 13]

BAND = 'predicted_label'

WATER_SET = ee.List([2, 3, 4])          # original water subclasses
CROP_SET  = ee.List([8, 9, 10, 11])     # original crop subclasses

def _is_in_list(img, vals):
    # vals is ee.List of ints
    return ee.Image(img).remap(vals, ee.List.repeat(1, vals.size()), 0).eq(1)

def _group_from_orig(orig_img):
    # orig_img: ee.Image single band predicted_label
    return orig_img.remap(ORIG, GROUP).toInt()

def _safe_water_source(orig_img):
    """
    Ensure output is a valid water subclass. If orig_img isn't water (2/3/4), fall back to 2.
    """
    isw = _is_in_list(orig_img, WATER_SET)
    return ee.Image(orig_img).where(isw.Not(), 2).toInt()

def reinstate_merged_classes_6yr(original_year_imgs, corrected_img, crop_intensity_imgs=None, band=BAND):
    """
    Inputs:
      original_year_imgs : list of 6 ee.Image (un-grouped), each containing `band` (predicted_label).
      corrected_img      : ee.Image that contains grouped corrected bands y1c..y6c
                           and doubletPattern/hasDoublet (as in your pipeline).
      crop_intensity_imgs: optional list of 6 ee.Image, each with `band` in {8,9,10,11}.
                           Used only where corrected class is crop (8) but original year is not crop.

    Output:
      ee.Image with bands y1f..y6f (final un-grouped labels).
    """
    if len(original_year_imgs) != 6:
        raise ValueError("Provide exactly 6 original yearly images.")
    if crop_intensity_imgs is not None and len(crop_intensity_imgs) != 6:
        raise ValueError("If provided, crop_intensity_imgs must have exactly 6 images.")

    # Doublet pattern flags (already matrix-filtered earlier)
    pat = corrected_img.select('doubletPattern').toInt()
    has = corrected_img.select('hasDoublet').eq(1)

    p1 = has.And(pat.eq(1))  # A X Y A A *
    p2 = has.And(pat.eq(2))  # A A X Y A A
    p3 = has.And(pat.eq(3))  # * A A X Y A

    # For each year t, mark if that year is X or Y (only matters if we need to "recreate" water subclass)
    # Pattern 1 affects y2 (X) and y3 (Y)
    # Pattern 2 affects y3 (X) and y4 (Y)
    # Pattern 3 affects y4 (X) and y5 (Y)
    isX = {
        1: ee.Image(0).eq(1),      # false
        2: p1,
        3: p2,
        4: p3,
        5: ee.Image(0).eq(1),
        6: ee.Image(0).eq(1),
    }
    isY = {
        1: ee.Image(0).eq(1),      # false
        2: ee.Image(0).eq(1),
        3: p1,
        4: p2,
        5: p3,
        6: ee.Image(0).eq(1),
    }

    out_bands = []

    for t in range(1, 7):
        orig_t = ee.Image(original_year_imgs[t-1]).select([band]).toInt()     # original fine label
        orig_grp_t = _group_from_orig(orig_t)                                 # grouped original
        corr_grp_t = corrected_img.select(f'y{t}c').toInt()                   # grouped corrected

        # Start with original by default
        final_t = orig_t

        # ------------------------------------------------------------------
        # (A) Non-water, non-crop classes: we can directly write grouped label
        #     because they were not merged (built-up, tree, barren, scrub, plantation, background).
        # ------------------------------------------------------------------
        isWaterCorr = corr_grp_t.eq(2)
        isCropCorr  = corr_grp_t.eq(8)

        nonMerged = isWaterCorr.Or(isCropCorr).Not()
        # If corrected differs from original grouped AND it's non-merged, write corr_grp directly
        final_t = final_t.where(nonMerged.And(corr_grp_t.neq(orig_grp_t)), corr_grp_t)

        # ------------------------------------------------------------------
        # (B) Crop re-introduction:
        #   If corrected says "crop(8)":
        #     - If original already had a crop subclass (8/9/10/11), keep it.
        #     - Else use crop_intensity_imgs[t-1] if provided; otherwise fallback to 8.
        # ------------------------------------------------------------------
        if crop_intensity_imgs is not None:
            crop_pred_t = ee.Image(crop_intensity_imgs[t-1]).select([band]).toInt()
        else:
            crop_pred_t = ee.Image(8).toInt()

        orig_is_crop = _is_in_list(orig_t, CROP_SET)

        crop_fill = ee.Image(orig_t) \
            .where(orig_is_crop.Not(), crop_pred_t) \
            .toInt()

        final_t = final_t.where(isCropCorr, crop_fill)

        # ------------------------------------------------------------------
        # (C) Water re-introduction:
        #   If corrected says "water(2)":
        #     - If original already had water subclass (2/3/4), keep it.
        #     - Else:
        #         * Singlet/other cases: use water class from one year before (t-1).
        #         * Doublet cases:
        #              - if this year is X -> use (t-1)
        #              - if this year is Y -> use (t-2)
        #       If chosen source isn't water, fall back to:
        #           (1) other prior source (t-2), else
        #           (2) 2 (generic water)
        # ------------------------------------------------------------------
        orig_is_water = _is_in_list(orig_t, WATER_SET)

        # prior sources (as images); if out of range, use orig_t itself
        prev1 = ee.Image(original_year_imgs[t-2]).select([band]).toInt() if t >= 2 else orig_t
        prev2 = ee.Image(original_year_imgs[t-3]).select([band]).toInt() if t >= 3 else prev1

        prev1w = _safe_water_source(prev1)
        prev2w = _safe_water_source(prev2)

        # choose source based on X/Y (doublet) or default singlet-like (t-1)
        choose_prev2 = isY[t]  # Y -> (t-2)
        choose_prev1 = isX[t].Or(choose_prev2.Not())  # X or non-doublet -> (t-1)

        water_source = ee.Image(2).toInt()
        water_source = water_source.where(choose_prev1, prev1w)
        water_source = water_source.where(choose_prev2, prev2w)

        # If original already water subclass, keep it; else use chosen source
        water_fill = ee.Image(orig_t).where(orig_is_water.Not(), water_source).toInt()

        final_t = final_t.where(isWaterCorr, water_fill)

        out_bands.append(final_t.rename(f'y{t}f'))

    return ee.Image.cat(out_bands)


In [6]:
crop_freq_imgs, crop_freq_date_ranges = get_six_year_cropping_frequency_rasters(
    roi_boundary=roi_boundary,
    start_year=2017,
    num_years=6
)

final_img = reinstate_merged_classes_6yr(
    original_year_imgs=images,
    corrected_img=corrected,
    crop_intensity_imgs=crop_freq_imgs  # later you will pass your 6 yearly crop-intensity rasters here
)



In [7]:
#Map.addLayer(final_img.select('y1f'), vis_params_lulc, 'y1 final (un-grouped)')
#Map.addLayer(final_img.select('y2f'), vis_params_lulc, 'y2 final (un-grouped)')
#Map.addLayer(final_img.select('y3f'), vis_params_lulc, 'y3 final (un-grouped)')
#Map.addLayer(final_img.select('y4f'), vis_params_lulc, 'y4 final (un-grouped)')
#Map.addLayer(final_img.select('y5f'), vis_params_lulc, 'y5 final (un-grouped)')
#Map.addLayer(final_img.select('y6f'), vis_params_lulc, 'y6 final (un-grouped)')
#Map

In [8]:
def export_temporal_corrected_assets(
    imgs,                 # list of ee.Image (length = num_years)
    aez,                  # int (AEZ number)
    start_year,           # int (e.g., 2017)
    project='raman-461708',
    region_fc="users/mtpictd/agro_eco_regions",  # same as your ROI source
    scale=10,
    maxPixels=1e13
):
    # ROI geometry for the AEZ (keeps exports tight)
    roi_geom = ee.FeatureCollection(region_fc) \
        .filter(ee.Filter.eq("ae_regcode", aez)) \
        .geometry()

    tasks = []
    for i, img in enumerate(imgs):
        y1 = start_year + i
        y2 = y1 + 1

        # Same date naming you used while loading
        date_tag = f"{y1}-07-01_{y2}-06-30"

        # Asset ID with requested extension
        asset_id = f"projects/{project}/assets/AEZ_{aez}_{date_tag}_temporal_corrected"

        # Task description (must be short-ish and unique)
        desc = f"AEZ_{aez}_{y1}_{y2}_temporal_corrected"

        task = ee.batch.Export.image.toAsset(
            image=ee.Image(img).clip(roi_geom),
            description=desc,
            assetId=asset_id,
            region=roi_geom,
            scale=scale,
            maxPixels=maxPixels
        )
        task.start()
        tasks.append(task)

        print(f"Started export: {asset_id}")

    return tasks

tasks = export_temporal_corrected_assets(
     imgs=[final_img.select('y'+str(i)+'f') for i in [1,2,3,4,5,6]],
     aez=aez,
     start_year=start_year,
     project='raman-461708',
     scale=10
)

Started export: projects/raman-461708/assets/AEZ_19_2017-07-01_2018-06-30_temporal_corrected
Started export: projects/raman-461708/assets/AEZ_19_2018-07-01_2019-06-30_temporal_corrected
Started export: projects/raman-461708/assets/AEZ_19_2019-07-01_2020-06-30_temporal_corrected
Started export: projects/raman-461708/assets/AEZ_19_2020-07-01_2021-06-30_temporal_corrected
Started export: projects/raman-461708/assets/AEZ_19_2021-07-01_2022-06-30_temporal_corrected
Started export: projects/raman-461708/assets/AEZ_19_2022-07-01_2023-06-30_temporal_corrected
