Skip to content

Commit

Permalink
Merge pull request #197 from hyanwong/set-site-times
Browse files Browse the repository at this point in the history
Allow sites_time to be set
  • Loading branch information
jeromekelleher committed Oct 25, 2019
2 parents 3c0f497 + e10305f commit 7442f6f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
68 changes: 58 additions & 10 deletions tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,15 +701,17 @@ def test_error_not_edit_mode(self):
input_file.add_site(position=0, alleles=alleles, genotypes=genotypes)
editable_sample_data = input_file.copy()
# Try editing: use setter in the normal way
editable_sample_data.sites_inference = np.array([False])
editable_sample_data.sites_inference = [True]
editable_sample_data.sites_time = [0.0]
# Try editing: use setter via setattr
setattr(editable_sample_data, 'sites_inference', np.array([True]))
setattr(editable_sample_data, 'sites_inference', [False])
setattr(editable_sample_data, 'sites_time', [1.0])
editable_sample_data.add_provenance(datetime.datetime.now().isoformat(), {})

editable_sample_data.finalise()
self.assertRaises(
ValueError, setattr, editable_sample_data, 'sites_inference',
np.array([True]))
ValueError, setattr, editable_sample_data, 'sites_inference', [True])
self.assertRaises(
ValueError, setattr, editable_sample_data, 'sites_time', [0.0])
self.assertRaises(
ValueError, editable_sample_data.add_provenance,
datetime.datetime.now().isoformat(), {})
Expand All @@ -725,7 +727,7 @@ def test_copy_new_uuid(self):

@unittest.skipIf(sys.platform == "win32",
"windows simultaneous file permissions issue")
def test_copy_update_inference_sites(self):
def test_copy_update_sites_inference(self):
with formats.SampleData() as data:
for j in range(4):
data.add_site(position=j, alleles=["0", "1"], genotypes=[0, 1, 1, 0])
Expand All @@ -744,7 +746,28 @@ def test_copy_update_inference_sites(self):
self.assertEqual(list(copy.sites_inference), inference)
self.assertEqual(list(data.sites_inference), [True, True, True, True])

def test_update_inference_sites_bad_data(self):
@unittest.skipIf(sys.platform == "win32",
"windows simultaneous file permissions issue")
def test_copy_update_sites_time(self):
with formats.SampleData() as data:
for j in range(4):
data.add_site(position=j, alleles=["0", "1"], genotypes=[0, 1, 1, 0])
self.assertEqual(list(data.sites_time), [2.0, 2.0, 2.0, 2.0]) # Freq == 2.0

with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir:
filename = os.path.join(tempdir, "samples.tmp")
for copy_path in [None, filename]:
copy = data.copy(path=copy_path)
copy.finalise()
self.assertTrue(copy.data_equal(data))
with data.copy(path=copy_path) as copy:
time = [0.0, 1.1, 2.2, 3.3]
copy.sites_time = time
self.assertFalse(copy.data_equal(data))
self.assertEqual(list(copy.sites_time), time)
self.assertEqual(list(data.sites_time), [2.0, 2.0, 2.0, 2.0])

def test_update_sites_inference_bad_data(self):
def set_value(data, value):
data.sites_inference = value

Expand All @@ -762,15 +785,40 @@ def set_value(data, value):
for a in bad_data:
self.assertRaises((ValueError, TypeError, OverflowError), set_value, copy, a)

def test_update_inference_sites_non_copy_mode(self):
def test_update_sites_time_bad_data(self):
def set_value(data, value):
data.sites_time = value

data = formats.SampleData()
for j in range(4):
data.add_site(position=j, alleles=["0", "1"], genotypes=[0, 1, 1, 0])
data.finalise()
self.assertEqual(list(data.sites_time), [2.0, 2.0, 2.0, 2.0])
copy = data.copy()
for bad_shape in [[], np.arange(100, dtype=np.float64), np.zeros((2, 2))]:
self.assertRaises((ValueError, TypeError), set_value, copy, bad_shape)
for bad_data in [["a", "b", "c", "d"]]:
self.assertRaises(ValueError, set_value, copy, bad_data)

def test_update_sites_inference_non_copy_mode(self):
def set_value(data, value):
data.sites_inference = value

data = formats.SampleData()
data.add_site(position=0, alleles=["0", "1"], genotypes=[0, 1, 1, 0])
self.assertRaises(ValueError, set_value, data, [])
self.assertRaises(ValueError, set_value, data, [True])
data.finalise()
self.assertRaises(ValueError, set_value, data, [True])

def test_update_sites_time_non_copy_mode(self):
def set_value(data, value):
data.sites_time = value

data = formats.SampleData()
data.add_site(position=0, alleles=["0", "1"], genotypes=[0, 1, 1, 0])
self.assertRaises(ValueError, set_value, data, [1.0])
data.finalise()
self.assertRaises(ValueError, set_value, data, [])
self.assertRaises(ValueError, set_value, data, [1.0])

@unittest.skipIf(sys.platform == "win32",
"windows simultaneous file permissions issue")
Expand Down
5 changes: 5 additions & 0 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,11 @@ def sites_position(self):
def sites_time(self):
return self.data["sites/time"]

@sites_time.setter
def sites_time(self, value):
self._check_edit_mode()
self.data["sites/time"][:] = np.array(value, dtype=np.float64, copy=False)

@property
def sites_alleles(self):
return self.data["sites/alleles"]
Expand Down

0 comments on commit 7442f6f

Please sign in to comment.