In [53]:
import glob
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import cv2
import os

In [55]:
def BFS(mask):
    visited = np.zeros_like(mask, dtype=np.int32)

    max_size = 0

    for iy, ix in np.ndindex(mask.shape):
        if mask[iy, ix] == 0 or (visited[iy, ix] == -1 and mask[iy, ix] == 255):
            continue
        else:
            queue = [(iy, ix)]
            size = 0
            while len(queue) > 0:
                y, x = queue.pop()
                visited[y, x] = -1
                size += 1
                for dy, dx in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
                    if 0 <= y + dy < mask.shape[0] and 0 <= x + dx < mask.shape[1]:
                        if mask[y + dy, x + dx] == 255 and visited[y + dy, x + dx] == 0:
                            queue.append((y + dy, x + dx))

            visited[iy, ix] = size
            max_size = max(max_size, size)

    for iy, ix in np.ndindex(visited.shape):
        if 0 < visited[iy, ix] < max_size and mask[iy, ix] == 255:
            queue = [(iy, ix)]
            while len(queue) > 0:
                y, x = queue.pop()
                mask[y, x] = 0
                for dy, dx in [(0, 1), (0, -1), (1, 0), (-1, 0)]:
                    if 0 <= y + dy < mask.shape[0] and 0 <= x + dx < mask.shape[1]:
                        if mask[y + dy, x + dx] == 255:
                            queue.append((y + dy, x + dx))

    return mask

def display_masks(mask1, mask2):
    mask1_d = np.zeros((*mask1.shape, 3), dtype=np.uint8)
    mask2_d = np.zeros((*mask2.shape, 3), dtype=np.uint8)

    mask1_d[mask1 == 255] = [255, 0, 0]
    mask2_d[mask2 == 255] = [0, 255, 0]

    plt.figure(figsize=(10, 10))
    plt.imshow(mask1_d, alpha=0.5)
    plt.imshow(mask2_d, alpha=0.3)
    plt.colorbar()

In [56]:
def cleanup(original_path, annotated_path):
    os.makedirs(original_path, exist_ok=True)
    os.makedirs(annotated_path, exist_ok=True)
    
    mask_files = glob.glob(str(annotated_path) + '/*.jpg')
    image_files = glob.glob(str(original_path) + '/*.jpg')

    PATCH_SIZE = 256

    for mask_file in tqdm(mask_files):
        mask_image = Image.open(mask_file)
        mask = np.asarray(mask_image)
        mask_fixed = np.zeros_like(mask)
        # reject any values that are not 0 or 255
        for iy, ix in np.ndindex(mask.shape):
            if mask[iy, ix] > 127:
                mask_fixed[iy, ix] = 255

        assert np.all(np.logical_or(mask_fixed == 0, mask_fixed == 255))

        mask_fixed = BFS(mask_fixed)
        #display_masks(mask, mask_fixed)

        mask_fixed = cv2.resize(mask_fixed, (PATCH_SIZE, PATCH_SIZE), interpolation=cv2.INTER_NEAREST)
        mask_fixed = Image.fromarray(mask_fixed)
        mask_fixed.save("./annotated/" + mask_file[mask_file.rindex("\\")+1:mask_file.rindex(".")] + ".png", "PNG")
        os.remove(mask_file)

    for image_file in tqdm(image_files):
        image = np.asarray(Image.open(image_file))
        image = cv2.resize(image, (PATCH_SIZE, PATCH_SIZE), interpolation=cv2.INTER_CUBIC)
        image = Image.fromarray(image)
        image.save("./original/" + image_file[image_file.rindex("\\")+1:])
        os.remove(image_file)


100%|██████████| 22/22 [01:34<00:00,  4.27s/it]
100%|██████████| 36/36 [00:00<00:00, 212.09it/s]


In [None]:
cleanup("../validation/original", "../validation/annotated")

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

class TempDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.files = glob.glob(path + "/*.jpg")
        self.transform = transforms.ToTensor()
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        image = np.asarray(Image.open(self.files[idx]))
        image = self.transform(image)
        return image
    
dataset = TempDataset(path="../original")
loader = DataLoader(dataset, batch_size=len(dataset))

mean = 0
std = 0

for images in loader:
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)

print(mean / len(dataset))
print(std / len(dataset))

tensor([0.5612, 0.5397, 0.5159])
tensor([0.2515, 0.2405, 0.2317])
