Skip to content

Commit

Permalink
feat(RandomSubset): a task to select a random subset of a catalog
Browse files Browse the repository at this point in the history
  • Loading branch information
jrs65 committed Oct 16, 2020
1 parent 8c8a533 commit abdb5fd
Showing 1 changed file with 76 additions and 1 deletion.
77 changes: 76 additions & 1 deletion draco/analysis/sourcestack.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
from mpi4py import MPI

from caput import config
from caput import config, pipeline
from cora.util import units

from ..util.tools import invert_no_zero
from ..util.random import RandomTask
from ..core import task, containers

# Constants
Expand Down Expand Up @@ -147,3 +148,77 @@ def process(self, formed_beam):
self.log.info("Number of quasars stacked: {0}".format(full_qcount))

return qstack


class RandomSubset(task.SingleTask, RandomTask):
"""Take a large mock catalog and draw `number` catalogs of a given `size`.
Attributes
----------
number : int
Number of catalogs to construct.
size : int
Number of objects in each catalog.
"""

number = config.Property(proptype=int)
size = config.Property(proptype=int)

def __init__(self):
super().__init__()
self.catalog_ind = 0

def setup(self, catalog):
"""Set the full mock catalog.
Parameters
----------
catalog : containers.SourceCatalog
The mock catalog to draw from.
"""
self.catalog = catalog
self.base_tag = f'{catalog.attrs.get("tag", "mock")}_{{}}'

def process(self):
"""Draw a new random catalog.
Returns
-------
new_catalog : containers.SourceCatalog subclass
A catalog of the same type as the input catalog, with a random set of
objects.
"""

if self.catalog_ind >= self.number:
raise pipeline.PipelineStopIteration

objects = self.catalog.index_map["object_id"]
num_cat = len(objects)

# NOTE: We need to be very careful here, the RNG is initialised at first access
# and this is a collective operation. So we need to ensure all ranks do it even
# though only rank=0 is going to use the RNG in this task
rng = self.rng

# Generate a random selection of objects on rank=0 and broadcast to all other
# ranks
if self.comm.rank == 0:
ind = rng.choice(num_cat, self.size, replace=False)
else:
ind = np.zeros(self.size, dtype=np.int64)
self.comm.Bcast(ind, root=0)

new_catalog = self.catalog.__class__(
object_id=objects[ind], attrs_from=self.catalog
)
new_catalog.attrs["tag"] = self.base_tag.format(self.catalog_ind)

# Loop over all datasets and if they have an object_id axis, select the
# relevant objects along that axis
for name, dset in new_catalog.datasets.items():
if dset.attrs["axis"][0] == "object_id":
dset[:] = self.catalog.datasets[name][ind]

self.catalog_ind += 1

return new_catalog

0 comments on commit abdb5fd

Please sign in to comment.