Skip to content

Commit

Permalink
BUG: Can't do multiple dictionary expansions into fn on python2
Browse files Browse the repository at this point in the history
  • Loading branch information
scottclowe committed Jul 13, 2021
1 parent 9b31d16 commit 29ef258
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions fissa/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@
from .base_test import BaseTestCase


def merge_dicts(x, *args):
"""Merge multiple dictionaries together."""
z = x.copy()
for arg in args:
z.update(arg)
return z


class ExperimentTestMixin:
"""Base tests for Experiment class."""

Expand Down Expand Up @@ -860,7 +868,7 @@ def test_load_npz(self):
# Make a save file which contains values set to `None`
fname = os.path.join(self.output_dir, "dummy.npz")
os.makedirs(self.output_dir)
np.savez_compressed(fname, **kwargs, **fields)
np.savez_compressed(fname, **merge_dicts(kwargs, fields))
# Load the file and check the data appears correctly
exp.load(fname)
for key, value in fields.items():
Expand Down Expand Up @@ -893,7 +901,7 @@ def test_load_none(self):
# Make a save file which contains values set to `None`
fname = os.path.join(self.output_dir, "dummy.npz")
os.makedirs(self.output_dir)
np.savez_compressed(fname, **kwargs, **fields)
np.savez_compressed(fname, **merge_dicts(kwargs, fields))
# Load the file and check the data appears as None, not np.array(None)
exp.load(fname)
for key, value in fields.items():
Expand All @@ -920,7 +928,7 @@ def test_load_scalar(self):
# Make a save file which contains values set to a scalar`
fname = os.path.join(self.output_dir, "dummy.npz")
os.makedirs(self.output_dir)
np.savez_compressed(fname, **kwargs, **fields)
np.savez_compressed(fname, **merge_dicts(kwargs, fields))
# Load the file and check the data appears correctly
exp.load(fname)
for key, value in fields.items():
Expand All @@ -939,7 +947,7 @@ def test_load_wrong_nRegions(self):
# Make a save file which contains values set badly
fname = os.path.join(self.output_dir, "dummy.npz")
os.makedirs(self.output_dir)
np.savez_compressed(fname, **kwargs, **fields)
np.savez_compressed(fname, **merge_dicts(kwargs, fields))
# Load the file and check the data is not loaded
with self.assertRaises(ValueError):
exp.load(fname)
Expand All @@ -951,7 +959,7 @@ def test_load_wrong_in_init(self):
# Make a save file which contains values set badly
fname = os.path.join(self.output_dir, "preparation.npz")
os.makedirs(self.output_dir)
np.savez_compressed(fname, **kwargs, **fields)
np.savez_compressed(fname, **merge_dicts(kwargs, fields))
# Load the file and check the data is not loaded
with self.assertRaises(ValueError):
core.Experiment(
Expand All @@ -975,7 +983,7 @@ def test_load_wrong_expansion(self):
# Make a save file which contains values set badly
fname = os.path.join(self.output_dir, "dummy.npz")
os.makedirs(self.output_dir)
np.savez_compressed(fname, **kwargs, **fields)
np.savez_compressed(fname, **merge_dicts(kwargs, fields))
# Load the file and check the data is not loaded
with self.assertRaises(ValueError):
exp.load(fname)
Expand All @@ -994,7 +1002,7 @@ def test_load_wrong_alpha_only_sep_results(self):
# Make a save file which contains values set badly
fname = os.path.join(self.output_dir, "dummy.npz")
os.makedirs(self.output_dir)
np.savez_compressed(fname, **kwargs, **fields)
np.savez_compressed(fname, **merge_dicts(kwargs, fields))
# Load the file and check the data is not loaded
with self.assertRaises(ValueError):
exp.load(fname)
Expand All @@ -1013,7 +1021,7 @@ def test_load_wrong_alpha_mixed_prep_sep(self):
# Make a save file which contains values set badly
fname = os.path.join(self.output_dir, "dummy.npz")
os.makedirs(self.output_dir)
np.savez_compressed(fname, **kwargs, **fields)
np.savez_compressed(fname, **merge_dicts(kwargs, fields))
# Load the file and check the data appears correctly
exp.load(fname)
self.assert_equal(exp.raw, fields["raw"])
Expand Down

0 comments on commit 29ef258

Please sign in to comment.