Skip to content

Commit

Permalink
Merge pull request #14 from portugueslab/fix-free-tests
Browse files Browse the repository at this point in the history
Fix tests
  • Loading branch information
OtPrat committed Nov 26, 2021
2 parents 46d0316 + 9706980 commit e6dda68
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 8 deletions.
37 changes: 37 additions & 0 deletions bouter/tests/_create_assets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Rerun this file if the logic of freely-swimming experiments changes
"""

import flammkuchen as fl

from bouter import free
from bouter.tests import ASSETS_PATH


def create_assets():
source_dataset_path = ASSETS_PATH / "freely_swimming_dataset"
experiment = free.FreelySwimmingExperiment(source_dataset_path)

bouts, continuity = experiment.get_bouts()
velocities_df = experiment.compute_velocity()

velocities = velocities_df[
["vel_f{}".format(i_fish) for i_fish in range(experiment.n_fish)]
]

bouts_summary = experiment.get_bout_properties()

# Load computed velocities
fl.save(
source_dataset_path / "test_extracted_bouts.h5",
dict(
bouts_summary=bouts_summary,
bouts=bouts,
continuity=continuity,
velocities=velocities,
),
)


if __name__ == "__main__":
create_assets()
Binary file not shown.
18 changes: 10 additions & 8 deletions bouter/tests/test_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,39 +107,41 @@ def test_compute_velocity(freely_swimming_exp_path):
experiment = free.FreelySwimmingExperiment(freely_swimming_exp_path)
velocities_df = experiment.compute_velocity()
fish_vels = velocities_df[
["vel2_f{}".format(i_fish) for i_fish in range(experiment.n_fish)]
["vel_f{}".format(i_fish) for i_fish in range(experiment.n_fish)]
]

# Load computed velocities
loaded_vel2 = fl.load(
loaded_vel = fl.load(
ASSETS_PATH / "freely_swimming_dataset" / "test_extracted_bouts.h5",
"/velocities",
)

# Compare DataFrame with velocities from the 3 experiment fish
assert_frame_equal(fish_vels, loaded_vel2)
assert_frame_equal(fish_vels, loaded_vel)


def test_bout_extraction(freely_swimming_exp_path):
experiment = free.FreelySwimmingExperiment(freely_swimming_exp_path)
bouts, cont = experiment.get_bouts()
bouts, continuities = experiment.get_bouts()

# Load expected bouts to be extracted. Only first fish is used for the assertion
loaded_bouts = fl.load(
ASSETS_PATH / "freely_swimming_dataset" / "test_extracted_bouts.h5",
"/bouts",
)
loaded_cont = fl.load(
loaded_continuities = fl.load(
ASSETS_PATH / "freely_swimming_dataset" / "test_extracted_bouts.h5",
"/continuity",
)

# Compare dataframes for each of the detected bouts in the first fish
for bout in range(len(bouts[0])):
assert_frame_equal(bouts[0][bout], loaded_bouts[bout])
for fish_bouts, fish_bouts_loaded in zip(bouts, loaded_bouts):
for bouts, loaded_bouts in zip(fish_bouts, fish_bouts_loaded):
assert_frame_equal(bouts, loaded_bouts)

# Compare also continuity array
assert_array_equal(cont[0], loaded_cont)
for cont, loaded_cont in zip(continuities, loaded_continuities):
assert_array_equal(cont, loaded_cont)


def test_bout_summary(freely_swimming_exp_path):
Expand Down

0 comments on commit e6dda68

Please sign in to comment.