In [1]:
import os
import sys
import warnings
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
sys.path.append("/home/pervinco/BKAI_MetaPolyp")

import cv2
import yaml
import random
import numpy as np
import tensorflow as tf
import albumentations as A
import matplotlib.pyplot as plt

from glob import glob
from data.batch_preprocess import *
from utils.utils import decode_mask
from data.batch_preprocess import encode_mask
from data.BKAIDataset import BKAIDataset
from data.BalancedBKAIDataset import BalancedBKAIDataset

In [2]:
with open("/home/pervinco/BKAI_MetaPolyp/config.yaml", "r") as f:
    config = yaml.safe_load(f)

data_dir = config["data_dir"]
image_dir = f"{data_dir}/train"
mask_dir = f"{data_dir}/train_mask"
gt_dir = f"{data_dir}/train_gt"

In [3]:
image_files = sorted(glob(f"{image_dir}/*"))
mask_files = sorted(glob(f"{mask_dir}/*"))
gt_files = sorted(glob(f"{gt_dir}/*"))

In [None]:
def compute_class_distribution(mask_files, num_classes):
    distribution = np.zeros(num_classes, dtype=np.int32)

    for mask_file in mask_files:
        mask = cv2.imread(mask_file)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        mask = cv2.resize(mask, (256, 256))
        mask = encode_mask(mask)
        
        for class_id in range(num_classes):
            distribution[class_id] += (mask == class_id).sum()

    return distribution


def plot_class_distribution(distribution, class_names=None):
    plt.figure(figsize=(10, 6))
    if class_names:
        plt.bar(class_names, distribution)
    else:
        plt.bar(np.arange(len(distribution)), distribution)

    plt.ylabel('Number of Pixels')
    plt.xlabel('Class')
    plt.title('Class Distribution in Semantic Segmentation')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

In [4]:
def calculate_batch_distribution(dataloader):    
    if not os.path.isdir("./images"):
        os.makedirs("./images")

    for i, (images, masks) in enumerate(dataloader):
        if i == 20:
            break

        print(i, images.shape, masks.shape)

        distribution = np.zeros(config["num_classes"], dtype=np.int32)
        for j, (image, mask) in enumerate(zip(images, masks)):
            image = image.numpy()
            image = (1 + image) * 127.5
            image = image.astype(np.uint8)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            prob_mask = np.argmax(mask, -1)
            for class_id in range(config["num_classes"]):
                distribution[class_id] += (prob_mask == class_id).sum()

            decoded_mask = decode_mask(prob_mask)
            decoded_mask = decoded_mask.astype(np.uint8)
            decoded_mask = cv2.cvtColor(decoded_mask, cv2.COLOR_BGR2RGB)

            overlay = cv2.addWeighted(image, 0.7, decoded_mask, 0.3, 0)

            cv2.imwrite(f"./images/batch{i}_no{j}.png", overlay)

        print(distribution,"\n")

In [None]:
gt_dist = compute_class_distribution(gt_files, num_classes=3)
print(gt_dist)

In [None]:
dist = compute_class_distribution(mask_files, num_classes=3)
print(dist)

In [None]:
# train_dataset = BKAIDataset(config=config, split=config["train"])
# train_dataloader = tf.data.Dataset.from_generator(lambda: train_dataset, 
#                                                   output_signature=(tf.TensorSpec(shape=(None, config["img_size"], config["img_size"], 3), dtype=tf.float32),
#                                                                     tf.TensorSpec(shape=(None, config["img_size"], config["img_size"], 3), dtype=tf.float32)))

# calculate_batch_distribution(train_dataloader)

In [5]:
dataset = BalancedBKAIDataset(config, split="train")
print(len(dataset))

dataloader = tf.data.Dataset.from_generator(lambda: dataset, 
                                            output_signature=(tf.TensorSpec(shape=(None, config["img_size"], config["img_size"], 3), dtype=tf.float32),
                                                              tf.TensorSpec(shape=(None, config["img_size"], config["img_size"], 3), dtype=tf.float32)))

calculate_batch_distribution(dataloader)

800
0 (16, 256, 256, 3) (16, 256, 256, 3)
[974976  47182  26418] 

1 (16, 256, 256, 3) (16, 256, 256, 3)
[980985  34430  33161] 

2 (16, 256, 256, 3) (16, 256, 256, 3)
[984375  31834  32367] 

3 (16, 256, 256, 3) (16, 256, 256, 3)
[982178  35825  30573] 

4 (16, 256, 256, 3) (16, 256, 256, 3)
[950211  64377  33988] 

5 (16, 256, 256, 3) (16, 256, 256, 3)
[987088  22806  38682] 

6 (16, 256, 256, 3) (16, 256, 256, 3)
[990169  27432  30975] 

7 (16, 256, 256, 3) (16, 256, 256, 3)
[977961  33800  36815] 

8 (16, 256, 256, 3) (16, 256, 256, 3)
[977069  40840  30667] 

9 (16, 256, 256, 3) (16, 256, 256, 3)
[961113  48608  38855] 

10 (16, 256, 256, 3) (16, 256, 256, 3)
[985211  27733  35632] 

11 (16, 256, 256, 3) (16, 256, 256, 3)
[1000592   21323   26661] 

12 (16, 256, 256, 3) (16, 256, 256, 3)
[972720  37338  38518] 

13 (16, 256, 256, 3) (16, 256, 256, 3)
[1004899   22428   21249] 

14 (16, 256, 256, 3) (16, 256, 256, 3)
[982873  35597  30106] 

15 (16, 256, 256, 3) (16, 256, 256, 3)
[