In [1]:
import ee

ee.Authenticate()
ee.Initialize(project='raman-461708')


# -------------------------------
# 1. LOADING

# -------------------------------

def load_aez_temporal_images(aez, start_year, num_years, project='raman-461708'):
    base = f"projects/{project}/assets/AEZ_{aez}_v4_temporal_3years__"
    base = f"projects/{project}/assets/LULC_v4_PanIndia_"
    images = []
    for i in range(num_years):
        y1 = start_year + i
        y2 = y1 + 1
        asset = f"{base}{y1}-07-01_{y2}-06-30"
        img = ee.Image(asset)
        images.append(img.select('predicted_label'))
    return images


# -------------------------------
# 2. BACKGROUND (CLASS 0) CORRECTION
# -------------------------------

def fill_background_temporally(images):
    """Fill class 0 using neighboring years."""
    images = list(images)  # copy
    length = len(images)

    # intermediate years
    for i in range(1, length - 1):
        before = images[i - 1]
        middle = images[i]
        after = images[i + 1]

        cond = before.gte(1).And(after.gte(1)).And(middle.eq(0))

        if i == 1:
            middle = middle.where(cond, after)
        else:
            middle = middle.where(cond, before)

        images[i] = middle

    # first year
    first = images[0]
    cond_first = first.eq(0).And(images[1].gte(1))
    images[0] = first.where(cond_first, images[1])

    # last year
    last = images[length - 1]
    cond_last = last.eq(0).And(images[length - 2].gte(1))
    images[length - 1] = last.where(cond_last, images[length - 2])

    return images


# -------------------------------
# 3. BUILD INCONSISTENCY COUNT
# -------------------------------

def compute_inconsistency_count(images):
    """Return zero_image: per-pixel count of suspicious temporal patterns."""
    length = len(images)
    geom = images[0].geometry()
    zero_image = ee.Image.constant(0).clip(geom)

    for i in range(1, length - 1):
        before = images[i - 1]
        middle = images[i]
        after = images[i + 1]

        crops = middle.eq(8).Or(middle.eq(9)).Or(middle.eq(10)).Or(middle.eq(11))
        green = middle.eq(6).Or(crops)

        # shrubs-green-shrubs
        cond1 = before.eq(12).And(after.eq(12)).And(green)

        # water-green-water
        before_water = before.eq(2).Or(before.eq(3)).Or(before.eq(4))
        after_water = after.eq(2).Or(after.eq(3)).Or(after.eq(4))
        cond2 = before_water.And(after_water).And(green)

        # tree-shrub-tree
        cond3 = before.eq(6).And(after.eq(6)).And(middle.eq(12))

        # crop-shrub-crop
        before_crops = before.eq(8).Or(before.eq(9)).Or(before.eq(10)).Or(before.eq(11))
        after_crops = after.eq(8).Or(after.eq(9)).Or(after.eq(10)).Or(after.eq(11))
        cond4 = before_crops.And(after_crops).And(middle.eq(12))

        # crop-barren-crop
        cond5 = before_crops.And(after_crops).And(middle.eq(7))

        # tree-farm-tree
        cond6 = before.eq(6).And(after.eq(6)).And(crops)

        # farm-tree-farm
        cond7 = before_crops.And(after_crops).And(middle.eq(6))

        # BU-tree-BU
        cond8 = before.eq(1).And(after.eq(1)).And(middle.eq(6))

        # tree-BU-tree
        cond9 = before.eq(6).And(after.eq(6)).And(middle.eq(1))

        # BU-farm-BU
        cond10 = before.eq(1).And(after.eq(1)).And(crops)

        # barren-green-barren
        cond11 = before.eq(7).And(after.eq(7)).And(green)

        zero_image = (zero_image
                      .add(cond1).add(cond2).add(cond3).add(cond4).add(cond5)
                      .add(cond6).add(cond7).add(cond8).add(cond9).add(cond10).add(cond11))

    return zero_image


# -------------------------------
# 4. PROCESS CONDITIONS FOR ONE WINDOW
# -------------------------------

def process_conditions(before, middle, after,
                       zero_image, th1, th2,
                       i, length, images, crop_freq_array):
    """Apply rule-based temporal corrections for one (before, middle, after) window."""
    zmask = zero_image.gte(th1).And(zero_image.lte(th2))

    crops_ba = before.eq(8).Or(before.eq(9)).Or(before.eq(10)).Or(before.eq(11))
    crops_af = after.eq(8).Or(after.eq(9)).Or(after.eq(10)).Or(after.eq(11))
    crops_mid = middle.eq(8).Or(middle.eq(9)).Or(middle.eq(10)).Or(middle.eq(11))
    green_mid = middle.eq(6).Or(crops_mid)
    water_ba = before.eq(2).Or(before.eq(3)).Or(before.eq(4))
    water_af = after.eq(2).Or(after.eq(3)).Or(after.eq(4))

    # 1) shrubs-green-shrubs -> shrubs
    cond1 = zmask.And(before.eq(12)).And(after.eq(12)).And(green_mid)

    # 2) water-green-water -> barren
    cond2 = zmask.And(water_ba).And(water_af).And(green_mid)

    # 3) tree-shrub-tree -> tree
    cond3 = zmask.And(before.eq(6)).And(after.eq(6)).And(middle.eq(12))

    # 4) crop-shrub-crop -> crop (from frequency)
    cond4 = zmask.And(crops_ba).And(crops_af).And(middle.eq(12))

    # 5) crop-barren-crop -> crop (from frequency)
    cond5 = zmask.And(crops_ba).And(crops_af).And(middle.eq(7))

    # 6) tree-farm-tree -> tree
    cond6 = zmask.And(before.eq(6)).And(after.eq(6)).And(crops_mid)

    # 7) farm-tree-farm -> crop (from frequency)
    cond7 = zmask.And(crops_ba).And(crops_af).And(middle.eq(6))

    # 8) BU-tree-BU
    cond8 = zmask.And(before.eq(1)).And(after.eq(1)).And(middle.eq(6))

    # 9) tree-BU-tree -> tree
    cond9 = zmask.And(before.eq(6)).And(after.eq(6)).And(middle.eq(1))

    # 10) BU-farm-BU -> crop (from frequency) across 3 years
    cond10 = zmask.And(before.eq(1)).And(after.eq(1)).And(crops_mid)

    # 11) barren-green-barren -> barren
    cond11 = zmask.And(before.eq(7)).And(after.eq(7)).And(green_mid)

    if i != 2:
        middle = middle.where(cond1, 12)
        middle = middle.where(cond2, 7)
        middle = middle.where(cond3, 6)
        middle = middle.where(cond6, 6)

        cropping_frequency_img = crop_freq_array[i]
        middle = middle.where(cond7, cropping_frequency_img)
        middle = middle.where(cond4, cropping_frequency_img)
        middle = middle.where(cond5, cropping_frequency_img)

        middle = middle.where(cond9, 6)

        before = before.where(cond10, crop_freq_array[i - 1])
        middle = middle.where(cond10, crop_freq_array[i])
        after = after.where(cond10, crop_freq_array[i + 1])

        middle = middle.where(cond11, 7)

    # BU-tree-BU with 5-year context
    if i != 1 and i != length - 2:
        cond8_extended = cond8.And(images[i - 2].eq(1).And(images[i + 2].eq(1)))
        middle = middle.where(cond8_extended, 1)

    return before, middle, after


# -------------------------------
# 5. APPLY TEMPORAL CORRECTIONS (ALL YEARS)
# -------------------------------

def apply_temporal_corrections(images, zero_image, crop_freq_array):
    images = list(images)
    length = len(images)

    # Pass 1: pixels with count == 1
    for i in range(1, length - 1):
        before = images[i - 1]
        middle = images[i]
        after = images[i + 1]

        before, middle, after = process_conditions(
            before, middle, after,
            zero_image, th1=1, th2=1,
            i=i, length=length,
            images=images,
            crop_freq_array=crop_freq_array
        )

        images[i - 1] = before
        images[i] = middle
        images[i + 1] = after

    # Pass 2: pixels with count >= 2 (focused range)
    for i in range(1, length - 2):
        before = images[i - 1]
        middle = images[i]
        after = images[i + 1]

        before, middle, after = process_conditions(
            before, middle, after,
            zero_image, th1=2, th2=length - 4,
            i=i, length=length,
            images=images,
            crop_freq_array=crop_freq_array
        )

        images[i - 1] = before
        images[i] = middle
        images[i + 1] = after

    # Pass 3: pixels with count >= 2 (full range cleanup)
    for i in range(1, length - 1):
        before = images[i - 1]
        middle = images[i]
        after = images[i + 1]

        before, middle, after = process_conditions(
            before, middle, after,
            zero_image, th1=2, th2=length - 4,
            i=i, length=length,
            images=images,
            crop_freq_array=crop_freq_array
        )

        images[i - 1] = before
        images[i] = middle
        images[i + 1] = after

    return images


# -------------------------------
# 6. FIRST-YEAR SPECIAL CORRECTIONS
# -------------------------------

def correct_first_year(images, crop_freq_array):
    images = list(images)
    length = len(images)
    first = images[0]

    crops1 = images[1].eq(8).Or(images[1].eq(9)).Or(images[1].eq(10)).Or(images[1].eq(11))
    crops2 = images[2].eq(8).Or(images[2].eq(9)).Or(images[2].eq(10)).Or(images[2].eq(11))

    # BU-farm-farm
    cond_bu_farm_farm = first.eq(1).And(crops1).And(crops2)

    # BU-tree-tree
    cond_bu_tree_tree = first.eq(1).And(images[1].eq(6)).And(images[2].eq(6))

    first = first.where(cond_bu_tree_tree, 6)
    first = first.where(cond_bu_farm_farm, crop_freq_array[0])

    images[0] = first
    return images


# -------------------------------
# 7. EXPORT
# -------------------------------

def export_time_series(images, roi_boundary,
                       final_output_filename_array,
                       final_output_assetid_array,
                       scale, band_name='predicted_label'):
    for i, img in enumerate(images):
        task = ee.batch.Export.image.toAsset(
            image=img.clip(roi_boundary.geometry()),
            description=final_output_filename_array[i],
            assetId=final_output_assetid_array[i],
            pyramidingPolicy={band_name: 'mode'},
            scale=scale,
            maxPixels=1e13,
            crs='EPSG:4326'
        )
        task.start()


# -------------------------------
# 8. PIPELINE WRAPPER
# -------------------------------

def run_temporal_correction_pipeline(aez,
                                     start_year,
                                     num_years,
                                     crop_freq_array,
                                     roi_boundary,
                                     final_output_filename_array,
                                     final_output_assetid_array,
                                     scale,
                                     project='raman-461708'):
    # 1. Load
    images = load_aez_temporal_images(aez, start_year, num_years, project)

    # 2. Fill background 0
    images = fill_background_temporally(images)

    # 3. Build inconsistency count
    zero_img = compute_inconsistency_count(images)

    # 4. Temporal rule-based corrections
    images = apply_temporal_corrections(images, zero_img, crop_freq_array)

    # 5. First-year corrections
    images = correct_first_year(images, crop_freq_array)

    # 6. Export
    export_time_series(images, roi_boundary,
                       final_output_filename_array,
                       final_output_assetid_array,
                       scale)

    return images


In [2]:
aez = 10
start_year = 2017
num_years = 7
project = 'raman-461708'

# 1. Load
images = load_aez_temporal_images(aez, start_year, num_years, project)

# 2. Fill background 0
images = fill_background_temporally(images)

# 3. Build inconsistency count
zero_img = compute_inconsistency_count(images)

In [3]:
images

[<ee.image.Image at 0x7becceb80ad0>,
 <ee.image.Image at 0x7becce0fe950>,
 <ee.image.Image at 0x7becceb34e50>,
 <ee.image.Image at 0x7becceb65bd0>,
 <ee.image.Image at 0x7becceb80190>,
 <ee.image.Image at 0x7becceb80e90>,
 <ee.image.Image at 0x7becceb81350>]

In [4]:
roi = ee.FeatureCollection("users/mtpictd/agro_eco_regions")

In [5]:
import geemap


# --------------------------------------------------
# Assuming you already have:
#   images = [list of yearly AEZ images]
#   zero_image = build_zero_image(images)
# If not, recreate them as before.
# --------------------------------------------------

# --- 1. Build masks ---

# 0 violations: very stable pixels
mask0 = zero_img.eq(0).selfMask()

# exactly 2 violations
mask1 = zero_img.eq(1).selfMask()

# 2 or more violations: noisy pixels
mask2 = zero_img.eq(2).selfMask()

mask3plus = zero_img.eq(3).selfMask()

mask4plus = zero_img.gte(4).selfMask()

# --- 2. Create map ---

# You can center it roughly on India or on your AEZ’s geometry
Map = geemap.Map()


# Simple binary palettes
vis0 = {"palette": ["1a9850"]}     # green for stable
vis1 = {"palette": ["fee08b"]}     # yellow for exactly 1
vis2 = {"palette": ["fee08b"]}    # red for 2+
vis3p = {"palette": ["fee08b"]}    # dark red for 3+
url = 'https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}'
Map.layout.height = '1000px'
Map.add_tile_layer(url, name="Google Map", attribution="Google")
Map.addLayer(mask0, vis0, "Zero violations (==0)")
Map.addLayer(mask1, vis1, "Exactly 1 violations (==1)")
Map.addLayer(mask2, vis2, "2 or more violations (>=2)")
Map.addLayer(mask3plus, vis3p, "3 or more violations (>=3)")
Map.addLayer(mask4plus, vis3p, "4 or more violations (>=4)")

Map


Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='topright', transp…

In [6]:


# images: your list of yearly AEZ images
# zero_img: the one you already built

def build_rule_any_masks(images):
    """Return a dict {rule_name: ee.Image(bool)} where each image marks pixels
       where that rule fired at least once across all triplets."""
    length = len(images)
    geom = images[0].geometry()

    # initialise 11 boolean images as all-false
    rule_any = [ee.Image.constant(0).clip(geom) for _ in range(11)]

    for i in range(1, length - 1):
        before = images[i - 1]
        middle = images[i]
        after  = images[i + 1]

        crops_mid  = middle.eq(8).Or(middle.eq(9)).Or(middle.eq(10)).Or(middle.eq(11))
        green_mid  = middle.eq(6).Or(crops_mid)

        crops_b = before.eq(8).Or(before.eq(9)).Or(before.eq(10)).Or(before.eq(11))
        crops_a = after.eq(8).Or(after.eq(9)).Or(after.eq(10)).Or(after.eq(11))

        water_b = before.eq(2).Or(before.eq(3)).Or(before.eq(4))
        water_a = after.eq(2).Or(after.eq(3)).Or(after.eq(4))

        # define each rule's condition for this triplet
        conds = [
            # 0: shrubs-green-shrubs
            before.eq(12).And(after.eq(12)).And(green_mid),

            # 1: water-green-water
            water_b.And(water_a).And(green_mid),

            # 2: tree-shrub-tree
            before.eq(6).And(after.eq(6)).And(middle.eq(12)),

            # 3: crop-shrub-crop
            crops_b.And(crops_a).And(middle.eq(12)),

            # 4: crop-barren-crop
            crops_b.And(crops_a).And(middle.eq(7)),

            # 5: tree-crop-tree
            before.eq(6).And(after.eq(6)).And(crops_mid),

            # 6: crop-tree-crop
            crops_b.And(crops_a).And(middle.eq(6)),

            # 7: BU-tree-BU
            before.eq(1).And(after.eq(1)).And(middle.eq(6)),

            # 8: tree-BU-tree
            before.eq(6).And(after.eq(6)).And(middle.eq(1)),

            # 9: BU-crop-BU
            before.eq(1).And(after.eq(1)).And(crops_mid),

            # 10: barren-green-barren
            before.eq(7).And(after.eq(7)).And(green_mid),
        ]

        # accumulate 'ever true' per rule
        rule_any = [
            rule_any[j].Or(conds[j])
            for j in range(11)
        ]

    # name each for clarity
    names = [
        "Rule1_Shrub-Green-Shrub",
        "Rule2_Water-Green-Water",
        "Rule3_Tree-Shrub-Tree",
        "Rule4_Crop-Shrub-Crop",
        "Rule5_Crop-Barren-Crop",
        "Rule6_Tree-Crop-Tree",
        "Rule7_Crop-Tree-Crop",
        "Rule8_BU-Tree-BU",
        "Rule9_Tree-BU-Tree",
        "Rule10_BU-Crop-BU",
        "Rule11_Barren-Green-Barren",
    ]

    return {names[j]: rule_any[j].selfMask() for j in range(11)}

rule_any_masks = build_rule_any_masks(images)

# pixels with exactly 1 violation
mask1 = zero_img.eq(1)

rule_single = {
    name: mask1.And(img).selfMask()
    for name, img in rule_any_masks.items()
}

# Add each rule's "single violation" mask as a separate layer
palette = ["ff0000"]  # red

for name, img in rule_single.items():
    Map.addLayer(img, {"palette": palette}, f"1-violation: {name}")


Map

Map(bottom=1012.0, center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='to…

In [7]:
def masked_area(mask, geom, scale=10):
    """Returns area (km²) covered by mask."""
    area_img = mask.multiply(ee.Image.pixelArea())   # m² per pixel
    area = area_img.reduceRegion(
        reducer=ee.Reducer.sum(),
        geometry=geom,
        scale=scale,
        maxPixels=1e13
    ).getInfo()
    print(area)
    # convert m² → km²
    return list(area.values())[0] / 1e6


a0 = masked_area(mask0, roi.geometry())        # zero inconsistency (stable)
a1 = masked_area(mask1, roi.geometry())        # exactly one violation
a2 = masked_area(mask2, roi.geometry())        # exactly two violations
a3 = masked_area(mask3plus, roi.geometry()) 

KeyboardInterrupt: 

In [16]:
import numpy as np

BU,W,T,C,Ba,S,P = 0,1,2,3,4,5,6
classes = [BU,W,T,C,Ba,S,P]    # encode as ints 0–6

def smooth_sequence(obs, lambda_penalty=2):
    T = len(obs)
    n = len(classes)

    dp = np.zeros((T,n))
    back = np.zeros((T,n), dtype=int)

    # initialize year 0
    for j in range(n):
        dp[0,j] = (obs[0]!=j)

    # dynamic programming
    for t in range(1,T):
        for j in range(n):
            penalties = dp[t-1] + (np.arange(n)!=j)*lambda_penalty + (obs[t]!=j)
            #print(penalties)
            back[t,j] = np.argmin(penalties)
            dp[t,j] = np.min(penalties)

    # backtrack
    y = np.zeros(T, dtype=int)
    y[T-1] = np.argmin(dp[T-1])
    for t in reversed(range(T-1)):
        y[t] = back[t+1,y[t+1]]

    return y

smooth_sequence([1,0,1,1,0,0,0], lambda_penalty=1)

array([1, 1, 1, 1, 0, 0, 0])

In [None]:
import ee
ee.Initialize(project='raman-461708')

BAND_NAME = 'predicted_label'  # change if needed

# --------------------------------------------------------------------
# 1. Define grouping: original → grouped
# --------------------------------------------------------------------
orig_values = ee.List([0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13])
group_values = ee.List([
    0,   # 0 -> 0 (background)
    1,   # 1 -> 1 (BU)
    2,   # 2 -> 2 (water)
    2,   # 3 -> 2 (water)
    2,   # 4 -> 2 (water)
    6,   # 6 -> 6 (tree)
    7,   # 7 -> 7 (barren)
    8,   # 8 -> 8 (crop)
    8,   # 9 -> 8 (crop)
    8,   # 10 -> 8 (crop)
    8,   # 11 -> 8 (crop)
    12,  # 12 -> 12 (shrub)
    13   # 13 -> 13 (plantation)
])

# HMM state values (grouped)
CLASS_VALUES = ee.List([0, 1, 2, 6, 7, 8, 12, 13])
N_CLASSES = CLASS_VALUES.size()
LAMBDA_PENALTY = 2  # tune smoothness


def to_grouped(img):
    """Remap fine classes → grouped classes."""
    return img.remap(orig_values, group_values).rename(BAND_NAME)


BU,W,T,C,Ba,S,P = 0,1,2,3,4,5,6
classes = [BU,W,T,C,Ba,S,P]    # encode as ints 0–6

def smooth_sequence(obs, lambda_penalty=2):
    T = len(obs)
    n = len(classes)

    dp = np.zeros((T,n))
    back = np.zeros((T,n), dtype=int)

    # initialize year 0
    for j in range(n):
        dp[0,j] = (obs[0]!=j)

    # dynamic programming
    for t in range(1,T):
        for j in range(n):
            penalties = dp[t-1] + (np.arange(n)!=j)*lambda_penalty + (obs[t]!=j)
            print(penalties)
            back[t,j] = np.argmin(penalties)
            dp[t,j] = np.min(penalties)

    # backtrack
    y = np.zeros(T, dtype=int)
    y[T-1] = np.argmin(dp[T-1])
    for t in reversed(range(T-1)):
        y[t] = back[t+1,y[t+1]]

    return y

def change_pixels(pixels):
    """Apply Viterbi smoothing to a pixel's time series."""
    


def viterbi_smooth_images(images, class_values):
    # Rename each image band uniquely and then combine
    renamed = [img.rename(f'band_{i}') for i, img in enumerate(images)]
    stacked = ee.Image.cat(renamed)
    smoothed_images = stacked.map(change_pixels)
    return smoothed_images

# -------------------------------
# Load your 7-year AEZ stack
# -------------------------------
aez = 10
num_years = 7
start_year = 2017
base = f"projects/raman-461708/assets/AEZ_{aez}_v4_temporal_3years__"
base = f"projects/raman-461708/assets/LULC_v4_PanIndia_"

images_raw = []
for i in range(num_years):
    y1 = start_year + i
    y2 = y1 + 1
    asset = f"{base}{y1}-07-01_{y2}-06-30"
    img = ee.Image(asset).select(BAND_NAME)
    images_raw.append(img)

# Grouped versions (water unified, crop unified)
images_grouped = [to_grouped(img) for img in images_raw]

# Run Viterbi smoothing on grouped labels
smoothed_grouped = viterbi_smooth_images(images_grouped, CLASS_VALUES)


TypeError: can only concatenate str (not "String") to str

In [6]:


# -----------------------------
# 0) Grouping (as you requested)
# -----------------------------
# Original classes present in your data
ORIG =  [0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13]

# Grouped mapping:
# water: 2,3,4 -> 2
# crop:  8,9,10,11 -> 8
GROUP = [0, 1, 2, 2, 2, 6, 7, 8, 8,  8,  8,  12, 13]

BAND = 'predicted_label'   # change if your band name differs


def group_classes(img, band=BAND):
    """Map original labels to grouped labels (water merged, crop merged)."""
    g = img.select([band]).remap(ORIG, GROUP).rename('predicted_label')
    return g


# ----------------------------------------------------------
# 1) Stack 7 yearly layers into one 7-band image: y1..y7
# ----------------------------------------------------------
def stack_7_years(year_imgs):
    """
    year_imgs: list of 7 ee.Image, already grouped to a single band 'lulc'
    """
    if len(year_imgs) != 7:
        raise ValueError("Provide exactly 7 yearly images.")

    bands = []
    for i, im in enumerate(year_imgs, start=1):
        bands.append(im.select(['predicted_label']).rename(f'y{i}'))
    return ee.Image.cat(bands)


# ----------------------------------------------------------
# 2) Flip indicator bits (6 bits): f2..f7 where ft = (yt != y(t-1))
#    Encoded into integer mask: bit5..bit0 = f2..f7
# ----------------------------------------------------------


def _count_ones_nbit(mask, nbits):
    bits = [mask.rightShift(k).bitwiseAnd(1) for k in range(nbits)]
    return ee.Image.cat(bits).reduce(ee.Reducer.sum()).rename('singletCount').toInt()

def _has_adjacent_ones(mask):
    # True if any '11' exists
    return mask.bitwiseAnd(mask.leftShift(1)).neq(0).rename('hasAdj11')

def _has_run_k(mask, nbits, k):
    """
    True if there exists a run of k consecutive 1s in an nbits mask.
    """
    base = (1 << k) - 1
    cond = ee.Image(0)
    for start in range(0, nbits - k + 1):
        pat = ee.Image(base << start).toInt()
        cond = cond.Or(mask.bitwiseAnd(pat).eq(pat))
    return cond

def _max_run_length(mask, nbits):
    # For singlet-center mask, nbits=5 (centers y2..y6), so max run is 5
    any1 = mask.neq(0)
    maxRun = ee.Image(0).where(any1, 1)
    for k in range(2, nbits + 1):
        maxRun = maxRun.where(_has_run_k(mask, nbits, k), k)
    return maxRun.rename('singletMaxRun').toInt()

def singlet_category(seq7):
    """
    Correct singlet detection using ABA on the class sequence.

    Input:
      seq7: ee.Image with bands y1..y7 (GROUPED classes).

    Output bands:
      - singletSeqMask  (0..31): 5-bit mask for centers y2..y6
      - singletCount    (0..5)
      - hasAdj11        (bool) adjacency among singlet centers
      - singletMaxRun   (0..5)
      - isolatedSingletMask (mask bits that are isolated, not in any adjacency)
      - singletCategory (int):
          0: no singlets
          1: disjoint singlets only (no adjacent singlet centers)
          2: max run = 2 (cluster of 2 singlet centers) and not composite
          3: max run = 3
          4: max run = 4
          5: max run = 5
          6: composite (both isolated singlets and a run-2 cluster exist)
    """
    y = [seq7.select(f'y{i}') for i in range(1, 8)]  # y[0]=y1 ... y[6]=y7

    # Singlet centers at years 2..6 (5 centers)
    centers = []
    for idx in range(1, 6):  # idx=1..5 corresponds to y2..y6
        left  = y[idx - 1]
        mid   = y[idx]
        right = y[idx + 1]
        s = left.eq(right).And(mid.neq(left)).rename(f's{idx+1}').toUint8()  # s2..s6
        centers.append(s)

    # Pack s2..s6 into 5-bit mask: s2 is MSB (bit4), s6 is LSB (bit0)
    mask = ee.Image(0).toInt()
    for j, s in enumerate(centers):  # j=0..4 => s2..s6
        mask = mask.bitwiseOr(s.toInt().leftShift(4 - j))
    singletSeqMask = mask.rename('singletSeqMask')

    singletCount = _count_ones_nbit(singletSeqMask, 5)
    anySinglet = singletSeqMask.neq(0)

    adj11 = _has_adjacent_ones(singletSeqMask)
    maxRun = _max_run_length(singletSeqMask, 5)

    # Identify isolated singlet bits (not part of any adjacent pair)
    bits_in_adj = singletSeqMask.bitwiseAnd(singletSeqMask.leftShift(1))
    bits_in_adj_covered = bits_in_adj.bitwiseOr(bits_in_adj.rightShift(1))
    isolatedMask = singletSeqMask.bitwiseAnd(bits_in_adj_covered.bitwiseNot()).rename('isolatedSingletMask')

    hasIsolated = isolatedMask.neq(0)
    isMax2 = maxRun.eq(2)

    # Composite = contains a run-2 cluster AND also has at least one isolated singlet elsewhere
    composite = anySinglet.And(isMax2).And(adj11).And(hasIsolated)

    disjointSingles = anySinglet.And(adj11.Not())

    cat = ee.Image(0).toInt() \
        .where(disjointSingles, 1) \
        .where(isMax2.And(disjointSingles.Not()).And(composite.Not()), 2) \
        .where(maxRun.eq(3), 3) \
        .where(maxRun.eq(4), 4) \
        .where(maxRun.eq(5), 5) \
        .where(composite, 6) \
        .where(anySinglet.Not(), 0) \
        .rename('singletCategory')

    # Also return the per-center bands s2..s6 for debugging/visualization
    return ee.Image.cat([
        cat,
        singletSeqMask,
        singletCount,
        adj11,
        maxRun,
        isolatedMask.toInt(),
        ee.Image.cat(centers)  # bands: s2..s6
    ])




# ----------------------------------------------------------
# 5) Doublet-flip detection (your exact patterns)
#    Pattern: A A X Y A A A, with X!=Y, X!=A, Y!=A
#    Valid configurations require >=2 consecutive A's on one side:
#      (i)  A X Y A A * *     -> check y1=A, y4=y5=A, y2=X, y3=Y
#      (ii) A A X Y A A *     -> check y1=y2=A, y5=y6=A
#      (iii)* A A X Y A A     -> check y2=y3=A, y6=y7=A
#      (iv) * * A A X Y A     -> check y3=y4=A, y7=A and (two A's on left: y3,y4)
# ----------------------------------------------------------
def detect_doublet(seq7):
    y1 = seq7.select('y1')
    y2 = seq7.select('y2')
    y3 = seq7.select('y3')
    y4 = seq7.select('y4')
    y5 = seq7.select('y5')
    y6 = seq7.select('y6')
    y7 = seq7.select('y7')

    def xy_constraints(A, X, Y):
        return X.neq(Y).And(X.neq(A)).And(Y.neq(A))

    # (i) A X Y A A * *
    A = y1
    X = y2
    Y = y3
    p1 = y4.eq(A).And(y5.eq(A)).And(xy_constraints(A, X, Y))

    # (ii) A A X Y A A *
    A = y1
    X = y3
    Y = y4
    p2 = y2.eq(A).And(y5.eq(A)).And(y6.eq(A)).And(xy_constraints(A, X, Y))

    # (iii) * A A X Y A A
    A = y2
    X = y4
    Y = y5
    p3 = y3.eq(A).And(y6.eq(A)).And(y7.eq(A)).And(xy_constraints(A, X, Y))

    # (iv) * * A A X Y A
    A = y3
    X = y5
    Y = y6
    p4 = y4.eq(A).And(y7.eq(A)).And(xy_constraints(A, X, Y))

    # Encode pattern id: 0 none, 1..4 matched (priority order)
    pat = ee.Image(0).toInt() \
        .where(p1, 1) \
        .where(p2, 2) \
        .where(p3, 3) \
        .where(p4, 4) \
        .rename('doubletPattern')

    anyDoublet = p1.Or(p2).Or(p3).Or(p4).rename('hasDoublet')
    return pat.addBands(anyDoublet)


# ----------------------------------------------------------
# 6) Main wrapper: provide 7 yearly images -> outputs categories
# ----------------------------------------------------------
def classify_temporal_flip_categories(year_imgs):
    """
    year_imgs: list of 7 ee.Image with band BAND (original labels).
    Returns ee.Image with:
      - y1..y7 (grouped)
      - flipMask, flipCount, hasAdj11, maxRun, isolatedMask
      - singletCategory (0..6)
      - doubletPattern (0..4), hasDoublet
    """
    grouped = [group_classes(im, band=BAND) for im in year_imgs]
    seq7 = stack_7_years(grouped)


    
    roi_mask = ee.Image.constant(1).clip(roi).selfMask()
    sing = singlet_category(seq7).updateMask(roi_mask)
    dbl = detect_doublet(seq7).updateMask(roi_mask)

    return seq7.addBands(sing).addBands(dbl)




singlet_vis = {
    'min': 0,
    'max': 6,
    'palette': [
        '000000',  # 0 - No flips (black)
        '2ecc71',  # 1 - Disjoint singles (green)
        'f1c40f',  # 2 - Run=2 (yellow)
        'e67e22',  # 3 - Run=3 (orange)
        'e74c3c',  # 4 - Run=4 (red)
        '8e44ad',  # 5 - Run=5 (purple)
        '34495e'   # 6 - Composite (dark slate)
    ]
}

# -----------------------------
# Example usage (replace these)
# -----------------------------
# year_imgs = [img2017, img2018, img2019, img2020, img2021, img2022, img2023]
out = classify_temporal_flip_categories(images)

Map.addLayer(out.select('singletCategory'), singlet_vis, 'singletCategory')
Map.addLayer(out.select('y1'), {}, 'y1')
Map.addLayer(out.select('y2'), {}, 'y2')
Map.addLayer(out.select('y3'), {}, 'y3')
Map.addLayer(out.select('y4'), {}, 'y4')
Map.addLayer(out.select('y5'), {}, 'y5')
Map.addLayer(out.select('y6'), {}, 'y6')
Map.addLayer(out.select('y7'), {}, 'y7')

#Map.addLayer(out.select('doubletPattern'), {}, 'doubletPattern')
#Map.addLayer(out.select('flip'), {}, 'flip')
Map


Map(center=[0, 0], controls=(WidgetControl(options=['position', 'transparent_bg'], position='topright', transp…

In [11]:
out.bandNames()