Skip to content

Commit

Permalink
update extract_subimgs_single.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Sep 6, 2018
1 parent c4d8621 commit 29ab5e1
Showing 1 changed file with 66 additions and 59 deletions.
125 changes: 66 additions & 59 deletions codes/scripts/extract_subimgs_single.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,87 @@
import os
import os.path
import sys
from multiprocessing import Pool
import time
import numpy as np
import cv2
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.progress_bar import ProgressBar


def main():
GT_dir = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800'
save_GT_dir = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub'
"""A multi-thread tool to crop sub imags."""
input_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800'
save_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub'
n_thread = 20
crop_sz = 480
step = 240
thres_sz = 48
compression_level = 3 # 3 is the default value in cv2
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
# compression time. If read raw images during training, use 0 for faster IO speed.

print('Parent process %s.' % os.getpid())
start = time.time()
if not os.path.exists(save_folder):
os.makedirs(save_folder)
print('mkdir [{:s}] ...'.format(save_folder))
else:
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
sys.exit(1)

p = Pool(n_thread)
# read all files to a list
all_files = []
for root, _, fnames in sorted(os.walk(GT_dir)):
full_path = [os.path.join(root, x) for x in fnames]
all_files.extend(full_path)
# cut into subtasks
def chunkify(lst, n): # for non-continuous chunks
return [lst[i::n] for i in range(n)]
img_list = []
for root, _, file_list in sorted(os.walk(input_folder)):
path = [os.path.join(root, x) for x in file_list] # assume only images in the input_folder
img_list.extend(path)

sub_lists = chunkify(all_files, n_thread)
# call workers
for i in range(n_thread):
p.apply_async(worker, args=(sub_lists[i], save_GT_dir))
print('Waiting for all subprocesses done...')
p.close()
p.join()
end = time.time()
print('All subprocesses done. Using time {} sec.'.format(end - start))
def update(arg):
pbar.update(arg)

pbar = ProgressBar(len(img_list))

pool = Pool(n_thread)
for path in img_list:
pool.apply_async(worker,
args=(path, save_folder, crop_sz, step, thres_sz, compression_level),
callback=update)
pool.close()
pool.join()
print('All subprocesses done.')

def worker(GT_paths, save_GT_dir):
crop_sz = 480
step = 240
thres_sz = 48

for GT_path in GT_paths:
base_name = os.path.basename(GT_path)
print(base_name, os.getpid())
img_GT = cv2.imread(GT_path, cv2.IMREAD_UNCHANGED)
def worker(path, save_folder, crop_sz, step, thres_sz, compression_level):
img_name = os.path.basename(path)
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)

n_channels = len(img_GT.shape)
if n_channels == 2:
h, w = img_GT.shape
elif n_channels == 3:
h, w, c = img_GT.shape
else:
raise ValueError('Wrong image shape - {}'.format(n_channels))
n_channels = len(img.shape)
if n_channels == 2:
h, w = img.shape
elif n_channels == 3:
h, w, c = img.shape
else:
raise ValueError('Wrong image shape - {}'.format(n_channels))

h_space = np.arange(0, h - crop_sz + 1, step)
if h - (h_space[-1] + crop_sz) > thres_sz:
h_space = np.append(h_space, h - crop_sz)
w_space = np.arange(0, w - crop_sz + 1, step)
if w - (w_space[-1] + crop_sz) > thres_sz:
w_space = np.append(w_space, w - crop_sz)
index = 0
for x in h_space:
for y in w_space:
index += 1
if n_channels == 2:
crop_img = img_GT[x:x + crop_sz, y:y + crop_sz]
else:
crop_img = img_GT[x:x + crop_sz, y:y + crop_sz, :]
h_space = np.arange(0, h - crop_sz + 1, step)
if h - (h_space[-1] + crop_sz) > thres_sz:
h_space = np.append(h_space, h - crop_sz)
w_space = np.arange(0, w - crop_sz + 1, step)
if w - (w_space[-1] + crop_sz) > thres_sz:
w_space = np.append(w_space, w - crop_sz)

crop_img = np.ascontiguousarray(crop_img)
index_str = '{:03d}'.format(index)
# var = np.var(crop_img / 255)
# if var > 0.008:
# print(index_str, var)
cv2.imwrite(os.path.join(save_GT_dir, base_name.replace('.png', \
'_s'+index_str+'.png')), crop_img, [cv2.IMWRITE_PNG_COMPRESSION, 0])
index = 0
for x in h_space:
for y in w_space:
index += 1
if n_channels == 2:
crop_img = img[x:x + crop_sz, y:y + crop_sz]
else:
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
crop_img = np.ascontiguousarray(crop_img)
# var = np.var(crop_img / 255)
# if var > 0.008:
# print(img_name, index_str, var)
cv2.imwrite(
os.path.join(save_folder, img_name.replace('.png', '_s{:03d}.png'.format(index))),
crop_img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level])
return 'Processing {:s} ...'.format(img_name)


if __name__ == '__main__':
Expand Down

0 comments on commit 29ab5e1

Please sign in to comment.