# Advanced Parameter Sampling

This tutorial builds off the "Sampling Parameters" notebook to describe how to implement more complex dependencies between parameters.

## Core Concepts

In the "Sampling Parameters" notebook, we introduced a few core concepts that will be heavily used throughout this notebook and are worth reviewing:

* A *parameters* is effectively a variable in the mathematical equations that is instantiated during a round of sampling. A simple model might contain a handful of parameters for everything from position on the sky (RA, Dec) to inherent physical quantities (hostmass) to purely functional parameters (curve decay rate). 
* A `ParameterizedNode` is a computational unit for working with parameters. These nodes provide the code that defines the recipe for computing its own parameters. These nodes may use inputs that are parameters computed within other nodes.
* A `GraphState` is a data structure that holds the sampled values for all the parameters in the model. Each `ParameterizedNode` object is stateless and does not store information about the parameters themselves. Instead all operations take a `GraphState` that contains the necessary input parameters and will store the corresponding output parameters.

The combined values of **all** the parameters in the graph define a single sample of the model's parameters.

## Basic Chaining

We can use one `ParameterizedNode` object to provide parameters for another object. As described in the previous notebook, this chaining can use the output of a node as the input to another node:

In [None]:
import numpy as np

from lightcurvelynx.math_nodes.np_random import NumpyRandomFunc
from lightcurvelynx.models.basic_models import ConstantSEDModel

brightness_dist = NumpyRandomFunc("normal", loc=20.0, scale=2.0, node_label="brightness_dist")
model = ConstantSEDModel(brightness=brightness_dist, node_label="test")
state = model.sample_parameters(num_samples=10)
print(state["test"]["brightness"])

In this example, the value of the object's brightness parameter is taken from the output of the brightness_dist node. Since that node is generating samples from a normal distribution, the object's brightness is being sampled as a normal distribution.

We can also reference another node's parameter using the dot notation:

In [None]:
ra_node = NumpyRandomFunc("uniform", low=0.0, high=360.0)
host = ConstantSEDModel(brightness=15.0, ra=ra_node, dec=2.0, node_label="host")
source = ConstantSEDModel(brightness=10.0, ra=host.ra, dec=host.dec, node_label="source")
state = source.sample_parameters(num_samples=10)
print("Host RA:", state["host"]["ra"])
print("Source RA:", state["source"]["ra"])

Here we see that the source's RA is simply a copy of the host's RA. 

## Sampling from Known Values

LightCurveLynx provides multiple `ParameterizedNode` subclasses for selecting known values in math_nodes/given_sampler.py. These can be used for testing, allowing a user to input different (but known) values for each sample. As we will see later in this notebook, they can also be combined with other node types to do more complex computations.

### BinarySampler

The `BinarySampler` node returns a single True or False value for each sample. This is specifically designed for is probabilistically applying effects or making decisions in the simulation.

In [None]:
from lightcurvelynx.math_nodes.given_sampler import BinarySampler

apply_effect = BinarySampler(0.25, node_label="apply_effect")
states = apply_effect.sample_parameters(num_samples=5000)

num_true = np.count_nonzero(states["apply_effect"]["function_node_result"])
print(f"Returned {num_true} TRUE and {5000 - num_true} FALSE")

### GivenValueList

The `GivenValueList` node returns the values from a given list (in the order in which they are given). This is primarily used for testing:

In [None]:
from lightcurvelynx.math_nodes.given_sampler import GivenValueList

brightness_dist = GivenValueList([18.0, 20.0, 22.0])
model = ConstantSEDModel(brightness=brightness_dist, node_label="test")
state = model.sample_parameters(num_samples=3)
print(state["test"]["brightness"])

### GivenValueSampler

The `GivenValueSampler` node returns a random value (with replacement) from a given list:

In [None]:
from lightcurvelynx.math_nodes.given_sampler import GivenValueSampler

brightness_dist = GivenValueSampler([18.0, 20.0, 22.0])
model = ConstantSEDModel(brightness=brightness_dist, node_label="test")
state = model.sample_parameters(num_samples=10)
print(state["test"]["brightness"])

### GivenValueSelector

The `GivenValueSelector` node takes a single input parameter *index* and uses that to lookup the parameter's value from a given list. Which item is selected is determined by the `GivenValueSelector`'s *index* parameter. Below we use a constant value for *index* so we return the same element each time:

In [None]:
from lightcurvelynx.math_nodes.given_sampler import GivenValueSelector

brightness_dist = GivenValueSelector([18.0, 20.0, 22.0], index=2)
model = ConstantSEDModel(brightness=brightness_dist, node_label="test")
state = model.sample_parameters(num_samples=10)
print(state["test"]["brightness"])

## Combining Node Types

We can perform complex sampling operations by combining multiple types of nodes. For example, imagine that we wanted to sample from a list of known objects where we have a list of the RAs, decs, brightness, and redshifts. We can combine a random selection of the object's index with nodes that look up the value for that object index in each of the corresponding lists:

In [None]:
ra_list = [10.0, 20.0, 30.0, 40.0, 50.0]
dec_list = [1.0, 2.0, 3.0, 4.0, 5.0]
brightness_list = [15.0, 16.0, 17.0, 18.0, 19.0]

index_dist = GivenValueSampler(5)
model = ConstantSEDModel(
    brightness=GivenValueSelector(brightness_list, index=index_dist),
    ra=GivenValueSelector(ra_list, index=index_dist),
    dec=GivenValueSelector(dec_list, index=index_dist),
    node_label="model",
)

state = model.sample_parameters(num_samples=10)
for i in range(10):
    ra = state["model"]["ra"][i]
    dec = state["model"]["dec"][i]
    brightness = state["model"]["brightness"][i]
    print(f"Sample {i + 1}: ({ra}, {dec}) = {brightness}")

The `GivenValueSampler` node chooses an object index value from the range [0, 5). The output of this node (the index) is passed as the input to multiple `GivenValueSelector` nodes to extract the corresponding element from each of the lists.

Any important consideration is that each node in the graph is only sampled once. This means a *single* index is chosen and used for all three lists. For each sample, the value of all parameters (RA, Dec, and brightness) will be consistent for a single object.

For other examples of how these types of nodes can be combined, see the implementation of the `MultiLightcurveTemplateModel` and the `RandomMultiObjectModel` models.

## Sampling from Tables

Instead of lists, we might want to extract values from tabular data represented as an a dictionary, AstroPy Table, or Pandas Dataframe. The `TableSampler` node will sampling a row from given tabular data and store a unique parameter for each column of the table.

For example we can create a table columns 'A', 'B', and 'C' and sample from those:

In [None]:
from astropy.table import Table

from lightcurvelynx.math_nodes.given_sampler import TableSampler

raw_data_dict = {
    "A": [1, 2, 3, 4, 5, 6, 7, 8],
    "B": [2, 3, 4, 5, 4, 3, 2, 1],
    "C": [3, 4, 5, 6, 7, 8, 9, 10],
}
data = Table(raw_data_dict)

table_node = TableSampler(data, in_order=True, node_label="node")
state = table_node.sample_parameters(num_samples=3)
print(state)

The `in_order` flag tells the node whether to extract the rows in order (`True`) or randomly with replacement (`False`).

As with other node types, we can use the dot notation to use these values as input for other models. For example, let’s assume that the 'B' column corresponds to Brightness, 'A' corresponds to RA, and 'C' is not used.

In [None]:
table_node = TableSampler(data, in_order=False, node_label="node")
model = ConstantSEDModel(
    brightness=table_node.B,
    ra=table_node.A,
    node_label="test",
)

state = model.sample_parameters(num_samples=10)
print(state)