In [6]:
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from torchvision.utils import save_image
import os
from PIL import Image
import random
from PIL import Image

from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.utils.data import DataLoader


In [7]:
import copy


datasets = {}


def register(name):
    def decorator(cls):
        datasets[name] = cls
        return cls
    return decorator


def make(dataset_spec, args=None):
    if args is not None:
        dataset_args = copy.deepcopy(dataset_spec['args'])
        dataset_args.update(args)
    else:
        dataset_args = dataset_spec['args']
    dataset = datasets[dataset_spec['name']](**dataset_args)
    return dataset


In [8]:
class Resampler(torch.nn.Module):

    def __init__(self, inp_size, interpolation_mode, resampling_factor):
        super(Resampler, self).__init__()
        self.inp_size = inp_size
        self.resampling_factor = resampling_factor
        self.interpolation_mode = InterpolationMode(interpolation_mode)
    
    def forward(self, img):
        new_size = self.inp_size // self.resampling_factor

        downsampler = transforms.Resize(size=(new_size,new_size), interpolation=self.interpolation_mode)
        upsampler = transforms.Resize(size=(self.inp_size, self.inp_size), interpolation=self.interpolation_mode)
        
        downsampled_image = downsampler.forward(img)
        transformed_image = upsampler.forward(downsampled_image)

        return transformed_image


@register('image-folder')
class ImageFolder(Dataset):
    def __init__(self, path,  split_file=None, split_key=None, first_k=None, size=None,
                 repeat=1, cache='none', mask=False):
        self.repeat = repeat
        self.cache = cache
        self.path = path
        self.Train = False
        self.split_key = split_key

        self.size = size
        self.mask = mask
        if self.mask:
            self.img_transform = transforms.Compose([
                transforms.Resize((self.size, self.size), interpolation=Image.NEAREST),
                transforms.ToTensor(),
            ])
        else:
            self.img_transform = transforms.Compose([
                transforms.Resize((self.size, self.size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

        if split_file is None:
            filenames = sorted(os.listdir(path))
        else:
            with open(split_file, 'r') as f:
                filenames = json.load(f)[split_key]
        if first_k is not None:
            filenames = filenames[:first_k]

        self.files = []

        for filename in filenames:
            file = os.path.join(path, filename)
            self.append_file(file)

    def append_file(self, file):
        if self.cache == 'none':
            self.files.append(file)
        elif self.cache == 'in_memory':
            self.files.append(self.img_process(file))

    def __len__(self):
        return len(self.files) * self.repeat

    def __getitem__(self, idx):
        x = self.files[idx % len(self.files)]

        if self.cache == 'none':
            return self.img_process(x)
        elif self.cache == 'in_memory':
            return x

    def img_process(self, file):
        if self.mask:
            return Image.open(file).convert('L')
        else:
            return Image.open(file).convert('RGB')

@register('paired-image-folders')
class PairedImageFolders(Dataset):

    def __init__(self, root_path_1, root_path_2, **kwargs):
        self.dataset_1 = ImageFolder(root_path_1, **kwargs)
        self.dataset_2 = ImageFolder(root_path_2, **kwargs, mask=True)

    def __len__(self):
        return len(self.dataset_1)

    def __getitem__(self, idx):
        return self.dataset_1[idx], self.dataset_2[idx]


@register('train')
class TrainDataset(Dataset):
    def __init__(self, dataset, size_min=None, size_max=None, inp_size=None,
                 augment=False, interpolation_mode="nearest", resampling_factor = 1, gt_resize=None):
        self.dataset = dataset
        self.size_min = size_min
        if size_max is None:
            size_max = size_min
        self.size_max = size_max
        self.augment = augment
        self.gt_resize = gt_resize

        self.inp_size = inp_size
        self.img_transform = transforms.Compose([
                transforms.Resize((self.inp_size, self.inp_size)),
                Resampler(inp_size = inp_size, interpolation_mode=interpolation_mode, resampling_factor=resampling_factor),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
        self.inverse_transform = transforms.Compose([
                transforms.Normalize(mean=[0., 0., 0.],
                                     std=[1/0.229, 1/0.224, 1/0.225]),
                transforms.Normalize(mean=[-0.485, -0.456, -0.406],
                                     std=[1, 1, 1])
            ])
        self.mask_transform = transforms.Compose([
                transforms.Resize((self.inp_size, self.inp_size)),
                Resampler(inp_size = inp_size, interpolation_mode=interpolation_mode, resampling_factor=resampling_factor),
                transforms.ToTensor(),
            ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, mask = self.dataset[idx]

        # random filp
        if random.random() < 0.5:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)

        img = transforms.Resize((self.inp_size, self.inp_size))(img)
        mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask)

        return {
            'inp': self.img_transform(img),
            'gt': self.mask_transform(mask)
        }

In [9]:
def make_data_loader(spec, tag=''):
    if spec is None:
        return None

    dataset = make(spec['dataset'])
    dataset = make(spec['wrapper'], args={'dataset': dataset})

    loader = DataLoader(dataset, batch_size=spec['batch_size'],
        shuffle=True, num_workers=8, pin_memory=True)
    return loader


In [10]:
import yaml
inp_size = 1024
interpolation_mode = 'bicubic'
resampling_factor = 1
config_file = "visualize.yaml"

# load config 
with open(config_file, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)


In [11]:
train_loader = make_data_loader(config.get('train_dataset'), tag='train')

In [12]:
from torch.utils.tensorboard import SummaryWriter
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/visualize')

In [13]:
# get some random training images
# for batch in train_loader:
#         # for k, v in batch.items():
#         #     batch[k] = v.to(device)
#         inp = batch['inp']
#         gt = batch['gt']
batch = train_loader.__iter__().__next__()
inp = batch['inp']
print(inp.shape)

  return torch._C._cuda_getDeviceCount() > 0


torch.Size([4, 3, 1024, 1024])


In [14]:
# create grid of images
img_grid = torchvision.utils.make_grid(inp)
# write to tensorboard
writer.add_image('example image', img_grid)