In [None]:
"""
This is is to update the combined masks so as to not generate warning during training
Logging setup complete!
2024-12-25 15:29:20,753 - INFO - Loading data for category: AnnualCrop
2024-12-25 15:29:24,359 - INFO - Loaded data from ../data/training_data/train/AnnualCrop_train.npy with shape (10, 513, 513, 3)
2024-12-25 15:29:24,836 - INFO - Loaded data from ../data/training_data/train/AnnualCrop_train_masks_combined.npy with shape (10, 513, 513)
2024-12-25 15:29:24,838 - WARNING - Mask values exceed num_classes=2. Mapping to valid indices.
2024-12-25 15:29:24,879 - INFO - Loading data for category: Forest
2024-12-25 15:29:29,643 - INFO - Loaded data from ../data/training_data/train/Forest_train.npy with shape (10, 513, 513, 3)
2024-12-25 15:29:30,124 - INFO - Loaded data from ../data/training_data/train/Forest_train_masks_combined.npy with shape (10, 513, 513)
2024-12-25 15:29:30,124 - WARNING - Mask values exceed num_classes=2. Mapping to valid indices.

"""

In [1]:
import os
import numpy as np

# Function to validate and remap combined masks
def validate_and_remap_combined_masks(mask_directory, num_classes):
    """
    Validates and remaps combined masks to ensure they fall within [0, num_classes - 1].

    Parameters:
        mask_directory (str): Path to the directory containing combined mask files.
        num_classes (int): Number of classes expected in the masks.

    Returns:
        None: Prints validation summary and updates masks in place if needed.
    """
    
    # Iterate through all combined mask files in the directory
    for root, _, files in os.walk(mask_directory):
        for file in files:
            if file.endswith("_masks_combined.npy"):
                mask_path = os.path.join(root, file)
                
                # Load the combined mask
                masks = np.load(mask_path)
                print(f"Validating: {mask_path} (Shape: {masks.shape})")

                # Check for unique values in the mask
                unique_values = np.unique(masks)
                print(f"  Unique values before remapping: {unique_values}")

                # Remap values outside the valid range
                remapped_masks = np.clip(masks, 0, num_classes - 1).astype(np.uint8)
                
                # Check for changes
                if not np.array_equal(masks, remapped_masks):
                    print(f"  Remapping values to [0, {num_classes - 1}]...")
                    np.save(mask_path, remapped_masks)
                    print(f"  Updated mask saved: {mask_path}")
                else:
                    print(f"  No remapping needed for: {mask_path}")

                # Validate again after remapping
                final_unique_values = np.unique(remapped_masks)
                print(f"  Unique values after remapping: {final_unique_values}\n")

# Paths and parameters
mask_directory = "../data/training_data"  # Update this path if needed
num_classes = 2  # Set the number of classes used in your model

# Run validation and remapping
validate_and_remap_combined_masks(mask_directory, num_classes)


Validating: ../data/training_data/train/AnnualCrop_train_masks_combined.npy (Shape: (1920, 513, 513))
  Unique values before remapping: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20]
  Remapping values to [0, 1]...
  Updated mask saved: ../data/training_data/train/AnnualCrop_train_masks_combined.npy
  Unique values after remapping: [0 1]

Validating: ../data/training_data/train/Forest_train_masks_combined.npy (Shape: (1920, 513, 513))
  Unique values before remapping: [ 0  3  4  5  6  7  8 11 14 15 16 17 18 19 20]
  Remapping values to [0, 1]...
  Updated mask saved: ../data/training_data/train/Forest_train_masks_combined.npy
  Unique values after remapping: [0 1]

Validating: ../data/training_data/train/HerbaceousVegetation_train_masks_combined.npy (Shape: (1920, 513, 513))
  Unique values before remapping: [ 0  1  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20]
  Remapping values to [0, 1]...
  Updated mask saved: ../data/training_data/train/HerbaceousVegetati