Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing dataset_loader #6

Merged
merged 2 commits into from Feb 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
@@ -1,4 +1,3 @@
**/data
**/checkpoints
**/datasets
**/apex
Expand Down
Empty file added data/__init__.py
Empty file.
76 changes: 76 additions & 0 deletions data/aligned_dataset.py
@@ -0,0 +1,76 @@
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image

class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot

### input A (label maps)
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
self.A_paths = sorted(make_dataset(self.dir_A))

### input B (real images)
if opt.isTrain or opt.use_encoded_image:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))

### instance maps
if not opt.no_instance:
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
self.inst_paths = sorted(make_dataset(self.dir_inst))

### load precomputed instance-wise encoded features
if opt.load_features:
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
print('----------- loading features from %s ----------' % self.dir_feat)
self.feat_paths = sorted(make_dataset(self.dir_feat))

self.dataset_size = len(self.A_paths)

def __getitem__(self, index):
### input A (label maps)
A_path = self.A_paths[index]
A = Image.open(A_path)
params = get_params(self.opt, A.size)
if self.opt.label_nc == 0:
transform_A = get_transform(self.opt, params)
A_tensor = transform_A(A.convert('RGB'))
else:
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
A_tensor = transform_A(A) * 255.0

B_tensor = inst_tensor = feat_tensor = 0
### input B (real images)
if self.opt.isTrain or self.opt.use_encoded_image:
B_path = self.B_paths[index]
B = Image.open(B_path).convert('RGB')
transform_B = get_transform(self.opt, params)
B_tensor = transform_B(B)

### if using instance maps
if not self.opt.no_instance:
inst_path = self.inst_paths[index]
inst = Image.open(inst_path)
inst_tensor = transform_A(inst)

if self.opt.load_features:
feat_path = self.feat_paths[index]
feat = Image.open(feat_path).convert('RGB')
norm = normalize()
feat_tensor = norm(transform_A(feat))

input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
'feat': feat_tensor, 'path': A_path}

return input_dict

def __len__(self):
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize

def name(self):
return 'AlignedDataset'
14 changes: 14 additions & 0 deletions data/base_data_loader.py
@@ -0,0 +1,14 @@

class BaseDataLoader():
def __init__(self):
pass

def initialize(self, opt):
self.opt = opt
pass

def load_data():
return None



90 changes: 90 additions & 0 deletions data/base_dataset.py
@@ -0,0 +1,90 @@
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random

class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()

def name(self):
return 'BaseDataset'

def initialize(self, opt):
pass

def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize
elif opt.resize_or_crop == 'scale_width_and_crop':
new_w = opt.loadSize
new_h = opt.loadSize * h // w

x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))

flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}

def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))

if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))

if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))

if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

transform_list += [transforms.ToTensor()]

if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)

def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)

def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)

def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img

def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
31 changes: 31 additions & 0 deletions data/custom_dataset_data_loader.py
@@ -0,0 +1,31 @@
import torch.utils.data
from data.base_data_loader import BaseDataLoader


def CreateDataset(opt):
dataset = None
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()

print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset

class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'

def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))

def load_data(self):
return self.dataloader

def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
7 changes: 7 additions & 0 deletions data/data_loader.py
@@ -0,0 +1,7 @@

def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
65 changes: 65 additions & 0 deletions data/image_folder.py
@@ -0,0 +1,65 @@
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
from PIL import Image
import os

IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir

for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)

return images


def default_loader(path):
return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))

self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader

def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img

def __len__(self):
return len(self.imgs)