In [None]:
from PIL import Image
import numpy as np

def read_png_image(file_path):
    image = Image.open(file_path)
    image_array = np.array(image)
    return image_array

In [None]:
def convert_mask_to_single_channel(mask_3_channels):
    """convert 3 channel mask (numpy array) in 1 channel mask, preserving labels 
    defined in "category_colors" dictionary
    
    Args:
        mask (~numpy.ndarray): A mask array to be transformed. This is in
            RGB format.
        
    Returns:
        mask (~numpy.ndarray) with one channel.
    """
    # Assuming mask_3_channels has shape (height, width, 3)
    height, width, _ = mask_3_channels.shape

    # Create an empty array with shape (height, width, 1) for the single-channel mask
    single_channel_mask = np.zeros((height, width, 1), dtype=np.uint8)

    # Define the colors representing each category (RGB values)
    category_colors = {
        (0, 0, 0):0,        # Class 0 - Black (no building) or un-classified
        (255, 255, 255):1,  # Class 1 - White (no-damage)
        (255,255,0):2,     # Class 2 - Yellow (minor damage)
        (255,165,0):3,     # Class 3 - Orange (major damage)
        (255, 0, 0):4,     # Class 4 - Red (destroyed)
    }
    # Loop through each pixel and assign the corresponding category to the single-channel mask
    for y in range(height):
        for x in range(width):
            pixel_color = tuple(mask_3_channels[y, x])
            category = category_colors.get(pixel_color, -1)  # -1 for unknown category
            single_channel_mask[y, x] = category

    return single_channel_mask

In [None]:
# GET MASKS USING PNG FILES
from os import path, walk, makedirs
from tqdm import tqdm

path_example="/Users/gmeneses/DScourse/00_capstone/xView2_baseline_fork/xBD_last_subset_test_mask"
disasters = next(walk(path_example))[1]

image_arrays = []
mask_arrays = []

for disaster in disasters:
    #print(disaster+':\n')
    # Create the full path to the images, labels, and mask output directories
    image_dir = path.join(path_example, disaster, 'images')
    mask_dir = path.join(path_example, disaster, 'masks')

    if not path.isdir(image_dir):
        print(
            "Error, could not find image files in {}.\n\n"
            .format(image_dir),
            file=stderr)
        exit(2)

    if not path.isdir(mask_dir):
        print(
            "Error, could not find labels in {}.\n\n"
            .format(mask_dir),
            file=stderr)
        exit(3)
    
        
    # running through masks because it can be that there are no masks for certain images (images with empty features)
    # attention: in this case masks have the same name than images
    masks_list = [j for j in next(walk(mask_dir))[2] if '_post' in j]

    for im in tqdm(masks_list, desc='Creating image and mask arrays for '+disaster, unit='im'):
        img_pre_name = path.splitext(im.replace('_post', '_pre'))[0] + '.png'
        img_post_name = im
        mask_name = im
        # path to images and mask
        img_pre_path = path.join(image_dir,img_pre_name)
        img_post_path = path.join(image_dir,img_post_name)
        mask = path.join(mask_dir,mask_name)
        # creating tensors from images
        img_pre = tf.io.read_file(img_pre_path)
        img_post = tf.io.read_file(img_post_path)

        array_pre = tf.image.decode_png(img_pre, channels=3, dtype=tf.uint8)
        array_post = tf.image.decode_png(img_post, channels=3, dtype=tf.uint8)
        # creating a final image array (1024x1024x6) --> this approach does not work with this model
        #array_image = tf.concat([array_pre, array_post], axis=2)

        # processing the mask
        array_mask_3d = read_png_image(mask)
        # converting mask to depth 1
        array_mask = convert_mask_to_single_channel(array_mask_3d)
         
        
        # adding to lists in array format
        image_arrays.append(array_post)
        mask_arrays.append(tf.convert_to_tensor(array_mask))


In [None]:
# # Define a normalizer that can be applied while visualizing masks to have a consistency
# NORM = mpl.colors.Normalize(vmin=0, vmax=58)

# # plot masks
# plt.figure(figsize=(25,13))
# for i in range(4,7):
#     plt.subplot(4,6,i)
#     img = masks[i]
#     plt.imshow(img, cmap='jet', norm=NORM)
#     plt.colorbar()
#     plt.axis('off')
# plt.show()

In [None]:
# #functions to resize the images and masks 
# def resize_image(image):
#     # scale the image
#     image = tf.cast(image, tf.float32)
#     image = image/255.0
#     # resize image
#     image = tf.image.resize(image, (128,128))
#     return image

# def resize_mask(mask):
#     # resize the mask
#     mask = tf.image.resize(mask, (128,128))
#     mask = tf.cast(mask, tf.uint8)
#     return mask    



#X = [resize_image(i) for i in images]
#y = [resize_mask(m) for m in masks]