In [1]:
import os
import numpy as np
import cv2
from tqdm import tqdm
import skimage.io as io
import matplotlib.pyplot as plt

from datasets.fewshotiseg.omniiseg_fst import OMNIFewShotISEG, OMNIISEG

In [2]:
# Hyperparameter
k = 3
cfg = dict(
    ds_base_='OMNIISEG',
    ds_base__subset='train',
    ds_novel='OMNIISEG',
    ds_novel_subset='val',
    sampling_origin_ds='OMNIISEG',
    sampling_origin_ds_subset='train',
    sampling_cats='base_',
    first_parents__only=0,
    first_children_only=0,
    sampling_scenario='children',
    repeats=0,
)

OMNIISEG.root = '../datasets/omniiseg/resources'
OMNIFewShotISEG.root = '../datasets/fewshotiseg/resources/omniiseg_fst'
ds = OMNIFewShotISEG(config=cfg, read_data=True)

Total amount of images 8000
Total amount of bboxes 8000
Total amount of cat_ids 8000
Novel cats
['B', 'D', 'H', 'I', 'N', 'V']
[38;5;28;1;4mTotal qrys_parents_: 7675[0m
[38;5;28;1;4mTotal qrys_children: 13552[0m
[38;5;28;1;4mSampling by children[0m
[38;5;28;1;4mOrder by children  reduced: 13552 => 13552[0m
[38;5;28;1;4mRepeating the set of entries 1 times[0m


In [3]:
### Create an array for all images
presence_arr = np.zeros((len(ds.cat_ids), ds.cats_total_amount), dtype=np.uint8)
for i in range(len(ds.cat_ids)):
    np.add.at(presence_arr[i], ds.cat_ids[i], 1)

In [4]:
# 1. Select a subset of images which does not have cats to delete at all
cats_to_del__in = np.array(presence_arr[:, ds.cats_to_del_]).sum(axis=-1)
imgs_no_cats_to_del_ = np.where(cats_to_del__in == 0)[0]
print('Amount of images with no cats to delete', len(imgs_no_cats_to_del_))

Amount of images with no cats to delete 4483


In [5]:
# 2. Sort selected images by the amount of instances on an image (in descending order)
total_objects_amount = presence_arr[imgs_no_cats_to_del_, :].sum(axis=-1)
indices = np.argsort(total_objects_amount)[::-1]
imgs_no_cats_to_del__desc = imgs_no_cats_to_del_[indices]
print('Amount of instances of categories to save for each image')
print(presence_arr[imgs_no_cats_to_del__desc].sum(axis=-1))

Amount of instances of categories to save for each image
[3 3 3 ... 2 2 2]


In [6]:
# 3. Remove images with no objects annotated and with more than 3K objects of the same class
indices = np.nonzero(
    (presence_arr[imgs_no_cats_to_del__desc].sum(axis=-1)) *
    (presence_arr[imgs_no_cats_to_del__desc].max(axis=-1) <= 3 * k)
)
imgs_no_cats_to_del__desc_non_zero = imgs_no_cats_to_del__desc[indices]
imgs_pool = presence_arr[imgs_no_cats_to_del__desc_non_zero]
print('The same as the last but after filtration of by [0; 3K] borders')
print(imgs_pool.sum(axis=-1))

The same as the last but after filtration of by [0; 3K] borders
[3 3 3 ... 2 2 2]


In [7]:
# 4. Check that images in a selection do not have VOC class examples (and crowd examples also)
selection = imgs_pool[:, ds.cats_to_del_]
print('Images pool array shape with categories to delete', selection.shape)
print(selection.min(), selection.max())
del selection

Images pool array shape with categories to delete (4483, 6)
0 0


In [8]:
# 5. Check that images in a selection have all cats to save with at least K examples for each category
selection = imgs_pool[:, ds.cats_to_save]
cats_to_save_only_presence_vec = selection.sum(axis=0)
print('Amount of instances in a dataset subselection')
print(cats_to_save_only_presence_vec)
indices_no_3k = np.where(cats_to_save_only_presence_vec < 3 * k)[0]
print(f'Critical categories which do not have {3 * k} instances')
print('Middle indices', indices_no_3k)
print('Real indices', ds.cats_to_save[indices_no_3k])

Amount of instances in a dataset subselection
[507 516 471 484 490 503 526 516 527 513 512 475 490 497 469 473 489 478
 481 488]
Critical categories which do not have 9 instances
Middle indices []
Real indices []


In [9]:
# 6. For each category, check amount of images with 1 instance, 2 instances, ...
for cat_id in ds.cats_to_save:
    column = imgs_pool[:, cat_id]
    # appearings = {}
    # for i in range(1, max(column) + 1):
    #     appearings[i] = np.count_nonzero(column == i)
    appearings = dict(zip(*np.unique(column, return_counts=True)))
    print(cat_id, appearings)
    del appearings
del cat_id
# Seems that there are a lot images with 1 and 2 instances

0 {0: 3989, 1: 481, 2: 13}
2 {0: 3987, 1: 476, 2: 20}
4 {0: 4032, 1: 431, 2: 20}
5 {0: 4014, 1: 454, 2: 15}
6 {0: 4001, 1: 474, 2: 8}
9 {0: 3997, 1: 469, 2: 17}
10 {0: 3975, 1: 490, 2: 18}
11 {0: 3987, 1: 477, 2: 18, 3: 1}
12 {0: 3971, 1: 497, 2: 15}
14 {0: 3990, 1: 473, 2: 20}
15 {0: 3986, 1: 482, 2: 15}
16 {0: 4021, 1: 449, 2: 13}
17 {0: 4004, 1: 468, 2: 11}
18 {0: 4003, 1: 463, 2: 17}
19 {0: 4034, 1: 430, 2: 18, 3: 1}
20 {0: 4016, 1: 461, 2: 6}
22 {0: 4013, 1: 451, 2: 19}
23 {0: 4019, 1: 450, 2: 14}
24 {0: 4019, 1: 447, 2: 17}
25 {0: 4011, 1: 456, 2: 16}


In [10]:
# 7. Estimating how difficult it may be to fold the selection
# 7.1. Count how many instances do images across the dataset contain
unique_cats_ids_on_img = (imgs_pool > 0).sum(axis=-1)
unique_cats_ids_amount, unique_cats_ids_amount_counts = np.unique(unique_cats_ids_on_img, return_counts=True)
print('Instances, images with this amount of instances')
print(np.array(list(zip(unique_cats_ids_amount, unique_cats_ids_amount_counts))))

Instances, images with this amount of instances
[[   1  170]
 [   2 3518]
 [   3  795]]


In [11]:
# 7.2. For each category, count images where this category is the only 1 represented
indices_only_1cat = np.where((imgs_pool > 0).sum(axis=-1) == 1)[0]
print('Amount of images with only 1 category', len(indices_only_1cat))

selection = imgs_pool[indices_only_1cat]
print('Selection shape', selection.shape)

for cat_id in ds.cats_to_save:
    column = selection[:, cat_id]
    unique, counts = np.unique(column, return_counts=True)
    print(f'Cat {cat_id:2}', dict(zip(unique, counts)))
    del unique, counts
del selection

Amount of images with only 1 category 170
Selection shape (170, 26)
Cat  0 {0: 164, 2: 6}
Cat  2 {0: 158, 2: 12}
Cat  4 {0: 158, 2: 12}
Cat  5 {0: 162, 2: 8}
Cat  6 {0: 163, 2: 7}
Cat  9 {0: 163, 2: 7}
Cat 10 {0: 163, 2: 7}
Cat 11 {0: 160, 2: 9, 3: 1}
Cat 12 {0: 159, 2: 11}
Cat 14 {0: 161, 2: 9}
Cat 15 {0: 164, 2: 6}
Cat 16 {0: 163, 2: 7}
Cat 17 {0: 163, 2: 7}
Cat 18 {0: 160, 2: 10}
Cat 19 {0: 157, 2: 12, 3: 1}
Cat 20 {0: 166, 2: 4}
Cat 22 {0: 160, 2: 10}
Cat 23 {0: 163, 2: 7}
Cat 24 {0: 163, 2: 7}
Cat 25 {0: 160, 2: 10}


In [16]:
# 8. Sort by the amount of representation
imgs_set_hidden_indexes = set()
total_cats_insts_in_set = np.zeros(ds.cats_total_amount)
total_cats_insts_in_set[ds.cats_to_del_] = -1
order = np.argsort(cats_to_save_only_presence_vec)
cats_to_save_ascending_cat_ids = ds.cats_to_save[order]
total_required = 3 * k

for n, cat_id in enumerate(cats_to_save_ascending_cat_ids):
    # Select images which have this category
    selection = imgs_pool[:, cat_id]
    hidden_indices = np.where(selection != 0)[0]
    real_indices = imgs_no_cats_to_del__desc_non_zero[hidden_indices]
    represented = imgs_pool[hidden_indices]
    print(f'Cat {cat_id} total images with this cat', len(represented))

    selected_num_this_cat_examples = represented[:, cat_id]
    selected_num_each_cat_examples = (represented > 0).sum(axis=-1)
    selected_triple = np.array(list(zip(selected_num_this_cat_examples,
                                        selected_num_each_cat_examples,
                                        hidden_indices)))
    groups = {amount: [] for amount in np.unique(selected_num_this_cat_examples)}
    for amount in groups:
        amount_group_indices = np.where(selected_triple[:, 0] == amount)[0]
        amount_group = selected_triple[amount_group_indices, :]
        if len(amount_group) > 1:
            order = np.argsort(amount_group[:, 1])
            amount_group = amount_group[order]
        groups[amount] = amount_group

    # Check that all images with these indices have this category (even show images)
    check_required = False
    for real_index in real_indices:
        cat_ids = ds.cat_ids[real_index]
        assert cat_id in cat_ids
        if not check_required:
            print('Checked ONE image and everything is OK')
            break
    if check_required:
        print('Checked ALL images and everything is OK')

    # Perform a selection
    triples_selected = []
    for amount in sorted(groups):
        if total_cats_insts_in_set[cat_id] == total_required:
            # print('Out of a main loop with success')
            break
        # print('Trying amount', amount)
        while total_cats_insts_in_set[cat_id] + amount > total_required:
            if len(triples_selected) == 0:
                raise NotImplementedError
            # Two strategies: random and try to delete less
            # triple_index = np.random.choice(len(triples_selected), replace=False)
            triple_index = 0
            triple = triples_selected[triple_index]
            hidden_img_index = triple[2]
            imgs_set_hidden_indexes.remove(hidden_img_index)
            total_cats_insts_in_set -= imgs_pool[hidden_img_index]
            del triples_selected[triple_index]
            print('Deleted triple', triple)

        for triple in groups[amount]:
            _, _, hidden_img_index = triple
            cur_img_cats = imgs_pool[hidden_img_index]
            assert cur_img_cats[cat_id] == amount
            summary = total_cats_insts_in_set + cur_img_cats
            if max(summary) > total_required:
                # print('More instances than required', summary)
                continue
            else:
                total_cats_insts_in_set = summary
                imgs_set_hidden_indexes.add(hidden_img_index)
                triples_selected.append(triple)
                # print('Added ', triple)
                if total_cats_insts_in_set[cat_id] == total_required:
                    success = True
                    # print('Chosen successfully')
                    break

        # print('Finished with amount', amount)
        # print(np.array(triples_selected))
    print('Finished for this cat')

print('*** Finished for all cats *** ')


Cat 19 total images with this cat 449
Checked ONE image and everything is OK
Finished for this cat
Cat 4 total images with this cat 451
Checked ONE image and everything is OK
Finished for this cat
Cat 20 total images with this cat 467
Checked ONE image and everything is OK
Finished for this cat
Cat 16 total images with this cat 462
Checked ONE image and everything is OK
Finished for this cat
Cat 23 total images with this cat 464
Checked ONE image and everything is OK
Finished for this cat
Cat 24 total images with this cat 464
Checked ONE image and everything is OK
Finished for this cat
Cat 5 total images with this cat 469
Checked ONE image and everything is OK
Finished for this cat
Cat 25 total images with this cat 472
Checked ONE image and everything is OK
Finished for this cat
Cat 22 total images with this cat 470
Checked ONE image and everything is OK
Finished for this cat
Cat 6 total images with this cat 482
Checked ONE image and everything is OK
Finished for this cat
Cat 17 total 

In [13]:
print('Total images in the set', len(imgs_set_hidden_indexes))
print('Amount of class instances depicted\n', total_cats_insts_in_set)
imgs_list_hidden_indexes = list(imgs_set_hidden_indexes)
real_indices = imgs_no_cats_to_del__desc_non_zero[imgs_list_hidden_indexes]
print(imgs_pool[imgs_list_hidden_indexes].sum(axis=0))
print(presence_arr[real_indices].sum(axis=0))

Total images in the set 90
Amount of class instances depicted
 [ 9. -1.  9. -1.  9.  9.  9. -1. -1.  9.  9.  9.  9. -1.  9.  9.  9.  9.
  9.  9.  9. -1.  9.  9.  9.  9.]
[9 0 9 0 9 9 9 0 0 9 9 9 9 0 9 9 9 9 9 9 9 0 9 9 9 9]
[9 0 9 0 9 9 9 0 0 9 9 9 9 0 9 9 9 9 9 9 9 0 9 9 9 9]


In [14]:
imgs_sps_selected = [ds.imgs_sps[i] for i in real_indices]
print(imgs_sps_selected[:5])

['006541.jpg', '006353.jpg', '006341.jpg', '006340.jpg', '006447.jpg']


In [15]:
from cp_utils.cp_dir_file_ops import write_json_safe

prefix = f'{ds.setup}_{ds.sampling_origin_ds}_{ds.sampling_origin_ds_subset}_{ds.sampling_cats}_K{k}_'
file_fp = os.path.join(ds.root, prefix + 'FINETUNE_REAL_INDICES.json')
write_json_safe(file_fp, list(real_indices))
print('Saved real indices to file:', file_fp)
file_fp = os.path.join(ds.root, prefix + 'FINETUNE_IMGS_SPS.json')
write_json_safe(file_fp, imgs_sps_selected)
print('Saved images short paths to file: ', file_fp)

JSON WRITE-SAFE FAIL ../datasets/fewshotiseg/resources/omniiseg_fst/OMNIISEG2OMNIISEG_OMNIISEG_train_base__K3_FINETUNE_REAL_INDICES.json
Saved real indices to file: ../datasets/fewshotiseg/resources/omniiseg_fst/OMNIISEG2OMNIISEG_OMNIISEG_train_base__K3_FINETUNE_REAL_INDICES.json
JSON WRITE-SAFE FAIL ../datasets/fewshotiseg/resources/omniiseg_fst/OMNIISEG2OMNIISEG_OMNIISEG_train_base__K3_FINETUNE_IMGS_SPS.json
Saved images short paths to file:  ../datasets/fewshotiseg/resources/omniiseg_fst/OMNIISEG2OMNIISEG_OMNIISEG_train_base__K3_FINETUNE_IMGS_SPS.json
