diff --git a/deepcell_tracking/tracking.py b/deepcell_tracking/tracking.py index 33b5e16..dd83275 100644 --- a/deepcell_tracking/tracking.py +++ b/deepcell_tracking/tracking.py @@ -77,6 +77,11 @@ class CellTracker(object): # pylint: disable=useless-object-inheritance dtype (str): data type for features, can be 'float32', 'float16', etc. data_format (str): determines the order of the channel axis, one of 'channels_first' and 'channels_last'. + crop_mode (str): Whether to do a fixed crop or to crop and resize + to create the appearance features + norm (bool): Whether to remove non cell features and normalize the + foreground pixels by zero-meaning and dividing by the standard + deviation. Applies to fixed crop mode only. """ def __init__(self, @@ -91,6 +96,8 @@ def __init__(self, division=0.9, track_length=5, embedding_axis=0, + crop_mode='resize', + norm=True, dtype='float32', data_format='channels_last'): @@ -123,6 +130,8 @@ def __init__(self, self.dtype = dtype self.track_length = track_length self.embedding_axis = embedding_axis + self.crop_mode = crop_mode + self.norm = norm self.a_matrix = [] self.c_matrix = [] @@ -211,7 +220,9 @@ def _est_feats(self): frame_features = get_image_features( self.X[frame], self.y[frame], - appearance_dim=self.appearance_dim) + appearance_dim=self.appearance_dim, + crop_mode=self.crop_mode, + norm=self.norm) for cell_idx, cell_id in enumerate(frame_features['labels']): self.id_to_idx[cell_id] = cell_idx @@ -692,47 +703,6 @@ def dataframe(self, **kwargs): return dataframe - def postprocess(self, filename=None, time_excl=9): - """Use graph postprocessing to eliminate false positive division errors - using a graph-based detection method. False positive errors are when a - cell is noted as a daughter of itself before the actual division occurs. - If a filename is passed, save the state of the cell tracker to a .trk - ('track') file. time_excl is the minimum number of frames expected to - exist between legitimate divisions - """ - - # Load data - track_review_dict = self._track_review_dict() - - # Prep data - tracked = track_review_dict['y_tracked'].astype('uint16') - lineage = track_review_dict['tracks'] - - # Identify false positives (FPs) - G = self._track_to_graph(lineage) - FPs = self._flag_false_pos(G, time_excl) - FPs_candidates = sorted(FPs.items(), key=lambda v: int(v[0].split('_')[1])) - FPs_sorted = self._review_candidate_nodes(FPs_candidates) - - # If FPs exist, use the results to correct - while len(FPs_sorted) != 0: - - lineage, tracked = self._remove_false_pos(lineage, tracked, FPs_sorted[0]) - G = self._track_to_graph(lineage) - FPs = self._flag_false_pos(G, time_excl) - FPs_candidates = sorted(FPs.items(), key=lambda v: int(v[0].split('_')[1])) - FPs_sorted = self._review_candidate_nodes(FPs_candidates) - - # Make sure the assignment is correct - track_review_dict['y_tracked'] = tracked - track_review_dict['tracks'] = lineage - - # Save information to a track file file if requested - if filename is not None: - self.dump(filename, track_review_dict) - - return track_review_dict - def dump(self, filename, track_review_dict=None): """Writes the state of the cell tracker to a .trk ('track') file. Includes raw & tracked images, and a lineage.json for parent/daughter @@ -752,184 +722,3 @@ def dump(self, filename, track_review_dict=None): lineage=track_review_dict['tracks'], raw=track_review_dict['X'], tracked=track_review_dict['y_tracked']) - - def _track_to_graph(self, tracks): - """Create a graph from the lineage information""" - Dattr = {} - edges = pd.DataFrame() - - for L in tracks.values(): - # Calculate node ids - cellid = ['{}_{}'.format(L['label'], f) for f in L['frames']] - # Add edges from cell ids - edges = edges.append(pd.DataFrame({'source': cellid[0:-1], - 'target': cellid[1:]})) - - # Collect any division attributes - if L['frame_div'] is not None: - Dattr['{}_{}'.format(L['label'], L['frame_div'] - 1)] = {'division': True} - - # Create any daughter-parent edges - if L['parent'] is not None: - source = '{}_{}'.format(L['parent'], min(L['frames']) - 1) - target = '{}_{}'.format(L['label'], min(L['frames'])) - edges = edges.append(pd.DataFrame({'source': [source], - 'target': [target]})) - - G = nx.from_pandas_edgelist(edges, source='source', target='target') - nx.set_node_attributes(G, Dattr) - return G - - def _flag_false_pos(self, G, time_excl): - """Examine graph for false positive nodes - """ - - # TODO: Current implementation may eliminate some divisions at the edge of the frame - - # Further research needed - - # Identify false positive nodes - node_fix = [] - for g in (G.subgraph(c) for c in nx.connected_components(G)): - div_nodes = [n for n, d in g.nodes(data=True) if d.get('division')] - if len(div_nodes) > 1: - for nd in div_nodes: - if g.degree(nd) == 2: - # Check how close suspected FP is to other known divisions - - keep_div = True - for div_nd in div_nodes: - if div_nd != nd: - time_spacing = abs(int(nd.split('_')[1]) - - int(div_nd.split('_')[1])) - # If division is sufficiently far away - # we should exclude it from FP list - if time_spacing > time_excl: - keep_div = False - - if keep_div is True: - node_fix.append(nd) - - # Add supplementary information for each false positive - D = {} - for node in node_fix: - D[node] = { - 'false positive': node, - 'neighbors': list(G.neighbors(node)), - 'connected lineages': set([int(node.split('_')[0]) - for node in nx.node_connected_component(G, node)]) - } - - return D - - def _review_candidate_nodes(self, FPs_candidates): - """ review candidate false positive nodes and remove any errant degree 2 nodes. - """ - FPs_presort = {} - # review candidate false positive nodes and remove any errant degree 2 nodes - for candidate_node in FPs_candidates: - node = candidate_node[0] - node_info = candidate_node[1] - - neighbors = [] # structure will be [(neighbor1, frame), (neighbor2,frame)] - for neighbor in node_info['neighbors']: - neighbor_label = int(neighbor.split('_')[0]) - neighbor_frame = int(neighbor.split('_')[1]) - neighbors.append((neighbor_label, neighbor_frame)) - - # if this cell only exists in one frame (and then it divides) but its 2 neighbors - # both exist in the same frame it will be a degree 2 node but not be a false positive - if neighbors[0][1] != neighbors[1][1]: - FPs_presort[node] = node_info - - FPs_sorted = sorted(FPs_presort.items(), key=lambda v: int(v[0].split('_')[1])) - - return FPs_sorted - - def _remove_false_pos(self, lineage, tracked, FP_info): - """ Remove nodes that have been identified as false positive divisions. - """ - node = FP_info[0] - node_info = FP_info[1] - - fp_label = int(node.split('_')[0]) - fp_frame = int(node.split('_')[1]) - - neighbors = [] # structure will be [(neighbor1, frame), (neighbor2,frame)] - for neighbor in node_info['neighbors']: - neighbor_label = int(neighbor.split('_')[0]) - neighbor_frame = int(neighbor.split('_')[1]) - neighbors.append((neighbor_label, neighbor_frame)) - - # Verify that the FP node only 2 neighbors - 1 before it and one after it - if len(neighbors) == 2: - # order the neighbors such that the time (frame order) is respected - if neighbors[0][1] > neighbors[1][1]: - temp = neighbors[0] - neighbors[0] = neighbors[1] - neighbors[1] = temp - - # Decide which labels to extend and which to remove - - # Neighbor_1 has same label as fp - the actual division hasnt occurred yet - if fp_label == neighbors[0][0]: - # The model mistakenly identified a division before the actual division occurred - label_to_remove = neighbors[1][0] - label_to_extend = neighbors[0][0] - - # Give all of the errant divisions information to the correct track - lineage[label_to_extend]['frames'].extend(lineage[label_to_remove]['frames']) - lineage[label_to_extend]['daughters'] = lineage[label_to_remove]['daughters'] - lineage[label_to_extend]['frame_div'] = lineage[label_to_remove]['frame_div'] - - # Adjust the parent information for the actual daughters - daughter_labels = lineage[label_to_remove]['daughters'] - for daughter in daughter_labels: - lineage[daughter]['parent'] = lineage[label_to_remove]['parent'] - - # Remove the errant node from the annotated images - channel = 0 # These images should only have one channel - for frame in lineage[label_to_remove]['frames']: - label_loc = np.where(tracked[frame, :, :, channel] == label_to_remove) - tracked[frame, :, :, channel][label_loc] = label_to_extend - - # Remove the errant node from the lineage - del lineage[label_to_remove] - - # Neighbor_2 has same label as fp - the actual division ocurred & - # the model mistakenly allowed another - # elif fp_label == neighbors[1][0]: - # The model mistakenly identified a division after - # the actual division occurred - # label_to_remove = fp_label - - # Neither neighbor has same label as fp - the actual division - # ocurred & the model mistakenly allowed another - else: - # The model mistakenly identified a division after the actual division occurred - label_to_remove = fp_label - label_to_extend = neighbors[1][0] - - # Give all of the errant divisions information to the correct track - lineage[label_to_extend]['frames'] = \ - lineage[fp_label]['frames'] + lineage[label_to_extend]['frames'] - lineage[label_to_extend]['parent'] = lineage[fp_label]['parent'] - - # Adjust the parent information for the actual daughter - parent_label = lineage[fp_label]['parent'] - for d_idx, daughter in enumerate(lineage[parent_label]['daughters']): - if daughter == fp_label: - lineage[parent_label]['daughters'][d_idx] = label_to_extend - - # Remove the errant node from the annotated images - channel = 0 # These images should only have one channel - for frame in lineage[label_to_remove]['frames']: - label_loc = np.where(tracked[frame, :, :, channel] == label_to_remove) - tracked[frame, :, :, channel][label_loc] = label_to_extend - - # Remove the errant node - del lineage[label_to_remove] - - else: - self.logger.error('Error: More than 2 neighbor nodes') - - return lineage, tracked diff --git a/deepcell_tracking/tracking_test.py b/deepcell_tracking/tracking_test.py index 3e91732..188d384 100644 --- a/deepcell_tracking/tracking_test.py +++ b/deepcell_tracking/tracking_test.py @@ -170,12 +170,7 @@ def test_track_cells(self, tmpdir): with pytest.raises(ValueError): tracker.dataframe(bad_value=-1) - # test tracker.postprocess tempdir = str(tmpdir) - path = os.path.join(tempdir, 'postprocess.xyz') - tracker.postprocess(filename=path) - post_saved_path = os.path.join(tempdir, 'postprocess.trk') - assert os.path.isfile(post_saved_path) # test tracker.dump path = os.path.join(tempdir, 'test.xyz') @@ -191,7 +186,7 @@ def test_track_cells(self, tmpdir): assert os.path.isfile(os.path.join(tempdir, 'all.trks')) # test load_trks - data = trk_io.load_trks(post_saved_path) + data = trk_io.load_trks(dump_saved_path) assert isinstance(data['lineages'], list) assert all(isinstance(d, dict) for d in data['lineages']) np.testing.assert_equal(data['X'], tracker.X) diff --git a/setup.py b/setup.py index 6a243c4..4e2fca4 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ download_url=DOWNLOAD_URL, license=LICENSE, install_requires=['networkx>=2.1', - 'numpy', + 'numpy<1.24', 'pandas', 'scipy', 'scikit-image>=0.14.5',