Skip to content

Commit

Permalink
BUG: Fix T-DaTing labels when continuing tracking (#280)
Browse files Browse the repository at this point in the history
* BUG: Fix T-DaTing labels when continuing tracking

When continuing tracking from previous list, the track labels
could be duplicated because the labeling started always from 0.
Instead, start from maximum label of previous cells.

* Add test for multi-step t-dating

* Update variable name
  • Loading branch information
ritvje committed Jun 9, 2022
1 parent 99c79d8 commit 6b5978b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
50 changes: 50 additions & 0 deletions pysteps/tests/test_tracking_tdating.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,56 @@
("mch", True),
]

arg_names_multistep = ("source", "len_timesteps")
arg_values_multistep = [
("mch", 6),
]


@pytest.mark.parametrize(arg_names_multistep, arg_values_multistep)
def test_tracking_tdating_dating_multistep(source, len_timesteps):
pytest.importorskip("skimage")

input_fields, metadata = get_precipitation_fields(
0, len_timesteps, True, True, 4000, source
)
input_fields, __ = to_reflectivity(input_fields, metadata)

timelist = metadata["timestamps"]

# First half of timesteps
tracks_1, cells, labels = dating(
input_fields[0 : len_timesteps // 2],
timelist[0 : len_timesteps // 2],
mintrack=1,
)
# Second half of timesteps
tracks_2, cells, _ = dating(
input_fields[len_timesteps // 2 - 2 :],
timelist[len_timesteps // 2 - 2 :],
mintrack=1,
start=2,
cell_list=cells,
label_list=labels,
)

# Since we are adding cells, number of tracks should increase
assert len(tracks_1) <= len(tracks_2)

# Tracks should be continuous in time so time difference should not exceed timestep
max_track_step = max([t.time.diff().max().seconds for t in tracks_2 if len(t) > 1])
timestep = np.diff(timelist).max().seconds
assert max_track_step <= timestep

# IDs of unmatched cells should increase in every timestep
for prev_df, cur_df in zip(cells[:-1], cells[1:]):
prev_ids = set(prev_df.ID)
cur_ids = set(cur_df.ID)
new_ids = list(cur_ids - prev_ids)
prev_unmatched = list(prev_ids - cur_ids)
if len(prev_unmatched):
assert np.all(np.array(new_ids) > max(prev_unmatched))


@pytest.mark.parametrize(arg_names, arg_values)
def test_tracking_tdating_dating(source, dry_input):
Expand Down
5 changes: 4 additions & 1 deletion pysteps/tracking/tdating.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def dating(
raise ValueError("start > len(timelist)")

oflow_method = motion.get_method("LK")
max_ID = 0
if len(label_list) == 0:
max_ID = 0
else:
max_ID = np.nanmax([np.nanmax(np.unique(label_list)), 0])
for t in range(start, len(timelist)):
cells_id, labels = tstorm_detect.detection(
input_video[t, :, :],
Expand Down

0 comments on commit 6b5978b

Please sign in to comment.