diff --git a/bouter/tests/_create_assets.py b/bouter/tests/_create_assets.py new file mode 100644 index 0000000..21510bb --- /dev/null +++ b/bouter/tests/_create_assets.py @@ -0,0 +1,37 @@ +""" +Rerun this file if the logic of freely-swimming experiments changes +""" + +import flammkuchen as fl + +from bouter import free +from bouter.tests import ASSETS_PATH + + +def create_assets(): + source_dataset_path = ASSETS_PATH / "freely_swimming_dataset" + experiment = free.FreelySwimmingExperiment(source_dataset_path) + + bouts, continuity = experiment.get_bouts() + velocities_df = experiment.compute_velocity() + + velocities = velocities_df[ + ["vel_f{}".format(i_fish) for i_fish in range(experiment.n_fish)] + ] + + bouts_summary = experiment.get_bout_properties() + + # Load computed velocities + fl.save( + source_dataset_path / "test_extracted_bouts.h5", + dict( + bouts_summary=bouts_summary, + bouts=bouts, + continuity=continuity, + velocities=velocities, + ), + ) + + +if __name__ == "__main__": + create_assets() diff --git a/bouter/tests/assets/freely_swimming_dataset/test_extracted_bouts.h5 b/bouter/tests/assets/freely_swimming_dataset/test_extracted_bouts.h5 index 3163edf..537c78b 100644 Binary files a/bouter/tests/assets/freely_swimming_dataset/test_extracted_bouts.h5 and b/bouter/tests/assets/freely_swimming_dataset/test_extracted_bouts.h5 differ diff --git a/bouter/tests/test_free.py b/bouter/tests/test_free.py index 4505d5f..7c055a6 100644 --- a/bouter/tests/test_free.py +++ b/bouter/tests/test_free.py @@ -107,39 +107,41 @@ def test_compute_velocity(freely_swimming_exp_path): experiment = free.FreelySwimmingExperiment(freely_swimming_exp_path) velocities_df = experiment.compute_velocity() fish_vels = velocities_df[ - ["vel2_f{}".format(i_fish) for i_fish in range(experiment.n_fish)] + ["vel_f{}".format(i_fish) for i_fish in range(experiment.n_fish)] ] # Load computed velocities - loaded_vel2 = fl.load( + loaded_vel = fl.load( ASSETS_PATH / "freely_swimming_dataset" / "test_extracted_bouts.h5", "/velocities", ) # Compare DataFrame with velocities from the 3 experiment fish - assert_frame_equal(fish_vels, loaded_vel2) + assert_frame_equal(fish_vels, loaded_vel) def test_bout_extraction(freely_swimming_exp_path): experiment = free.FreelySwimmingExperiment(freely_swimming_exp_path) - bouts, cont = experiment.get_bouts() + bouts, continuities = experiment.get_bouts() # Load expected bouts to be extracted. Only first fish is used for the assertion loaded_bouts = fl.load( ASSETS_PATH / "freely_swimming_dataset" / "test_extracted_bouts.h5", "/bouts", ) - loaded_cont = fl.load( + loaded_continuities = fl.load( ASSETS_PATH / "freely_swimming_dataset" / "test_extracted_bouts.h5", "/continuity", ) # Compare dataframes for each of the detected bouts in the first fish - for bout in range(len(bouts[0])): - assert_frame_equal(bouts[0][bout], loaded_bouts[bout]) + for fish_bouts, fish_bouts_loaded in zip(bouts, loaded_bouts): + for bouts, loaded_bouts in zip(fish_bouts, fish_bouts_loaded): + assert_frame_equal(bouts, loaded_bouts) # Compare also continuity array - assert_array_equal(cont[0], loaded_cont) + for cont, loaded_cont in zip(continuities, loaded_continuities): + assert_array_equal(cont, loaded_cont) def test_bout_summary(freely_swimming_exp_path):