Skip to content

Commit

Permalink
Reduce memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
tsurumeso committed Mar 3, 2017
1 parent b4aa54a commit 005489e
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions lib/dataset_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import six
import numpy as np
import multiprocessing
from tempfile import NamedTemporaryFile

from lib import iproc
from lib.pairwise_transform import pairwise_transform
Expand All @@ -14,7 +15,7 @@ def __init__(self, datalist, config, repeat=False):
self.config = config
self.repeat = repeat
self.worker = None
self.data_queue = None
self.name_queue = None
self.finalized = None
self.dataset = None
self.running = False
Expand All @@ -27,30 +28,38 @@ def __del__(self):
def finalize(self):
if self.running:
self.finalized.set()
garbage = self.data_queue.get(timeout=0.5)
garbage = self.name_queue.get(timeout=0.5)
self.worker.join()
del garbage
os.remove(garbage)

def reload_switch(self):
self._switch = True

def _init_process(self):
self.data_queue = multiprocessing.Queue()
self.name_queue = multiprocessing.Queue()
self.finalized = multiprocessing.Event()
args = [self.datalist, self.data_queue, self.config, self.finalized]
args = [self.datalist, self.name_queue, self.config, self.finalized]
self.worker = multiprocessing.Process(target=_worker, args=args)
self.worker.daemon = True
self.worker.start()
self.running = True

def get(self):
if self.running and self._switch:
self.dataset = self.data_queue.get()
cache_name = self.name_queue.get()
self.worker.join()
six.print_(' * loading dataset from cache...',
end=' ', flush=True)
with np.load(cache_name) as cached_arr:
self.dataset = cached_arr['x'], cached_arr['y']
os.remove(cache_name)
six.print_('done')

self.running = False
self._switch = False
if self.repeat:
self._init_process()

return self.dataset

def save_images(self, dir):
Expand All @@ -70,18 +79,22 @@ def save_images(self, dir):
iy.save(os.path.join(dir, header + '_y.png'))


def _worker(datalist, out_queue, cfg, finalized):
def _worker(datalist, name_queue, cfg, finalized):
sample_size = cfg.patches * len(datalist)
x = np.zeros(
(sample_size, cfg.ch, cfg.insize, cfg.insize), dtype=np.uint8)
y = np.zeros(
(sample_size, cfg.ch, cfg.crop_size, cfg.crop_size), dtype=np.uint8)

for i in six.moves.range(len(datalist)):
if finalized.is_set():
break
img = iproc.read_image_rgb_uint8(datalist[i])
xc_batch, yc_batch = pairwise_transform(img, cfg)
x[cfg.patches * i:cfg.patches * (i + 1)] = xc_batch[:]
y[cfg.patches * i:cfg.patches * (i + 1)] = yc_batch[:]
out_queue.put([x, y])
del x, y

with NamedTemporaryFile(delete=False) as cache:
np.savez(cache, x=x, y=y)
name_queue.put(cache.name)
del x, y

0 comments on commit 005489e

Please sign in to comment.