diff --git a/src/pykoala/data_container.py b/src/pykoala/data_container.py index b494d5bd..85a87c5f 100644 --- a/src/pykoala/data_container.py +++ b/src/pykoala/data_container.py @@ -2,8 +2,10 @@ This module contains the parent class that represents the data used during the reduction process """ +import numpy as np from astropy.io.fits import Header +from astropy.nddata import bitmask from pykoala.exceptions.exceptions import NoneAttrError @@ -248,7 +250,132 @@ class Parameter(object): """Class that represents some parameter and associated metadata""" def __init__(self) -> None: pass + +class GenericFlagMap(bitmask.BitFlagNameMap): + """PyKOALA generic BitMask flag map. + + Attributes + ---------- + - CR: Cosmic ray (2) + - HP: Hot pixel (4) + - DP: Dark pixel (8) + + Example + ------- + Accessing the description of each BitFlag is possible by + `GenericFlagMap.FLAG_NAME.__doc__` + """ + CR = 2, 'Cosmic ray' + HP = 4, 'Hot pixel' + DP = 8, 'Dark pixel' +class DataMask(object): + """This class contains mask information of DataContainers. + + Description + ----------- + This class represents a bitmask of a data structure defined by a given + dimensionality shape. + + Attributes + ---------- + - data: (np.ndarray) + The mask data values. + - flag_map: (astropy.bitmask.BitFlagNameMap) + A mapping between the flag names and their numerical values. + + Methods + ------- + - get_flags + - add_flags + - get_mask + """ + def __init__(self, shape, flag_map=GenericFlagMap): + # Initialise the mask with all pixels being valid + self.data = np.zeros(shape, dtype=int) + self.flag_map = flag_map + + def get_flags(self): + """Return the current flags used in the mask.""" + att = list(self.flag_map.__dict__) + return [a for a in att if "_" not in a] + + def get_flag_values(self): + """Get the values associated to the flags""" + flags = self.get_flags() + values = {} + for flag in flags: + values[flag] = getattr(self.flag_map, flag).real + return values + + def get_description(self): + """Get the description of the BitFlags.""" + flags = self.get_flags() + doc = {} + for flag in flags: + doc[flag] = getattr(self.flag_map, flag).__doc__ + return doc + + def add_flags(self, **extra_flags): + """Add a new flag to the mask.""" + self.flag_map = bitmask.extend_bit_flag_map( + "DQ_MASK", self.flag_map, **extra_flags) + + def get_mask(self, use_flags=None, use_bits=None, ignore_flags=None, + ignore_bits=None, dtype=bool): + """Get a mask. + + Description + ----------- + This method returns a mask computed by including or excluding a list + of bitmask values. + + Parameters + ---------- + - use_flags: (list, default=None) + A list of strings corresponding to the flag names to be included. + - use_bits: (list, default=None) + A list of bitmask values to be included in the mask. This is ignored + if `use_flags` is provided. + - ignore_flags: (list, default=None) + A list of strings corresponding to the flags to be excluded. Only + used if `use_flags` and `use_bits` are `None`. + - ignore_bits: (list, default=None) + A list of bitmask values to be excluded in the mask. This is ignored + if any of the other arguments is not `None`. + - dtype: (type, default=bool) + Data type of the output mask + Example + ------- + ``` + # Create a mask + mask = DataMask() + # Get a mask for all flagged pixels + total_mask = mask.get_mask() + # Get a mask only for pixels flagged as cosmic rays (CR) + cr_mask = mask.get_mask(use_flags=['CR']) + ``` + """ + flags = self.get_flag_values() + flag_names = list(flags.keys()) + flag_values = list(flags.values()) + if use_flags is not None: + flags_to_ignore = [f for f in flag_names if f not in use_flags] + elif use_bits is not None: + flags_to_ignore = [f for f in flag_names if flag_values[ + flag_names.index(f)] not in use_bits] + elif ignore_flags is not None: + flags_to_ignore = ignore_flags + elif ignore_bits is not None: + flags_to_ignore = [f for f in flag_names if flag_values[ + flag_names.index(f)] in ignore_bits] + else: + flags_to_ignore = [] + print(flags_to_ignore) + return bitmask.bitfield_to_boolean_mask(self.data, + ignore_flags=flags_to_ignore, + flag_name_map=self.flag_map, + dtype=dtype) class DataContainer(object): """