From 8f75cfae89bb170605cafb82b13868579ef3ba08 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Fri, 23 Aug 2019 17:09:36 -0700 Subject: [PATCH 1/5] update tracking.py from deepcell --- redis_consumer/tracking.py | 417 ++++++++++++++++++++++++++++++------- 1 file changed, 344 insertions(+), 73 deletions(-) diff --git a/redis_consumer/tracking.py b/redis_consumer/tracking.py index 3ca5c833..b3287d4b 100644 --- a/redis_consumer/tracking.py +++ b/redis_consumer/tracking.py @@ -30,14 +30,17 @@ import json import pathlib import tarfile -import zipfile import tempfile +import timeit +import zipfile +import cv2 +import networkx as nx import numpy as np +import pandas as pd from scipy.optimize import linear_sum_assignment from skimage.measure import regionprops from skimage.transform import resize -from pandas import DataFrame class cell_tracker(): @@ -122,7 +125,6 @@ def _create_new_track(self, frame, old_label): """ This function creates new tracks """ - import traceback new_track = len(self.tracks.keys()) new_label = new_track + 1 @@ -198,6 +200,7 @@ def _get_cost_matrix(self, frame): """Uses the model to create the cost matrix for assigning the cells in frame to existing tracks. """ + t = timeit.default_timer() # Initialize matrices number_of_tracks = np.int(len(self.tracks.keys())) @@ -241,39 +244,59 @@ def _get_cost_matrix(self, frame): # Compute assignment matrix - Initialize and get model inputs # Fill the input matrices for track in range(number_of_tracks): + + # we need to get the future frame for the track we are comparing to + try: + track_label = self.tracks[track]['label'] + track_frame_features = self._get_features( + self.x, self.y_tracked, [frame - 1], [track_label]) + except: + # `track_label` might not exist in `frame - 1` + # if this happens, default to the cell's neighborhood + track_frame_features = dict() + for cell in range(number_of_cells): - feature_ok = True feature_vals = {} + + # If distance is a feature it is used to exclude + # impossible pairings from the get_feature call + if 'distance' in self.features: + _, _, is_cell_in_range = self._compute_feature( + 'distance', + track_features['distance'][track], + frame_features['distance'][cell]) + else: + # not worried about distance, just calculate features + is_cell_in_range = True + + if not is_cell_in_range: + # Cell is outside of range, set cost to max and move on + assignment_matrix[track, cell] = 1 + continue + + # The cell is within range so we should add + # all the information for all features for feature_name in self.features: - track_feature, frame_feature, ok = self._compute_feature( + + track_feature, frame_feature, _ = self._compute_feature( feature_name, track_features[feature_name][track], frame_features[feature_name][cell]) # this condition changes `frame_feature` if feature_name == 'neighborhood': - # we need to get the future frame for the track we are comparing to - track_label = self.tracks[track]['label'] - try: - track_frame_features = self._get_features( - self.x, self.y_tracked, [frame - 1], [track_label]) - frame_feature = track_frame_features['~future area'] - except: - # `track_label` might not exist in `frame - 1` - # if this happens, default to the cell's neighborhood - pass - - if ok: - feature_vals[feature_name] = (track_feature, frame_feature) - else: - feature_ok = False - assignment_matrix[track, cell] = 1 - - if feature_ok: - input_pairs.append((track, cell)) - for feature_name, (track_feature, frame_feature) in feature_vals.items(): - inputs[feature_name][0].append(track_feature) - inputs[feature_name][1].append(frame_feature) + # This segment of the loop should not be run + # if the disance check fails + frame_feature = track_frame_features.get('~future area', frame_feature) + + feature_vals[feature_name] = (track_feature, frame_feature) + + input_pairs.append((track, cell)) + for feature_name, (track_feature, frame_feature) in feature_vals.items(): + inputs[feature_name][0].append(track_feature) + inputs[feature_name][1].append(frame_feature) + + print('Got features in {}s'.format(timeit.default_timer() - t)) if input_pairs == []: # if the frame is empty @@ -585,30 +608,37 @@ def _fetch_track_neighborhoods(self, before_frame): return track_neighborhoods def _sub_area(self, X_frame, y_frame, cell_label, num_channels): - shape = (2 * self.neighborhood_scale_size + 1, - 2 * self.neighborhood_scale_size + 1, - 1) - neighborhood = np.zeros(shape, dtype='float32') - - pads = ((self.neighborhood_true_size, self.neighborhood_true_size), - (self.neighborhood_true_size, self.neighborhood_true_size), + t = timeit.default_timer() + true_size = self.neighborhood_true_size + pads = ((true_size, true_size), + (true_size, true_size), (0, 0)) + X_padded = np.pad(X_frame, pads, mode='constant', constant_values=0) y_padded = np.pad(y_frame, pads, mode='constant', constant_values=0) + props = regionprops(np.squeeze(np.int32(y_padded == cell_label))) + center_x, center_y = props[0].centroid center_x, center_y = np.int(center_x), np.int(center_y) - X_reduced = X_padded[ - center_x - self.neighborhood_true_size:center_x + self.neighborhood_true_size, - center_y - self.neighborhood_true_size:center_y + self.neighborhood_true_size, :] - # resize to neighborhood_scale_size + X_reduced = X_padded[center_x - true_size:center_x + true_size, + center_y - true_size:center_y + true_size] + + # resize to neighborhood_scale_size with skimage + # resize_shape = (2 * self.neighborhood_scale_size + 1, + # 2 * self.neighborhood_scale_size + 1, + # num_channels) + # X_reduced = resize(X_reduced, resize_shape, mode='constant', preserve_range=True) + + # resize to neighborhood_scale_size with cv2 resize_shape = (2 * self.neighborhood_scale_size + 1, - 2 * self.neighborhood_scale_size + 1, - num_channels) - X_reduced = resize(X_reduced, resize_shape, mode='constant', preserve_range=True) - # X_reduced /= np.amax(X_reduced) + 2 * self.neighborhood_scale_size + 1) + X_reduced = cv2.resize(np.squeeze(X_reduced), resize_shape) + # X_reduced /= np.amax(X_reduced) + X_reduced = np.expand_dims(X_reduced, axis=self.channel_axis) + print('_sub_area finished in {}s'.format(timeit.default_timer() - t)) return X_reduced def _get_features(self, X, y, frames, labels): @@ -671,8 +701,12 @@ def _get_features(self, X, y, frames, labels): resize_shape = (self.crop_dim, self.crop_dim, X.shape[channel_axis]) # Resize images from bounding box - appearance = resize(appearance, resize_shape, mode="constant", preserve_range=True) + t = timeit.default_timer() + # appearance = resize(appearance, resize_shape, mode="constant", preserve_range=True) + resize_shape = (self.crop_dim, self.crop_dim) + appearance = cv2.resize(np.squeeze(appearance), resize_shape) # appearance /= np.amax(appearance) + appearance = np.expand_dims(appearance, axis=self.channel_axis) if self.data_format == 'channels_first': appearances[:, counter] = appearance @@ -707,11 +741,26 @@ def _track_cells(self): """ for frame in range(1, self.x.shape[0]): self.logger.info('Tracking frame ' + str(frame)) + + t_whole = timeit.default_timer() # TODEL + t = timeit.default_timer() # TODEL + cost_matrix, predictions = self._get_cost_matrix(frame) + + print('Time to get_cost_matrix: ', timeit.default_timer() - t) # TODEL + t = timeit.default_timer() # TODEL + assignments = self._run_lap(cost_matrix) + + print('Time to run lap: ', timeit.default_timer() - t) # TODEL + t = timeit.default_timer() # TODEL + self._update_tracks(assignments, frame, predictions) self.model.progress(frame / self.x.shape[0]) + print('Time to update tracks: ', timeit.default_timer() - t) # TODEL + print('Time to track one frame: ', timeit.default_timer() - t_whole) # TODEL + def _track_review_dict(self): def process(key, track_item): if track_item is None: @@ -755,7 +804,7 @@ def dataframe(self, **kwargs): data = [] for cell_id, track in self.tracks.items(): data.append(extra_column_vals + [track[c] for c in track_columns]) - dataframe = DataFrame(data, columns=extra_columns + track_columns) + dataframe = pd.DataFrame(data, columns=extra_columns + track_columns) # daughters contains track_id not labels dataframe['daughters'] = dataframe['daughters'].apply( @@ -763,44 +812,51 @@ def dataframe(self, **kwargs): return dataframe - def dump(self, filename, file_format='.trk'): - """Writes the state of the cell tracker to a .trk ("track") or .zip - file. Includes raw & tracked images, and a lineage.json for - parent/daughter information. - - Args: - filename: PathLib or string to the output file - file_format: either 'zip' or 'trk' + def postprocess(self, filename=None, time_excl=9): + """Use graph postprocessing to eliminate false positive division errors + using a graph-based detection method. False positive errors are when a + cell is noted as a daughter of itself before the actual division occurs. + If a filename is passed, save the state of the cell tracker to a .trk + ('track') file. time_excl is the minimum number of frames expected to + exist between legitimate divisions """ + + # Load data track_review_dict = self._track_review_dict() - filename = pathlib.Path(filename) - if file_format not in ('.trk', '.zip'): - raise ValueError("file_format must be either '.zip' or '.trk'") + # Prep data + tracked = track_review_dict['y_tracked'].astype('uint16') + lineage = track_review_dict['tracks'] - if filename.suffix != file_format: - filename = filename.with_suffix(file_format) + # Identify false positives (FPs) + G = self._track_to_graph(lineage) + FPs = self._flag_false_pos(G, time_excl) + FPs_candidates = sorted(FPs.items(), key=lambda v: int(v[0].split('_')[1])) + FPs_sorted = self._review_candidate_nodes(FPs_candidates) - filename = str(filename) + # If FPs exist, use the results to correct + while len(FPs_sorted) != 0: - if file_format == '.zip': - with zipfile.open(filename, 'w') as trks: - with tempfile.NamedTemporaryFile('w') as lineage_file: - json.dump(track_review_dict['tracks'], lineage_file, indent=1) - lineage_file.flush() - trks.write(lineage_file.name, 'lineage.json') + lineage, tracked = self._remove_false_pos(lineage, tracked, FPs_sorted[0]) + G = self._track_to_graph(lineage) + FPs = self._flag_false_pos(G, time_excl) + FPs_candidates = sorted(FPs.items(), key=lambda v: int(v[0].split('_')[1])) + FPs_sorted = self._review_candidate_nodes(FPs_candidates) - with tempfile.NamedTemporaryFile() as raw_file: - np.save(raw_file, track_review_dict['X']) - raw_file.flush() - trks.write(raw_file.name, 'raw.npy') + # Make sure the assignment is correct + track_review_dict['y_tracked'] = tracked + track_review_dict['tracks'] = lineage - with tempfile.NamedTemporaryFile() as tracked_file: - np.save(tracked_file, track_review_dict['y_tracked']) - tracked_file.flush() - trks.write(tracked_file.name, 'tracked.npy') + # Save information to a track file file if requested + if filename is not None: + # Prep filepath + filename = pathlib.Path(filename) + if filename.suffix != '.trk': + filename = filename.with_suffix('.trk') - else: + filename = str(filename) + + # Save with tarfile.open(filename, 'w') as trks: with tempfile.NamedTemporaryFile('w') as lineage_file: json.dump(track_review_dict['tracks'], lineage_file, indent=1) @@ -816,3 +872,218 @@ def dump(self, filename, file_format='.trk'): np.save(tracked_file, track_review_dict['y_tracked']) tracked_file.flush() trks.add(tracked_file.name, 'tracked.npy') + + return track_review_dict + + def dump(self, filename): + """Writes the state of the cell tracker to a .trk ('track') file. + Includes raw & tracked images, and a lineage.json for parent/daughter + information. + """ + track_review_dict = self._track_review_dict() + filename = pathlib.Path(filename) + + if filename.suffix != '.trk': + filename = filename.with_suffix('.trk') + + filename = str(filename) + + with tarfile.open(filename, 'w') as trks: + with tempfile.NamedTemporaryFile('w') as lineage_file: + json.dump(track_review_dict['tracks'], lineage_file, indent=1) + lineage_file.flush() + trks.add(lineage_file.name, 'lineage.json') + + with tempfile.NamedTemporaryFile() as raw_file: + np.save(raw_file, track_review_dict['X']) + raw_file.flush() + trks.add(raw_file.name, 'raw.npy') + + with tempfile.NamedTemporaryFile() as tracked_file: + np.save(tracked_file, track_review_dict['y_tracked']) + tracked_file.flush() + trks.add(tracked_file.name, 'tracked.npy') + + def _track_to_graph(self, tracks): + """Create a graph from the lineage information""" + Dattr = {} + edges = pd.DataFrame() + + for L in tracks.values(): + # Calculate node ids + cellid = ['{}_{}'.format(L['label'], f) for f in L['frames']] + # Add edges from cell ids + edges = edges.append(pd.DataFrame({'source': cellid[0:-1], + 'target': cellid[1:]})) + + # Collect any division attributes + if L['frame_div'] is not None: + Dattr['{}_{}'.format(L['label'], L['frame_div'] - 1)] = {'division': True} + + # Create any daughter-parent edges + if L['parent'] is not None: + source = '{}_{}'.format(L['parent'], min(L['frames']) - 1) + target = '{}_{}'.format(L['label'], min(L['frames'])) + edges = edges.append(pd.DataFrame({'source': [source], + 'target': [target]})) + + G = nx.from_pandas_edgelist(edges, source='source', target='target') + nx.set_node_attributes(G, Dattr) + return G + + def _flag_false_pos(self, G, time_excl): + """Examine graph for false positive nodes + """ + + # TODO: Current implementation may eliminate some divisions at the edge of the frame - + # Further research needed + + # Identify false positive nodes + node_fix = [] + for g in nx.connected_component_subgraphs(G): + div_nodes = [node for node, d in g.node.data() if d.get('division', False) is True] + if len(div_nodes) > 1: + for nd in div_nodes: + if g.degree(nd) == 2: + # Check how close suspected FP is to other known divisions + neighbors = list(G.neighbors(nd)) + + keep_div = True + for div_nd in div_nodes: + if div_nd != nd: + time_spacing = abs(int(nd.split('_')[1]) - + int(div_nd.split('_')[1])) + # If division is sufficiently far away + # we should exclude it from FP list + if time_spacing > time_excl: + keep_div = False + + if keep_div is True: + node_fix.append(nd) + + # Add supplementary information for each false positive + D = {} + for node in node_fix: + D[node] = { + 'false positive': node, + 'neighbors': list(G.neighbors(node)), + 'connected lineages': set([int(n.split('_')[0]) + for n in nx.node_connected_component(G, n)]) + } + + return D + + def _review_candidate_nodes(self, FPs_candidates): + """ review candidate false positive nodes and remove any errant degree 2 nodes. + """ + FPs_presort = {} + # review candidate false positive nodes and remove any errant degree 2 nodes + for candidate_node in FPs_candidates: + node = candidate_node[0] + node_info = candidate_node[1] + fp_label = int(node.split('_')[0]) + fp_frame = int(node.split('_')[1]) + + neighbors = [] # structure will be [(neighbor1, frame), (neighbor2,frame)] + for neighbor in node_info['neighbors']: + neighbor_label = int(neighbor.split('_')[0]) + neighbor_frame = int(neighbor.split('_')[1]) + neighbors.append((neighbor_label, neighbor_frame)) + + # if this cell only exists in one frame (and then it divides) but its 2 neighbors + # both exist in the same frame it will be a degree 2 node but not be a false positive + if neighbors[0][1] != neighbors[1][1]: + FPs_presort[node] = node_info + + FPs_sorted = sorted(FPs_presort.items(), key=lambda v: int(v[0].split('_')[1])) + + return FPs_sorted + + def _remove_false_pos(self, lineage, tracked, FP_info): + """ Remove nodes that have been identified as false positive divisions. + """ + node = FP_info[0] + node_info = FP_info[1] + + fp_label = int(node.split('_')[0]) + fp_frame = int(node.split('_')[1]) + + neighbors = [] # structure will be [(neighbor1, frame), (neighbor2,frame)] + for neighbor in node_info['neighbors']: + neighbor_label = int(neighbor.split('_')[0]) + neighbor_frame = int(neighbor.split('_')[1]) + neighbors.append((neighbor_label, neighbor_frame)) + + # Verify that the FP node only 2 neighbors - 1 before it and one after it + if len(neighbors) == 2: + # order the neighbors such that the time (frame order) is respected + if neighbors[0][1] > neighbors[1][1]: + temp = neighbors[0] + neighbors[0] = neighbors[1] + neighbors[1] = temp + + # Decide which labels to extend and which to remove + + # Neighbor_1 has same label as fp - the actual division hasnt occurred yet + if fp_label == neighbors[0][0]: + # The model mistakenly identified a division before the actual division occurred + label_to_remove = neighbors[1][0] + label_to_extend = neighbors[0][0] + + # Give all of the errant divisions information to the correct track + lineage[label_to_extend]['frames'].extend(lineage[label_to_remove]['frames']) + lineage[label_to_extend]['daughters'] = lineage[label_to_remove]['daughters'] + lineage[label_to_extend]['frame_div'] = lineage[label_to_remove]['frame_div'] + + # Adjust the parent information for the actual daughters + daughter_labels = lineage[label_to_remove]['daughters'] + for daughter in daughter_labels: + lineage[daughter]['parent'] = lineage[label_to_remove]['parent'] + + # Remove the errant node from the annotated images + channel = 0 # These images should only have one channel + for frame in lineage[label_to_remove]['frames']: + label_loc = np.where(tracked[frame, :, :, channel] == label_to_remove) + tracked[frame, :, :, channel][label_loc] = label_to_extend + + # Remove the errant node from the lineage + del lineage[label_to_remove] + + # Neighbor_2 has same label as fp - the actual division ocurred & + # the model mistakenly allowed another + # elif fp_label == neighbors[1][0]: + # The model mistakenly identified a division after + # the actual division occurred + # label_to_remove = fp_label + + # Neither neighbor has same label as fp - the actual division + # ocurred & the model mistakenly allowed another + else: + # The model mistakenly identified a division after the actual division occurred + label_to_remove = fp_label + label_to_extend = neighbors[1][0] + + # Give all of the errant divisions information to the correct track + lineage[label_to_extend]['frames'] = \ + lineage[fp_label]['frames'] + lineage[label_to_extend]['frames'] + lineage[label_to_extend]['parent'] = lineage[fp_label]['parent'] + + # Adjust the parent information for the actual daughter + parent_label = lineage[fp_label]['parent'] + for d_idx, daughter in enumerate(lineage[parent_label]['daughters']): + if daughter == fp_label: + lineage[parent_label]['daughters'][d_idx] = label_to_extend + + # Remove the errant node from the annotated images + channel = 0 # These images should only have one channel + for frame in lineage[label_to_remove]['frames']: + label_loc = np.where(tracked[frame, :, :, channel] == label_to_remove) + tracked[frame, :, :, channel][label_loc] = label_to_extend + + # Remove the errant node + del lineage[label_to_remove] + + else: + print('Error: More than 2 neighbor nodes') + + return lineage, tracked From 7b5aa1f48efd3cc562df331be6ff32d56c455f3a Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Fri, 23 Aug 2019 17:11:53 -0700 Subject: [PATCH 2/5] update redis-consumer for new tracking.py --- redis_consumer/consumers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/redis_consumer/consumers.py b/redis_consumer/consumers.py index 7c707727..4195578f 100644 --- a/redis_consumer/consumers.py +++ b/redis_consumer/consumers.py @@ -1039,7 +1039,7 @@ def _load_data(self, hvalues, subdir, fname): # remove the last dimensions added by `get_image` - tiff_stack = np.squeeze(raw, -1) + tiff_stack = np.squeeze(raw, -1) # TODO: required? check the ndim? if len(tiff_stack.shape) != 3: raise ValueError("This tiff file has shape {}, which is not 3 " "dimensions. Tracking can only be done on images " @@ -1114,7 +1114,7 @@ def _load_data(self, hvalues, subdir, fname): self.logger.debug('Hash %s has status %s', segment_hash, status) - if status == 'failed': + if status == self.failed_status: reason = self.redis.hget(segment_hash, 'reason') raise RuntimeError( 'Tracking failed during segmentation on frame {}.' @@ -1185,7 +1185,8 @@ def _consume(self, redis_hash): save_name = os.path.join( tempdir, hvalues.get('original_name', fname)) + '.trk' - tracker.dump(save_name, file_format='.trk') + # Post-process and save the output file + tracker.postprocess(save_name) output_file_name, output_url = self.storage.upload(save_name) self.update_key(redis_hash, { From 98db264fd0865ef72feb5464aedba259a4a04a5d Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Fri, 23 Aug 2019 17:28:59 -0700 Subject: [PATCH 3/5] add opencv to requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index f3c21877..b127fea0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ grpcio==1.22.0 dict-to-protobuf==0.0.3.9 pandas>=0.24.2 pytz==2019.1 +opencv-python==4.1.0.25 From b303909d117e497186d16c2272eac2e73ca17a92 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 29 Aug 2019 13:16:22 -0700 Subject: [PATCH 4/5] add postprocess to DummyTracker --- redis_consumer/consumers_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/redis_consumer/consumers_test.py b/redis_consumer/consumers_test.py index 8b8203b0..d7dcf20a 100644 --- a/redis_consumer/consumers_test.py +++ b/redis_consumer/consumers_test.py @@ -177,6 +177,9 @@ def _track_cells(self): def dump(*_, **__): return None + def postprocess(*_, **__): + return None + class TestConsumer(object): From fedeb3a2a8df6fd177753017ebadaed22c36c1e3 Mon Sep 17 00:00:00 2001 From: William Graf <7930703+willgraf@users.noreply.github.com> Date: Thu, 29 Aug 2019 13:16:53 -0700 Subject: [PATCH 5/5] add self to the DummyTracker methods --- redis_consumer/consumers_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis_consumer/consumers_test.py b/redis_consumer/consumers_test.py index d7dcf20a..d04da669 100644 --- a/redis_consumer/consumers_test.py +++ b/redis_consumer/consumers_test.py @@ -174,10 +174,10 @@ class DummyTracker(object): def _track_cells(self): return None - def dump(*_, **__): + def dump(self, *_, **__): return None - def postprocess(*_, **__): + def postprocess(self, *_, **__): return None