In [17]:
import torch
from tqdm.auto import tqdm

categories = ['bicycle', 'bus', 'motorcycle', 'rider', 'traffic light']
category = 'bicycle'
file_path = f'datasets/halo_extra_data/hipie/{category}_annotations.pt'
data = torch.load(file_path)

print("keys: ", data.keys())
labels = data.pop('labels')
print("labels: ", labels)

keys:  dict_keys(['labels', '00000', '00001', '00002', '00003', '00004', '00005', '00006', '00007', '00008', '00009', '00010', '00011', '00012', '00013', '00014', '00015', '00016', '00017', '00018', '00019', '00020', '00021', '00022', '00023', '00024', '00025', '00026', '00027', '00028', '00029', '00030', '00031', '00032', '00033', '00034', '00035', '00036', '00037', '00038', '00039', '00040', '00041', '00042', '00043', '00044', '00045', '00046', '00047', '00048', '00049', '00050', '00051', '00052', '00053', '00054', '00055', '00056', '00057', '00058', '00059', '00060', '00061', '00062', '00063', '00064', '00065', '00066', '00067', '00068', '00069', '00070', '00071', '00072', '00073', '00074', '00075', '00076', '00077', '00078', '00079', '00080', '00081', '00082', '00083', '00084', '00085', '00086', '00087', '00088', '00089', '00090', '00091', '00092', '00093', '00094', '00095', '00096', '00097', '00098', '00099', '00100', '00101', '00102', '00103', '00104', '00105', '00106', '00107', 

In [18]:
trainid2name = {
    0: "road",
    1: "sidewalk",
    2: "building",
    3: "wall",
    4: "fence",
    5: "pole",
    6: "light",
    7: "sign",
    8: "vegetation",
    9: "terrain",
    10: "sky",
    11: "person",
    12: "rider",
    13: "car",
    14: "truck",
    15: "bus",
    16: "train",
    17: "motocycle",
    18: "bicycle",
    255: "unknown",
}

In [19]:
trainname2id = {v: k for k, v in trainid2name.items()}
trainname2id["tree"] = trainname2id["vegetation"]
trainname2id["traffic light"] = trainname2id["light"]
trainname2id["motorcycle"] = trainname2id["motocycle"]
print("trainname2id: ", trainname2id)

trainname2id:  {'road': 0, 'sidewalk': 1, 'building': 2, 'wall': 3, 'fence': 4, 'pole': 5, 'light': 6, 'sign': 7, 'vegetation': 8, 'terrain': 9, 'sky': 10, 'person': 11, 'rider': 12, 'car': 13, 'truck': 14, 'bus': 15, 'train': 16, 'motocycle': 17, 'bicycle': 18, 'unknown': 255, 'tree': 8, 'traffic light': 6, 'motorcycle': 17}


In [20]:
first_item = next(iter(data.items()))

image_id = first_item[0]
print("image_id: ", image_id)

keys = first_item[1].keys()
print("keys: ", keys)

mask = first_item[1]['mask']
info = first_item[1]['info']
print("mask: ", mask.shape)
print("info: ", info)

image_id:  00000
keys:  dict_keys(['mask', 'info'])
mask:  torch.Size([1500, 1500])
info:  [{'id': 1, 'isthing': False, 'category_id': 18, 'area': 160021.0}, {'id': 2, 'isthing': False, 'category_id': 14, 'area': 354388.0}, {'id': 3, 'isthing': True, 'category_id': 7, 'area': 551563.0}, {'id': 4, 'isthing': False, 'category_id': 12, 'area': 1062544.0}]


In [21]:
def parse_info(info) -> dict:
    id2category = {0: 255}
    for elem in info:
        id2category[elem['id']] = elem['category_id']
    return id2category

def process_mask(mask, id2category) -> torch.Tensor:
    unique = torch.unique(mask)
    new_mask = mask.clone()
    for i in unique:
        category_id = id2category[i.item()]
        if category_id == 255 or category_id == 0:
            new_mask[new_mask == i.item()] = 255
        else:
            category_name = labels[category_id]
            class_id = trainname2id[category_name]
            new_mask[new_mask == i] = class_id
    return new_mask

In [56]:
import os

from core.utils.misc import get_color_pallete

new_masks = torch.zeros((len(data), 1500, 1500), dtype=torch.uint8)

outdir = f'datasets/halo_extra_data/gtFine/train/{category}'
if not os.path.exists(outdir):
    os.makedirs(outdir)
    
for i, (k, v) in enumerate(tqdm(data.items())):
    img_id = k
    mask = v['mask']
    info = v['info']
    id2category = parse_info(info)
    new_mask = process_mask(mask, id2category)
    new_masks[i] = new_mask

    # save masks
    filepath_id = os.path.join(outdir, f'{category}_{img_id}_000019_gtFine_labelIds.png')
    filepath_color = os.path.join(outdir, f'{category}_{img_id}_000019_gtFine_color.png')

    png_id = new_mask.numpy()
    png_id = Image.fromarray(png_id)
    png_id = png_id.convert('L')
    png_id.save(filepath_id)

    png_color = new_mask.numpy()
    png_color = get_color_pallete(png_color)
    png_color.save(filepath_color)


  0%|          | 0/200 [00:00<?, ?it/s]

In [29]:
def get_color_palette(npimg):
    out_img = Image.fromarray(npimg.astype('uint8')).convert('P')
    cityspallete = [
        128, 64, 128,
        244, 35, 232,
        70, 70, 70,
        102, 102, 156,
        190, 153, 153,
        153, 153, 153,
        250, 170, 30,
        220, 220, 0,
        107, 142, 35,
        152, 251, 152,
        0, 130, 180,
        220, 20, 60,
        255, 0, 0,
        0, 0, 142,
        0, 0, 70,
        0, 60, 100,
        0, 80, 100,
        0, 0, 230,
        119, 11, 32,
    ]
    out_img.putpalette(cityspallete)
    return out_img

def extract_color(c):
    cityspallete = {
            0 :  [128, 64,  128],
            1 :  [244, 35,  232],
            2 :  [70,  70,  70],
            3 :  [102, 102, 156],
            4 :  [190, 153, 153],
            5 :  [153, 153, 153],
            6 :  [250, 170, 30],
            7 :  [220, 220, 0],
            8 :  [107, 142, 35],
            9 :  [152, 251, 152],
            10 : [0,   130, 180],
            11 : [220, 20,  60],
            12 : [255, 0,   0],
            13 : [0,   0,   142],
            14 : [0,   0,   70],
            15 : [0,   60,  100],
            16 : [0,   80,  100],
            17 : [0,   0,   230],
            18 : [119, 11,  32],
            255: [255,   255,   255]}
    return [color/255 for color in cityspallete[c]]

In [55]:
import numpy as np
from PIL import Image

# mask_color = Image.open("datasets/cityscapes/gtFine/train/aachen/aachen_000000_000019_gtFine_color.png")
# # mask_color = Image.open("datasets/halo_extra_data/gtFine/train/bicycle/bicycle_00000_000019_gtFine_color.png")
# mask_color = torch.from_numpy(np.array(mask_color))
# mask_color = mask_color.permute(2, 0, 1)
# print("mask_color: ", mask_color.shape)

paths = [
    "datasets/cityscapes/gtFine/train/aachen/aachen_000000_000019_gtFine_labelIds.png",
    "datasets/halo_extra_data/gtFine/train/bicycle/bicycle_00000_000019_gtFine_labelIds.png"
    ]

for path in paths:
    # mask_labelIds = Image.open("datasets/cityscapes/gtFine/train/aachen/aachen_000000_000019_gtFine_labelIds.png")
    mask_labelIds = Image.open(path)
    mask_labelIds = mask_labelIds.convert('L')
    mask_labelIds = torch.from_numpy(np.array(mask_labelIds))
    print("Path: ", path)
    print("mask_labelIds: ", mask_labelIds.shape)
    print("mask_labelIds.unique: ", mask_labelIds.unique())
    print()


Path:  datasets/cityscapes/gtFine/train/aachen/aachen_000000_000019_gtFine_labelIds.png
mask_labelIds:  torch.Size([1024, 2048])
mask_labelIds.unique:  tensor([ 0,  1,  3,  4,  7,  8, 11, 17, 20, 21, 22, 23, 24, 25, 26, 33],
       dtype=torch.uint8)

Path:  datasets/halo_extra_data/gtFine/train/bicycle/bicycle_00000_000019_gtFine_labelIds.png
mask_labelIds:  torch.Size([1500, 1500])
mask_labelIds.unique:  tensor([  0,   3,   8,  18, 255], dtype=torch.uint8)



In [50]:
mask_labelIds

tensor([[3, 3, 3,  ..., 8, 8, 8],
        [3, 3, 3,  ..., 8, 8, 8],
        [3, 3, 3,  ..., 8, 8, 8],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.int32)

In [46]:
mask_color[3,:,:].unique()

tensor([255], dtype=torch.uint8)