Skip to content

Commit

Permalink
Upgrade TrackingConsumer with the new functionality. (#110)
Browse files Browse the repository at this point in the history
* TrackingClient inherits from PredictClient, DRY the client.

* leave the tempdir context as fast as possible to save memory.

* skimage >= 0.17 has skimage.external error.

* add stub for testing _load_data
  • Loading branch information
willgraf committed May 27, 2020
1 parent c709b49 commit 4e268b2
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 153 deletions.
202 changes: 100 additions & 102 deletions redis_consumer/consumers/tracking_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def _get_model(self, redis_hash, hvalues):
model, version = settings.TRACKING_MODEL.split(':')

t = timeit.default_timer()
model = TrackingClient(hostname,
redis_hash,
model,
int(version),
model = TrackingClient(host=hostname,
redis_hash=redis_hash,
model_name=model,
model_version=int(version),
progress_callback=self._update_progress)

self.logger.debug('Created the TrackingClient in %s seconds.',
Expand All @@ -92,7 +92,7 @@ def _get_tracker(self, redis_hash, hvalues, raw, segmented):
# If not, the data must be normalized before being tracked.
if settings.NORMALIZE_TRACKING:
for frame in range(raw.shape[0]):
raw[frame, :, :, 0] = processing.normalize(raw[frame, :, :, 0])
raw[frame, ..., 0] = processing.normalize(raw[frame, ..., 0])

features = {'appearance', 'distance', 'neighborhood', 'regionprop'}
tracker = tracking.CellTracker(
Expand Down Expand Up @@ -138,7 +138,7 @@ def _load_data(self, redis_hash, subdir, fname):
raw = utils.get_image(os.path.join(subdir, fname))

# remove the last dimensions added by `get_image`
tiff_stack = np.squeeze(raw, -1) # TODO: required? check the ndim?
tiff_stack = np.squeeze(raw, -1)
if len(tiff_stack.shape) != 3:
raise ValueError('This tiff file has shape {}, which is not 3 '
'dimensions. Tracking can only be done on images '
Expand All @@ -162,88 +162,82 @@ def _load_data(self, redis_hash, subdir, fname):
num_frames = len(tiff_stack)
hash_to_frame = {}
remaining_hashes = set()
frames = {}

self.logger.debug('Got tiffstack shape %s.', tiff_stack.shape)
self.logger.debug('tiffstack num_frames %s.', num_frames)

with utils.get_tempdir() as tempdir:
uid = uuid.uuid4().hex
for (i, img) in enumerate(tiff_stack):
# make a file name for this frame
uid = uuid.uuid4().hex
for i, img in enumerate(tiff_stack):

with utils.get_tempdir() as tempdir:
# Save and upload the frame.
segment_fname = '{}-{}-tracking-frame-{}.tif'.format(
uid, hvalues.get('original_name'), i)
segment_local_path = os.path.join(tempdir, segment_fname)

# upload it
tifffile.imsave(segment_local_path, img)
upload_file_name, upload_file_url = self.storage.upload(
segment_local_path)

# prepare hvalues for this frame's hash
current_timestamp = self.get_current_timestamp()
frame_hvalues = {
'identity_upload': self.name,
'input_file_name': upload_file_name,
'original_name': segment_fname,
'model_name': model_name,
'model_version': model_version,
'postprocess_function': postprocess_function,
'status': 'new',
'created_at': current_timestamp,
'updated_at': current_timestamp,
'url': upload_file_url,
'scale': scale,
# 'label': str(label)
}

self.logger.debug('Setting %s', frame_hvalues)

# make a hash for this frame
segment_hash = '{prefix}:{file}:{hash}'.format(
prefix=settings.SEGMENTATION_QUEUE,
file=segment_fname,
hash=uuid.uuid4().hex)

# push the hash to redis and the predict queue
self.redis.hmset(segment_hash, frame_hvalues)
self.redis.lpush(settings.SEGMENTATION_QUEUE, segment_hash)
self.logger.debug('Added new hash for segmentation `%s`: %s',
segment_hash, json.dumps(frame_hvalues,
indent=4))
hash_to_frame[segment_hash] = i
remaining_hashes.add(segment_hash)

# pop hash, check it, and push it back if it's not done
# this checks the same hash over and over again, since set's
# pop is not random. This is fine, since we still need every
# hash to finish before doing anything.
frames = {}
while remaining_hashes:
finished_hashes = set()

self.logger.debug('Checking on hashes.')
for segment_hash in remaining_hashes:
status = self.redis.hget(segment_hash, 'status')

self.logger.debug('Hash %s has status %s',
segment_hash, status)

if status == self.failed_status:
reason = self.redis.hget(segment_hash, 'reason')
raise RuntimeError(
'Tracking failed during segmentation on frame {}.'
'\nSegmentation Error: {}'.format(
hash_to_frame[segment_hash], reason))

if status == self.final_status:
# if it's done, save the frame, as they'll be packed up
# later
# prepare hvalues for this frame's hash
current_timestamp = self.get_current_timestamp()
frame_hvalues = {
'identity_upload': self.name,
'input_file_name': upload_file_name,
'original_name': segment_fname,
'model_name': model_name,
'model_version': model_version,
'postprocess_function': postprocess_function,
'status': 'new',
'created_at': current_timestamp,
'updated_at': current_timestamp,
'url': upload_file_url,
'scale': scale,
# 'label': str(label)
}

# make a hash for this frame
segment_hash = '{prefix}:{file}:{hash}'.format(
prefix=settings.SEGMENTATION_QUEUE,
file=segment_fname,
hash=uuid.uuid4().hex)

# push the hash to redis and the predict queue
self.redis.hmset(segment_hash, frame_hvalues)
self.redis.lpush(settings.SEGMENTATION_QUEUE, segment_hash)
self.logger.debug('Added new hash for segmentation `%s`: %s',
segment_hash, json.dumps(frame_hvalues, indent=4))
hash_to_frame[segment_hash] = i
remaining_hashes.add(segment_hash)

# pop hash, check it, and push it back if it's not done
# this checks the same hash over and over again, since set's
# pop is not random. This is fine, since we still need every
# hash to finish before doing anything.
while remaining_hashes:
finished_hashes = set()
for segment_hash in remaining_hashes:
status = self.redis.hget(segment_hash, 'status')

self.logger.debug('Hash %s has status %s',
segment_hash, status)

if status == self.failed_status:
# Segmentation failed, tracking cannot be finished.
reason = self.redis.hget(segment_hash, 'reason')
raise RuntimeError(
'Tracking failed during segmentation on frame {}. '
'Segmentation Error: {}'.format(
hash_to_frame[segment_hash], reason))

if status == self.final_status:
# Segmentation is finished, save and load the frame.
with utils.get_tempdir() as tempdir:
frame_zip = self.storage.download(
self.redis.hget(segment_hash, 'output_file_name'),
tempdir)

frame_files = list(utils.iter_image_archive(frame_zip,
tempdir))
frame_files = list(utils.iter_image_archive(
frame_zip, tempdir))

if len(frame_files) != 1:
raise RuntimeError(
Expand All @@ -255,16 +249,17 @@ def _load_data(self, redis_hash, subdir, fname):
frames[frame_idx] = utils.get_image(frame_files[0])
finished_hashes.add(segment_hash)

remaining_hashes -= finished_hashes
time.sleep(settings.INTERVAL)
remaining_hashes -= finished_hashes
time.sleep(settings.INTERVAL)

frames = [frames[i] for i in range(num_frames)]
labels = [frames[i] for i in range(num_frames)]

# Cast y to int to avoid issues during fourier transform/drift correction
return {'X': np.expand_dims(tiff_stack, axis=-1),
'y': np.array(frames, dtype='uint16')}
'y': np.array(labels, dtype='uint16')}

def _consume(self, redis_hash):
start = timeit.default_timer()
hvalues = self.redis.hgetall(redis_hash)
self.logger.debug('Found `%s:*` hash to process "%s": %s',
self.queue, redis_hash, json.dumps(hvalues, indent=4))
Expand All @@ -278,36 +273,36 @@ def _consume(self, redis_hash):
self.update_key(redis_hash, {
'status': 'started',
'progress': 0,
'identity_started': self.name,
})

with utils.get_tempdir() as tempdir:
fname = self.storage.download(hvalues.get('input_file_name'),
tempdir)
data = self._load_data(redis_hash, tempdir, fname)

self.logger.debug('Got contents tracking file contents.')
self.logger.debug('X shape: %s', data['X'].shape)
self.logger.debug('y shape: %s', data['y'].shape)

# Correct for drift if enabled
if settings.DRIFT_CORRECT_ENABLED:
t = timeit.default_timer()
data['X'], data['y'] = processing.correct_drift(data['X'], data['y'])
self.logger.debug('Drift correction complete in %s seconds.',
timeit.default_timer() - t)
self.logger.debug('Got contents tracking file contents.')
self.logger.debug('X shape: %s', data['X'].shape)
self.logger.debug('y shape: %s', data['y'].shape)

# TODO Add support for rescaling in the tracker
tracker = self._get_tracker(redis_hash, hvalues,
data['X'], data['y'])
self.logger.debug('Trying to track...')
# Correct for drift if enabled
if settings.DRIFT_CORRECT_ENABLED:
t = timeit.default_timer()
data['X'], data['y'] = processing.correct_drift(data['X'], data['y'])
self.logger.debug('Drift correction complete in %s seconds.',
timeit.default_timer() - t)

tracker.track_cells()
# TODO: Add support for rescaling in the tracker
tracker = self._get_tracker(redis_hash, hvalues, data['X'], data['y'])

self.logger.debug('Tracking done!')
self.logger.debug('Trying to track...')
tracker.track_cells()
self.logger.debug('Tracking done!')

# Post-process and save the output file
tracked_data = tracker.postprocess()
# Post-process and save the output file
tracked_data = tracker.postprocess()

with utils.get_tempdir() as tempdir:
# Save lineage data to JSON file
lineage_file = os.path.join(tempdir, 'lineage.json')
with open(lineage_file, 'w') as fp:
Expand All @@ -329,10 +324,13 @@ def _consume(self, redis_hash):

output_file_name, output_url = self.storage.upload(zip_file)

self.update_key(redis_hash, {
'status': self.final_status,
'output_url': output_url,
'output_file_name': output_file_name,
'finished_at': self.get_current_timestamp(),
})
t = timeit.default_timer() - start
self.update_key(redis_hash, {
'status': self.final_status,
'output_url': output_url,
'output_file_name': output_file_name,
'finished_at': self.get_current_timestamp(),
'total_jobs': 1,
'total_time': t,
})
return self.final_status
18 changes: 16 additions & 2 deletions redis_consumer/consumers/tracking_consumer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
import os
import copy

import pytest

import numpy as np
from skimage.external import tifffile as tiff

import pytest

from redis_consumer import consumers
from redis_consumer import settings
from redis_consumer import utils
Expand Down Expand Up @@ -208,6 +208,20 @@ def test_is_valid_hash(self):
assert consumer.is_valid_hash('track:1234567890:file.trk') is True
assert consumer.is_valid_hash('track:1234567890:file.trks') is True

def test__load_data(self, tmpdir):
queue = 'track'
items = ['item%s' % x for x in range(1, 4)]
redis_hash = 'track:1234567890:file.trks'
storage = DummyStorage()
redis_client = DummyRedis(items)
consumer = consumers.TrackingConsumer(redis_client, storage, queue)

# test bad filetype
with pytest.raises(ValueError):
consumer._load_data(redis_hash, str(tmpdir), 'data.npz')

# TODO: test successful workflow

def test__get_tracker(self):
queue = 'track'
items = ['item%s' % x for x in range(1, 4)]
Expand Down
Loading

0 comments on commit 4e268b2

Please sign in to comment.