In [1]:
from torch.utils.data import Dataset
import albumentations as A

import numpy as np
from PIL import Image
import os

from tqdm.auto import tqdm

In [6]:
NUM_CLASSES = 20
NUM_TRAIN = 6016
NUM_VAL = 7539 - 6016

def shift_class_indices(segmap):
    y = segmap + 1
    y[y==NUM_CLASSES] = 0
    return y


def albumentation_transform(transforms, x, y):
    transformed = transforms(image=x, mask=y)
    return  transformed['image'], transformed['mask']


class UrbanSynDataset(Dataset):
    def __init__(self, path, transforms, split='train', resized=True, downscaling=4, shift_class_indices=False):
        self.path = path
        self.transforms = transforms
        self.split = split
        self.resized = resized
        self.shift_class_indices = shift_class_indices
        if self.resized:
            size_str = str(1024//downscaling)
            if not os.path.exists(os.path.join(self.path, size_str)):
                self.path = os.path.join(self.path+'_resized', size_str)

    def __getitem__(self, index):
        if self.split=='val':
            i = index + NUM_TRAIN + 1
        else:
            i = index + 1
        if self.resized:
            x = np.load(os.path.join(self.path, 'rgb', f'rgb_{i:04}.npy'))
            y = np.load(os.path.join(self.path, 'ss', f'ss_{i:04}.npy'))
        else:
            x = np.array(Image.open(os.path.join(self.path, 'rgb', f'rgb_{i:04}.png')).convert('RGB'))
            y = np.array(Image.open(os.path.join(self.path, 'ss', f'ss_{i:04}.png')).convert('L'))
        x,y = albumentation_transform(self.transforms, x, y)
        if self.shift_class_indices:
            y = shift_class_indices(y)
        return x,y
    
    def __len__(self):
        if self.split=='train':
            return NUM_TRAIN
        elif self.split=='val':
            return NUM_VAL
        else:
            return NUM_TRAIN+NUM_VAL

In [None]:
input_path = './data/urbansyn'
output_path = './data/urbansyn_resized'

size = np.array((1024, 2048))
downsampling = [4, 2]
os.makedirs(output_path, exist_ok=True)

for ds in downsampling:
    new_size = size//ds
    dataset = UrbanSynDataset(input_path, transforms=A.Resize(*new_size), split='all', resized=False, shift_class_indices=True)

    os.makedirs(os.path.join(output_path, str(new_size[0])), exist_ok=True)
    rgb_path = os.path.join(output_path, str(new_size[0]), 'rgb')
    os.makedirs(rgb_path, exist_ok=True)
    ss_path = os.path.join(output_path, str(new_size[0]), 'ss')
    os.makedirs(ss_path, exist_ok=True)
    for i, (x, y) in tqdm(enumerate(dataset), total=len(dataset)):
        np.save(os.path.join(rgb_path, f'rgb_{i+1:04}.npy'), x)
        np.save(os.path.join(ss_path, f'ss_{i+1:04}.npy'), y)
        if i == len(dataset)-1:
            break