# Data balancing
After preprocessing, some of the datasets were imbalanced. This notebook checks the balance of all datasets and rebalances when necessary. Only run this one time because else it will mess up the datasets!

In [34]:
from utils.data import SegmentationDataset, ALL_DATASETS
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.io import read_image
import os
import random

In [22]:
def get_dataset_stats(dataset_name, subset, labels, n_labels):
    dataset = SegmentationDataset(dataset_name, subset)
    dataloader = DataLoader(dataset)

    if n_labels == 1:
        mask = 0
        no_mask = 0
        for _, (_, y) in tqdm(enumerate(dataloader), desc="  " + subset, total=len(dataloader)):
            y = y.numpy()
            if y.sum() == 0:
                no_mask += 1
            else:
                mask += 1

        return len(dataloader), mask, no_mask
    
    elif n_labels == 2:
        only_mask1 = 0
        only_mask2 = 0
        both_masks = 0
        no_masks = 0

        for _, (_, y) in tqdm(enumerate(dataloader), desc="  " + subset, total=len(dataloader)):
            y = y.numpy()
            if y[:, 0, :, :].sum() > 0 and y[:, 1, :, :].sum() == 0:
                only_mask1 += 1
            elif y[:, 0, :, :].sum() == 0 and y[:, 1, :, :].sum() > 0:
                only_mask2 += 1
            elif y.sum() == 0:
                no_masks += 1
            else:
                both_masks += 1

        return len(dataloader), only_mask1, only_mask2, both_masks, no_masks
    else:
        raise ValueError("Currently only supports 1 or 2 labels")
    

for dataset_name in ALL_DATASETS.keys():
    n_labels = ALL_DATASETS[dataset_name]["n_labels"]
    labels = ALL_DATASETS[dataset_name]["labels"]

    print(f"{dataset_name}:")
    for subset in ["train", "val", "test"]:
        if n_labels == 1:
            total, mask, no_mask = get_dataset_stats(dataset_name, subset, labels, n_labels)
            print(f"    mask:    {mask} ({round(mask / float(total) * 100, 2)}%)")
            print(f"    no mask: {no_mask} ({round(no_mask / float(total) * 100, 2)}%)")
        elif n_labels == 2:
            total, only_mask1, only_mask2, both_masks, no_masks = get_dataset_stats(dataset_name, subset, labels, n_labels)
            print(f"    no masks:       {no_masks} ({round(no_masks / float(total) * 100, 2)}%)")
            print(f"    only {labels[0]} mask: {only_mask1} ({round(only_mask1 / float(total) * 100, 2)}%)")
            print(f"    only {labels[1]} mask:   {only_mask2} ({round(only_mask2 / float(total) * 100, 2)}%)")
            print(f"    both masks:     {both_masks} ({round(both_masks / float(total) * 100, 2)}%)")
        else:
            raise ValueError("Currently only supports 1 or 2 labels")
    print("")
    




fetal_head:


  train: 100%|██████████| 350/350 [00:01<00:00, 210.76it/s]


    mask:    350 (100.0%)
    no mask: 0 (0.0%)


  val: 100%|██████████| 74/74 [00:00<00:00, 207.31it/s]


    mask:    74 (100.0%)
    no mask: 0 (0.0%)


  test: 100%|██████████| 76/76 [00:00<00:00, 201.69it/s]


    mask:    76 (100.0%)
    no mask: 0 (0.0%)

breast_cancer:


  train: 100%|██████████| 259/259 [00:01<00:00, 187.73it/s]


    mask:    259 (100.0%)
    no mask: 0 (0.0%)


  val: 100%|██████████| 55/55 [00:00<00:00, 186.25it/s]


    mask:    55 (100.0%)
    no mask: 0 (0.0%)


  test: 100%|██████████| 56/56 [00:00<00:00, 185.53it/s]


    mask:    56 (100.0%)
    no mask: 0 (0.0%)

mouse_embryo:


  train: 100%|██████████| 18885/18885 [02:36<00:00, 120.87it/s]


    no masks:       4465 (23.64%)
    only body mask: 7213 (38.19%)
    only bv mask:   4 (0.02%)
    both masks:     7203 (38.14%)


  val: 100%|██████████| 4503/4503 [00:37<00:00, 121.20it/s]


    no masks:       728 (16.17%)
    only body mask: 2105 (46.75%)
    only bv mask:   1 (0.02%)
    both masks:     1669 (37.06%)


  test: 100%|██████████| 5604/5604 [00:42<00:00, 130.70it/s]


    no masks:       1040 (18.56%)
    only body mask: 2835 (50.59%)
    only bv mask:   0 (0.0%)
    both masks:     1729 (30.85%)

covid:


  train: 100%|██████████| 3122/3122 [00:24<00:00, 126.16it/s]


    mask:    1580 (50.61%)
    no mask: 1542 (49.39%)


  val: 100%|██████████| 218/218 [00:01<00:00, 121.99it/s]


    mask:    149 (68.35%)
    no mask: 69 (31.65%)


  test: 100%|██████████| 180/180 [00:01<00:00, 120.09it/s]


    mask:    115 (63.89%)
    no mask: 65 (36.11%)

pancreas:


  train: 100%|██████████| 11923/11923 [01:35<00:00, 125.41it/s]


    mask:    4607 (38.64%)
    no mask: 7316 (61.36%)


  val: 100%|██████████| 2491/2491 [00:27<00:00, 89.80it/s] 


    mask:    989 (39.7%)
    no mask: 1502 (60.3%)


  test: 100%|██████████| 2850/2850 [00:23<00:00, 123.25it/s]


    mask:    1109 (38.91%)
    no mask: 1741 (61.09%)

brain_tumor:


  train: 100%|██████████| 2895/2895 [00:19<00:00, 145.86it/s]


    mask:    1036 (35.79%)
    no mask: 1859 (64.21%)


  val: 100%|██████████| 462/462 [00:02<00:00, 159.41it/s]


    mask:    159 (34.42%)
    no mask: 303 (65.58%)


  test: 100%|██████████| 572/572 [00:03<00:00, 145.79it/s]


    mask:    178 (31.12%)
    no mask: 394 (68.88%)

prostate:


  train: 100%|██████████| 2687/2687 [00:19<00:00, 135.97it/s]


    mask:    1317 (49.01%)
    no mask: 1370 (50.99%)


  val: 100%|██████████| 585/585 [00:03<00:00, 147.73it/s]


    mask:    271 (46.32%)
    no mask: 314 (53.68%)


  test: 100%|██████████| 615/615 [00:04<00:00, 134.87it/s]

    mask:    279 (45.37%)
    no mask: 336 (54.63%)






### Balance mouse_embryo
From mouse_embryo: remove those without masks or only a brain mask

In [21]:
for subset in ["train", "val", "test"]:
    dataset = SegmentationDataset("mouse_embryo", subset)
    dataloader = DataLoader(dataset)
    to_remove = 0
    for path_img, (path_mask_body, path_mask_bv), (img, mask) in tqdm(zip(dataset.inputs, dataset.masks, dataloader), total=len(dataloader), desc="subset"):
        mask = mask.numpy()
        body_mask = mask[:, 0, :, :]
        bv_mask = mask[:, 1, :, :]


        if mask.sum() == 0 or (body_mask.sum() == 0 and bv_mask.sum() > 0):
            to_remove += 1
            os.remove(path_img)
            os.remove(path_mask_body)
            os.remove(path_mask_bv)
    print(f"Removed {to_remove} images from {subset}")


subset: 100%|██████████| 18885/18885 [02:38<00:00, 119.23it/s]


Removed 4469 images from train


subset: 100%|██████████| 4503/4503 [00:35<00:00, 126.84it/s]


Removed 729 images from val


subset: 100%|██████████| 5604/5604 [00:43<00:00, 127.63it/s]

Removed 1040 images from test





### Balance pancreas
From panceas: randomly remove 2709 from train, 513 from val, 632 from test without mask

In [37]:
for subset, n in [("train", 2709), ("val", 513), ("test", 632)]:
    dataset = SegmentationDataset("pancreas", subset)
    ids = list(range(len(dataset)))
    random.shuffle(ids)

    removed = 0

    for i in ids:
        path_input = dataset.inputs[i]
        path_mask = dataset.masks[i]

        mask = read_image(path_mask).numpy()
        if mask.sum() == 0:
            os.remove(path_input)
            os.remove(path_mask)
            removed += 1
        
        if removed == n:
            print(f"Removed {removed} images from {subset}")
            break

Removed 2709 images from train
Removed 513 images from val
Removed 632 images from test


### Balance brain_tumor
From brain_tumor: randomly remove 823 from train, 144 from val, 216 from test without mask

In [40]:
for subset, n in [("train", 823), ("val", 144), ("test", 216)]:
    dataset = SegmentationDataset("brain_tumor", subset)
    ids = list(range(len(dataset)))
    random.shuffle(ids)

    removed = 0

    for i in ids:
        path_input = dataset.inputs[i]
        path_mask = dataset.masks[i]

        mask = read_image(path_mask).numpy()
        if mask.sum() == 0:
            os.remove(path_input)
            os.remove(path_mask)
            removed += 1
        
        if removed == n:
            print(f"Removed {removed} images from {subset}")
            break

Removed 823 images from train
Removed 144 images from val
Removed 216 images from test


### Balance covid
From covid: randomly remove 80 from val and  50 from test **with** mask

In [45]:
for subset, n in [("val", 80), ("test", 50)]:
    dataset = SegmentationDataset("covid", subset)
    ids = list(range(len(dataset)))
    random.shuffle(ids)

    removed = 0

    for i in ids:
        path_input = dataset.inputs[i]
        path_mask = dataset.masks[i]

        mask = read_image(path_mask).numpy()
        if mask.sum() > 0:
            os.remove(path_input)
            os.remove(path_mask)
            removed += 1
        
        if removed == n:
            print(f"Removed {removed} images from {subset}")
            break

Removed 80 images from val
Removed 50 images from test


## Recheck statistics

In [47]:
for dataset_name in ALL_DATASETS.keys():
    n_labels = ALL_DATASETS[dataset_name]["n_labels"]
    labels = ALL_DATASETS[dataset_name]["labels"]

    print(f"{dataset_name}:")
    for subset in ["train", "val", "test"]:
        if n_labels == 1:
            total, mask, no_mask = get_dataset_stats(dataset_name, subset, labels, n_labels)
            print(f"    mask:    {mask} ({round(mask / float(total) * 100, 2)}%)")
            print(f"    no mask: {no_mask} ({round(no_mask / float(total) * 100, 2)}%)")
        elif n_labels == 2:
            total, only_mask1, only_mask2, both_masks, no_masks = get_dataset_stats(dataset_name, subset, labels, n_labels)
            print(f"    no masks:       {no_masks} ({round(no_masks / float(total) * 100, 2)}%)")
            print(f"    only {labels[0]} mask: {only_mask1} ({round(only_mask1 / float(total) * 100, 2)}%)")
            print(f"    only {labels[1]} mask:   {only_mask2} ({round(only_mask2 / float(total) * 100, 2)}%)")
            print(f"    both masks:     {both_masks} ({round(both_masks / float(total) * 100, 2)}%)")
        else:
            raise ValueError("Currently only supports 1 or 2 labels")
    print("")
    




fetal_head:


  train: 100%|██████████| 350/350 [00:02<00:00, 168.18it/s]


    mask:    350 (100.0%)
    no mask: 0 (0.0%)


  val: 100%|██████████| 74/74 [00:00<00:00, 168.33it/s]


    mask:    74 (100.0%)
    no mask: 0 (0.0%)


  test: 100%|██████████| 76/76 [00:00<00:00, 165.88it/s]


    mask:    76 (100.0%)
    no mask: 0 (0.0%)

breast_cancer:


  train: 100%|██████████| 259/259 [00:01<00:00, 151.71it/s]


    mask:    259 (100.0%)
    no mask: 0 (0.0%)


  val: 100%|██████████| 55/55 [00:00<00:00, 155.94it/s]


    mask:    55 (100.0%)
    no mask: 0 (0.0%)


  test: 100%|██████████| 56/56 [00:00<00:00, 156.06it/s]


    mask:    56 (100.0%)
    no mask: 0 (0.0%)

mouse_embryo:


  train: 100%|██████████| 14416/14416 [03:18<00:00, 72.50it/s]


    no masks:       0 (0.0%)
    only body mask: 7213 (50.03%)
    only bv mask:   0 (0.0%)
    both masks:     7203 (49.97%)


  val: 100%|██████████| 3774/3774 [00:50<00:00, 75.20it/s]


    no masks:       0 (0.0%)
    only body mask: 2105 (55.78%)
    only bv mask:   0 (0.0%)
    both masks:     1669 (44.22%)


  test: 100%|██████████| 4564/4564 [01:12<00:00, 63.21it/s]


    no masks:       0 (0.0%)
    only body mask: 2835 (62.12%)
    only bv mask:   0 (0.0%)
    both masks:     1729 (37.88%)

covid:


  train: 100%|██████████| 3122/3122 [00:37<00:00, 82.22it/s]


    mask:    1580 (50.61%)
    no mask: 1542 (49.39%)


  val: 100%|██████████| 138/138 [00:02<00:00, 68.98it/s]


    mask:    69 (50.0%)
    no mask: 69 (50.0%)


  test: 100%|██████████| 130/130 [00:01<00:00, 65.39it/s]


    mask:    65 (50.0%)
    no mask: 65 (50.0%)

pancreas:


  train: 100%|██████████| 9214/9214 [01:19<00:00, 115.82it/s]


    mask:    4607 (50.0%)
    no mask: 4607 (50.0%)


  val: 100%|██████████| 1978/1978 [00:14<00:00, 133.47it/s]


    mask:    989 (50.0%)
    no mask: 989 (50.0%)


  test: 100%|██████████| 2218/2218 [00:16<00:00, 132.58it/s]


    mask:    1109 (50.0%)
    no mask: 1109 (50.0%)

brain_tumor:


  train: 100%|██████████| 2072/2072 [00:17<00:00, 116.22it/s]


    mask:    1036 (50.0%)
    no mask: 1036 (50.0%)


  val: 100%|██████████| 318/318 [00:02<00:00, 106.32it/s]


    mask:    159 (50.0%)
    no mask: 159 (50.0%)


  test: 100%|██████████| 356/356 [00:03<00:00, 101.07it/s]


    mask:    178 (50.0%)
    no mask: 178 (50.0%)

prostate:


  train: 100%|██████████| 2687/2687 [00:29<00:00, 91.69it/s] 


    mask:    1317 (49.01%)
    no mask: 1370 (50.99%)


  val: 100%|██████████| 585/585 [00:07<00:00, 76.97it/s]


    mask:    271 (46.32%)
    no mask: 314 (53.68%)


  test: 100%|██████████| 615/615 [00:08<00:00, 76.23it/s]

    mask:    279 (45.37%)
    no mask: 336 (54.63%)




