In [1]:
import ee

ee.Authenticate()
ee.Initialize(project='raman-461708')


# -------------------------------
# 1. LOADING

# -------------------------------

def load_aez_temporal_images(aez, start_year, num_years, project='raman-461708'):
    base = f"projects/{project}/assets/AEZ_{aez}_v4_temporal_3years__"
    base = f"projects/{project}/assets/LULC_v4_PanIndia_"
    images = []
    for i in range(num_years):
        y1 = start_year + i
        y2 = y1 + 1
        asset = f"{base}{y1}-07-01_{y2}-06-30"
        img = ee.Image(asset)
        images.append(img.select('predicted_label'))
    return images

aez = 10
start_year = 2017
num_years = 6
project = 'raman-461708'

roi = ee.FeatureCollection("users/mtpictd/agro_eco_regions")

# 1. Load
images = load_aez_temporal_images(aez, start_year, num_years, project)



In [2]:


# -----------------------------
# 0) Grouping (as you requested)
# -----------------------------
# Original classes present in your data
ORIG =  [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13]

# Grouped mapping:
# water: 2,3,4 -> 2
# crop:  8,9,10,11 -> 8
GROUP = [0, 1, 2, 2, 2, 6, 7, 8, 8,  8,  8,  12, 13]

BAND = 'predicted_label'   # change if your band name differs


def group_classes(img, band=BAND):
    """Map original labels to grouped labels (water merged, crop merged)."""
    g = img.select([band]).remap(ORIG, GROUP).rename('predicted_label')
    return g


# ----------------------------------------------------------
# 1) Stack 7 yearly layers into one 7-band image: y1..y7
# ----------------------------------------------------------
def stack_7_years(year_imgs):
    """
    year_imgs: list of 7 ee.Image, already grouped to a single band 'lulc'
    """
    if len(year_imgs) != 7:
        raise ValueError("Provide exactly 7 yearly images.")

    bands = []
    for i, im in enumerate(year_imgs, start=1):
        bands.append(im.select(['predicted_label']).rename(f'y{i}'))
    return ee.Image.cat(bands)


# ----------------------------------------------------------
# 2) Flip indicator bits (6 bits): f2..f7 where ft = (yt != y(t-1))
#    Encoded into integer mask: bit5..bit0 = f2..f7
# ----------------------------------------------------------


def _count_ones_nbit(mask, nbits):
    bits = [mask.rightShift(k).bitwiseAnd(1) for k in range(nbits)]
    return ee.Image.cat(bits).reduce(ee.Reducer.sum()).rename('singletCount').toInt()

def _has_adjacent_ones(mask):
    # True if any '11' exists
    return mask.bitwiseAnd(mask.leftShift(1)).neq(0).rename('hasAdj11')

def _has_run_k(mask, nbits, k):
    """
    True if there exists a run of k consecutive 1s in an nbits mask.
    """
    base = (1 << k) - 1
    cond = ee.Image(0)
    for start in range(0, nbits - k + 1):
        pat = ee.Image(base << start).toInt()
        cond = cond.Or(mask.bitwiseAnd(pat).eq(pat))
    return cond

def _max_run_length(mask, nbits):
    # For singlet-center mask, nbits=5 (centers y2..y6), so max run is 5
    any1 = mask.neq(0)
    maxRun = ee.Image(0).where(any1, 1)
    for k in range(2, nbits + 1):
        maxRun = maxRun.where(_has_run_k(mask, nbits, k), k)
    return maxRun.rename('singletMaxRun').toInt()

def singlet_category(seq7):
    """
    Correct singlet detection using ABA on the class sequence.

    Input:
      seq7: ee.Image with bands y1..y7 (GROUPED classes).

    Output bands:
      - singletSeqMask  (0..31): 5-bit mask for centers y2..y6
      - singletCount    (0..5)
      - hasAdj11        (bool) adjacency among singlet centers
      - singletMaxRun   (0..5)
      - isolatedSingletMask (mask bits that are isolated, not in any adjacency)
      - singletCategory (int):
          0: no singlets
          1: disjoint singlets only (no adjacent singlet centers)
          2: max run = 2 (cluster of 2 singlet centers) and not composite
          3: max run = 3
          4: max run = 4
          5: max run = 5
          6: composite (both isolated singlets and a run-2 cluster exist)
    """
    y = [seq7.select(f'y{i}') for i in range(1, 8)]  # y[0]=y1 ... y[6]=y7

    # Singlet centers at years 2..6 (5 centers)
    centers = []
    for idx in range(1, 6):  # idx=1..5 corresponds to y2..y6
        left  = y[idx - 1]
        mid   = y[idx]
        right = y[idx + 1]
        s = left.eq(right).And(mid.neq(left)).rename(f's{idx+1}').toUint8()  # s2..s6
        centers.append(s)

    # Pack s2..s6 into 5-bit mask: s2 is MSB (bit4), s6 is LSB (bit0)
    mask = ee.Image(0).toInt()
    for j, s in enumerate(centers):  # j=0..4 => s2..s6
        mask = mask.bitwiseOr(s.toInt().leftShift(4 - j))
    singletSeqMask = mask.rename('singletSeqMask')

    singletCount = _count_ones_nbit(singletSeqMask, 5)
    anySinglet = singletSeqMask.neq(0)

    adj11 = _has_adjacent_ones(singletSeqMask)
    maxRun = _max_run_length(singletSeqMask, 5)

    # Identify isolated singlet bits (not part of any adjacent pair)
    bits_in_adj = singletSeqMask.bitwiseAnd(singletSeqMask.leftShift(1))
    bits_in_adj_covered = bits_in_adj.bitwiseOr(bits_in_adj.rightShift(1))
    isolatedMask = singletSeqMask.bitwiseAnd(bits_in_adj_covered.bitwiseNot()).rename('isolatedSingletMask')

    hasIsolated = isolatedMask.neq(0)
    isMax2 = maxRun.eq(2)

    # Composite = contains a run-2 cluster AND also has at least one isolated singlet elsewhere
    composite = anySinglet.And(isMax2).And(adj11).And(hasIsolated)

    disjointSingles = anySinglet.And(adj11.Not())

    cat = ee.Image(0).toInt() \
        .where(disjointSingles, 1) \
        .where(isMax2.And(disjointSingles.Not()).And(composite.Not()), 2) \
        .where(maxRun.eq(3), 3) \
        .where(maxRun.eq(4), 4) \
        .where(maxRun.eq(5), 5) \
        .where(composite, 6) \
        .where(anySinglet.Not(), 0) \
        .rename('singletCategory')

    # Also return the per-center bands s2..s6 for debugging/visualization
    return ee.Image.cat([
        cat,
        singletSeqMask,
        singletCount,
        adj11,
        maxRun,
        isolatedMask.toInt(),
        ee.Image.cat(centers)  # bands: s2..s6
    ])




# ----------------------------------------------------------
# 5) Doublet-flip detection (your exact patterns)
#    Pattern: A A X Y A A A, with X!=Y, X!=A, Y!=A
#    Valid configurations require >=2 consecutive A's on one side:
#      (i)  A X Y A A * *     -> check y1=A, y4=y5=A, y2=X, y3=Y
#      (ii) A A X Y A A *     -> check y1=y2=A, y5=y6=A
#      (iii)* A A X Y A A     -> check y2=y3=A, y6=y7=A
#      (iv) * * A A X Y A     -> check y3=y4=A, y7=A and (two A's on left: y3,y4)
# ----------------------------------------------------------
def detect_doublet(seq7):
    y1 = seq7.select('y1')
    y2 = seq7.select('y2')
    y3 = seq7.select('y3')
    y4 = seq7.select('y4')
    y5 = seq7.select('y5')
    y6 = seq7.select('y6')
    y7 = seq7.select('y7')

    def xy_constraints(A, X, Y):
        return X.neq(Y).And(X.neq(A)).And(Y.neq(A))

    # (i) A X Y A A * *
    A = y1
    X = y2
    Y = y3
    p1 = y4.eq(A).And(y5.eq(A)).And(xy_constraints(A, X, Y))

    # (ii) A A X Y A A *
    A = y1
    X = y3
    Y = y4
    p2 = y2.eq(A).And(y5.eq(A)).And(y6.eq(A)).And(xy_constraints(A, X, Y))

    # (iii) * A A X Y A A
    A = y2
    X = y4
    Y = y5
    p3 = y3.eq(A).And(y6.eq(A)).And(y7.eq(A)).And(xy_constraints(A, X, Y))

    # (iv) * * A A X Y A
    A = y3
    X = y5
    Y = y6
    p4 = y4.eq(A).And(y7.eq(A)).And(xy_constraints(A, X, Y))

    # Encode pattern id: 0 none, 1..4 matched (priority order)
    pat = ee.Image(0).toInt() \
        .where(p1, 1) \
        .where(p2, 2) \
        .where(p3, 3) \
        .where(p4, 4) \
        .rename('doubletPattern')

    anyDoublet = p1.Or(p2).Or(p3).Or(p4).rename('hasDoublet')
    return pat.addBands(anyDoublet)


# ----------------------------------------------------------
# 6) Main wrapper: provide 7 yearly images -> outputs categories
# ----------------------------------------------------------
def classify_temporal_flip_categories(year_imgs):
    """
    year_imgs: list of 7 ee.Image with band BAND (original labels).
    Returns ee.Image with:
      - y1..y7 (grouped)
      - flipMask, flipCount, hasAdj11, maxRun, isolatedMask
      - singletCategory (0..6)
      - doubletPattern (0..4), hasDoublet
    """
    grouped = [group_classes(im, band=BAND) for im in year_imgs]
    seq7 = stack_7_years(grouped)


    
    roi_mask = ee.Image.constant(1).clip(roi).selfMask()
    sing = singlet_category(seq7).updateMask(roi_mask)
    dbl = detect_doublet(seq7).updateMask(roi_mask)

    return seq7.addBands(sing).addBands(dbl)




singlet_vis = {
    'min': 0,
    'max': 6,
    'palette': [
        '000000',  # 0 - No flips (black)
        '2ecc71',  # 1 - Disjoint singles (green)
        'f1c40f',  # 2 - Run=2 (yellow)
        'e67e22',  # 3 - Run=3 (orange)
        'e74c3c',  # 4 - Run=4 (red)
        '8e44ad',  # 5 - Run=5 (purple)
        '34495e'   # 6 - Composite (dark slate)
    ]
}

# -----------------------------
# Example usage (replace these)
# -----------------------------
# year_imgs = [img2017, img2018, img2019, img2020, img2021, img2022, img2023]
out = classify_temporal_flip_categories(images)

import geemap
Map = geemap.Map()
url = 'https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}'
Map.layout.height = '1000px'
Map.add_tile_layer(url, name="Google Map", attribution="Google")

Map.addLayer(out.select('singletCategory'), singlet_vis, 'singletCategory')
Map.addLayer(out.select('y1'), {}, 'y1')
Map.addLayer(out.select('y2'), {}, 'y2')
Map.addLayer(out.select('y3'), {}, 'y3')
Map.addLayer(out.select('y4'), {}, 'y4')
Map.addLayer(out.select('y5'), {}, 'y5')
Map.addLayer(out.select('y6'), {}, 'y6')
Map.addLayer(out.select('y7'), {}, 'y7')

#Map.addLayer(out.select('doubletPattern'), {}, 'doubletPattern')
#Map.addLayer(out.select('flip'), {}, 'flip')
Map


Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='topright', transp…

In [2]:

import ee

ee.Authenticate()
ee.Initialize(project='raman-461708')


# -------------------------------
# 1. LOADING

# -------------------------------

def load_aez_temporal_images(aez, start_year, num_years, project='raman-461708'):
    base = f"projects/{project}/assets/AEZ_{aez}_v4_temporal_3years__"
    base = f"projects/{project}/assets/LULC_v4_PanIndia_"
    images = []
    for i in range(num_years):
        y1 = start_year + i
        y2 = y1 + 1
        asset = f"{base}{y1}-07-01_{y2}-06-30"
        img = ee.Image(asset)
        images.append(img)
    return images

aez = 10
start_year = 2017
num_years = 6
project = 'raman-461708'

roi = ee.FeatureCollection("users/mtpictd/agro_eco_regions")

# 1. Load
images = load_aez_temporal_images(aez, start_year, num_years, project)


# -----------------------------
# 0) Grouping (as you requested)
# -----------------------------
# Original classes present in your data
ORIG =  [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13]

# Grouped mapping:
# water: 2,3,4 -> 2
# crop:  8,9,10,11 -> 8
GROUP = [0, 1, 2, 2, 2, 6, 7, 8, 8,  8,  8,  12, 13]

BAND = 'predicted_label'   # change if your band name differs


def group_classes(img, band=BAND):
    """Map original labels to grouped labels (water merged, crop merged)."""
    g = img.select([band]).remap(ORIG, GROUP).rename('predicted_label')
    return g


# ----------------------------------------------------------
# 1) Stack 6 yearly layers into one 7-band image: y1..y7
# ----------------------------------------------------------
def stack_6_years(year_imgs):
    if len(year_imgs) != 6:
        raise ValueError("Provide exactly 6 yearly images.")

    bands = []
    for i, im in enumerate(year_imgs, start=1):
        bands.append(im.select(['predicted_label']).rename(f'y{i}'))
    return ee.Image.cat(bands)



# ----------------------------------------------------------
# 2) Flip indicator bits (6 bits): f2..f7 where ft = (yt != y(t-1))
#    Encoded into integer mask: bit5..bit0 = f2..f7
# ----------------------------------------------------------


def _count_ones_nbit(mask, nbits):
    bits = [mask.rightShift(k).bitwiseAnd(1) for k in range(nbits)]
    return ee.Image.cat(bits).reduce(ee.Reducer.sum()).rename('singletCount').toInt()

def _has_adjacent_ones(mask):
    # True if any '11' exists
    return mask.bitwiseAnd(mask.leftShift(1)).neq(0).rename('hasAdj11')

def _has_run_k(mask, nbits, k):
    """
    True if there exists a run of k consecutive 1s in an nbits mask.
    """
    base = (1 << k) - 1
    cond = ee.Image(0)
    for start in range(0, nbits - k + 1):
        pat = ee.Image(base << start).toInt()
        cond = cond.Or(mask.bitwiseAnd(pat).eq(pat))
    return cond

def _max_run_length(mask, nbits):
    # For singlet-center mask, nbits=5 (centers y2..y6), so max run is 5
    any1 = mask.neq(0)
    maxRun = ee.Image(0).where(any1, 1)
    for k in range(2, nbits + 1):
        maxRun = maxRun.where(_has_run_k(mask, nbits, k), k)
    return maxRun.rename('singletMaxRun').toInt()

def singlet_category_6yr(seq6):
    """
    ABA singlets for 6-year sequence (y1..y6).

    Centers possible: y2..y5 => 4 centers => 4-bit mask (0..15)

    Output bands:
      - singletCategory (0..6) using same semantics, but max run is 4
      - singletSeqMask (0..15)
      - singletCount   (0..4)
      - hasAdj11
      - singletMaxRun  (0..4)
      - isolatedSingletMask
      - s2..s5
    """
    y = [seq6.select(f'y{i}') for i in range(1, 7)]  # y1..y6

    centers = []
    for idx in range(1, 5):  # idx=1..4 corresponds to y2..y5
        left  = y[idx - 1]
        mid   = y[idx]
        right = y[idx + 1]
        s = left.eq(right).And(mid.neq(left)).rename(f's{idx+1}').toUint8()  # s2..s5
        centers.append(s)

    # Pack s2..s5 into 4-bit mask: s2 is MSB (bit3), s5 is LSB (bit0)
    mask = ee.Image(0).toInt()
    for j, s in enumerate(centers):  # j=0..3 => s2..s5
        mask = mask.bitwiseOr(s.toInt().leftShift(3 - j))
    singletSeqMask = mask.rename('singletSeqMask')

    # count ones in 4-bit mask
    singletCount = _count_ones_nbit(singletSeqMask, 4).rename('singletCount')
    anySinglet = singletSeqMask.neq(0)

    adj11 = _has_adjacent_ones(singletSeqMask)  # works for any bit-length
    maxRun = _max_run_length(singletSeqMask, 4).rename('singletMaxRun')

    # Isolated singlet bits (not part of any adjacent pair)
    bits_in_adj = singletSeqMask.bitwiseAnd(singletSeqMask.leftShift(1))
    bits_in_adj_covered = bits_in_adj.bitwiseOr(bits_in_adj.rightShift(1))
    isolatedMask = singletSeqMask.bitwiseAnd(bits_in_adj_covered.bitwiseNot()).rename('isolatedSingletMask')

    hasIsolated = isolatedMask.neq(0)
    isMax2 = maxRun.eq(2)

    composite = anySinglet.And(isMax2).And(adj11).And(hasIsolated)
    disjointSingles = anySinglet.And(adj11.Not())

    # Category logic (same labels; note max run is now 4, so case 5 won't occur)
    cat = ee.Image(0).toInt() \
        .where(disjointSingles, 1) \
        .where(isMax2.And(disjointSingles.Not()).And(composite.Not()), 2) \
        .where(maxRun.eq(3), 3) \
        .where(maxRun.eq(4), 4) \
        .where(composite, 6) \
        .where(anySinglet.Not(), 0) \
        .rename('singletCategory')

    return ee.Image.cat([
        cat,
        singletSeqMask,
        singletCount,
        adj11,
        maxRun,
        isolatedMask.toInt(),
        ee.Image.cat(centers)  # s2..s5
    ])


def detect_doublet_6yr(seq6):
    """
    Doublet flip detection for 6-year sequence (y1..y6).

    Detects patterns where two consecutive anomalous labels X,Y occur inside an otherwise
    stable class A, with:
      - X != Y
      - X != A and Y != A
      - At least two consecutive A's on one side of the XY block

    Supported placements (length-6):
      (1) A X Y A A *   -> A=y1, X=y2, Y=y3, require y4=y5=A
      (2) A A X Y A A   -> A=y1, X=y3, Y=y4, require y2=y5=y6=A
      (3) * A A X Y A   -> A=y2, X=y4, Y=y5, require y3=A and y6=A (and y2=y3=A => two A on left)

    Output:
      - doubletPattern6 (0..3): 0 none; 1/2/3 indicates which placement matched
      - hasDoublet6 (0/1)
    """
    y1 = seq6.select('y1')
    y2 = seq6.select('y2')
    y3 = seq6.select('y3')
    y4 = seq6.select('y4')
    y5 = seq6.select('y5')
    y6 = seq6.select('y6')

    def xy_constraints(A, X, Y):
        return X.neq(Y).And(X.neq(A)).And(Y.neq(A))

    # (1) A X Y A A *
    A = y1
    X = y2
    Y = y3
    p1 = y4.eq(A).And(y5.eq(A)).And(xy_constraints(A, X, Y))

    # (2) A A X Y A A
    A = y1
    X = y3
    Y = y4
    p2 = y2.eq(A).And(y5.eq(A)).And(y6.eq(A)).And(xy_constraints(A, X, Y))

    # (3) * A A X Y A
    A = y2
    X = y4
    Y = y5
    p3 = y3.eq(A).And(y6.eq(A)).And(xy_constraints(A, X, Y))

    pat = ee.Image(0).toInt() \
        .where(p1, 1) \
        .where(p2, 2) \
        .where(p3, 3) \
        .rename('doubletPattern')

    anyDoublet = p1.Or(p2).Or(p3).rename('hasDoublet')
    return pat.addBands(anyDoublet)


# ----------------------------------------------------------
# 6) Main wrapper: provide 7 yearly images -> outputs categories
# ----------------------------------------------------------
def classify_temporal_flip_categories_6yr(year_imgs):
    grouped = [group_classes(im, band=BAND) for im in year_imgs]
    seq6 = stack_6_years(grouped)

    roi_mask = ee.Image.constant(1).clip(roi).selfMask()

    sing = singlet_category_6yr(seq6).updateMask(roi_mask)
    dbl  = detect_doublet_6yr(seq6).updateMask(roi_mask)

    return seq6.addBands(sing).addBands(dbl)


singlet_vis = {
    'min': 0,
    'max': 6,
    'palette': [
        '000000',  # 0 - No flips (black)
        '2ecc71',  # 1 - Disjoint singles (green)
        'f1c40f',  # 2 - Run=2 (yellow)
        'e67e22',  # 3 - Run=3 (orange)
        'e74c3c',  # 4 - Run=4 (red)
        '8e44ad',  # 5 - Run=5 (purple)
        '34495e'   # 6 - Composite (dark slate)
    ]
}

doublet_vis_6yr = {
    'min': 0,
    'max': 3,
    'palette': [
        '000000',  # 0 none
        '00bcd4',  # 1 pattern: A X Y A A *
        'ff9800',  # 2 pattern: A A X Y A A
        'e91e63'   # 3 pattern: * A A X Y A
    ]
}

doublet_mask_vis = {
    'min': 0,
    'max': 1,
    'palette': ['000000', 'ff0000']  # red = detected
}


# -----------------------------
# Example usage (replace these)
# -----------------------------
# year_imgs = [img2017, img2018, img2019, img2020, img2021, img2022]
out = classify_temporal_flip_categories_6yr(images)
pallete_lulc = [
  '000000','ff0000','74ccf4','1ca3ec','0f5e9c',
  'f1c232','38761d','A9A9A9','BAD93E','f59d22',
  'FF9371','b3561d','a9a9a9','84994F'
]



vis_params_lulc = {
    'min': 0,
    'max': 13,
    'palette': pallete_lulc
}

import geemap
Map = geemap.Map()
url = 'https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}'
Map.layout.height = '1000px'
Map.add_tile_layer(url, name="Google Map", attribution="Google")

Map.addLayer(out.select('singletCategory'), singlet_vis, 'singletCategory')
Map.addLayer(out.select('doubletPattern'), doublet_vis_6yr, 'doubletPattern (6yr)')
Map.addLayer(out.select('y1'), vis_params_lulc, 'y1')
Map.addLayer(out.select('y2'), vis_params_lulc, 'y2')
Map.addLayer(out.select('y3'), vis_params_lulc, 'y3')
Map.addLayer(out.select('y4'), vis_params_lulc, 'y4')
Map.addLayer(out.select('y5'), vis_params_lulc, 'y5')
Map.addLayer(out.select('y6'), vis_params_lulc, 'y6')


#Map.addLayer(out.select('doubletPattern'), {}, 'doubletPattern')
#Map.addLayer(out.select('flip'), {}, 'flip')
Map


Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='topright', transp…

In [6]:
images_ = load_aez_temporal_images(aez, start_year, num_years, project)
images_[0].bandNames()

In [2]:


# ------------------------------------------------------------
# 0) Utility: admissibility matrix (B -> A allowed?)
#    Classes used (after your grouping):
#      Built-up=1, Water=2, Tree=6, Barren=7, Crop=8, Scrub=12, Plantation=13
#    This function returns an ee.Image (0/1) for "is B->A allowed".
# ------------------------------------------------------------
def allow_flip(A, B):
    # A and B are ee.Image single-band class labels
    # Allowed pairs extracted from YOUR table (✓). Everything else is disallowed.
    # Note: diagonal (A==B) isn't used for flips anyway.

    BU = ee.Image(1)
    WA = ee.Image(2)
    TR = ee.Image(6)
    BA = ee.Image(7)
    CR = ee.Image(8)
    SC = ee.Image(12)
    PL = ee.Image(13)

    allowed = ee.Image(0).toUint8()

    # Row Built-up: Barren, Crop, Scrub, Plantation are ✓ (and Tree/Water are ✗)
    allowed = allowed.Or(A.eq(BU).And(B.eq(BA)))
    allowed = allowed.Or(A.eq(BU).And(B.eq(CR)))
    allowed = allowed.Or(A.eq(BU).And(B.eq(SC)))
    allowed = allowed.Or(A.eq(BU).And(B.eq(PL)))

    # Row Water: Tree ✓ (others ✗)
    allowed = allowed.Or(A.eq(WA).And(B.eq(TR)))

    # Row Tree/Forest: Built-up, Water, Barren, Crop, Scrub, Plantation are ✓
    allowed = allowed.Or(A.eq(TR).And(B.eq(BU)))
    allowed = allowed.Or(A.eq(TR).And(B.eq(WA)))
    allowed = allowed.Or(A.eq(TR).And(B.eq(BA)))
    allowed = allowed.Or(A.eq(TR).And(B.eq(CR)))
    allowed = allowed.Or(A.eq(TR).And(B.eq(SC)))
    allowed = allowed.Or(A.eq(TR).And(B.eq(PL)))

    # Row Barrenland: Built-up, Tree, Crop, Scrub, Plantation are ✓ (Water ✗)
    allowed = allowed.Or(A.eq(BA).And(B.eq(BU)))
    allowed = allowed.Or(A.eq(BA).And(B.eq(TR)))
    allowed = allowed.Or(A.eq(BA).And(B.eq(CR)))
    allowed = allowed.Or(A.eq(BA).And(B.eq(SC)))
    allowed = allowed.Or(A.eq(BA).And(B.eq(PL)))

    # Row Cropland: Tree, Barren, Scrub are ✓ (Built-up/Water ✗, Plantation ✗)
    allowed = allowed.Or(A.eq(CR).And(B.eq(TR)))
    allowed = allowed.Or(A.eq(CR).And(B.eq(BA)))
    allowed = allowed.Or(A.eq(CR).And(B.eq(SC)))

    # Row Scrub: Tree, Barren, Crop, Plantation are ✓
    allowed = allowed.Or(A.eq(SC).And(B.eq(TR)))
    allowed = allowed.Or(A.eq(SC).And(B.eq(BA)))
    allowed = allowed.Or(A.eq(SC).And(B.eq(CR)))
    allowed = allowed.Or(A.eq(SC).And(B.eq(PL)))

    # Row Plantation: Built-up, Tree, Barren, Scrub are ✓ (Water ✗, Crop ✗)
    allowed = allowed.Or(A.eq(PL).And(B.eq(BU)))
    allowed = allowed.Or(A.eq(PL).And(B.eq(TR)))
    allowed = allowed.Or(A.eq(PL).And(B.eq(BA)))
    allowed = allowed.Or(A.eq(PL).And(B.eq(SC)))

    return allowed.rename('allowFlip')  # 1 if allowed, else 0


# ------------------------------------------------------------
# 1) Singlet correction for 6 years (ABA at centers y2..y5)
# ------------------------------------------------------------
def apply_singlet_corrections_6yr(seq6):
    """
    seq6: ee.Image with bands y1..y6 (grouped classes)
    returns corrected seq6 + a 'changedSinglet' mask
    """
    y1 = seq6.select('y1')
    y2 = seq6.select('y2')
    y3 = seq6.select('y3')
    y4 = seq6.select('y4')
    y5 = seq6.select('y5')
    y6 = seq6.select('y6')

    changed = ee.Image(0).toUint8()

    # Center y2: y1 == y3 and y2 != y1, and B->A allowed where A=y1, B=y2
    A = y1; B = y2
    c2 = y1.eq(y3).And(y2.neq(y1)).And(allow_flip(A, B).eq(1))
    y2c = y2.where(c2, A)
    changed = changed.Or(c2)

    # Center y3: y2 == y4 and y3 != y2
    A = y2c; B = y3  # note: y2 may have been corrected already
    c3 = y2c.eq(y4).And(y3.neq(y2c)).And(allow_flip(A, B).eq(1))
    y3c = y3.where(c3, A)
    changed = changed.Or(c3)

    # Center y4: y3 == y5 and y4 != y3
    A = y3c; B = y4
    c4 = y3c.eq(y5).And(y4.neq(y3c)).And(allow_flip(A, B).eq(1))
    y4c = y4.where(c4, A)
    changed = changed.Or(c4)

    # Center y5: y4 == y6 and y5 != y4
    A = y4c; B = y5
    c5 = y4c.eq(y6).And(y5.neq(y4c)).And(allow_flip(A, B).eq(1))
    y5c = y5.where(c5, A)
    changed = changed.Or(c5)

    out = ee.Image.cat([
        y1.rename('y1'),
        y2c.rename('y2'),
        y3c.rename('y3'),
        y4c.rename('y4'),
        y5c.rename('y5'),
        y6.rename('y6')
    ])

    return out, changed.rename('changedSinglet')


# ------------------------------------------------------------
# 2) Doublet correction for 6 years
#    We enforce:
#      - X != Y, X != A, Y != A
#      - pattern placement (3 placements for 6 years)
#      - AND admissibility: allow_flip(A,X) and allow_flip(A,Y)
# ------------------------------------------------------------
def apply_doublet_corrections_6yr(seq6):
    """
    seq6: ee.Image with bands y1..y6 (grouped classes)
    returns corrected seq6 + a 'changedDoublet' mask + 'doubletPattern' (0..3)
    """
    y1 = seq6.select('y1')
    y2 = seq6.select('y2')
    y3 = seq6.select('y3')
    y4 = seq6.select('y4')
    y5 = seq6.select('y5')
    y6 = seq6.select('y6')

    changed = ee.Image(0).toUint8()

    def xy_constraints(A, X, Y):
        return X.neq(Y).And(X.neq(A)).And(Y.neq(A))

    # (1) A X Y A A *  => A=y1, X=y2, Y=y3, require y4=y5=A
    A = y1; X = y2; Y = y3
    p1 = y4.eq(A).And(y5.eq(A)).And(xy_constraints(A, X, Y))
    p1 = p1.And(allow_flip(A, X).eq(1)).And(allow_flip(A, Y).eq(1))

    # (2) A A X Y A A  => A=y1, X=y3, Y=y4, require y2=y5=y6=A
    A = y1; X = y3; Y = y4
    p2 = y2.eq(A).And(y5.eq(A)).And(y6.eq(A)).And(xy_constraints(A, X, Y))
    p2 = p2.And(allow_flip(A, X).eq(1)).And(allow_flip(A, Y).eq(1))

    # (3) * A A X Y A  => A=y2, X=y4, Y=y5, require y3=A and y6=A (and y2,y3 are two A's on left)
    A = y2; X = y4; Y = y5
    p3 = y3.eq(A).And(y6.eq(A)).And(xy_constraints(A, X, Y))
    p3 = p3.And(allow_flip(A, X).eq(1)).And(allow_flip(A, Y).eq(1))

    # pattern id band (priority order)
    pat = ee.Image(0).toInt().where(p1, 1).where(p2, 2).where(p3, 3).rename('doubletPattern')

    # Apply flips: set both X and Y positions to A, for each pattern
    # Pattern (1): y2,y3 -> y1
    y2c = y2.where(p1, y1)
    y3c = y3.where(p1, y1)
    changed = changed.Or(p1)

    # Pattern (2): y3,y4 -> y1
    y3c = y3c.where(p2, y1)
    y4c = y4.where(p2, y1)
    changed = changed.Or(p2)

    # Pattern (3): y4,y5 -> y2
    y4c = y4c.where(p3, y2)
    y5c = y5.where(p3, y2)
    changed = changed.Or(p3)

    out = ee.Image.cat([
        y1.rename('y1'),
        y2c.rename('y2'),
        y3c.rename('y3'),
        y4c.rename('y4'),
        y5c.rename('y5'),
        y6.rename('y6')
    ])

    return out, changed.rename('changedDoublet'), pat


# ------------------------------------------------------------
# 3) Full correction: singlets first, then doublets (fits your doc's idea
#    of stabilizing easy anomalies before harder ones). :contentReference[oaicite:2]{index=2}
# ------------------------------------------------------------
def temporal_correct_6yr(seq6, roi_fc=None):
    """
    seq6: ee.Image with y1..y6 (grouped)
    roi_fc: optional FeatureCollection for masking output to ROI
    """
    # Step 1: apply singlets
    s_corrected, changedS = apply_singlet_corrections_6yr(seq6)

    # Step 2: apply doublets on the updated sequence
    d_corrected, changedD, pat = apply_doublet_corrections_6yr(s_corrected)

    changedAny = changedS.Or(changedD).rename('changedAny')

    out = d_corrected.addBands([changedS, changedD, changedAny, pat])

    if roi_fc is not None:
        roi_mask = ee.Image.constant(1).clip(roi_fc).selfMask()
        out = out.updateMask(roi_mask)

    return out


In [3]:
images6 = images[:6]

# Group classes per year (water->2, crop intensities->8)
grouped = [group_classes(im, band='predicted_label') for im in images6]
seq6 = stack_6_years(grouped)
corrected = temporal_correct_6yr(seq6, roi_fc=roi)

# Visualize changed pixels (QA)
Map.addLayer(corrected.select('changedAny'), {'min':0,'max':1,'palette':['000000','ff0000']}, 'Any correction')

# Visualize doublet patterns
doublet_vis_6yr = {'min':0,'max':3,'palette':['000000','00bcd4','ff9800','e91e63']}

pallete_lulc = [
  '000000','ff0000','74ccf4','1ca3ec','0f5e9c',
  'f1c232','38761d','A9A9A9','BAD93E','f59d22',
  'FF9371','b3561d','a9a9a9','84994F'
]



vis_params_lulc = {
    'min': 0,
    'max': 13,
    'palette': pallete_lulc
}


Map.addLayer(corrected.select('doubletPattern'), doublet_vis_6yr, 'doubletPattern')
Map.addLayer(corrected.select('y1'), vis_params_lulc, 'y1|')
Map.addLayer(corrected.select('y2'), vis_params_lulc, 'y2|')
Map.addLayer(corrected.select('y3'), vis_params_lulc, 'y3|')
Map.addLayer(corrected.select('y4'), vis_params_lulc, 'y4|')
Map.addLayer(corrected.select('y5'), vis_params_lulc, 'y5|')
Map.addLayer(corrected.select('y6'), vis_params_lulc, 'y6|')
Map

Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='topright', transp…

In [2]:
images[0].bandNames()

In [3]:
import pandas as pd
from datetime import datetime, timedelta
from dateutil.relativedelta import relativedelta

import ee
ee.Authenticate()
ee.Initialize(project='raman-461708')
AEZ_no = 8

roi_boundary = ee.FeatureCollection("users/mtpictd/agro_eco_regions").filter(ee.Filter.eq("ae_regcode", AEZ_no))


chastainBandNames = ['BLUE', 'GREEN', 'RED', 'NIR', 'SWIR1', 'SWIR2']

# Regression model parameters from Table-4. MSI TOA reflectance as a function of OLI TOA reflectance.
msiOLISlopes = [1.0946,1.0043,1.0524,0.8954,1.0049,1.0002]
msiOLIIntercepts = [-0.0107,0.0026,-0.0015,0.0033,0.0065,0.0046]

# Regression model parameters from Table-5. MSI TOA reflectance as a function of ETM+ TOA reflectance.
msiETMSlopes = [1.10601,0.99091,1.05681,1.0045,1.03611,1.04011]
msiETMIntercepts = [-0.0139,0.00411,-0.0024,-0.0076,0.00411,0.00861]

# Regression model parameters from Table-6. OLI TOA reflectance as a function of ETM+ TOA reflectance.
oliETMSlopes =[1.03501,1.00921,1.01991,1.14061,1.04351,1.05271];
oliETMIntercepts = [-0.0055,-0.0008,-0.0021,-0.0163,-0.0045,0.00261]

# Construct dictionary to handle all pairwise combos
chastainCoeffDict = { 'MSI_OLI':[msiOLISlopes,msiOLIIntercepts,1], # check what the third item corresponds to
                    'MSI_ETM':[msiETMSlopes,msiETMIntercepts,1],
                    'OLI_ETM':[oliETMSlopes,oliETMIntercepts,1],
                    'OLI_MSI':[msiOLISlopes,msiOLIIntercepts,0],
                    'ETM_MSI':[msiETMSlopes,msiETMIntercepts,0],
                    'ETM_OLI':[oliETMSlopes,oliETMIntercepts,0]
                    }


'''
Function to mask cloudy pixels in Landsat-7
'''
def maskL7cloud(image):
    qa = image.select('QA_PIXEL')
    mask = qa.bitwiseAnd(1 << 4).eq(0)
    return image.updateMask(mask).select(['B1', 'B2', 'B3' , 'B4' , 'B5' , 'B7']).rename('BLUE', 'GREEN', 'RED' , 'NIR' , 'SWIR1' , 'SWIR2')


'''
Function to mask cloudy pixels in Landsat-8
'''
def maskL8cloud(image):
    qa = image.select('QA_PIXEL')
    mask = qa.bitwiseAnd(1 << 4).eq(0)
    return image.updateMask(mask).select(['B2', 'B3', 'B4' , 'B5' , 'B6' , 'B7']).rename('BLUE', 'GREEN', 'RED' , 'NIR' , 'SWIR1' , 'SWIR2')


'''
Function to mask clouds using the quality band of Sentinel-2 TOA
'''
def maskS2cloudTOA(image):
    qa = image.select('QA60')
    # Bits 10 and 11 are clouds and cirrus, respectively.
    cloudBitMask = 1 << 10
    cirrusBitMask = 1 << 11
    # Both flags should be set to zero, indicating clear conditions.
    mask = qa.bitwiseAnd(cloudBitMask).eq(0).And(qa.bitwiseAnd(cirrusBitMask).eq(0));
    return image.updateMask(mask).select(['B2', 'B3', 'B4', 'B8',  'B11', 'B12']).rename(['BLUE', 'GREEN', 'RED', 'NIR', 'SWIR1', 'SWIR2'])


'''
Get Landsat and Sentinel image collections
'''
def Get_L7_L8_S2_ImageCollections(inputStartDate, inputEndDate, roi_boundary):
    # ------ Landsat 7 TOA
    L7 = ee.ImageCollection('LANDSAT/LE07/C02/T1_TOA') \
            .filterDate(inputStartDate, inputEndDate) \
            .filterBounds(roi_boundary) \
            .map(maskL7cloud)
    # print('\n Original Landsat 7 TOA dataset: \n',L7.limit(1).getInfo())
    # print('Number of images in Landsat 7 TOA dataset: \t',L7.size().getInfo())

    # ------ Landsat 8 TOA
    L8 = ee.ImageCollection('LANDSAT/LC08/C02/T1_TOA') \
            .filterDate(inputStartDate, inputEndDate) \
            .filterBounds(roi_boundary) \
            .map(maskL8cloud)
    # print('\n Original Landsat 8 TOA dataset: \n', L8.limit(1).getInfo())
    # print('Number of images in Landsat 8 TOA dataset: \t',L8.size().getInfo())

    # ------ Sentinel-2 TOA
    S2 = ee.ImageCollection('COPERNICUS/S2_HARMONIZED') \
            .filterDate(inputStartDate, inputEndDate) \
            .filterBounds(roi_boundary)  \
            .map(maskS2cloudTOA)
    # print('\n Original Sentinel-2 TOA dataset: \n',S2.limit(1).getInfo())
    # print('Number of images in Sentinel 2 TOA dataset: \t',S2.size().getInfo())

    return L7, L8, S2


'''
Function to apply model in one direction
'''
def dir0Regression(img,slopes,intercepts):
    return img.select(chastainBandNames).multiply(slopes).add(intercepts)


'''
Applying the model in the opposite direction
'''
def dir1Regression(img,slopes,intercepts):
    return img.select(chastainBandNames).subtract(intercepts).divide(slopes)


'''
Function to correct one sensor to another
'''
def harmonizationChastain(img, fromSensor,toSensor):
    # Get the model for the given from and to sensor
    comboKey = fromSensor.upper() + '_' + toSensor.upper()
    coeffList = chastainCoeffDict[comboKey]
    slopes = coeffList[0]
    intercepts = coeffList[1]
    direction = ee.Number(coeffList[2])

    # Apply the model in the respective direction
    out = ee.Algorithms.If(direction.eq(0),dir0Regression(img,slopes,intercepts),dir1Regression(img,slopes,intercepts))
    return ee.Image(out).copyProperties(img).copyProperties(img,['system:time_start'])


'''
Calibrate Landsat-8 (OLI) and Sentinel-2 (MSI) to Landsat-7 (ETM+)
'''
def Harmonize_L7_L8_S2(L7, L8, S2):
    # harmonization
    harmonized_L8 = L8.map( lambda img: harmonizationChastain(img, 'OLI','ETM') )
    harmonized_S2 = S2.map( lambda img: harmonizationChastain(img, 'MSI','ETM') )

    # Merge harmonized landsat-8 and sentinel-2 to landsat-7 image collection
    harmonized_LandsatSentinel_ic = ee.ImageCollection(L7.merge(harmonized_L8).merge(harmonized_S2))
    # print(harmonized_LandsatSentinel_ic.size().getInfo())
    return harmonized_LandsatSentinel_ic


'''
Add NDVI band to harmonized image collection
'''
def addNDVI(image):
    return image.addBands(image.normalizedDifference(['NIR', 'RED']).rename('NDVI')).float()


'''
Function definitions to get NDVI values at each 16-day composites
'''
def Get_NDVI_image_datewise(harmonized_LS_ic, roi_boundary):
    def get_NDVI_datewise(date):
        empty_band_image = ee.Image(0).float().rename(['NDVI']).updateMask(ee.Image(0).clip(roi_boundary))
        return harmonized_LS_ic.select(['NDVI']) \
                                .filterDate(ee.Date(date), ee.Date(date).advance(16, 'day')) \
                                .merge(empty_band_image)\
                                .median() \
                                .set('system:time_start',ee.Date(date).millis())
    return get_NDVI_datewise

def Get_LS_16Day_NDVI_TimeSeries(inputStartDate, inputEndDate, harmonized_LS_ic, roi_boundary):
    startDate = datetime.strptime(inputStartDate,"%Y-%m-%d")
    endDate = datetime.strptime(inputEndDate,"%Y-%m-%d")

    date_list = pd.date_range(start=startDate, end=endDate, freq='16D').tolist()
    date_list = ee.List( [datetime.strftime(curr_date,"%Y-%m-%d") for curr_date in date_list] )

    LSC =  ee.ImageCollection.fromImages(date_list.map(Get_NDVI_image_datewise(harmonized_LS_ic, roi_boundary)))

    return LSC


'''
Pair available LSC and modis values for each time stamp.
'''
def pairLSModis(lsRenameBands):
    def pair(feature):
        date = ee.Date( feature.get('system:time_start') )
        startDateT = date.advance(-8,'day')
        endDateT = date.advance(8,'day')

        # ------ MODIS VI ( We can add EVI to the band list later )
        modis = ee.ImageCollection('MODIS/061/MOD13Q1') \
                .filterDate(startDateT, endDateT) \
                .select(['NDVI','SummaryQA']) \
                .filterBounds(roi_boundary) \
                .median() \
                .rename(['NDVI_modis', 'SummaryQA_modis'])

        return feature.rename(lsRenameBands).addBands(modis)
    return pair


'''
Function to get Pearson Correlation Coffecient to perform GapFilling
'''
def get_Pearson_Correlation_Coefficients(LSC_modis_paired_ic, roi_boundary, bandList):
    corr = LSC_modis_paired_ic.filterBounds(roi_boundary) \
                                .select(bandList).toArray() \
                                .arrayReduce( reducer = ee.Reducer.pearsonsCorrelation(), axes=[0], fieldAxis=1 ) \
                                .arrayProject([1]).arrayFlatten([['c', 'p']])
    return corr


'''Use print(...) to write to this console.
Fill gaps in LSC timeseries using modis data
'''
def gapfillLSM(LSC_modis_regression_model, LSC_bandName, modis_bandName):
    def peformGapfilling(image):
        offset = LSC_modis_regression_model.select('offset')
        scale = LSC_modis_regression_model.select('scale')
        nodata = -1

        lsc_image = image.select(LSC_bandName)
        modisfit = image.select(modis_bandName).multiply(scale).add(offset)

        mask = lsc_image.mask()#update mask needs an input (no default input from the API document)
        gapfill = lsc_image.unmask(nodata)
        gapfill = gapfill.where(mask.Not(), modisfit)

        '''
        in SummaryQA,
        0: Good data, use with confidence
        1: Marginal data, useful but look at detailed QA for more information
        2: Pixel covered with snow/ice
        3: Pixel is cloudy
        '''
        qc_m = image.select('SummaryQA_modis').unmask(3)  # missing value is grouped as cloud
        w_m  = modisfit.mask().rename('w_m').where(qc_m.eq(0), 0.8)  # default is 0.8
        w_m = w_m.where(qc_m.eq(1), 0.5)   # Marginal
        w_m = w_m.where(qc_m.gte(2), 0.2) # snow/ice or cloudy

        # make sure these modis values are read where there is missing data from LandSat, Sentinel
        w_l = gapfill.mask() # default is 1
        w_l = w_l.where(mask.Not(), w_m)

        return gapfill.addBands(w_l).rename(['gapfilled_'+LSC_bandName,'SummaryQA']) #have NDVI from modis and a summary of clarity for each

    return peformGapfilling


'''
Function to combine LSC with Modis data
'''
def Combine_LS_Modis(LSC):
    lsRenameBands = ee.Image(LSC.first()).bandNames().map( lambda band: ee.String(band).cat('_lsc') )
    LSC_modis_paired_ic = LSC.map( pairLSModis(lsRenameBands) )

    # Output contains scale, offset i.e. two bands
    LSC_modis_regression_model_NDVI = LSC_modis_paired_ic.select(['NDVI_modis', 'NDVI_lsc']) \
                                                            .reduce(ee.Reducer.linearFit())

    corr_NDVI = get_Pearson_Correlation_Coefficients(LSC_modis_paired_ic, roi_boundary, ['NDVI_modis', 'NDVI_lsc'])
    LSMC_NDVI = LSC_modis_paired_ic.map(gapfillLSM(LSC_modis_regression_model_NDVI, 'NDVI_lsc', 'NDVI_modis'))

    return LSMC_NDVI


'''
Mask out low quality pixels
'''
def mask_low_QA(lsmc_image):
    low_qa = lsmc_image.select('SummaryQA').neq(0.2)
    return lsmc_image.updateMask(low_qa).copyProperties(lsmc_image, ['system:time_start'])


'''
Add image timestamp to each image in time series
'''
def add_timestamp(image):
    timeImage = image.metadata('system:time_start').rename('timestamp')
    timeImageMasked = timeImage.updateMask(image.mask().select(0))
    return image.addBands(timeImageMasked)


'''
Perform linear interpolation on missing values
'''
def performInterpolation(image):
    image = ee.Image(image)
    beforeImages = ee.List(image.get('before'))
    beforeMosaic = ee.ImageCollection.fromImages(beforeImages).mosaic()
    afterImages = ee.List(image.get('after'))
    afterMosaic = ee.ImageCollection.fromImages(afterImages).mosaic()

    # Interpolation formula
    # y = y1 + (y2-y1)*((t – t1) / (t2 – t1))
    # y = interpolated image
    # y1 = before image
    # y2 = after image
    # t = interpolation timestamp
    # t1 = before image timestamp
    # t2 = after image timestamp

    t1 = beforeMosaic.select('timestamp').rename('t1')
    t2 = afterMosaic.select('timestamp').rename('t2')
    t = image.metadata('system:time_start').rename('t')
    timeImage = ee.Image.cat([t1, t2, t])
    timeRatio = timeImage.expression('(t - t1) / (t2 - t1)', {
                    't': timeImage.select('t'),
                    't1': timeImage.select('t1'),
                    't2': timeImage.select('t2'),
                })

    interpolated = beforeMosaic.add((afterMosaic.subtract(beforeMosaic).multiply(timeRatio)))
    result = image.unmask(interpolated)
    fill_value = ee.ImageCollection([beforeMosaic, afterMosaic]).mosaic()
    result = result.unmask(fill_value)

    return result.copyProperties(image, ['system:time_start'])


def interpolate_timeseries(S1_TS):
    lsmc_masked = S1_TS.map(mask_low_QA)
    filtered = lsmc_masked.map(add_timestamp)

    # Time window in which we are willing to look forward and backward for unmasked pixel in time series
    timeWindow = 120

    # Define a maxDifference filter to find all images within the specified days. Convert days to milliseconds.
    millis = ee.Number(timeWindow).multiply(1000*60*60*24)
    # Filter says that pick only those timestamps which lie between the 2 timestamps not more than millis difference apart
    maxDiffFilter = ee.Filter.maxDifference(
                                difference = millis,
                                leftField = 'system:time_start',
                                rightField = 'system:time_start',
                                )

    # Filter to find all images after a given image. Compare the image's timstamp against other images.
    # Images ahead of target image should have higher timestamp.
    lessEqFilter = ee.Filter.lessThanOrEquals(
                                leftField = 'system:time_start',
                                rightField = 'system:time_start'
                            )

    # Similarly define this filter to find all images before a given image
    greaterEqFilter = ee.Filter.greaterThanOrEquals(
                                leftField = 'system:time_start',
                                rightField = 'system:time_start'
                            )

    # Apply first join to find all images that are after the target image but within the timeWindow
    filter1 = ee.Filter.And( maxDiffFilter, lessEqFilter )
    join1 = ee.Join.saveAll(
                    matchesKey = 'after',
                    ordering = 'system:time_start',
                    ascending = False
            )
    join1Result = join1.apply(
                    primary = filtered,
                    secondary = filtered,
                    condition = filter1
                    )

    # Apply first join to find all images that are after the target image but within the timeWindow
    filter2 = ee.Filter.And( maxDiffFilter, greaterEqFilter )
    join2 = ee.Join.saveAll(
                    matchesKey = 'before',
                    ordering = 'system:time_start',
                    ascending = True
            )
    join2Result = join2.apply(
                    primary = join1Result,
                    secondary = join1Result,
                    condition = filter2
                    )

    interpolated_S1_TS = ee.ImageCollection(join2Result.map(performInterpolation))

    return interpolated_S1_TS


'''
Function Definition to get Padded NDVI LSMC timeseries image for a given ROI
'''
def Get_Padded_NDVI_TS_Image(startDate, endDate, roi_boundary):
    L7, L8, S2 = Get_L7_L8_S2_ImageCollections(startDate, endDate, roi_boundary)

    harmonized_LS_ic = Harmonize_L7_L8_S2(L7, L8, S2)
    harmonized_LS_ic = harmonized_LS_ic.map(addNDVI)
    LSC = Get_LS_16Day_NDVI_TimeSeries(startDate, endDate, harmonized_LS_ic, roi_boundary)
    LSMC_NDVI = Combine_LS_Modis(LSC)
    Interpolated_LSMC_NDVI = interpolate_timeseries(LSMC_NDVI)
    final_LSMC_NDVI_TS = Interpolated_LSMC_NDVI.select(['gapfilled_NDVI_lsc']).toBands()
    final_LSMC_NDVI_TS = final_LSMC_NDVI_TS.clip(roi_boundary)

    input_bands = final_LSMC_NDVI_TS.bandNames()
    return final_LSMC_NDVI_TS, input_bands


'''
Function definition to compute euclidean distance to each cluster centroid
features ---> clusters
flattened ---> time series image clipped to target area
input_bands ---> band names for time series image
studyarea ---> geometry of region of interest
'''
# Function to get distances as required from each pixel to each cluster centroid
def Get_Euclidean_Distance(cluster_centroids, roi_timeseries_img, input_bands, roi_boundary):

    def wrapper(curr_centroid):
        temp_img = ee.Image()
        curr_centroid = ee.Feature(curr_centroid).setGeometry(roi_boundary)
        temp_fc = ee.FeatureCollection( [curr_centroid] )
        class_img = temp_fc.select(['class']).reduceToImage(['class'], ee.Reducer.first()).rename(['class'])
        def create_img(band_name):
            return temp_fc.select([band_name]).reduceToImage([band_name], ee.Reducer.first()).rename([band_name])

        temp_img = input_bands.map(create_img)
        empty = ee.Image()
        temp_img = ee.Image( temp_img.iterate( lambda img, prev: ee.Image(prev).addBands(img) , empty))

        temp_img = temp_img.select(temp_img.bandNames().remove('constant'))
        distance = roi_timeseries_img.spectralDistance(temp_img, 'sed')
        confidence = ee.Image(1.0).divide(distance).rename(['confidence'])
        distance = distance.addBands(confidence)
        return distance.addBands(class_img.rename(['class'])).set('class', curr_centroid.get('class'))

    return cluster_centroids.map(wrapper)


'''
Function definition to get final prediction image from distance images
'''
def Get_final_prediction_image(distance_imgs_list):
    # Denominator is an image that is sum of all confidences to each cluster
    sum_of_distances = ee.ImageCollection( distance_imgs_list ).select(['confidence']).sum().unmask()
    distance_imgs_ic = ee.ImageCollection( distance_imgs_list ).select(['distance','class'])

    # array is an image where distance band is an array of distances to each cluster centroid and class band is an array of classes associated with each cluster
    array_img = ee.ImageCollection(distance_imgs_ic).toArray()

    axes = {'image': 0, 'band':1}
    sort = array_img.arraySlice(axes['band'], 0, 1)
    sorted = array_img.arraySort(sort)

    # take the first image only
    values = sorted.arraySlice(axes['image'], 0, 1)
    # convert back to an image
    min = values.arrayProject([axes['band']]).arrayFlatten([['distance', 'class']])
    # Extract the hard classification
    predicted_class_img = min.select(1)
    predicted_class_img = predicted_class_img.rename(['predicted_label'])

    return predicted_class_img

## My Helper Functions
def change_clusters(cluster_centroids):
    size = cluster_centroids.size().getInfo()
    features = []
    for i in range(size):
        features.append(ee.Feature(cluster_centroids.toList(size).get(i)).set("class", 13+i))
    return ee.FeatureCollection(features)


def get_cropping_frequency(roi_boundary, startDate, endDate):
    cluster_centroids = ee.FeatureCollection('projects/ee-indiasat/assets/L3_LULC_Clusters/Final_Level3_PanIndia_Clusters')
    ignore_clusters = [12] # remove invalid clusters
    cluster_centroids = cluster_centroids.filter(ee.Filter.Not( ee.Filter.inList('class', ignore_clusters)))
    
    final_LSMC_NDVI_TS, input_bands =  Get_Padded_NDVI_TS_Image(startDate, endDate, roi_boundary)
    distance_imgs_list = Get_Euclidean_Distance(cluster_centroids, final_LSMC_NDVI_TS, input_bands, roi_boundary)
    final_classified_img = Get_final_prediction_image(distance_imgs_list)
    ### adding Cluster values after 12
    #cluster_centroids = change_clusters(cluster_centroids)
    distance_imgs_list = Get_Euclidean_Distance(cluster_centroids, final_LSMC_NDVI_TS, input_bands, roi_boundary)
    final_cluster_classified_img = Get_final_prediction_image(distance_imgs_list)
    final_cluster_classified_img = final_cluster_classified_img.rename(['predicted_cluster'])
    final_classified_img = final_classified_img.addBands(final_cluster_classified_img)
    return final_classified_img, final_LSMC_NDVI_TS

In [None]:
# ============================================================
# Temporal correction (6 years) with:
#  - Grouped correction logic (water + crop grouped)
#  - Admissibility matrix constraint (your ✓/✗ table)
#  - Restore original labels after correction:
#       * Water: use "first A-side" original class (left neighbor)
#       * Crop : use cropping frequency model output (8/9/10/11 per year)
#       * Others: write corrected grouped class directly
# ============================================================

import ee
ee.Authenticate()
ee.Initialize(project='raman-461708')

# -------------------------------
# 1) LOADING YEARLY LULC IMAGES
# -------------------------------
def load_aez_temporal_images(aez, start_year, num_years, project='raman-461708'):
    base = f"projects/{project}/assets/LULC_v4_PanIndia_"
    images = []
    for i in range(num_years):
        y1 = start_year + i
        y2 = y1 + 1
        asset = f"{base}{y1}-07-01_{y2}-06-30"
        img = ee.Image(asset)
        images.append(img.select('predicted_label'))
    return images

# -------------------------------
# 2) CLASS GROUPING
# -------------------------------
ORIG  = [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13]
GROUP = [0, 1, 2, 2, 2, 6, 7, 8, 8,  8,  8,  12, 13]
BAND = 'predicted_label'

def group_classes(img, band=BAND):
    return img.select([band]).remap(ORIG, GROUP).rename('predicted_label')

# -------------------------------
# 3) STACK ORIGINAL + GROUPED (6 years)
# -------------------------------
def stack_6_years_original(year_imgs, band=BAND):
    bands = []
    for i, im in enumerate(year_imgs, start=1):
        bands.append(im.select([band]).rename(f'y{i}o'))
    return ee.Image.cat(bands)

def stack_6_years_grouped(year_imgs, band=BAND):
    grouped = [group_classes(im, band=band) for im in year_imgs]
    bands = []
    for i, im in enumerate(grouped, start=1):
        bands.append(im.rename(f'y{i}'))
    return ee.Image.cat(bands)

# -------------------------------
# 4) ADMISSIBILITY MATRIX (B -> A allowed?)
# Classes (after grouping): Built-up=1, Water=2, Tree=6, Barren=7, Crop=8, Scrub=12, Plantation=13
# Using YOUR matrix exactly.
# -------------------------------
def allow_flip(A, B):
    BU = ee.Image(1)
    WA = ee.Image(2)
    TR = ee.Image(6)
    BA = ee.Image(7)
    CR = ee.Image(8)
    SC = ee.Image(12)
    PL = ee.Image(13)

    allowed = ee.Image(0).toUint8()

    # Built-up row: Barren, Crop, Scrub, Plantation ✓  (Tree/Water ✗)
    allowed = allowed.Or(A.eq(BU).And(B.eq(BA)))
    allowed = allowed.Or(A.eq(BU).And(B.eq(CR)))
    allowed = allowed.Or(A.eq(BU).And(B.eq(SC)))
    allowed = allowed.Or(A.eq(BU).And(B.eq(PL)))

    # Water row: Tree ✓ only
    allowed = allowed.Or(A.eq(WA).And(B.eq(TR)))

    # Tree row: Built-up, Water, Barren, Crop, Scrub, Plantation ✓
    allowed = allowed.Or(A.eq(TR).And(B.eq(BU)))
    allowed = allowed.Or(A.eq(TR).And(B.eq(WA)))
    allowed = allowed.Or(A.eq(TR).And(B.eq(BA)))
    allowed = allowed.Or(A.eq(TR).And(B.eq(CR)))
    allowed = allowed.Or(A.eq(TR).And(B.eq(SC)))
    allowed = allowed.Or(A.eq(TR).And(B.eq(PL)))

    # Barren row: Built-up, Tree, Crop, Scrub, Plantation ✓ (Water ✗)
    allowed = allowed.Or(A.eq(BA).And(B.eq(BU)))
    allowed = allowed.Or(A.eq(BA).And(B.eq(TR)))
    allowed = allowed.Or(A.eq(BA).And(B.eq(CR)))
    allowed = allowed.Or(A.eq(BA).And(B.eq(SC)))
    allowed = allowed.Or(A.eq(BA).And(B.eq(PL)))

    # Crop row: Tree, Barren, Scrub ✓ (Plantation ✗, Built-up/Water ✗)
    allowed = allowed.Or(A.eq(CR).And(B.eq(TR)))
    allowed = allowed.Or(A.eq(CR).And(B.eq(BA)))
    allowed = allowed.Or(A.eq(CR).And(B.eq(SC)))

    # Scrub row: Tree, Barren, Crop, Plantation ✓ (Built-up/Water ✗)
    allowed = allowed.Or(A.eq(SC).And(B.eq(TR)))
    allowed = allowed.Or(A.eq(SC).And(B.eq(BA)))
    allowed = allowed.Or(A.eq(SC).And(B.eq(CR)))
    allowed = allowed.Or(A.eq(SC).And(B.eq(PL)))

    # Plantation row: Built-up, Tree, Barren, Scrub ✓ (Water ✗, Crop ✗)
    allowed = allowed.Or(A.eq(PL).And(B.eq(BU)))
    allowed = allowed.Or(A.eq(PL).And(B.eq(TR)))
    allowed = allowed.Or(A.eq(PL).And(B.eq(BA)))
    allowed = allowed.Or(A.eq(PL).And(B.eq(SC)))

    return allowed.rename('allowFlip')  # 1 allowed else 0

# -------------------------------
# 5) SINGLET (ABA) CORRECTION on GROUPED seq6
#    Also records, per corrected year, "Aorig" = left A-side original label.
# -------------------------------
def apply_singlet_corrections_6yr_with_sources(seq6g, seq6o):
    y  = [seq6g.select(f'y{i}')  for i in range(1, 7)]
    yo = [seq6o.select(f'y{i}o') for i in range(1, 7)]

    y1,y2,y3,y4,y5,y6 = y
    y1o,y2o,y3o,y4o,y5o,y6o = yo

    changed2 = ee.Image(0).toUint8()
    changed3 = ee.Image(0).toUint8()
    changed4 = ee.Image(0).toUint8()
    changed5 = ee.Image(0).toUint8()

    Aorig2 = ee.Image(0).toInt()
    Aorig3 = ee.Image(0).toInt()
    Aorig4 = ee.Image(0).toInt()
    Aorig5 = ee.Image(0).toInt()

    # y2 center: y1==y3, y2!=y1, and allowed
    A = y1; B = y2
    c2 = y1.eq(y3).And(y2.neq(y1)).And(allow_flip(A,B).eq(1))
    y2c = y2.where(c2, y1)
    changed2 = c2.rename('changed_y2').toUint8()
    Aorig2 = y1o.rename('Aorig_y2')  # valid where changed_y2=1

    # y3 center: use y2c
    A = y2c; B = y3
    c3 = y2c.eq(y4).And(y3.neq(y2c)).And(allow_flip(A,B).eq(1))
    y3c = y3.where(c3, y2c)
    changed3 = c3.rename('changed_y3').toUint8()
    Aorig3 = y2o.rename('Aorig_y3')

    # y4 center: use y3c
    A = y3c; B = y4
    c4 = y3c.eq(y5).And(y4.neq(y3c)).And(allow_flip(A,B).eq(1))
    y4c = y4.where(c4, y3c)
    changed4 = c4.rename('changed_y4').toUint8()
    Aorig4 = y3o.rename('Aorig_y4')

    # y5 center: use y4c
    A = y4c; B = y5
    c5 = y4c.eq(y6).And(y5.neq(y4c)).And(allow_flip(A,B).eq(1))
    y5c = y5.where(c5, y4c)
    changed5 = c5.rename('changed_y5').toUint8()
    Aorig5 = y4o.rename('Aorig_y5')

    corrected_g = ee.Image.cat([
        y1.rename('y1'),
        y2c.rename('y2'),
        y3c.rename('y3'),
        y4c.rename('y4'),
        y5c.rename('y5'),
        y6.rename('y6')
    ])

    sources = ee.Image.cat([
        changed2, changed3, changed4, changed5,
        Aorig2, Aorig3, Aorig4, Aorig5
    ])

    return corrected_g, sources

# -------------------------------
# 6) DOUBLET CORRECTION on GROUPED seq6
#    Patterns:
#      (1) A X Y A A *
#      (2) A A X Y A A
#      (3) * A A X Y A
#    plus admissibility for BOTH X->A and Y->A.
#    Also records A-orig sources for water restoration.
# -------------------------------
def apply_doublet_corrections_6yr_with_sources(seq6g, seq6o):
    y1,y2,y3,y4,y5,y6 = [seq6g.select(f'y{i}') for i in range(1,7)]
    y1o,y2o,y3o,y4o,y5o,y6o = [seq6o.select(f'y{i}o') for i in range(1,7)]

    def xy_constraints(A,X,Y):
        return X.neq(Y).And(X.neq(A)).And(Y.neq(A))

    # (1) A X Y A A *
    A = y1; X = y2; Y = y3
    p1 = y4.eq(A).And(y5.eq(A)).And(xy_constraints(A,X,Y))
    p1 = p1.And(allow_flip(A,X).eq(1)).And(allow_flip(A,Y).eq(1))

    # (2) A A X Y A A
    A = y1; X = y3; Y = y4
    p2 = y2.eq(A).And(y5.eq(A)).And(y6.eq(A)).And(xy_constraints(A,X,Y))
    p2 = p2.And(allow_flip(A,X).eq(1)).And(allow_flip(A,Y).eq(1))

    # (3) * A A X Y A
    A = y2; X = y4; Y = y5
    p3 = y3.eq(A).And(y6.eq(A)).And(xy_constraints(A,X,Y))
    p3 = p3.And(allow_flip(A,X).eq(1)).And(allow_flip(A,Y).eq(1))

    pat = ee.Image(0).toInt().where(p1,1).where(p2,2).where(p3,3).rename('doubletPattern')

    # Apply grouped flips
    y2c = y2.where(p1, y1)
    y3c = y3.where(p1, y1)

    y3c = y3c.where(p2, y1)
    y4c = y4.where(p2, y1)

    y4c = y4c.where(p3, y2)
    y5c = y5.where(p3, y2)

    corrected_g = ee.Image.cat([
        y1.rename('y1'),
        y2c.rename('y2'),
        y3c.rename('y3'),
        y4c.rename('y4'),
        y5c.rename('y5'),
        y6.rename('y6')
    ])

    # per-year changed masks due to doublets (only y2..y5 possible)
    changed_y2_d = p1.rename('changed_y2_d').toUint8()
    changed_y3_d = p1.Or(p2).rename('changed_y3_d').toUint8()
    changed_y4_d = p2.Or(p3).rename('changed_y4_d').toUint8()
    changed_y5_d = p3.rename('changed_y5_d').toUint8()

    # Aorig for water restoration:
    # p1/p2 use y1o; p3 uses y2o.
    Aorig_y2_d = y1o.rename('Aorig_y2_d')
    Aorig_y3_d = y1o.rename('Aorig_y3_d')
    Aorig_y4_d = y1o.where(p3, y2o).rename('Aorig_y4_d')
    Aorig_y5_d = y2o.rename('Aorig_y5_d')

    sources = ee.Image.cat([
        changed_y2_d, changed_y3_d, changed_y4_d, changed_y5_d,
        Aorig_y2_d, Aorig_y3_d, Aorig_y4_d, Aorig_y5_d,
        pat
    ])

    return corrected_g, sources

# -------------------------------
# 7) CROPPING FREQUENCY (your function exists already)
#    Assumption from you: it returns predicted_label in {8,9,10,11}.
#    We call it per-year to get cropfreq_y1..cropfreq_y6.
# -------------------------------
def get_cropfreq_images_6yr(roi_boundary, start_year):
    crop_imgs = []
    for i in range(6):
        y1 = start_year + i
        start = f"{y1}-07-01"
        end   = f"{y1+1}-06-30"
        cf_img, _ = get_cropping_frequency(roi_boundary, start, end)  # <-- uses YOUR provided code
        crop_imgs.append(cf_img.select('predicted_label').rename(f'cropfreq_y{i+1}'))
    return ee.Image.cat(crop_imgs)

# -------------------------------
# 8) RESTORE ORIGINAL LABELS AFTER GROUPED CORRECTION
#    - unchanged -> keep original
#    - changed to water-group (2) -> use left A-side original label (from Aorig bands)
#       preference: doublet source overrides singlet source
#    - changed to crop-group (8) -> use cropfreq (8/9/10/11)
#    - changed to others -> write grouped corrected label (same in original)
# -------------------------------
def reconstruct_original_labels_6yr(seq6o, seq6g_before, seq6g_after, sing_src, dbl_src, cropfreq6):
    out_bands = []
    for i in range(1,7):
        yo = seq6o.select(f'y{i}o')
        gb = seq6g_before.select(f'y{i}')
        ga = seq6g_after.select(f'y{i}')

        changed = gb.neq(ga)

        y_final = yo

        # WATER restoration (only possible for years 2..5 from our singlet/doublet ops)
        if i in [2,3,4,5]:
            cd = dbl_src.select(f'changed_y{i}_d')
            cs = sing_src.select(f'changed_y{i}')

            Aorig_d = dbl_src.select(f'Aorig_y{i}_d')
            Aorig_s = sing_src.select(f'Aorig_y{i}')

            # choose Aorig: doublet if changed by doublet else singlet
            Aorig = Aorig_s.where(cd.eq(1), Aorig_d)

            water_restore = changed.And(ga.eq(2))
            y_final = y_final.where(water_restore, Aorig)

        # CROP restoration
        crop_restore = changed.And(ga.eq(8))
        y_final = y_final.where(crop_restore, cropfreq6.select(f'cropfreq_y{i}'))

        # OTHER classes restoration (1,6,7,12,13)
        other_restore = changed.And(ga.neq(2)).And(ga.neq(8))
        y_final = y_final.where(other_restore, ga)

        out_bands.append(y_final.rename(f'y{i}_final'))

    return ee.Image.cat(out_bands)

# -------------------------------
# 9) MAIN DRIVER (copy-paste)
# -------------------------------
def temporal_correct_and_restore_6yr(images6, roi_boundary, start_year, roi_fc=None):
    # stacks
    seq6o = stack_6_years_original(images6)     # y1o..y6o
    seq6g = stack_6_years_grouped(images6)      # y1..y6 grouped

    # singlets
    g1, sing_src = apply_singlet_corrections_6yr_with_sources(seq6g, seq6o)

    # doublets
    g2, dbl_src = apply_doublet_corrections_6yr_with_sources(g1, seq6o)

    # cropping frequency per year
    cropfreq6 = get_cropfreq_images_6yr(roi_boundary, start_year)

    # restore final original labels
    final_orig = reconstruct_original_labels_6yr(seq6o, seq6g, g2, sing_src, dbl_src, cropfreq6)

    # changedAny (grouped)
    changedAny = ee.Image(0).toUint8()
    for i in range(1,7):
        changedAny = changedAny.Or(seq6g.select(f'y{i}').neq(g2.select(f'y{i}')))
    changedAny = changedAny.rename('changedAny')

    out = ee.Image.cat([
        seq6o,        # y1o..y6o
        seq6g,        # y1..y6 grouped input
        g2,           # y1..y6 grouped corrected
        sing_src,     # singlet masks + Aorig_y2..Aorig_y5
        dbl_src,      # doublet masks + Aorig_d + doubletPattern
        cropfreq6,    # cropfreq_y1..cropfreq_y6 (8/9/10/11)
        final_orig,   # y1_final..y6_final (restored original labels)
        changedAny
    ])

    if roi_fc is not None:
        roi_mask = ee.Image.constant(1).clip(roi_fc).selfMask()
        out = out.updateMask(roi_mask)

    return out

# ============================================================
# =====================  EXAMPLE USAGE  ======================
# ============================================================

# Inputs
AEZ_no = 8
start_year = 2017

roi = ee.FeatureCollection("users/mtpictd/agro_eco_regions").filter(ee.Filter.eq("ae_regcode", AEZ_no))

# Load 6 yearly LULC images
images = load_aez_temporal_images(AEZ_no, start_year, num_years=6, project='raman-461708')
images6 = images[:6]

# Run correction + restoration
# NOTE: get_cropping_frequency(...) must already be defined in your notebook/script.
roi = ee.FeatureCollection("users/mtpictd/agro_eco_regions").filter(ee.Filter.eq("ae_regcode", AEZ_no))

out = temporal_correct_and_restore_6yr(images6, roi_boundary=roi, start_year=start_year, roi_fc=roi)

# ============================================================
# =====================  VISUALIZATION  ======================
# ============================================================
import geemap
Map = geemap.Map()
url = 'https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}'
Map.layout.height = '900px'
Map.add_tile_layer(url, name="Google Map", attribution="Google")

# Show where anything changed
Map.addLayer(out.select('changedAny'), {'min':0,'max':1,'palette':['000000','ff0000']}, 'changedAny')

# Show doublet pattern (0..3)
doublet_vis_6yr = {'min':0,'max':3,'palette':['000000','00bcd4','ff9800','e91e63']}
Map.addLayer(out.select('doubletPattern'), doublet_vis_6yr, 'doubletPattern')


pallete_lulc = [
  '000000','ff0000','74ccf4','1ca3ec','0f5e9c',
  'f1c232','38761d','A9A9A9','BAD93E','f59d22',
  'FF9371','b3561d','a9a9a9','84994F'
]



vis_params_lulc = {
    'min': 0,
    'max': 13,
    'palette': pallete_lulc
}

# Compare original vs final (example: year-3)
Map.addLayer(out.select('y1o'), vis_params_lulc, 'y1 original')
Map.addLayer(out.select('y2o'), vis_params_lulc, 'y2 original')
Map.addLayer(out.select('y3o'), vis_params_lulc, 'y3 original')
Map.addLayer(out.select('y4o'), vis_params_lulc, 'y4 original')
Map.addLayer(out.select('y5o'), vis_params_lulc, 'y5 original')
Map.addLayer(out.select('y6o'), vis_params_lulc, 'y6 original')

Map.addLayer(out.select('y1_final'), vis_params_lulc, 'y1 final')
Map.addLayer(out.select('y2_final'), vis_params_lulc, 'y2 final')
Map.addLayer(out.select('y3_final'), vis_params_lulc, 'y3 final')
Map.addLayer(out.select('y4_final'), vis_params_lulc, 'y4 final')
Map.addLayer(out.select('y5_final'), vis_params_lulc, 'y5 final')
Map.addLayer(out.select('y6_final'), vis_params_lulc, 'y6 final')

# You can add other years similarly:
# for i in range(1,7):
#     Map.addLayer(out.select(f'y{i}o'), {}, f'y{i} original')
#     Map.addLayer(out.select(f'y{i}_final'), {}, f'y{i} final')

Map


Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='topright', transp…

TypeError: 'NoneType' object is not iterable

TypeError: 'NoneType' object is not iterable

TypeError: 'NoneType' object is not iterable

In [None]:
#Rough

import ee

ee.Authenticate()
ee.Initialize(project='raman-461708')


# -------------------------------
# 1. LOADING

# -------------------------------

def load_aez_temporal_images(aez, start_year, num_years, project='raman-461708'):
    base = f"projects/{project}/assets/AEZ_{aez}_v4_temporal_3years__"
    base = f"projects/{project}/assets/LULC_v4_PanIndia_"
    images = []
    for i in range(num_years):
        y1 = start_year + i
        y2 = y1 + 1
        asset = f"{base}{y1}-07-01_{y2}-06-30"
        img = ee.Image(asset)
        images.append(img)
    return images

aez = 10
start_year = 2017
num_years = 6
project = 'raman-461708'

roi = ee.FeatureCollection("users/mtpictd/agro_eco_regions")
roi_boundary = roi =  ee.FeatureCollection("users/mtpictd/agro_eco_regions") \
    .filter(ee.Filter.eq("ae_regcode", 10)).geometry()
# 1. Load
images = load_aez_temporal_images(aez, start_year, num_years, project)


# -----------------------------
# 0) Grouping (as you requested)
# -----------------------------
# Original classes present in your data
ORIG =  [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13]

# Grouped mapping:
# water: 2,3,4 -> 2
# crop:  8,9,10,11 -> 8
GROUP = [0, 1, 2, 2, 2, 6, 7, 8, 8,  8,  8,  12, 13]

BAND = 'predicted_label'   # change if your band name differs


def group_classes(img, band=BAND):
    """Map original labels to grouped labels (water merged, crop merged)."""
    g = img.select([band]).remap(ORIG, GROUP).rename('predicted_label')
    return g


# ----------------------------------------------------------
# 1) Stack 6 yearly layers into one 7-band image: y1..y7
# ----------------------------------------------------------
def stack_6_years(year_imgs):
    if len(year_imgs) != 6:
        raise ValueError("Provide exactly 6 yearly images.")

    bands = []
    for i, im in enumerate(year_imgs, start=1):
        bands.append(im.select(['predicted_label']).rename(f'y{i}'))
    return ee.Image.cat(bands)



# ----------------------------------------------------------
# 2) Flip indicator bits (6 bits): f2..f7 where ft = (yt != y(t-1))
#    Encoded into integer mask: bit5..bit0 = f2..f7
# ----------------------------------------------------------


def _count_ones_nbit(mask, nbits):
    bits = [mask.rightShift(k).bitwiseAnd(1) for k in range(nbits)]
    return ee.Image.cat(bits).reduce(ee.Reducer.sum()).rename('singletCount').toInt()

def _has_adjacent_ones(mask):
    # True if any '11' exists
    return mask.bitwiseAnd(mask.leftShift(1)).neq(0).rename('hasAdj11')

def _has_run_k(mask, nbits, k):
    """
    True if there exists a run of k consecutive 1s in an nbits mask.
    """
    base = (1 << k) - 1
    cond = ee.Image(0)
    for start in range(0, nbits - k + 1):
        pat = ee.Image(base << start).toInt()
        cond = cond.Or(mask.bitwiseAnd(pat).eq(pat))
    return cond

def _max_run_length(mask, nbits):
    # For singlet-center mask, nbits=5 (centers y2..y6), so max run is 5
    any1 = mask.neq(0)
    maxRun = ee.Image(0).where(any1, 1)
    for k in range(2, nbits + 1):
        maxRun = maxRun.where(_has_run_k(mask, nbits, k), k)
    return maxRun.rename('singletMaxRun').toInt()

def singlet_category_6yr(seq6):
    """
    ABA singlets for 6-year sequence (y1..y6).

    Centers possible: y2..y5 => 4 centers => 4-bit mask (0..15)

    Output bands:
      - singletCategory (0..6) using same semantics, but max run is 4
      - singletSeqMask (0..15)
      - singletCount   (0..4)
      - hasAdj11
      - singletMaxRun  (0..4)
      - isolatedSingletMask
      - s2..s5
    """
    y = [seq6.select(f'y{i}') for i in range(1, 7)]  # y1..y6

    centers = []
    for idx in range(1, 5):  # idx=1..4 corresponds to y2..y5
        left  = y[idx - 1]
        mid   = y[idx]
        right = y[idx + 1]
        s = left.eq(right).And(mid.neq(left)).rename(f's{idx+1}').toUint8()  # s2..s5
        centers.append(s)

    # Pack s2..s5 into 4-bit mask: s2 is MSB (bit3), s5 is LSB (bit0)
    mask = ee.Image(0).toInt()
    for j, s in enumerate(centers):  # j=0..3 => s2..s5
        mask = mask.bitwiseOr(s.toInt().leftShift(3 - j))
    singletSeqMask = mask.rename('singletSeqMask')

    # count ones in 4-bit mask
    singletCount = _count_ones_nbit(singletSeqMask, 4).rename('singletCount')
    anySinglet = singletSeqMask.neq(0)

    adj11 = _has_adjacent_ones(singletSeqMask)  # works for any bit-length
    maxRun = _max_run_length(singletSeqMask, 4).rename('singletMaxRun')

    # Isolated singlet bits (not part of any adjacent pair)
    bits_in_adj = singletSeqMask.bitwiseAnd(singletSeqMask.leftShift(1))
    bits_in_adj_covered = bits_in_adj.bitwiseOr(bits_in_adj.rightShift(1))
    isolatedMask = singletSeqMask.bitwiseAnd(bits_in_adj_covered.bitwiseNot()).rename('isolatedSingletMask')

    hasIsolated = isolatedMask.neq(0)
    isMax2 = maxRun.eq(2)

    composite = anySinglet.And(isMax2).And(adj11).And(hasIsolated)
    disjointSingles = anySinglet.And(adj11.Not())

    # Category logic (same labels; note max run is now 4, so case 5 won't occur)
    cat = ee.Image(0).toInt() \
        .where(disjointSingles, 1) \
        .where(isMax2.And(disjointSingles.Not()).And(composite.Not()), 2) \
        .where(maxRun.eq(3), 3) \
        .where(maxRun.eq(4), 4) \
        .where(composite, 6) \
        .where(anySinglet.Not(), 0) \
        .rename('singletCategory')

    return ee.Image.cat([
        cat,
        singletSeqMask,
        singletCount,
        adj11,
        maxRun,
        isolatedMask.toInt(),
        ee.Image.cat(centers)  # s2..s5
    ])


def detect_doublet_6yr(seq6):
    """
    Doublet flip detection for 6-year sequence (y1..y6).

    Detects patterns where two consecutive anomalous labels X,Y occur inside an otherwise
    stable class A, with:
      - X != Y
      - X != A and Y != A
      - At least two consecutive A's on one side of the XY block

    Supported placements (length-6):
      (1) A X Y A A *   -> A=y1, X=y2, Y=y3, require y4=y5=A
      (2) A A X Y A A   -> A=y1, X=y3, Y=y4, require y2=y5=y6=A
      (3) * A A X Y A   -> A=y2, X=y4, Y=y5, require y3=A and y6=A (and y2=y3=A => two A on left)

    Output:
      - doubletPattern6 (0..3): 0 none; 1/2/3 indicates which placement matched
      - hasDoublet6 (0/1)
    """
    y1 = seq6.select('y1')
    y2 = seq6.select('y2')
    y3 = seq6.select('y3')
    y4 = seq6.select('y4')
    y5 = seq6.select('y5')
    y6 = seq6.select('y6')

    def xy_constraints(A, X, Y):
        return X.neq(Y).And(X.neq(A)).And(Y.neq(A))

    # (1) A X Y A A *
    A = y1
    X = y2
    Y = y3
    p1 = y4.eq(A).And(y5.eq(A)).And(xy_constraints(A, X, Y))

    # (2) A A X Y A A
    A = y1
    X = y3
    Y = y4
    p2 = y2.eq(A).And(y5.eq(A)).And(y6.eq(A)).And(xy_constraints(A, X, Y))

    # (3) * A A X Y A
    A = y2
    X = y4
    Y = y5
    p3 = y3.eq(A).And(y6.eq(A)).And(xy_constraints(A, X, Y))

    pat = ee.Image(0).toInt() \
        .where(p1, 1) \
        .where(p2, 2) \
        .where(p3, 3) \
        .rename('doubletPattern')

    anyDoublet = p1.Or(p2).Or(p3).rename('hasDoublet')
    return pat.addBands(anyDoublet)


# ----------------------------------------------------------
# 6) Main wrapper: provide 7 yearly images -> outputs categories
# ----------------------------------------------------------
def classify_temporal_flip_categories_6yr(year_imgs):
    grouped = [group_classes(im, band=BAND) for im in year_imgs]
    seq6 = stack_6_years(grouped)

    roi_mask = ee.Image.constant(1).clip(roi).selfMask()

    sing = singlet_category_6yr(seq6).updateMask(roi_mask)
    dbl  = detect_doublet_6yr(seq6).updateMask(roi_mask)

    return seq6.addBands(sing).addBands(dbl)


singlet_vis = {
    'min': 0,
    'max': 6,
    'palette': [
        '000000',  # 0 - No flips (black)
        '2ecc71',  # 1 - Disjoint singles (green)
        'f1c40f',  # 2 - Run=2 (yellow)
        'e67e22',  # 3 - Run=3 (orange)
        'e74c3c',  # 4 - Run=4 (red)
        '8e44ad',  # 5 - Run=5 (purple)
        '34495e'   # 6 - Composite (dark slate)
    ]
}

doublet_vis_6yr = {
    'min': 0,
    'max': 3,
    'palette': [
        '000000',  # 0 none
        '00bcd4',  # 1 pattern: A X Y A A *
        'ff9800',  # 2 pattern: A A X Y A A
        'e91e63'   # 3 pattern: * A A X Y A
    ]
}

doublet_mask_vis = {
    'min': 0,
    'max': 1,
    'palette': ['000000', 'ff0000']  # red = detected
}


# -----------------------------
# Example usage (replace these)
# -----------------------------
# year_imgs = [img2017, img2018, img2019, img2020, img2021, img2022]
out = classify_temporal_flip_categories_6yr(images)
pallete_lulc = [
  '000000','ff0000','74ccf4','1ca3ec','0f5e9c',
  'f1c232','38761d','A9A9A9','BAD93E','f59d22',
  'FF9371','b3561d','a9a9a9','84994F'
]



vis_params_lulc = {
    'min': 0,
    'max': 13,
    'palette': pallete_lulc
}

import geemap
Map = geemap.Map()
url = 'https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}'
Map.layout.height = '1000px'
Map.add_tile_layer(url, name="Google Map", attribution="Google")
\
Map.addLayer(out.select('singletCategory'), singlet_vis, 'singletCategory')
Map.addLayer(out.select('doubletPattern'), doublet_vis_6yr, 'doubletPattern (6yr)')
Map.addLayer(out.select('y1'), vis_params_lulc, 'y1')
Map.addLayer(out.select('y2'), vis_params_lulc, 'y2')
Map.addLayer(out.select('y3'), vis_params_lulc, 'y3')
Map.addLayer(out.select('y4'), vis_params_lulc, 'y4')
Map.addLayer(out.select('y5'), vis_params_lulc, 'y5')
Map.addLayer(out.select('y6'), vis_params_lulc, 'y6')


#Map.addLayer(out.select('doubletPattern'), {}, 'doubletPattern')
#Map.addLayer(out.select('flip'), {}, 'flip')
Map
