Skip to content

Commit

Permalink
Use _get_annotated_movie in tracking tests.
Browse files Browse the repository at this point in the history
Remove _get_dummy_tracking_data, buggy.
  • Loading branch information
willgraf committed Apr 12, 2021
1 parent aebe142 commit 17e44a4
Showing 1 changed file with 31 additions and 29 deletions.
60 changes: 31 additions & 29 deletions deepcell_tracking/tracking_test.py
Expand Up @@ -42,30 +42,7 @@

from deepcell_tracking import tracking
from deepcell_tracking import utils


def _get_dummy_tracking_data(length=128, frames=3,
data_format='channels_last'):
if data_format == 'channels_last':
channel_axis = -1
else:
channel_axis = 0

x, y = [], []
while len(x) < frames:
_x = sk.data.binary_blobs(length=length, n_dim=2)
_y = sk.measure.label(_x)
if len(np.unique(_y)) > 3:
x.append(_x)
y.append(_y)

x = np.stack(x, axis=0) # expand to 3D
y = np.stack(y, axis=0) # expand to 3D

x = np.expand_dims(x, axis=channel_axis)
y = np.expand_dims(y, axis=channel_axis)

return x.astype('float32'), y.astype('int32')
from deepcell_tracking.test_utils import _get_annotated_movie


class DummyModel(object): # pylint: disable=useless-object-inheritance
Expand Down Expand Up @@ -107,9 +84,15 @@ def predict(self, data):
class TestTracking(object): # pylint: disable=useless-object-inheritance

def test_simple(self):
length = 128
data_format = 'channels_last'
frames = 3
x, y = _get_dummy_tracking_data(length, frames=frames)
labels_per_frame = 5
y = _get_annotated_movie(img_size=256,
labels_per_frame=labels_per_frame,
frames=frames,
mov_type='sequential', seed=0,
data_format=data_format)
x = np.random.random(y.shape)
num_objects = len(np.unique(y)) - 1
model = DummyModel()
encoder = DummyEncoder()
Expand Down Expand Up @@ -164,15 +147,34 @@ def test_simple(self):
# tracker.get_feature_shape('bad feature name')

def test_track_cells(self):
length = 128
frames = 5
track_length = 2

# TODO: Fix for channels_first
for data_format in ('channels_last',): # 'channels_first'):

x, y = _get_dummy_tracking_data(
length, frames=frames, data_format=data_format)
labels_per_frame = 5
frames = 2

y1 = _get_annotated_movie(img_size=256,
labels_per_frame=labels_per_frame,
frames=frames,
mov_type='sequential', seed=1,
data_format=data_format)
y2 = _get_annotated_movie(img_size=256,
labels_per_frame=labels_per_frame * 2,
frames=frames,
mov_type='sequential', seed=2,
data_format=data_format)
y3 = _get_annotated_movie(img_size=256,
labels_per_frame=labels_per_frame,
frames=frames,
mov_type='sequential', seed=3,
data_format=data_format)

y = np.concatenate((y1, y2, y3))

x = np.random.random(y.shape)

tracker = tracking.CellTracker(
x, y,
Expand Down

0 comments on commit 17e44a4

Please sign in to comment.