In [1]:
import os
import cv2
import torch as T
from torch.utils.data import Dataset, DataLoader
import numpy as np
from glob import glob
import re
from skimage.filters import threshold_multiotsu, threshold_otsu
from tqdm.notebook import tqdm
import time
import matplotlib.pyplot as plt

In [2]:
# 📁 Set your base image directory
BASE_PATH = r"D:\AAU Internship\Code\CWF-788\IMAGE512x384"
DEVICE = T.device("cuda" if T.cuda.is_available() else "cpu")
print(f"🚀 Using device: {DEVICE}")

# 📦 Dataset class
class VegetationDataset(Dataset):
    def __init__(self, folder_path):
        self.image_paths = sorted(
            glob(os.path.join(folder_path, "*.jpg")) + glob(os.path.join(folder_path, "*.png")),
            key=lambda x: [int(t) if t.isdigit() else t for t in re.split(r'(\d+)', x)]
        )

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = cv2.imread(img_path) 
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
        img = cv2.resize(img, (512, 384), interpolation=cv2.INTER_AREA)
        img = img.astype(np.float32) / 255.0
        img_tensor = T.from_numpy(img).permute(2, 0, 1)

        return img_tensor.to(DEVICE), img_path

def get_loader(folder_path, batch_size=4):
    dataset = VegetationDataset(folder_path)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return loader

train_path = os.path.join(BASE_PATH, "train")
val_path   = os.path.join(BASE_PATH, "validation")
test_path  = os.path.join(BASE_PATH, "test")
train_loader = get_loader(train_path)
val_loader   = get_loader(val_path)
test_loader  = get_loader(test_path)

🚀 Using device: cuda


In [3]:
def compute_vegetation_indices(batch_imgs):
    """
    Compute 6 vegetation indices from a batch of RGB images.
    Input:  batch_imgs [B, 3, H, W], float32 in range [0, 1]
    Output: indices_tensor [B, 6, H, W]
    """
    def normalize_index(idx):
        idx_min = idx.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]
        idx_max = idx.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]
        
        return (idx - idx_min) / (idx_max - idx_min + 1e-6)

    def enhance_contrast_clahe(fused_tensor):
        enhanced = []
        fused_np = fused_tensor.squeeze(1).cpu().numpy()  # [B, H, W]
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    
        for img in fused_np:
            img_uint8 = np.uint8(img * 255)
            img_eq = clahe.apply(img_uint8)
            enhanced.append(img_eq / 255.0)
            
    R = batch_imgs[:, 0, :, :]
    G = batch_imgs[:, 1, :, :]
    B = batch_imgs[:, 2, :, :]

    eps = 1e-6  # to prevent division by zero

    # 1. ExG = 2G - R - B
    ExG = 2 * G - R - B

    # 2. ExR = 1.4R - G
    ExR = 1.4 * R - G

    # 3. CIVE = 0.441R - 0.811G + 0.385B + 18.787
    CIVE = 0.441 * R - 0.811 * G + 0.385 * B + 18.787

    # 4. VEG = G / (R^0.667 * B^0.333 + eps)
    VEG = G / ((R**0.667) * (B**0.333) + eps)

    # 5. NDI = (G - R) / (G + R + eps)
    NDI = (G - R) / (G + R + eps)

    # 6. GLI = (2G - R - B) / (2G + R + B + eps)
    GLI = (2 * G - R - B) / (2 * G + R + B + eps)

    # 7. AGRI = (G - B) / (G + B + eps)
    AGRI = (G - B) / (G + B + eps)

    # 8. VARI = (G - R) / (G + R - B + eps)
    VARI = (G - R) / (G + R - B + eps)
    
    # 9. MVI = (2G - R - B) / (2G + R + B + eps) - CIVE
    MVI = ((2 * G - R - B) / (2 * G + R + B + eps)) - CIVE

    #10 BGI = (G - B) / (G + B + eps)
    BGI = (G - B) / (G + B + eps)

    #11 CIg = G / (R + eps)
    CIg = G / (R + eps)

    ExG  = normalize_index(ExG)
    ExR  = normalize_index(ExR)
    CIVE = normalize_index(CIVE)
    VEG  = normalize_index(VEG)
    NDI  = normalize_index(NDI)
    GLI  = normalize_index(GLI)
    AGRI = normalize_index(AGRI)
    VARI = normalize_index(VARI)
    MVI  = normalize_index(MVI)
    BGI  = normalize_index(BGI)
    CIg  = normalize_index(CIg)

    fused = (
        + 0.9 * GLI     # good for thin & yellowish-green vegetation
        + 0.7 * VARI    # excellent under variable lighting
        + 0.6 * NDI     # green-red contrast
        + 0.4 * CIVE    # bright/contrast patches (limited)
        + 0.4 * VEG     # strong for crops
        + 0.3 * AGRI    # soil separation
        + 0.2 * CIg     # helps crops & partial green
        - 0.6 * ExR     # suppress reddish areas
        - 0.5 * MVI     # suppress flat/dark patches
        - 0.4 * BGI     # suppress bluish/purple
        - 0.3 * R       # pipe suppressor
    )

    
    return fused.unsqueeze(1)  # [B, 1, H, W]

In [4]:
'''
def fit_pca_on_sample(indices_tensor, n_components=1, sample_size=700):
    """
    indices_tensor: [B, 6, H, W]
    """
    B, C, H, W = indices_tensor.shape
    reshaped = indices_tensor.permute(0, 2, 3, 1).reshape(-1, C)  # [B*H*W, 6]
    
    # Sample randomly to limit fitting time
    sample = reshaped[T.randperm(reshaped.shape[0])[:sample_size * H]]
    pca = PCA(n_components=n_components)
    pca.fit(sample.cpu().numpy())
    
    print("Explained variance ratio:", pca.explained_variance_ratio_)
    
    return pca

# 2️⃣ Apply PCA to full batch
def apply_pca_to_batch(indices_tensor, pca):
    """
    indices_tensor: [B, 6, H, W]
    Returns: [B, 1, H, W] (PCA projected grayscale)
    """
    B, C, H, W = indices_tensor.shape
    reshaped = indices_tensor.permute(0, 2, 3, 1).reshape(-1, C)  # [B*H*W, 6]

    projected = pca.transform(reshaped.cpu().numpy())  # [B*H*W, 1]
    out = T.tensor(projected, dtype=T.float32).reshape(B, H, W)
    return out.unsqueeze(1).to(indices_tensor.device)  # [B, 1, H, W]
'''

'\ndef fit_pca_on_sample(indices_tensor, n_components=1, sample_size=700):\n    """\n    indices_tensor: [B, 6, H, W]\n    """\n    B, C, H, W = indices_tensor.shape\n    reshaped = indices_tensor.permute(0, 2, 3, 1).reshape(-1, C)  # [B*H*W, 6]\n\n    # Sample randomly to limit fitting time\n    sample = reshaped[T.randperm(reshaped.shape[0])[:sample_size * H]]\n    pca = PCA(n_components=n_components)\n    pca.fit(sample.cpu().numpy())\n\n    print("Explained variance ratio:", pca.explained_variance_ratio_)\n\n    return pca\n\n# 2️⃣ Apply PCA to full batch\ndef apply_pca_to_batch(indices_tensor, pca):\n    """\n    indices_tensor: [B, 6, H, W]\n    Returns: [B, 1, H, W] (PCA projected grayscale)\n    """\n    B, C, H, W = indices_tensor.shape\n    reshaped = indices_tensor.permute(0, 2, 3, 1).reshape(-1, C)  # [B*H*W, 6]\n\n    projected = pca.transform(reshaped.cpu().numpy())  # [B*H*W, 1]\n    out = T.tensor(projected, dtype=T.float32).reshape(B, H, W)\n    return out.unsqueeze(1)

In [5]:
def normalize_pca_tensor(pca_tensor):
    # [B, 1, H, W] → normalize each image to [0, 1]
    B = pca_tensor.shape[0]
    normed = []

    for i in range(B):
        img = pca_tensor[i]
        img_min = img.min()
        img_max = img.max()
        if (img_max - img_min) < 1e-5 or T.isnan(img_max - img_min):
            # Skip normalization if image is nearly constant
            normed_img = T.zeros_like(img)
        else:
            normed_img = (img - img_min) / (img_max - img_min)

        normed.append(normed_img)

    return T.stack(normed, dim=0)


In [6]:
def apply_thresholding(pca_tensor):
    masks = []
    pca_np = pca_tensor.squeeze(1).cpu().numpy()

    for img in pca_np:
        try:
            thresholds = threshold_multiotsu(img, classes=2)
            mask = (img > thresholds[0]).astype(np.uint8)
        except:
            # fallback in case of flat image
            try:
                thresh = threshold_otsu(img)
                mask = (img > thresh).astype(np.uint8)
            except:
                mask = np.zeros_like(img, dtype=np.uint8)

        masks.append(mask)

    return masks


In [7]:
def refine_masks_with_morphops(mask_list, apply_opening=True, apply_closing=True, dilate=False):
    """
    Refine a list of binary masks using Morphological Operations.
    Input: list of [H, W] uint8 masks (0 and 1)
    Output: list of refined [H, W] masks
    """
    kernel = np.ones((5, 5), np.uint8)
    refined = []

    for mask in mask_list:
        mask_uint8 = (mask * 255).astype(np.uint8)

        if apply_opening:
            mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_OPEN, kernel)

        if apply_closing:
            mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel)

        if dilate:
            mask_uint8 = cv2.dilate(mask_uint8, kernel, iterations=1)

        refined.append(mask_uint8 // 255)  # Normalize back to 0 or 1

    return refined


In [8]:
def save_vegetation_masks(masks, paths, save_root="Vegetation_Masks"):
    os.makedirs(save_root, exist_ok=True)

    for mask, img_path in zip(masks, paths):
        # Extract base filename
        base_filename = os.path.splitext(os.path.basename(img_path))[0] + ".png"

        # Determine subfolder name from input path
        if "train" in img_path:
            subfolder = "train"
        elif "validation" in img_path:
            subfolder = "val"
        elif "test" in img_path:
            subfolder = "test"
        else:
            raise ValueError(f"❌ Cannot determine subfolder for: {img_path}")

        # Create subfolder path
        save_dir = os.path.join(save_root, subfolder)
        os.makedirs(save_dir, exist_ok=True)

        # Convert mask to 0-255 grayscale
        grayscale_mask = (mask * 255).astype(np.uint8)

        # Save the mask
        save_path = os.path.join(save_dir, base_filename)
        cv2.imwrite(save_path, grayscale_mask)


In [9]:
def debug_image(img, mask, index):
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(img, cmap='gray')
    plt.title(f'PCA [{index}]')
    plt.subplot(1, 2, 2)
    plt.imshow(mask, cmap='gray')
    plt.title(f'Mask [{index}]')
    plt.show()

In [10]:
def run_vegetation_masking_pipeline(dataloader, set_name):
    """
    set_name: one of ["train", "val", "test"]
    Returns:
        total_images (int), total_time (float in seconds)
    """
    print(f"\n🚀 Starting vegetation masking for {set_name} set...")
    total_images = 0
    start_time = time.time()

    for batch_imgs, batch_paths in tqdm(dataloader, desc=f"🌱 Processing {set_name}"):
        total_images += len(batch_imgs)

        # Step 1: Compute vegetation indices
        veg_fused = compute_vegetation_indices(batch_imgs)  # [B, 1, H, W]
        
        # Step 2: No PCA — normalize the fused tensor
        veg_fused = normalize_pca_tensor(veg_fused)
        
        # Optional debug
#        for i in range(min(3, veg_fused.shape[0])):
#            print(f"🧪 Fused range [{i}]: min={veg_fused[i].min().item():.4f}, max={veg_fused[i].max().item():.4f}")
        
        # Step 3: Threshold
        binary_masks = apply_thresholding(veg_fused)
        
        # Step 4: Morph ops
#        refined_masks = refine_masks_with_morphops(binary_masks)
#        for i, mask in enumerate(refined_masks):
#            if np.sum(mask) == 0:
#                print(f"⚠️ Empty mask after morphOps at index {i}")

        # Step 6: Save
        save_vegetation_masks(binary_masks, batch_paths)

    end_time = time.time()
    total_time = end_time - start_time
    ms_per_image = (total_time / total_images) * 1000

    print(f"✅ Completed {total_images} images in {total_time:.2f} sec")
    print(f"⚡ Avg processing time: {ms_per_image:.2f} ms/image\n")

    return total_images, total_time


In [11]:
# Aggregate total timing across all splits
total_imgs_all = 0
total_time_all = 0

for loader, name in zip([train_loader, val_loader, test_loader], ["train", "val", "test"]):
    imgs, time_taken = run_vegetation_masking_pipeline(loader, name)
    total_imgs_all += imgs
    total_time_all += time_taken



🚀 Starting vegetation masking for train set...


🌱 Processing train:   0%|          | 0/100 [00:00<?, ?it/s]

✅ Completed 400 images in 7.26 sec
⚡ Avg processing time: 18.16 ms/image


🚀 Starting vegetation masking for val set...


🌱 Processing val:   0%|          | 0/22 [00:00<?, ?it/s]

✅ Completed 88 images in 3.55 sec
⚡ Avg processing time: 40.29 ms/image


🚀 Starting vegetation masking for test set...


🌱 Processing test:   0%|          | 0/75 [00:00<?, ?it/s]

✅ Completed 300 images in 7.56 sec
⚡ Avg processing time: 25.20 ms/image



In [12]:
# Final performance summary
overall_ms_per_image = (total_time_all / total_imgs_all) * 1000
print(f"🧠 Total: {total_imgs_all} images processed in {total_time_all:.2f} seconds")
print(f"🚀 Overall Speed: {overall_ms_per_image:.2f} ms/image")

🧠 Total: 788 images processed in 18.37 seconds
🚀 Overall Speed: 23.31 ms/image
