In [16]:
# Function clean_bagls_masks is detecting obvious outliers of missegmented glottis among sequence frames

import os
import re
import cv2
import numpy as np
from glob import glob

_DIGITS_RE = re.compile(r'(\d+)')

def _stem(path):
    return os.path.basename(path).split('_', 1)[0]

def _stem_num(path):
    s = _stem(path)
    m = _DIGITS_RE.search(s)
    return int(m.group(1)) if m else s

def _block_id(n):
    """
    Map numeric stem n to its 100-frame block id.
    Example: 10101..10200 -> 101 ; 19901..20000 -> 199
    """
    return (int(n) - 1) // 100

def _largest_cc_binary(bin_img):
    num, lbl, stats, _ = cv2.connectedComponentsWithStats(bin_img, connectivity=8)
    if num <= 2:
        return bin_img
    largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
    return (lbl == largest).astype(np.uint8)

def _rolling_ref_area_block(areas, i, window, idxs_in_block):
    """
    Rolling median within the same block (indices in idxs_in_block only).
    """
    # local window but clipped to block indices
    i_pos = idxs_in_block.index(i)
    lo_pos = max(0, i_pos - window)
    hi_pos = min(len(idxs_in_block), i_pos + window + 1)
    neigh_idxs = [j for j in idxs_in_block[lo_pos:hi_pos] if j != i]
    neigh = [areas[j] for j in neigh_idxs if areas[j] > 0]
    if len(neigh) == 0:
        # fallback: median over all non-zero in block
        nz = [areas[j] for j in idxs_in_block if areas[j] > 0]
        return np.median(nz) if len(nz) else 0
    return np.median(neigh)

def clean_bagls_masks(
    root_dir,
    area_jump_thresh=5.0,
    window=5,
    use_neighbor=True,
    print_only=False,
    also_flag_tiny=False,
    tiny_frac=0.15,
    class_glottis=1
):
    """
    Works on a SINGLE folder containing *_rgb.png / *_mask.png pairs.
    All detection and replacement are restricted to the SAME 100-frame 'video' block.
    """

    # 1) Collect and sort by numeric stem
    mask_files = sorted(glob(os.path.join(root_dir, "*_mask.png")), key=_stem_num)
    if not mask_files:
        print("No *_mask.png files found in", root_dir)
        return
    rgb_files  = [os.path.join(root_dir, f"{_stem(mf)}_rgb.png") for mf in mask_files]
    stems_num  = [_stem_num(mf) for mf in mask_files]
    blocks     = [_block_id(n) for n in stems_num]

    # 2) Build index lists per block
    block_to_indices = {}
    for idx, b in enumerate(blocks):
        block_to_indices.setdefault(b, []).append(idx)

    # 3) Precompute glottis areas
    areas = []
    for mf in mask_files:
        m = cv2.imread(mf, cv2.IMREAD_UNCHANGED)
        if m is None:
            areas.append(0); continue
        gl = (m == class_glottis).astype(np.uint8)
        areas.append(int(gl.sum()))

    total_outliers = 0

    # Helpers that respect block
    def is_outlier(i):
        a = areas[i]
        if a <= 0: return False
        idxs_in_block = block_to_indices[blocks[i]]
        ref = _rolling_ref_area_block(areas, i, window, idxs_in_block)
        if ref <= 0: return False
        big  = a > area_jump_thresh * ref
        tiny = (a < tiny_frac * ref) if also_flag_tiny else False
        return big or tiny

    # Precompute 'valid' (non-outlier) flags within block
    valid = [False] * len(mask_files)
    for b, idxs in block_to_indices.items():
        for i in idxs:
            a = areas[i]
            ref = _rolling_ref_area_block(areas, i, window, idxs)
            ok = (a > 0) and (ref > 0) and (a <= area_jump_thresh * ref) and (not (also_flag_tiny and a < tiny_frac * ref))
            valid[i] = ok

    # 4) Process per frame, constrained to its block
    for i, mf in enumerate(mask_files):
        if not is_outlier(i):
            continue
        total_outliers += 1
        stem_bad = _stem(mf)
        bad_rgb, bad_mask = rgb_files[i], mask_files[i]
        idxs_in_block = block_to_indices[blocks[i]]

        # Find nearest valid neighbor *within the same block*
        neighbor_idx = None
        # search outward by rank within the block list (not absolute index)
        i_pos = idxs_in_block.index(i)
        max_r = max(i_pos + 1, len(idxs_in_block) - i_pos)
        for r in range(1, max_r):
            left_pos  = i_pos - r
            right_pos = i_pos + r
            if left_pos >= 0:
                cand = idxs_in_block[left_pos]
                if valid[cand]:
                    neighbor_idx = cand; break
            if right_pos < len(idxs_in_block):
                cand = idxs_in_block[right_pos]
                if valid[cand]:
                    neighbor_idx = cand; break

        if use_neighbor and (neighbor_idx is not None):
            nei_rgb, nei_mask = rgb_files[neighbor_idx], mask_files[neighbor_idx]
            print(f"[OUTLIER] {stem_bad} → neighbor {_stem(nei_mask)} (same block {blocks[i]})")
            if not print_only:
                img = cv2.imread(nei_rgb, cv2.IMREAD_UNCHANGED)
                msk = cv2.imread(nei_mask, cv2.IMREAD_UNCHANGED)
                if img is None or msk is None:
                    print(f"  ! Skip: neighbor files missing for {_stem(nei_mask)}")
                else:
                    cv2.imwrite(bad_rgb, img)
                    cv2.imwrite(bad_mask, msk)
        else:
            # No valid neighbor inside the block → largest-CC cleanup (stays within file)
            action = "no valid neighbor in block" if use_neighbor else "largest-CC cleanup"
            print(f"[OUTLIER] {stem_bad} → {action}; applying largest-CC cleanup (block {blocks[i]})")
            if not print_only:
                m = cv2.imread(bad_mask, cv2.IMREAD_UNCHANGED)
                if m is None:
                    print(f"  ! Skip: cannot read {bad_mask}")
                    continue
                gl = (m == class_glottis).astype(np.uint8)
                gl_cc = _largest_cc_binary(gl)
                fixed = m.copy()
                fixed[m == class_glottis] = 0
                fixed[gl_cc == 1] = class_glottis
                cv2.imwrite(bad_mask, fixed)

    print(f"\nDone. Outliers detected: {total_outliers} (print_only={print_only})")

In [31]:
clean_bagls_masks(r"C:\Users\olegp\miniconda3\envs\proj1\0_jupyter_notebook\data\train_extension",
                  area_jump_thresh=5.0,
                  window=5,
                  use_neighbor=True,   # plan to duplicate neighbor pair
                  print_only=True)

[OUTLIER] 11100 → neighbor 11099 (same block 110)
[OUTLIER] 14199 → neighbor 14198 (same block 141)
[OUTLIER] 14200 → neighbor 14198 (same block 141)
[OUTLIER] 14299 → neighbor 14298 (same block 142)
[OUTLIER] 14300 → neighbor 14298 (same block 142)

Done. Outliers detected: 5 (print_only=True)
