In [None]:
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import numpy as np
import sys
from tqdm import tqdm

Image.MAX_IMAGE_PIXELS = sys.maxsize

from utils.patch import create_image_patches

In [None]:
input_train = "./data/BingRGB/dhaka_train.tif"
input_train_gt = "./data/BingRGB/dhaka_train_gt.tif"
input_test = "./data/BingRGB/dhaka_test.tif"
input_test_gt = "./data/BingRGB/dhaka_test_gt.tif"

train_dir = "./data/BingRGB/train"
test_dir = "./data/BingRGB/test"
val_dir = "./data/BingRGB/val"

In [None]:
# create_and_save_patch(input_train, train_dir)
# create_and_save_patch(input_train_gt, train_dir, postfix='_gt', encode=1)
# create_and_save_patch(input_test, test_dir, stride=512)
# create_and_save_patch(input_test_gt, test_dir, stride=512, postfix='_gt', encode=1)

In [None]:
class_map = {(0, 0, 0):      0,  # bg
            (0, 255, 0):     1,  # farmland
            (0, 0, 255):     2,  # water
            (0, 255, 255):   3,  # forest
            (128, 0, 0):     4,  # urban_structure
            (255, 0, 255):   4,  # rural_built_up
            (255, 0, 0):     4,  # urban_built_up
            (160, 160, 164): 4,  # road
            (255, 255, 0):   5,  # meadow
            (255, 251, 240): 5,  # marshland
            (128, 0, 128):   4  # brick_factory
            }

def encode_bw_mask(rgb_mask, class_map=class_map):
    # Check if the object is a PIL Image object
    assert isinstance(rgb_mask, Image.Image), "Object is not a PIL Image"
    # Check that all keys are tuples of length 3
    assert all(isinstance(k, tuple) and len(k) == 3 for k in class_map.keys()), "Invalid keys in class_map"
    # Check that all values are integers
    assert all(isinstance(v, int) for v in class_map.values()), "Invalid values in class_map"

    rgb_mask_array = np.array(rgb_mask)
    num_classes = len(class_map)
    channels = 3

    assert rgb_mask_array.shape[2] == channels, f"mask should have 3 channels but found {rgb_mask_array.shape[2]}."
    assert len(np.unique(rgb_mask_array)) <= num_classes*channels, "rgb mask has more classes than expected"

    # Label encode the mask
    bw_mask = np.zeros((rgb_mask_array.shape[0], rgb_mask_array.shape[1]), dtype=np.uint8)
    for rgb_val, class_label in class_map.items():
        indices = np.where(np.all(rgb_mask_array == rgb_val, axis=-1))
        bw_mask[indices] = class_label

    assert len(bw_mask.shape) == 2, f"Invalid Shape {bw_mask.shape}"
    assert len(np.unique(bw_mask)) <= num_classes, f"Invalid number of classes {len(np.unique(bw_mask))}"

    bw_mask = Image.fromarray(bw_mask)

    return bw_mask

In [None]:
func = encode_bw_mask

In [None]:

create_image_patches(input_train, train_dir)
create_image_patches(input_train_gt, train_dir, patch_name_suffix="_gt", patch_function=func)
create_image_patches(input_test, test_dir)
create_image_patches(input_test_gt, test_dir, patch_name_suffix="_gt", patch_function=func)
create_image_patches(input_test, val_dir)
create_image_patches(input_test_gt, val_dir, patch_name_suffix="_gt", patch_function=func)