In [52]:
# Install required libs
# !python -m pip install --user numpy scipy matplotlib ipython jupyter segmentation-models-pytorch albumentations opencv-python tqdm natsort Pillow

## Loading data

In [53]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
import cv2
import matplotlib.pyplot as plt

In [55]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

In [None]:
from tqdm import tqdm
from natsort import natsorted
from functools import reduce
from PIL import Image
#  0: Background 1: Cap/hat 2: Helmet 3: Face 4: Hair 5: Left-arm 6: Right-arm 7: Left-hand 8: Right-hand 9: Protector 
# 10: Bikini/bra 11: Jacket/windbreaker/hoodie 12: Tee-shirt 13: Polo-shirt 14: Sweater 15: Singlet 16: Torso-skin 
# 17: Pants 18: Shorts/swim-shorts 19: Skirt 20: Stockings 21: Socks 22: Left-boot 23: Right-boot 24: Left-shoe 
# 25: Right-shoe 26: Left-highheel 27: Right-highheel 28: Left-sandal 29: Right-sandal 30: Left-leg 31: Right-leg 
# 32: Left-foot 33: Right-foot 34: Coat 35: Dress 36: Robe 37: Jumpsuit 38: Other-full-body-clothes 39: Headwear 
# 40: Backpack 41: Ball 42: Bats 43: Belt 44: Bottle 45: Carrybag 46: Cases 47: Sunglasses 48: Eyewear 49: Glove 
# 50:Scarf 51: Umbrella 52: Wallet/purse 53: Watch 54: Wristband 55: Tie 56: Other-accessary 
# 57: Other-upper-body-clothes 58: Other-lower-body-clothes

CLASS_LABEL = 0 # Label to extract from dataset

def get_fp_infos(mask_fp):
    mask_id = os.path.basename(mask_fp)
    return tuple(mask_id.split(".")[0].split("_"))

def create_dict(masks_dir):
    masks_ids = natsorted(os.listdir(masks_dir))
    masks_fps = [os.path.join(masks_dir, mask_id) for mask_id in masks_ids]

    dict = {}
    for fp in masks_fps:
        (id, max, current) = get_fp_infos(fp)
        if id not in dict:
            dict[id] = []
        dict[id].append(fp)
    return dict

def convert_masks(masks_dir, save_dir):
    fps_dict = create_dict(masks_dir)
    for key in tqdm(fps_dict.keys(), desc="Converting mask files"):   
        if not os.path.exists(save_dir + key + ".png"):
            # Every image file has multiple corresponding mask files, one for every person in the image
            # The class labels correspond to the red values 
            masks = [np.array(Image.open(value)) for value in fps_dict[key]] # Add all mask files to one array
            if all(mask.ndim == 3 for mask in masks):
                if key > 1:
                    break
                mask = reduce(lambda array_a, array_b: array_a | array_b, masks) # Create union of all masks
                binary_mask = np.where(mask[:,:,0] == CLASS_LABEL, 0, 1) # Create binary mask of red channel
                mask[:,:,0] = binary_mask # Replace first channel with binary mask

                # im = Image.fromarray(mask)
                # im.save(save_dir + key + ".png")

            else:
                print(save_dir + key + ".png")

# convert_masks("data/LV-MHP-v2/train/parsing_annos/", "data/LV-MHP-v2-augmented/train/parsing_annos/")
# convert_masks("data/LV-MHP-v2/val/parsing_annos/", "data/LV-MHP-v2-augmented/val/parsing_annos/")
