# The Survey Pipeline

If you have a galaxy catalog (either of `Parametric` origin or from a simulation), an [`EmissionModel`](../emission_models/emission_models.rst), and a set of [instruments](../instrumentation/instrument_example.ipynb) you want observables for, you can easily write a pipeline to generate the observations you want using the Synthesizer UI. However, lets say you have a new catalog you want to run the same analysis on, or a whole different set of instruments you want to use. You could modify your old pipeline or write a whole new pipeline, but thats a lot of work and boilerplate. 

This is where the `Survey` shines. Instead, of having to write a pipeline, the `Survey` class is a high-level interface that allows you to easily generate observations for a given catalog, emission model, and set of instruments. All you need to do is define a galaxy loader, setup the ``Survey`` object, and run the observable methods you want to include. Possible observables include:

- Spectra.
- Emission Lines.
- Photometry.
- Images (with or without PSF convolution/noise).
- Spectral data cubes (IFUs) [WIP].
- Instrument specific spectroscopy [WIP].

The ``Survey`` will generate all the requested observations for all (compatible) instruments and galaxies, before writing them out to a standardised HDF5 format.

As a bonus, the abstraction into the `Survey` class allows for easy parallelization of the analysis, not only over local threads but optionally over MPI. 

In the following sections we will show how to instantiate and use a ``Survey`` object to generate observations for a given catalog, emission model, and set of instruments.

## Setting up a ``Survey`` object

Before we instatiate a survey we need to define its "dependencies". These are a method to load a galaxy catalog, an emission model, and a set of instruments. 

### Defining a galaxy loader

The galaxy loader function can be distributed across a number of threads (also MPI ranks but we'll cover this in more detail below). To ensure the galaxy loader works and is parallelisable it must adhere to a set of rules:

- It must return a single ``Galaxy`` object or ``None``. The latter of these is to handle any galaxies which failed to be loaded. These will be sanitised out of the catalog before any analysis is run.
- It's first argument must be the galaxy's "index" in the catalog. This argument must be called "gal_index" since this is how we check the function is compatible under the hood. For instance, if you have a HDF5 file from which you are loading a ``Galaxy`` this index should be the index into the file for the galaxy you want to load. 
- It can take any number of additional arguments and keyword arguments.

Below we define a fake galaxy loader for illustrative purposes. This function generates a particle based galaxy from a parametric star formation and metallicity history.

In [None]:
import numpy as np
from unyt import Msun, Mpc, Myr

from synthesizer.particle.stars import sample_sfhz
from synthesizer.parametric.stars import Stars as ParametricStars
from synthesizer.parametric import SFH, ZDist
from synthesizer import Galaxy

def galaxy_loader(gal_index, log10age, metallicity):
    """
    Load a fake particle Galaxy.
    
    Args:
        gal_index (int): 
            The index of the galaxy to load. (Here, this is unused 
            but must be included.)
        log10age (unyt_quantity):
            The log10ages axis of the SFZH.
        metallicity (array_like):
            The metallicities axis of the SFZH.
    """
    # Initialise the parametric Stars object
    param_stars = ParametricStars(
        log10age,
        metallicity,
        sf_hist=SFH.Constant(max_age= 100 * Myr),
        metal_dist=ZDist.DeltaConstant(metallicity=0.01),
        initial_mass=10**10 *Msun,
    )

    # Define the number of stellar particles we want
    n = int(50 * (np.random.rand() + 0.5))

    # Sample the parametric SFZH, producing a particle Stars object
    # we will also pass some keyword arguments for some example attributes
    part_stars = sample_sfhz(
        sfzh=param_stars.sfzh,
        log10ages=param_stars.log10ages,
        log10metallicities=param_stars.log10metallicities,
        nstar=n,
        current_masses=np.full(n, 10**9 / n) * Msun,
        redshift=1,
        coordinates=np.random.normal(0, 0.01, (n, 3)) * Mpc,
        centre=np.zeros(3) * Mpc,
        smoothing_lengths=np.ones(n) * 0.01 * np.random.rand(n) * Mpc,
    )

    # If we got a bad galaxy just return None
    if part_stars is None:
        return None

    # And create the galaxy
    gal = Galaxy(
        stars=part_stars,
        redshift=1,
    )

    return gal

This way of defining a loader leaves the definition of a ``Galaxy`` entirely in the users hands. You are free to add whatever attributes you see fit, and can load data from any source you desire. 

Notice here how we have 3 arguments. The required ``gal_index``, and then the ``Grid`` axes which our loader needs to define the SFZH. We'll provide this argument later on when we want to run our survey. 

### Defining an emission model

The ``EmissionModel`` defines the emissions we'll generate, including the origin and any reprocessing the emission undergoes. For more details see the ``EmissionModel`` [docs](../emission_models/emission_models.rst). 

For demonstration, we'll use a simple premade ``IntrinsicEmission`` model which defines the intrinsic stellar emission (i.e. stellar emission without any ISM dust reprocessing).

In [None]:
from synthesizer.emission_models import IntrinsicEmission
from synthesizer.grid import Grid

# Get the grid
grid_dir = "../../../tests/test_grid/"
grid_name = "test_grid"
grid = Grid(grid_name, grid_dir=grid_dir)

model = IntrinsicEmission(grid, fesc=0.1)
model.set_per_particle(True)  # we want per particle emissions

### Defining the instruments

We don't need any instruments if all we want is spectra at the resolution of the ``Grid`` or emission lines. However, to get anything more sophisticated we need ``Instruments`` that define the technical specifications of the observations we want to generate. For a full breakdown see the instrumentation [docs](../instrumentation/instrument_example.ipynb).

Here we'll define a simple set of instruments including a subset of NIRCam filters (capable of imaging with a 0.1 kpc resolution) and a set of UVJ top hat filters (only capable of photometry).

In [None]:
from unyt import angstrom, kpc
from synthesizer.instruments import FilterCollection, UVJ
from synthesizer.instruments import Instrument


# Get the filters
lam = np.linspace(10**3, 10**5, 1000) * angstrom
webb_filters = FilterCollection(
    filter_codes=[
    f"JWST/NIRCam.{f}"
    for f in ["F090W", "F150W", "F200W", "F277W", "F356W", "F444W"]
],
new_lam=lam,
)
uvj_filters= UVJ(new_lam=lam)

# Instatiate the instruments
webb_inst = Instrument("JWST", filters=webb_filters, resolution=0.1 * kpc)
uvj_inst = Instrument("UVJ", filters=uvj_filters)
instruments = webb_inst + uvj_inst 

print(instruments)

### Instantiating the ``Survey`` object

Now we have all the ingredients we need to instantiate a ``Survey`` object. All we need to do now is pass them into the ``Survey`` object alongside the number of galaxies in the catalog in total and the number of threads we want to use during the analysis (in this notebook we'll only use 1 for such a small handful of galaxies).

In [None]:
from synthesizer.survey.survey import Survey

survey = Survey(
    gal_loader_func=galaxy_loader,  
    emission_model=model, 
    n_galaxies=10, 
    instruments=instruments, 
    nthreads=1, 
    verbose=1,
)

Notice that we got a log out of the ``Survey`` object detailing the basic setup. The ``Survey`` will automatically output logging information to the console but this can be supressed by passing ``verbose=0`` which limits the outputs to saying hello, goodbye, and any errors that occur.

## Adding analysis functions

We could just run the analysis now and get whatever predefined outputs we want. However, we can also add our own analysis functions to the ``Survey`` object. These functions will be run on each galaxy in the catalog and can be used to generate any additional outputs we want. Importantly, these functions will be run **after** all other analysis has finished so they can make use of any outputs generated by the ``Survey`` object.

Below we'll define an analysis function to compute the half light radius of each galaxy for each photometric band we've defined in our instruments. Any extra analysis functions must obey the following rules:

- It calculate the "result" for a single galaxy at a time.
- The function's first argument must be the galaxy to calculate for.
- It must store the result in an attribute on the galaxy object. This attribute can be anything as long as it doesn't clash with any existing attributes. It can either be a dictionary or a single value. The code will automatically parse these before writing the output to the HDF5 file.
- It can take any number of additional arguments and keyword arguments, but **beware**, adding large objects to the function signature will slow down the threadpools due to the need to serialise and deserialise these objects.

In [None]:
def get_half_light_radius(gal):
    """
    Compute the half light radius of a galaxy in each filter.
    
    Args:
        gal (Galaxy): 
            The galaxy to compute the half light radius of.
        emission_model (EmissionModel):
            The emission model defining the emission types.
        instruments (Instrument): 
            The instrument object to use for the computation.
    """
    # Setup where we'll store the result
    gal.stars.half_flux_radii = {}

    # Loop over 
    for spec in gal.stars.photo_fnu.keys():
        gal.stars.half_flux_radii[spec] = {}
        # Loop over filters 
        for filt in gal.stars.photo_fnu[spec].keys():
            # Get the half light radius
            gal.stars.half_flux_radii[model.label][filt] = gal.stars.get_half_flux_radius(
                model.label, filt
            )


To add this to the ``Survey`` we need to pass it along with a string defining the attribute it stores it's results in and any required arguments.

In [None]:
survey.add_analysis_func(get_half_light_radius, result_attribute="stars.half_flux_radii")

## Running the pipeline

To run the pipeline we just need to first load our galaxies and then call the various observable generation methods. This approach allows you to explicitly control which observables you want to generate with a single line of code for each.

### Loading the galaxies

First we load the galaxies, recall that our loader function requires a ``grid`` argument, we pass that now.

In [None]:
survey.load_galaxies(log10age=grid.log10ages, metallicity=grid.metallicity)

### Generating the observables

Now we have the galaxies we can generate their observables. We do this by calling the various observable generation methods on the ``Survey`` object. These will automatically use the number of threads we defined when we instantiated the ``Survey`` object, if this was 1 then everything will be done in serial.

There is a required order to the calling of the observable methods. For instance, you can't generate photometry with first generating the spectra. Each method knows to check it's dependencies have been satisfied, if they have not an error will be raised.

We'll start with the spectra. If we want fluxes, we'll need to pass an ``astropy.cosmology`` object.

In [None]:
from astropy.cosmology import Planck18 as cosmo

survey.get_spectra(cosmo=cosmo)

Next we'll generate the emission lines. Here we can pass exactly which emission lines we want to generate based on line ID. Here we'll just generate all lines offered by the ``Grid``.

In [None]:
survey.get_lines(line_ids=grid.available_lines)

Next, the photometry. This requires no extra inputs but we have separate methods for luminosities and fluxes (with the latter requiring a ``astropy.cosmology`` object was based when spectra were generated).

In [None]:
survey.get_photometry_luminosities()
survey.get_photometry_fluxes()

Finally, we'll generate the images. Again, these are split into luminosity and flux flavours. Here we define our field of view and pass that into each method. We are also doing "smoothed" imaging where each particle is smoothed over its SPH kernel. For this style of image genration we need to pass the kernel array, which we'll extract here.

Had we defined instruments with PSFs and/or noise these methods would automatically generate images with these effects/contributions included.

In [None]:
from synthesizer.kernel_functions import Kernel

# Get the SPH kernel
sph_kernel = Kernel()
kernel = sph_kernel.get_kernel()

survey.get_images_luminosity(fov=50 * kpc, kernel=kernel)
survey.get_images_flux(fov=50* kpc, kernel=kernel)

## Writing out the data

Finally, we write out the data to a HDF5 file. This file will contain all the observables we generated, as well as any additional analysis we ran. This file is structure to mirror the structure of Synthesizer objects, with each galaxy being a group, each component being a subgroup, and each observale being a dataset (or set of subgroups with the observables as datasets at their leaves in the case of a dicitonary attribute).

To write out the data we just pass the path to the file we want to write to to the ``write`` method.

In [None]:
survey.write("output.hdf5")


## Putting it all together

Here is what the pipeline would look like without all the descriptive fluff...

In [None]:
survey = Survey(galaxy_loader, model, 10, instruments)
survey.add_analysis_func(get_half_light_radius, "stars.half_flux_radii")
survey.load_galaxies(log10age=grid.log10ages, metallicity=grid.metallicity)
survey.get_spectra(cosmo=cosmo)
survey.get_lines(line_ids=grid.available_lines)
survey.get_photometry_luminosities()
survey.get_photometry_fluxes()
survey.get_images_luminosity(fov=50 * kpc, kernel=kernel)
survey.get_images_flux(fov=50* kpc, kernel=kernel)
survey.write("output.hdf5")

# Hybrid parallelism with MPI

Above we demonstrated how to run a pipeline using only local shared memory parallelism. We can also use `mpi4py` to not only use the shared memory parallelism but also distribute the analysis across multiple nodes (hence "hybrid parallelism"). 

To make use of MPI we only need to make a couple changes to running the pipeline. The first is simply that we need to pass the ``comm`` object to the ``Survey`` object when we instantiate it. 

```python
from mpi4py import MPI

survey = Survey(
    gal_loader_func=galaxy_loader,  
    emission_model=model, 
    n_galaxies=10, 
    instruments=instruments, 
    nthreads=4, 
    verbose=1,
    comm=MPI.COMM_WORLD,
)
```

Note, that ``verbose=1`` will mean only rank 0 will output logging information. If you want all ranks to output logging information you should set ``verbose=2``.

The only other thing we need to do is partition the galaxies **before** we load them.

Below we will spoof an MPI enabled ``Survey`` object to demonstrate this (we can't actually run MPI in a notebook).

In [None]:
# Make a survey to demo partitioning
survey = Survey(
    gal_loader_func=galaxy_loader,  
    emission_model=model, 
    n_galaxies=10, 
    instruments=instruments, 
    nthreads=4, 
    verbose=1,
)

# Fake the MPI ranks (you can ignore this)
survey.using_mpi = True
survey.rank = 0
survey.size = 4

To partition the galaxies we simply call the ``partition_galaxies`` method on the ``Survey`` object. This will split the galaxies evenly across all ranks.

In [None]:
survey.partition_galaxies()

However, this will almost certainly result in a terrible balance of work since galaxies can have extremely different computational costs. To account for this we can optionally pass an array of galaxy weights, e.g. the mass of the galaxies, the particle counts, or any other meaningful cost metric (Here were faking an imbalanced galaxy catalog).

In [None]:
survey.partition_galaxies(galaxy_weights=np.logspace(1, 4, 10))

And that's it, nothing else needs changing. You can now proceed with the process detailed above and run a pipeline utilising local threads and distributed nodes.

One final note, if you have a parallel enabled ``h5py`` installed then the ``write`` method will automatically write out the data in parallel across all ranks. This can result in a significant speed up when writing out large datasets. Otherwise, the writing will be done in serial on rank 0 after collecting the arrays from all other ranks.