# Parallel Computation

This notebook describes how to use LightCurveLynx to perform parallel computation. 

The core simulation function of the model can take a `concurrent.futures.Executor` object and use that to distribute the computation over multiple processes. This object can be a built in parallelization method, such as `ThreadPoolExecutor` or `ProcessPoolExecutor`, or other libraries, such as Dask.

Each process will load a full version of all the data, so they may be memory intensive.

In [None]:
from lightcurvelynx.astro_utils.passbands import PassbandGroup
from lightcurvelynx.models.basic_models import ConstantSEDModel
from lightcurvelynx.obstable.opsim import OpSim
from lightcurvelynx.simulate import simulate_lightcurves

# Usually we would not hardcode the path to the passband files, but for this demo we will use a relative path
# to the test data directory so that we do not have to download the files.
table_dir = "../../tests/lightcurvelynx/data/passbands"

## Prerequisite Data

We start by loading the standard information that we need for any simulation:

  * An `ObsTable` that includes the survey’s pointing and noise information.
  * A `PassbandGroup` for that survey.

We start by creating a toy survey that includes pointings at two locations (0.0, 10.0) and (180.0, -10.0) in the "g" and "r" bands and loading the passband group.

In [None]:
obsdata1 = {
    "time": [0.0, 1.0, 2.0, 3.0],
    "ra": [0.0, 0.0, 180.0, 180.0],
    "dec": [10.0, 10.0, -10.0, -10.0],
    "filter": ["g", "r", "g", "r"],
    "zp": [5.0, 6.0, 7.0, 8.0],
    "seeing": [1.12, 1.12, 1.12, 1.12],
    "skybrightness": [20.0, 20.0, 20.0, 20.0],
    "exptime": [29.2, 29.2, 29.2, 29.2],
    "nexposure": [2, 2, 2, 2],
}
obstable1 = OpSim(obsdata1)

passband_group1 = PassbandGroup.from_preset(
    preset="LSST",
    table_dir=table_dir,
    filters=["g", "r", "i"],
)

## Model Creation

Next we create a model from which to simulate observations. We define a model and its parameters as we would with any other simulation.  Here we use a constant SED model (same value for all times and wavelengths). We place the object at (0.0, 10.0) so it is observed by some of the pointings from each survey.

In [None]:
model = ConstantSEDModel(brightness=100.0, t0=0.0, ra=0.0, dec=10.0, redshift=0.0, node_label="my_star")

## Simulation

The only change in running the simulation in parallel is that we create a `ProcessPoolExecutor` object and pass that to the simulation function:

In [None]:
import concurrent.futures

with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
    results = simulate_lightcurves(
        model=model,
        num_samples=10_000,
        obstable=obstable1,
        passbands=passband_group1,
        obstable_save_cols=["zp_nJy"],
        executor=executor,
        batch_size=100,
    )

print(f"Generated {len(results)} light curves")
print(results["lightcurve"][0])

If we do not provide an executor object, but rather a number of jobs, we automatically create and manage the `ProcessPoolExecutor`. Here we run the simulation on 4 processes.

In [None]:
results = simulate_lightcurves(
    model=model,
    num_samples=10_000,
    obstable=obstable1,
    passbands=passband_group1,
    obstable_save_cols=["zp_nJy"],
    num_jobs=4,
    batch_size=100,
)

print(f"Generated {len(results)} light curves")
print(results["lightcurve"][0])

## Dask

We can parallelize the computation via Dask by using dask.distributed.

**Note:** Dask is not installed by default, so users will need to install dask (`pip install dask`) to run this cell.

In [None]:
try:
    import dask.distributed

    with dask.distributed.Client() as client:
        results = simulate_lightcurves(
            model=model,
            num_samples=100,
            obstable=obstable1,
            passbands=passband_group1,
            obstable_save_cols=["zp_nJy"],
            executor=client,
        )
    print(f"Generated {len(results)} light curves")
    print(results["lightcurve"][0])
except ImportError:
    print("Dask is not installed, skipping Dask example")

## Ray

We can parallelize the computation via Ray by using ray.util.multiprocessing.Pool 

**Note:** Ray is not installed by default, so users will need to install dask (`pip install -U "ray[default]"`) to run this cell.

In [None]:
try:
    import ray
    from ray.util.multiprocessing import Pool

    with Pool(processes=4) as executor:
        results = simulate_lightcurves(
            model=model,
            num_samples=100,
            obstable=obstable1,
            passbands=passband_group1,
            obstable_save_cols=["zp_nJy"],
            executor=executor,
        )
    print(f"Generated {len(results)} light curves")
    print(results["lightcurve"][0])

    ray.shutdown()
except ImportError:
    print("Ray is not installed, skipping Ray example")

## Saving to Files

Depending on the size of the simulated results, you might not want to load the full set into memory as a single table. The `simulate_lightcurves` has a function to save each shard (the result of each process) to a unique file. Instead of returning the NestedFrames, the function returns the list of file paths containing the data. Users can then analyze or load these later.

In [None]:
file_paths = simulate_lightcurves(
    model=model,
    num_samples=10_000,
    obstable=obstable1,
    passbands=passband_group1,
    num_jobs=4,
    batch_size=1000,
    obstable_save_cols=["zp_nJy"],
    output_file_path="./scratch/nb_results.parquet",
)
print(file_paths)

As you can see the results are broken up into ten different files.

## Overhead

As with any distributed computation, there will be per-batch overhead. All of the input data (model, obstable, etc.) are pickled and sent to the new processes. It takes time to pack and unpack this information. So care must be taken to ensure the parallelization is worth it.

The user can provide a `batch_size` parameter to control the target batch size for each process. This allows the user to ensure that each process has enough data to be worth it.