## Essencial Functions and Libraries

In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import torch
from pathlib import Path
import numpy as np
import tifffile as tiff
from difflib import get_close_matches

## Read data

In [22]:
# Main dataset directory
dataset_dir = Path('spotlite_dataset_loca1_date1_dir1')

images_dir = dataset_dir / 'images'
masks_dir = dataset_dir / 'masks'

assert images_dir.exists(), f"Image folder not found: {images_dir}"
assert masks_dir.exists(), f"Mask folder not found: {masks_dir}"

# Collect all image and mask files (multiple extensions)
image_paths = sorted(
    [*images_dir.glob('*.tif'), *images_dir.glob('*.tiff')],
    key=lambda p: p.stem
)
mask_paths = sorted(
    [*masks_dir.glob('*.tif'), *masks_dir.glob('*.tiff'),
     *masks_dir.glob('*.png'), *masks_dir.glob('*.jpg')],
    key=lambda p: p.stem
)

# Dictionary of masks by base name
mask_dict = {p.stem: p for p in mask_paths}
unused_masks = set(mask_dict.keys())

imgs_array, masks_array = [], []

for img_path in image_paths:
    stem = img_path.stem
    mask_path = None

    # 1) Exact match
    if stem in mask_dict:
        mask_path = mask_dict[stem]
        unused_masks.discard(stem)
    else:
        # 2) Approximate match
        candidates = get_close_matches(stem, mask_dict.keys(), n=1, cutoff=0.6)
        if candidates:
            sel = candidates[0]
            mask_path = mask_dict[sel]
            unused_masks.discard(sel)
        else:
            # 3) No match found; just notify and skip
            print(f"Warning: no mask found for {img_path.name}; skipping.")
            continue

    # Read image and mask
    img = tiff.imread(str(img_path))
    msk = tiff.imread(str(mask_path))

    # Check dimensions
    if img.shape[:2] != msk.shape:
        raise ValueError(
            f"Incompatible dimensions: {img_path.name} {img.shape[:2]} vs "
            f"{mask_path.name} {msk.shape}"
        )

    imgs_array.append(img)
    masks_array.append(msk)

# Remaining masks
if unused_masks:
    print("Warning: these masks were not used (no corresponding image):")
    for s in sorted(unused_masks):
        print("  ", s)

# Stack arrays
imgs = np.stack(imgs_array, axis=0)   # (N, H, W, C)
masks = np.stack(masks_array, axis=0)  # (N, H, W)

print(f'Final: {len(imgs)} pairs loaded.')
print(f'Shape of imgs: {imgs.shape}')
print(f'Shape of masks: {masks.shape}')


Final: 391 pairs loaded.
Shape of imgs: (391, 512, 512, 4)
Shape of masks: (391, 512, 512)


## Extract features and labels

In [None]:
def remap_mask(mask, remapping=False):
    if not torch.is_tensor(mask):
        mask = torch.from_numpy(mask)
    
    if not remapping:
        return mask

    vegetacao_classes = [2, 3, 4]
    remapped = torch.zeros_like(mask)

    for cls in vegetacao_classes:
        remapped[mask == cls] = 1

    return remapped.long()

masks = remap_mask(masks, remapping=True)
np.unique_labels = np.unique(masks)
print(f'Unique labels in masks: {np.unique_labels}')


N, H, W, C = imgs.shape
X = np.array(imgs) 
y = np.array(masks)


Unique labels in masks: [0 1]


## Split training/testing

In [24]:
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=42, stratify=None
)

print(f'shape of X_train: {X_train.shape}')
print(f'shape of y_train: {y_train.shape}')
print(f'shape of X_test: {X_test.shape}')
print(f'shape of y_test: {y_test.shape}')

shape of X_train: (293, 512, 512, 4)
shape of y_train: (293, 512, 512)
shape of X_test: (98, 512, 512, 4)
shape of y_test: (98, 512, 512)


## Train Random Forest 

In [53]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import resample

X_flat = X_train.reshape(-1, 4)         
y_flat = y_train.reshape(-1)            

X_sub, y_sub = resample(
    X_flat, y_flat,
    n_samples=5_000_000,
    random_state=42
)

clf = RandomForestClassifier(
    n_estimators=20,
    max_depth=15,
    class_weight='balanced',
    n_jobs=-1,
    random_state=42
)

clf.fit(X_sub, y_sub)

X_test_flat = X_test.reshape(-1, 4)
probs = clf.predict_proba(X_test_flat)[:, 1]
preds = clf.predict(X_test_flat)

confidence_mask = probs > 0.5
high_conf_preds = preds.copy()
high_conf_preds[~confidence_mask] = 255

## Plot

In [55]:
import os
import matplotlib.pyplot as plt
import numpy as np

def stretch_rgb(rgb):
    out = np.zeros_like(rgb, dtype=np.float32)
    for k in range(3):
        p2, p98 = np.percentile(rgb[...,k], (2, 98))
        if p98 > p2:
            out[...,k] = np.clip((rgb[...,k] - p2)/(p98 - p2), 0, 1)
    return out

def save_comparisons(X_rgb, y_true, y_pred, save_dir, indices=None, prefix='img'):
    os.makedirs(save_dir, exist_ok=True)
    N = X_rgb.shape[0]
    indices = indices or list(range(N))

    for idx in indices:
        img = X_rgb[idx]
        gt = y_true[idx]
        pred = y_pred[idx]

        # Detecta e extrai RGB
        if img.ndim == 3 and img.shape[0] in (3, 4):  # CHW
            rgb = img[:3].transpose(1, 2, 0)
        elif img.ndim == 3 and img.shape[2] >= 3:      # HWC
            rgb = img[..., :3]
        else:
            raise ValueError("Unsupported image shape: {}".format(img.shape))

        # Normaliza
        if np.issubdtype(rgb.dtype, np.integer):
            rgb = rgb.astype(np.float32)
        rgb = stretch_rgb(rgb)

        fig, axs = plt.subplots(1, 3, figsize=(8, 6))
        axs[0].imshow(rgb)
        axs[0].set_title("RGB Image")
        axs[1].imshow(gt, cmap='gray', vmin=0, vmax=1)
        axs[1].set_title("Ground Truth")
        axs[2].imshow(pred == 1, cmap='gray', vmin=0, vmax=1)
        axs[2].set_title("Prediction (binary)")
        for ax in axs:
            ax.axis('off')
        plt.tight_layout()

        # Salva imagem
        filename = f"{prefix}_{idx:03d}.png"
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath, bbox_inches='tight', dpi=150)
        plt.close(fig)


n_imgs = X_test.shape[0] 
H, W = X_test.shape[1], X_test.shape[2]  # 512, 512
high_conf_reshaped = high_conf_preds.reshape(n_imgs, H, W)

save_comparisons(
    X_rgb=X_test,
    y_true=y_test,
    y_pred=high_conf_reshaped,
    save_dir='comparisons',
    indices=list(range(20))  # ou None para salvar todos
)
