Skip to content

Commit

Permalink
update data
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Jun 15, 2018
1 parent 598615d commit 5e2b763
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 26 deletions.
7 changes: 3 additions & 4 deletions codes/data/LRHR_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def __init__(self, opt):
'HR and LR datasets have different number of images - {}, {}.'.format(\
len(self.paths_LR), len(self.paths_HR))

# self.random_scale_list = [1, 0.9, 0.8, 0.7, 0.6, 0.5]
self.random_scale_list = None
self.random_scale_list = [1]

def __getitem__(self, index):
HR_path, LR_path = None, None
Expand All @@ -63,7 +62,7 @@ def __getitem__(self, index):
img_LR = util.read_img(self.LR_env, LR_path)
else: # down-sampling on-the-fly
# randomly scale during training
if self.opt['phase'] == 'train' and self.random_scale_list:
if self.opt['phase'] == 'train':
random_scale = random.choice(self.random_scale_list)
H_s, W_s, _ = img_HR.shape
def _mod(n, random_scale, scale, thres):
Expand All @@ -72,7 +71,7 @@ def _mod(n, random_scale, scale, thres):
return thres if rlt < thres else rlt
H_s = _mod(H_s, random_scale, scale, HR_size)
W_s = _mod(W_s, random_scale, scale, HR_size)
img_HR = cv2.resize(img_HR, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
img_HR = cv2.resize(np.copy(img_HR), (W_s, H_s), interpolation=cv2.INTER_LINEAR)

H, W, _ = img_HR.shape
# using matlab imresize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,51 +7,60 @@
import data.util as util


class LRHRSegDataset(data.Dataset):
class LRHRSeg_BG_Dataset(data.Dataset):
'''
Read HR, seg; generate LR, category
Read HR image, seg map; generate LR image, category
for SFT-GAN
also sample general scenes for background
'''

def name(self):
return 'LRHRSegDataset'
return 'LRHRSeg_BG_Dataset'

def __init__(self, opt):
super(LRHRSegDataset, self).__init__()
super(LRHRSeg_BG_Dataset, self).__init__()
self.opt = opt
self.paths_LR = None
self.paths_HR = None
self.LR_env = None # environment for lmdb
self.paths_HR_bg = None
self.LR_env = None # environment for lmdb
self.HR_env = None
self.HR_env_bg = None

# read image list from lmdb or image files
self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
self.HR_env_bg, self.paths_HR_bg = util.get_image_paths(opt['data_type'], \
opt['dataroot_HR_bg'])

assert self.paths_HR, 'Error: HR paths are empty.'
if self.paths_LR and self.paths_HR:
assert len(self.paths_LR) == len(self.paths_HR), \
'HR and LR datasets have different number of images - {}, {}.'.format(\
len(self.paths_LR), len(self.paths_HR))

# randomly scale list
self.random_scale_list = [1, 0.9, 0.8, 0.7, 0.6, 0.5]

self.ration = 10 # 10 OST data and 1 DIV2K general data

def __getitem__(self, index):
HR_path, LR_path = None, None
scale = self.opt['scale']
HR_size = self.opt['HR_size']

# get HR image
HR_path = self.paths_HR[index]
img_HR = util.read_img(self.HR_env, HR_path)
if random.choice(list(range(self.ration))) == 0: # read bg image
bg_index = random.randint(0, len(self.paths_HR_bg) - 1)
HR_path = self.paths_HR_bg[bg_index]
img_HR = util.read_img(self.HR_env_bg, HR_path)
seg = torch.FloatTensor(8, img_HR.shape[0], img_HR.shape[1]).fill_(0)
seg[0,:,:] = 1 # background
else:
HR_path = self.paths_HR[index]
img_HR = util.read_img(self.HR_env, HR_path)
seg = torch.load(HR_path.replace('/img/', '/bicseg/').replace('.png', '.pth'))
# modcrop in validation phase
if self.opt['phase'] != 'train':
img_HR = util.modcrop(img_HR, 8)

# get segmentation probability map
seg = torch.load(HR_path.replace('/img/', '/bicseg/').replace('.png', '.pth'))
seg = np.transpose(seg.numpy(), (1, 2, 0))

# get LR image
Expand Down Expand Up @@ -95,21 +104,21 @@ def _mod(n, random_scale, scale, thres):

# category
if 'building' in HR_path:
category = 0
elif 'plant' in HR_path:
category = 1
elif 'mountain' in HR_path:
elif 'plant' in HR_path:
category = 2
elif 'water' in HR_path:
elif 'mountain' in HR_path:
category = 3
elif 'sky' in HR_path:
elif 'water' in HR_path:
category = 4
elif 'grass' in HR_path:
elif 'sky' in HR_path:
category = 5
elif 'animal' in HR_path:
elif 'grass' in HR_path:
category = 6
elif 'animal' in HR_path:
category = 7
else:
category = 7 # background
category = 0 # background
else:
category = -1 # during val, useless

Expand Down
8 changes: 6 additions & 2 deletions codes/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ def create_dataset(dataset_opt):
from data.LR_dataset import LRDataset as D
elif mode == 'LRHR':
from data.LRHR_dataset import LRHRDataset as D
elif mode == 'LRHRseg':
from data.LRHR_seg_dataset import LRHRSegDataset as D
# elif mode == 'LRHR_bg':
# from data.LRHR_bg_dataset import LRHR_BG_Dataset as D
# elif mode == 'LRHRseg':
# from data.LRHR_seg_dataset import LRHRSegDataset as D
elif mode == 'LRHRseg_bg':
from data.LRHR_seg_bg_dataset import LRHRSeg_BG_Dataset as D
else:
raise NotImplementedError("Dataset [%s] is not recognized." % mode)
dataset = D(dataset_opt)
Expand Down

0 comments on commit 5e2b763

Please sign in to comment.