# Creating a `TargetingFlags` column for flat files

In this notebook we will query the SDSS database and create a FITS file with two columns: `SDSS_ID`, and `SDSS5_TARGET_FLAGS`.
The `SDSS5_TARGET_FLAGS` will be an array of integers, created using `semaphore`. 

Andy Casey (andrew.casey@monash.edu)

You can install `semaphore` with:

```pip install sdss-semaphore==0.2.3```

In [1]:
import numpy as np
from tqdm import tqdm

from sdssdb.peewee.sdss5db import database

database.set_profile("operations")

from sdssdb.peewee.sdss5db.targetdb import Assignment, CartonToTarget, Target
from sdssdb.peewee.sdss5db.catalogdb import CatalogdbModel

from sdss_semaphore.targeting import TargetingFlags

# At the time of writing, this peewee model did not exist in my version of sdssdb.
# In future you can probably just import directly: `from sdssdb.peewee.sdss5db.catalogdb import SDSS_ID_Flat`

class SDSS_ID_Flat(CatalogdbModel):
    class Meta:
        table_name = "sdss_id_flat"




If we knew how many sources there would be then we could just create one `TargetingFlags` object for N sources, and then use the source identifier (e.g., SDSS_ID) to look up which index we need to set each bit for. But that lookup can be expensive for many objects, and we want to avoid pre-computing how many sources there are.

Instead we will create a dictionary of `TargetingFlags` objects, keyed by the source identifier, and merge them together into one `TargetingFlags` object at the end.


# Assignments to `TargetingFlags`

In [2]:

# This is a YUGE query. Let's limit to 10,000 for testing purposes
q = (
    SDSS_ID_Flat
    .select(
        SDSS_ID_Flat.sdss_id,
        CartonToTarget.carton_pk,
    )
    .join(Target, on=(SDSS_ID_Flat.catalogid == Target.catalogid))
    .join(CartonToTarget, on=(Target.pk == CartonToTarget.target_pk))
    .join(Assignment, on=(Assignment.carton_to_target_pk == CartonToTarget.pk))
    .tuples()
    .limit(10_000)
    .iterator()
)

manual_counts = {}

flags_dict = {}
for sdss_id, carton_pk in tqdm(q, total=1): # total=1 prevents tqdm from executing the count() query
    try:
        flags_dict[sdss_id]
    except KeyError:
        flags_dict[sdss_id] = TargetingFlags()
        
    flags_dict[sdss_id].set_bit_by_carton_pk(0, carton_pk) # 0 since this is the only object
    manual_counts.setdefault(carton_pk, set())
    manual_counts[carton_pk].add(sdss_id)

10000it [00:00, 18325.14it/s]        


In [3]:
# Now we will create two columns:
# - one for all our source identifiers
# - one for all our targeting flags

sdss_ids = list(flags_dict.keys())
flags = TargetingFlags(list(flags_dict.values()))

In [7]:
# A sanity check.
for carton_pk, count in flags.count_by_attribute("carton_pk", skip_empty=True).items():
    assert count == len(manual_counts[carton_pk])
    

In [8]:
# Now let's write out our columns to a fits file.
from astropy.io import fits

N, F = flags.array.shape
hdul = fits.HDUList([
    fits.PrimaryHDU(),
    fits.BinTableHDU.from_columns([
        fits.Column(name="SDSS_ID", array=sdss_ids, format="K"),
        fits.Column(name="SDSS5_TARGET_FLAGS", array=flags.array, format=f"{F}B", dim=f"({F})")
    ])
])
hdul.writeto("output.fits")