Skip to content

Commit

Permalink
Merge pull request #190 from hyanwong/sample-ages
Browse files Browse the repository at this point in the history
Support for adding times to individuals and samples
  • Loading branch information
jeromekelleher committed Sep 24, 2019
2 parents db19f5c + f566084 commit 919174c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 8 deletions.
28 changes: 27 additions & 1 deletion tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,20 @@ def get_example_ts(self, sample_size, sequence_length):
sample_size, recombination_rate=1, mutation_rate=10,
length=sequence_length, random_seed=100)

def get_example_historical_sampled_ts(self, sample_times, sequence_length):
samples = [msprime.Sample(population=0, time=t) for t in sample_times]
return msprime.simulate(
samples=samples, recombination_rate=1, mutation_rate=10,
length=sequence_length, random_seed=100)

def verify_data_round_trip(self, ts, input_file):
self.assertGreater(ts.num_sites, 1)
for pop in ts.populations():
input_file.add_population()
for sample in ts.samples():
node = ts.node(sample)
input_file.add_individual(ploidy=1, population=node.population)
input_file.add_individual(
ploidy=1, population=node.population, time=node.time)
for v in ts.variants():
age = None
if len(v.site.mutations) == 1:
Expand Down Expand Up @@ -134,6 +141,14 @@ def test_from_tree_sequence(self):
sd2 = formats.SampleData.from_tree_sequence(ts)
self.assertTrue(sd1.data_equal(sd2))

def test_from_historical_tree_sequence(self):
sample_times = (5*[1] + 5*[0])
ts = self.get_example_historical_sampled_ts(sample_times, 10)
sd1 = formats.SampleData(sequence_length=ts.sequence_length)
self.verify_data_round_trip(ts, sd1)
sd2 = formats.SampleData.from_tree_sequence(ts)
self.assertTrue(sd1.data_equal(sd2))

def test_chunk_size(self):
ts = self.get_example_ts(4, 2)
self.assertGreater(ts.num_sites, 50)
Expand Down Expand Up @@ -387,6 +402,15 @@ def test_individual_metadata(self):
self.assertEqual(sample_data.populations_metadata[0], {"a": 1})
self.assertEqual(sample_data.populations_metadata[1], {"b": 2})

def test_add_individual_time(self):
sample_data = formats.SampleData(sequence_length=10)
sample_data.add_individual()
sample_data.add_individual(time=0.5)
sample_data.add_site(0, [0, 0])
sample_data.finalise()
self.assertEqual(sample_data.individuals_time[0], 0)
self.assertEqual(sample_data.individuals_time[1], 0.5)

def test_add_individual_return(self):
sample_data = formats.SampleData(sequence_length=10)
iid, sids = sample_data.add_individual()
Expand Down Expand Up @@ -430,6 +454,8 @@ def test_add_individual_errors(self):
self.assertRaises(ValueError, sample_data.add_individual, population=1)
self.assertRaises(ValueError, sample_data.add_individual, location="x234")
self.assertRaises(ValueError, sample_data.add_individual, ploidy=0)
self.assertRaises(ValueError, sample_data.add_individual, time=None)
self.assertRaises(ValueError, sample_data.add_individual, time=[1, 2])

def test_no_data(self):
sample_data = formats.SampleData(sequence_length=10)
Expand Down
33 changes: 26 additions & 7 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ class Individual(object):
# TODO document properly.
id = attr.ib()
location = attr.ib()
time = attr.ib()
metadata = attr.ib()


Expand Down Expand Up @@ -726,8 +727,11 @@ def __init__(self, sequence_length=0, **kwargs):
location = individuals_group.create_dataset(
"location", shape=(0,), chunks=chunks, compressor=self._compressor,
dtype="array:f8")
time = individuals_group.create_dataset(
"time", shape=(0,), chunks=chunks, compressor=self._compressor,
dtype=np.float64)
self._individuals_writer = BufferedItemWriter(
{"metadata": metadata, "location": location},
{"metadata": metadata, "location": location, "time": time},
num_threads=self._num_flush_threads)

samples_group = self.data.create_group("samples")
Expand Down Expand Up @@ -819,6 +823,10 @@ def individuals_metadata(self):
def individuals_location(self):
return self.data["individual/location"]

@property
def individuals_time(self):
return self.data["individual/time"]

@property
def samples_population(self):
return self.data["samples/population"]
Expand Down Expand Up @@ -874,6 +882,7 @@ def __str__(self):
("populations/metadata", zarr_summary(self.populations_metadata)),
("individuals/metadata", zarr_summary(self.individuals_metadata)),
("individuals/location", zarr_summary(self.individuals_location)),
("individuals/time", zarr_summary(self.individuals_time)),
("samples/individual", zarr_summary(self.samples_individual)),
("samples/population", zarr_summary(self.samples_population)),
("samples/metadata", zarr_summary(self.samples_metadata)),
Expand Down Expand Up @@ -908,6 +917,7 @@ def data_equal(self, other):
self.num_samples == other.num_samples and
self.num_sites == other.num_sites and
self.num_inference_sites == other.num_inference_sites and
np.all(self.individuals_time[:] == other.individuals_time[:]) and
np.all(self.samples_individual[:] == other.samples_individual[:]) and
np.all(self.samples_population[:] == other.samples_population[:]) and
np.all(self.sites_position[:] == other.sites_position[:]) and
Expand Down Expand Up @@ -941,7 +951,7 @@ def from_tree_sequence(cls, ts, use_times=True, **kwargs):
self.add_population()
for u in ts.samples():
node = ts.node(u)
self.add_individual(population=node.population, ploidy=1)
self.add_individual(population=node.population, time=node.time, ploidy=1)
for v in ts.variants():
age = None
if len(v.site.mutations) == 1 and use_times:
Expand Down Expand Up @@ -987,7 +997,8 @@ def add_population(self, metadata=None):
raise ValueError("Cannot add populations after adding samples or sites")
return self._populations_writer.add(metadata=self._check_metadata(metadata))

def add_individual(self, ploidy=1, metadata=None, population=None, location=None):
def add_individual(
self, ploidy=1, metadata=None, population=None, location=None, time=0):
"""
Adds a new :ref:`sec_inference_data_model_individual` to this
:class:`.SampleData` and returns its ID and those of the resulting additional
Expand All @@ -1012,6 +1023,9 @@ def add_individual(self, ploidy=1, metadata=None, population=None, location=None
:param arraylike location: An array-like object defining n-dimensional
spatial location of this individual. If not specified or None, the
empty location is stored.
:param float time: The historical time into the past when the samples
associated with this individual were taken. By default we assume that
all samples come from the present time (i.e. the default time is 0).
:return: The ID of the newly added individual and a list of the sample
IDs also added.
:rtype: tuple(int, list(int))
Expand All @@ -1024,6 +1038,9 @@ def add_individual(self, ploidy=1, metadata=None, population=None, location=None
if self._build_state != self.ADDING_SAMPLES:
raise ValueError("Cannot add individuals after adding sites")

time = np.float64(time).item()
if not np.isfinite(time):
raise ValueError("time must be a single finite number")
if population is None:
population = tskit.NULL
if population >= self.num_populations:
Expand All @@ -1034,7 +1051,7 @@ def add_individual(self, ploidy=1, metadata=None, population=None, location=None
location = []
location = np.array(location, dtype=np.float64)
individual_id = self._individuals_writer.add(
metadata=self._check_metadata(metadata), location=location)
metadata=self._check_metadata(metadata), location=location, time=time)
sample_ids = []
for _ in range(ploidy):
# For now default the metadata to the empty dict.
Expand Down Expand Up @@ -1263,9 +1280,11 @@ def haplotypes(self, samples=None, inference_sites=None):

def individuals(self):
# TODO document
iterator = zip(self.individuals_location[:], self.individuals_metadata[:])
for j, (location, metadata) in enumerate(iterator):
yield Individual(j, location=location, metadata=metadata)
iterator = zip(
self.individuals_location[:], self.individuals_metadata[:],
self.individuals_time[:])
for j, (location, metadata, time) in enumerate(iterator):
yield Individual(j, location=location, metadata=metadata, time=time)


@attr.s
Expand Down

0 comments on commit 919174c

Please sign in to comment.