Skip to content

Commit

Permalink
Merge b6d73be into 586575e
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf committed Sep 4, 2019
2 parents 586575e + b6d73be commit 595fc5c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 131 deletions.
120 changes: 12 additions & 108 deletions redis_consumer/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,104 +477,6 @@ def _get_processing_function(self, process_type, function_name):
% (name, cat))
return settings.PROCESSING_FUNCTIONS[cat][name]

# def _process(self, image, key, process_type, streaming=False):
# """Apply each processing function to image.
#
# Args:
# image: numpy array of image data
# key: function to apply to image
# process_type: pre or post processing
# streaming: boolean. if True, streams data in multiple requests
#
# Returns:
# list of processed image data
# """
# # Squeeze out batch dimension if unnecessary
# if image.shape[0] == 1:
# image = np.squeeze(image, axis=0)
#
# if not key:
# return image
#
# self.logger.debug('Starting %s %s-processing image of shape %s',
# key, process_type, image.shape)
#
# retrying = True
# count = 0
# start = timeit.default_timer()
# while retrying:
# try:
# key = str(key).lower()
# process_type = str(process_type).lower()
# hostname = '{}:{}'.format(settings.DP_HOST, settings.DP_PORT)
# client = ProcessClient(hostname, process_type, key)
#
# if streaming:
# dtype = 'DT_STRING'
# else:
# dtype = settings.TF_TENSOR_DTYPE
#
# req_data = [{'in_tensor_name': settings.TF_TENSOR_NAME,
# 'in_tensor_dtype': dtype,
# 'data': np.expand_dims(image, axis=0)}]
#
# if streaming:
# results = client.stream_process(req_data, settings.GRPC_TIMEOUT)
# else:
# results = client.process(req_data, settings.GRPC_TIMEOUT)
#
# finished = timeit.default_timer() - start
#
# self.update_key(self._redis_hash, {
# '{}process_time'.format(process_type): finished
# })
#
# self.logger.debug('%s-processed key %s (model %s:%s, '
# 'preprocessing: %s, postprocessing: %s)'
# ' (%s retries) in %s seconds.',
# process_type.capitalize(), self._redis_hash,
# self._redis_values.get('model_name'),
# self._redis_values.get('model_version'),
# self._redis_values.get('preprocess_function'),
# self._redis_values.get('postprocess_function'),
# count, finished)
#
# results = results['results']
# # Again, squeeze out batch dimension if unnecessary
# if results.shape[0] == 1:
# results = np.squeeze(results, axis=0)
#
# retrying = False
# return results
# except grpc.RpcError as err:
# # pylint: disable=E1101
# if err.code() in settings.GRPC_RETRY_STATUSES:
# count += 1
# temp_status = 'retry-processing - {} - {}'.format(
# count, err.code().name)
# self.update_key(self._redis_hash, {
# 'status': temp_status,
# 'process_retries': count,
# })
# self.logger.warning('%sException `%s: %s` during %s '
# '%s-processing request. Waiting %s '
# 'seconds before retrying.',
# type(err).__name__, err.code().name,
# err.details(), key, process_type,
# settings.GRPC_BACKOFF)
# self.logger.debug('Waiting for %s seconds before retrying',
# settings.GRPC_BACKOFF)
# time.sleep(settings.GRPC_BACKOFF) # sleep before retry
# retrying = True # Unneccessary but explicit
# else:
# retrying = False
# raise err
# except Exception as err:
# retrying = False
# self.logger.error('Encountered %s during %s %s-processing: %s',
# type(err).__name__, key, process_type, err)
# raise err

def process(self, image, key, process_type):
start = timeit.default_timer()
if not key:
Expand Down Expand Up @@ -659,11 +561,13 @@ def _consume(self, redis_hash):
'identity_started': self.hostname,
})

model_name = hvals.get('model_name')
model_version = hvals.get('model_version')
cuts = hvals.get('cuts', '0')
field = hvals.get('field_size', '61')

# Overridden with LABEL_DETECT_ENABLED
model_name = hvals.get('model_name')
model_version = hvals.get('model_version')

with utils.get_tempdir() as tempdir:
_ = timeit.default_timer()
fname = self.storage.download(hvals.get('input_file_name'), tempdir)
Expand All @@ -678,8 +582,8 @@ def _consume(self, redis_hash):
})

# Calculate scale of image and rescale
scale = hvals.get('scale')
if scale is None:
scale = hvals.get('scale', '')
if not scale:
# Detect scale of image
scale = self.detect_scale(image)
self.logger.debug('Image scale detected: %s', scale)
Expand All @@ -692,8 +596,8 @@ def _consume(self, redis_hash):

if settings.LABEL_DETECT_ENABLED:
# Detect image label type
label = hvals.get('label')
if label is None:
label = hvals.get('label', '')
if not label:
label = self.detect_label(image)
self.logger.debug('Image label detected: %s', label)
self.update_key(redis_hash, {'label': str(label)})
Expand All @@ -720,7 +624,7 @@ def _consume(self, redis_hash):
self.update_key(redis_hash, {'status': 'post-processing'})

if settings.LABEL_DETECT_ENABLED:
post_funcs = utils._pick_postprocess(label)
post_funcs = utils._pick_postprocess(label).split(',')
else:
post_funcs = hvals.get('postprocess_function', '').split(',')

Expand Down Expand Up @@ -1100,9 +1004,9 @@ def _get_model(self, redis_hash, hvalues):
hostname = '{}:{}'.format(settings.TF_HOST, settings.TF_PORT)

# Pick model based on redis or default setting
model = hvalues.get('model_name')
version = hvalues.get('model_version')
if (model is None) or (version is None):
model = hvalues.get('model_name', '')
version = hvalues.get('model_version', '')
if not model or not version:
model, version = settings.TRACKING_MODEL.split(':')

t = timeit.default_timer()
Expand Down
33 changes: 10 additions & 23 deletions redis_consumer/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self,
# Initialize tracks
self._initialize_tracks()

self.logger.info("Tracks initialized and cleaned up.")
self.logger.info('Tracks initialized and cleaned up.')

def _clean_up_annotations(self):
"""Relabels every frame in the label matrix.
Expand Down Expand Up @@ -296,7 +296,7 @@ def _get_cost_matrix(self, frame):
inputs[feature_name][0].append(track_feature)
inputs[feature_name][1].append(frame_feature)

print('Got features in {}s'.format(timeit.default_timer() - t))
self.logger.info('Got features in %s seconds.', timeit.default_timer() - t)

if input_pairs == []:
# if the frame is empty
Expand All @@ -315,9 +315,9 @@ def _get_cost_matrix(self, frame):

predictions = self.model.predict(model_input)

self.logger.info("assignment_matrix.shape: %s",
self.logger.info('assignment_matrix.shape: %s',
assignment_matrix.shape)
self.logger.info("predictions.shape: %s", predictions.shape)
self.logger.info('predictions.shape: %s', predictions.shape)
for i, (track, cell) in enumerate(input_pairs):
assignment_matrix[track, cell] = 1 - predictions[i, 1]

Expand Down Expand Up @@ -608,7 +608,6 @@ def _fetch_track_neighborhoods(self, before_frame):
return track_neighborhoods

def _sub_area(self, X_frame, y_frame, cell_label, num_channels):
t = timeit.default_timer()
true_size = self.neighborhood_true_size
pads = ((true_size, true_size),
(true_size, true_size),
Expand Down Expand Up @@ -638,7 +637,6 @@ def _sub_area(self, X_frame, y_frame, cell_label, num_channels):

# X_reduced /= np.amax(X_reduced)
X_reduced = np.expand_dims(X_reduced, axis=self.channel_axis)
print('_sub_area finished in {}s'.format(timeit.default_timer() - t))
return X_reduced

def _get_features(self, X, y, frames, labels):
Expand Down Expand Up @@ -740,26 +738,17 @@ def _track_cells(self):
"""Tracks all of the cells in every frame.
"""
for frame in range(1, self.x.shape[0]):
self.logger.info('Tracking frame ' + str(frame))

t_whole = timeit.default_timer() # TODEL
t = timeit.default_timer() # TODEL
t = timeit.default_timer()
self.logger.info('Tracking frame %s', frame)

cost_matrix, predictions = self._get_cost_matrix(frame)

print('Time to get_cost_matrix: ', timeit.default_timer() - t) # TODEL
t = timeit.default_timer() # TODEL

assignments = self._run_lap(cost_matrix)

print('Time to run lap: ', timeit.default_timer() - t) # TODEL
t = timeit.default_timer() # TODEL

self._update_tracks(assignments, frame, predictions)
self.model.progress(frame / self.x.shape[0])

print('Time to update tracks: ', timeit.default_timer() - t) # TODEL
print('Time to track one frame: ', timeit.default_timer() - t_whole) # TODEL
self.logger.info('Tracked frame %s in %s seconds.',
frame, timeit.default_timer() - t)

def _track_review_dict(self):
def process(key, track_item):
Expand Down Expand Up @@ -836,7 +825,6 @@ def postprocess(self, filename=None, time_excl=9):

# If FPs exist, use the results to correct
while len(FPs_sorted) != 0:

lineage, tracked = self._remove_false_pos(lineage, tracked, FPs_sorted[0])
G = self._track_to_graph(lineage)
FPs = self._flag_false_pos(G, time_excl)
Expand Down Expand Up @@ -968,9 +956,8 @@ def _flag_false_pos(self, G, time_excl):
'false positive': node,
'neighbors': list(G.neighbors(node)),
'connected lineages': set([int(n.split('_')[0])
for n in nx.node_connected_component(G, n)])
for n in nx.node_connected_component(G, n)])
}

return D

def _review_candidate_nodes(self, FPs_candidates):
Expand Down Expand Up @@ -1084,6 +1071,6 @@ def _remove_false_pos(self, lineage, tracked, FP_info):
del lineage[label_to_remove]

else:
print('Error: More than 2 neighbor nodes')
self.logger.error('Error: More than 2 neighbor nodes')

return lineage, tracked

0 comments on commit 595fc5c

Please sign in to comment.