Skip to content

Commit

Permalink
Merge 3af6731 into 35ef710
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf committed Jul 15, 2021
2 parents 35ef710 + 3af6731 commit 8d82355
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 108 deletions.
23 changes: 9 additions & 14 deletions deepcell_tracking/tracking.py
Expand Up @@ -747,7 +747,7 @@ def dump(self, filename, track_review_dict=None):

filename = str(filename)

with tarfile.open(filename, 'w') as trks:
with tarfile.open(filename, 'w:gz') as trks:
# disable auto deletion and close/delete manually
# to resolve double-opening issue on Windows.
with tempfile.NamedTemporaryFile('w', delete=False) as lineage:
Expand All @@ -757,19 +757,14 @@ def dump(self, filename, track_review_dict=None):
trks.add(lineage.name, 'lineage.json')
os.remove(lineage.name)

with tempfile.NamedTemporaryFile(delete=False) as raw:
np.save(raw, track_review_dict['X'])
raw.flush()
raw.close()
trks.add(raw.name, 'raw.npy')
os.remove(raw.name)

with tempfile.NamedTemporaryFile(delete=False) as tracked:
np.save(tracked, track_review_dict['y_tracked'])
tracked.flush()
tracked.close()
trks.add(tracked.name, 'tracked.npy')
os.remove(tracked.name)
with tempfile.NamedTemporaryFile(delete=False) as npz_file:
raw = track_review_dict['X']
tracked = track_review_dict['y_tracked']
np.savez_compressed(npz_file, X=raw, y=tracked)
npz_file.flush()
npz_file.close()
trks.add(npz_file.name, 'data.npz')
os.remove(npz_file.name)

def _track_to_graph(self, tracks):
"""Create a graph from the lineage information"""
Expand Down
74 changes: 32 additions & 42 deletions deepcell_tracking/tracking_test.py
Expand Up @@ -29,10 +29,7 @@
from __future__ import division
from __future__ import print_function

import errno
import os
import shutil
import tempfile

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -123,7 +120,7 @@ def test_simple(self):
neighborhood_encoder=encoder,
data_format='invalid')

def test_track_cells(self):
def test_track_cells(self, tmpdir):
frames = 10
track_length = 3
labels_per_frame = 3
Expand Down Expand Up @@ -172,41 +169,34 @@ def test_track_cells(self):
with pytest.raises(ValueError):
tracker.dataframe(bad_value=-1)

try:
# test tracker.postprocess
tempdir = tempfile.mkdtemp() # create dir
path = os.path.join(tempdir, 'postprocess.xyz')
tracker.postprocess(filename=path)
post_saved_path = os.path.join(tempdir, 'postprocess.trk')
assert os.path.isfile(post_saved_path)

# test tracker.dump
path = os.path.join(tempdir, 'test.xyz')
tracker.dump(path)
dump_saved_path = os.path.join(tempdir, 'test.trk')
assert os.path.isfile(dump_saved_path)

# utility tests for loading trk files
# TODO: move utility tests into utils_test.py

# test trk_folder_to_trks
utils.trk_folder_to_trks(tempdir, os.path.join(tempdir, 'all.trks'))
assert os.path.isfile(os.path.join(tempdir, 'all.trks'))

# test load_trks
data = utils.load_trks(post_saved_path)
assert isinstance(data['lineages'], list)
assert all(isinstance(d, dict) for d in data['lineages'])
np.testing.assert_equal(data['X'], tracker.X)
np.testing.assert_equal(data['y'], tracker.y_tracked)
# load trks instead of trk
data = utils.load_trks(os.path.join(tempdir, 'all.trks'))

# test trks_stats
utils.trks_stats(os.path.join(tempdir, 'test.trk'))
finally:
try:
shutil.rmtree(tempdir) # delete directory
except OSError as exc:
if exc.errno != errno.ENOENT: # no such file or directory
raise # re-raise exception
# test tracker.postprocess
tempdir = str(tmpdir) # create dir
path = os.path.join(tempdir, 'postprocess.xyz')
tracker.postprocess(filename=path)
post_saved_path = os.path.join(tempdir, 'postprocess.trk')
assert os.path.isfile(post_saved_path)

# test tracker.dump
path = os.path.join(tempdir, 'test.xyz')
tracker.dump(path)
dump_saved_path = os.path.join(tempdir, 'test.trk')
assert os.path.isfile(dump_saved_path)

# utility tests for loading trk files
# TODO: move utility tests into utils_test.py

# test trk_folder_to_trks
utils.trk_folder_to_trks(tempdir, os.path.join(tempdir, 'all.trks'))
assert os.path.isfile(os.path.join(tempdir, 'all.trks'))

# test load_trks
data = utils.load_trks(post_saved_path)
assert isinstance(data['lineages'], list)
assert all(isinstance(d, dict) for d in data['lineages'])
np.testing.assert_equal(data['X'], tracker.X)
np.testing.assert_equal(data['y'], tracker.y_tracked)
# load trks instead of trk
data = utils.load_trks(os.path.join(tempdir, 'all.trks'))

# test trks_stats
utils.trks_stats(os.path.join(tempdir, 'test.trk'))
40 changes: 14 additions & 26 deletions deepcell_tracking/utils.py
Expand Up @@ -138,19 +138,14 @@ def load_trks(filename):
dict: A dictionary with raw, tracked, and lineage data.
"""
with tarfile.open(filename, 'r') as trks:

# numpy can't read these from disk...
array_file = io.BytesIO()
array_file.write(trks.extractfile('raw.npy').read())
array_file.seek(0)
raw = np.load(array_file)
array_file.close()

array_file = io.BytesIO()
array_file.write(trks.extractfile('tracked.npy').read())
array_file.seek(0)
tracked = np.load(array_file)
array_file.close()
npz_file = io.BytesIO()
npz_file.write(trks.extractfile('data.npz').read())
npz_file.seek(0)
data = np.load(npz_file)
raw = data['X']
tracked = data['y']
npz_file.close()

# trks.extractfile opens a file in bytes mode, json can't use bytes.
_, file_extension = os.path.splitext(filename)
Expand Down Expand Up @@ -214,27 +209,20 @@ def save_trks(filename, lineages, raw, tracked):
if not str(filename).lower().endswith('.trks'):
raise ValueError('filename must end with `.trks`. Found %s' % filename)

with tarfile.open(filename, 'w') as trks:
with tarfile.open(filename, 'w:gz') as trks:
with tempfile.NamedTemporaryFile('w', delete=False) as lineages_file:
json.dump(lineages, lineages_file, indent=4)
lineages_file.flush()
lineages_file.close()
trks.add(lineages_file.name, 'lineages.json')
os.remove(lineages_file.name)

with tempfile.NamedTemporaryFile(delete=False) as raw_file:
np.save(raw_file, raw)
raw_file.flush()
raw_file.close()
trks.add(raw_file.name, 'raw.npy')
os.remove(raw_file.name)

with tempfile.NamedTemporaryFile(delete=False) as tracked_file:
np.save(tracked_file, tracked)
tracked_file.flush()
tracked_file.close()
trks.add(tracked_file.name, 'tracked.npy')
os.remove(tracked_file.name)
with tempfile.NamedTemporaryFile(delete=False) as npz_file:
np.savez_compressed(npz_file, X=raw, y=tracked)
npz_file.flush()
npz_file.close()
trks.add(npz_file.name, 'data.npz')
os.remove(npz_file.name)


def trks_stats(filename):
Expand Down
41 changes: 15 additions & 26 deletions deepcell_tracking/utils_test.py
Expand Up @@ -29,10 +29,7 @@
from __future__ import print_function

import copy
import errno
import os
import shutil
import tempfile

import numpy as np
import skimage as sk
Expand Down Expand Up @@ -144,33 +141,25 @@ def test_count_pairs(self):
y, same_probability=prob, data_format='channels_first')
assert pairs == expected

def test_save_trks(self):
def test_save_trks(self, tmpdir):
X = get_image(30, 30)
y = np.random.randint(low=0, high=10, size=X.shape)
lineage = [dict()]

try:
tempdir = tempfile.mkdtemp() # create dir
with pytest.raises(ValueError):
badfilename = os.path.join(tempdir, 'x.trk')
utils.save_trks(badfilename, lineage, X, y)

filename = os.path.join(tempdir, 'x.trks')
utils.save_trks(filename, lineage, X, y)
assert os.path.isfile(filename)

# test saved tracks can be loaded
loaded = utils.load_trks(filename)
assert loaded['lineages'] == lineage
np.testing.assert_array_equal(X, loaded['X'])
np.testing.assert_array_equal(y, loaded['y'])

finally:
try:
shutil.rmtree(tempdir) # delete directory
except OSError as exc:
if exc.errno != errno.ENOENT: # no such file or directory
raise # re-raise exception
tempdir = str(tmpdir) # create dir
with pytest.raises(ValueError):
badfilename = os.path.join(tempdir, 'x.trk')
utils.save_trks(badfilename, lineage, X, y)

filename = os.path.join(tempdir, 'x.trks')
utils.save_trks(filename, lineage, X, y)
assert os.path.isfile(filename)

# test saved tracks can be loaded
loaded = utils.load_trks(filename)
assert loaded['lineages'] == lineage
np.testing.assert_array_equal(X, loaded['X'])
np.testing.assert_array_equal(y, loaded['y'])

def test_normalize_adj_matrix(self):
frames = 3
Expand Down

0 comments on commit 8d82355

Please sign in to comment.