# Population Inference on GWTC-3

Adapted from the example at [gwpopulation](https://github.com/ColmTalbot/gwpopulation)


# Agenda
1. Take all the BBH events in The thrid gravitational-wave transient catalog [GWTC-3](https://arxiv.org/abs/1811.12907)

   This includes all compact binary coalescences observed during Advanced LIGO/Virgo's first three oberving runs.

1. Define a population model for the mass distributions (primary mass, mass ratio)

1. Carry out a hierarchical inference to infer the hyper-parameters of the mass model.



# Tools

1. `GWPopulation`. simple, modular, user-friendly, population inference. Mass, spin, redshift models defined here.
1. Sampling and the language of [Bilby](https://git.ligo.org/lscsoft/bilby) ([arXiv:1811.02042](https://arxiv.org/abs/1811.02042)) 

# This exercise

- Use a mass distribution in primary mass and mass ratio from Talbot & Thrane (2018) ([arXiv:1801:02699](https://arxiv.org/abs/1801.02699)).
- Equivalent to the `PowerLaw + Peak` model used in LVK analyses without the low-mass smoothing for computational efficiency.


## Also existant
- [Implemented models](https://colmtalbot.github.io/gwpopulation/_autosummary/gwpopulation.models.html#module-gwpopulation.models).
- Half-Gaussian + isotropic spin tilt distribution from Talbot & Thrane (2017) ([arXiv:1704.08370](https://arxiv.org/abs/1704.08370)).
- Beta spin magnitude distribution from Wysocki+ (2018) ([arXiv:1805:06442](https://arxiv.org/abs/1805.06442)).
- Each of these are also available with independent but identically distributed spins.
- Redshift evolution model as in Fishbach+ (2018) ([arXiv:1805.10270](https://arxiv.org/abs/1805.10270)).
- Can implement custom models.

# Setup

## Google collab
1. Choose a GPU-accelerated runtime (e.g. T4 GPU).

"runtime"->"change runtime type"->"Hardware accelerator = GPU"


### Install some needed packages

All of the dependencies for this are integrated into `GWPopulation`.
These include `Bilby` and `dynesty` for sampling.

In [None]:
!pip install gwpopulation --quiet --progress-bar off

In [None]:
!gdown https://drive.google.com/uc?id=16gStLIjt65gWBkw-gNOVUqNbZ89q8CLF
!gdown https://drive.google.com/uc?id=10pevUCM3V2-D-bROFEMAcTJsX_9RzeM6


## Personal laptop

1. Follow PC setup instructions
1. Downliad the above files using the urls to your notebook directory

## Download data

We need to download the data for the events and simulated "injections" used to characterize the detection sensitivity.

### Event posteriors

We're using the posteriors from the GWTC-3 data release in a pre-processed format.

The file was produced by [gwpopulation-pipe](https://docs.ligo.org/ratesAndPopulations/gwpopulation_pipe) to reduce the many GB of posterior sample files to a single ~30Mb file.

The choice of events in this file was not very careful and should only be considered qualitatively correct.

The data file can be found [here](https://drive.google.com/drive/folders/1wyfR6sYvYVdBefF9_vrVTp0Btu03OlzL?usp=drive_link).
The original data can be found at [zenodo:5546663](https://zenodo.org/records/5546663) and [zenodo:6513631](https://zenodo.org/records/6513631) along with citation information.

### Sensitivity injections

Again I have pre-processed the full injection set using `gwpopulation-pipe` to reduce the filesize.
The original data is available at [zenodo:7890398](https://zenodo.org/records/7890398) along with citation information.

## Imports

Import the packages required for the script.
We also set the backend for array operations to `jax` which allows us to take advantage of just-in-time (jit) compilation in addition to GPU-parallelisation when available.

In [None]:
import bilby as bb
import gwpopulation as gwpop
import jax
import matplotlib.pyplot as plt
import pandas as pd
from bilby.core.prior import PriorDict, Uniform
from gwpopulation.experimental.jax import JittedLikelihood, NonCachingModel
import numpy as np

gwpop.set_backend("jax")

xp = gwpop.utils.xp

%matplotlib inline

## Load posteriors

We remove two events from the file that shouldn't be there that have NS-like secondaries as we are just interested in BBHs for this demonstration.

In [None]:
posteriors = pd.read_pickle("gwtc-3-samples.pkl")
del posteriors[15]
del posteriors[38]

## Load injections

Load the injections used to characterize the sensitivity of the gravitaitonal-wave survey.

In [None]:
import dill

with open("gwtc-3-injections.pkl", "rb") as ff:
    injections = dill.load(ff)

## Define some models and the likelihood

We need to define `Bilby` `Model` objects for the numerator and denominator independently as these cache some computations interally.

We create a model that uses a cosmology fixed to the Planck 2015 values for flat Lambda CDM.

The `HyperparameterLikelihood` marginalises over the local merger rate, with a uniform-in-log prior.
The posterior for the merger rate can be recovered in post-processing.

We provide:

- `posteriors`: a list of `pandas` DataFrames.
- `hyper_prior`: our population model, as defined above.
- `selection_function`: anything which evaluates the selection function.

We can also provide:

- `conversion_function`: this converts between the parameters we sample in and those needed by the model, e.g., for sampling in the mean and variance of the beta distribution.
- `max_samples`: the maximum number of samples to use from each posterior, this defaults to the length of the shortest posterior.

In [None]:
model = NonCachingModel(
    model_functions=[
        gwpop.models.mass.two_component_primary_mass_ratio,
        gwpop.models.spin.iid_spin,
        gwpop.models.redshift.PowerLawRedshift(cosmo_model="Planck15"),
    ],
)

vt = gwpop.vt.ResamplingVT(model=model, data=injections, n_events=len(posteriors))

likelihood = gwpop.hyperpe.HyperparameterLikelihood(
    posteriors=posteriors,
    hyper_prior=model,
    selection_function=vt,
)

## Define our prior

The mass model has eight parameters that we vary that are described in arXiv:1801:02699. This model is sometimes referred to as "PowerLaw+Peak"

The spin magnitude model is a `Beta` distribution with the usual parameterization, and the spin orientation model is a mixure of a uniform component and a truncated Gaussian that peaks at aligned spin. This combination is sometimes referred to as "Default".

For redshift we use a model that looks like

$$p(z) \propto \frac{d V_{c}}{dz} (1 + z)^{λ - 1}$$

In [None]:
priors = PriorDict()

# mass
priors["alpha"] = Uniform(minimum=-2, maximum=4, latex_label="$\\alpha$")
priors["beta"] = Uniform(minimum=-4, maximum=12, latex_label="$\\beta$")
priors["mmin"] = Uniform(minimum=2, maximum=2.5, latex_label="$m_{\\min}$")
priors["mmax"] = Uniform(minimum=80, maximum=100, latex_label="$m_{\\max}$")
priors["lam"] = Uniform(minimum=0, maximum=1, latex_label="$\\lambda_{m}$")
priors["mpp"] = Uniform(minimum=10, maximum=50, latex_label="$\\mu_{m}$")
priors["sigpp"] = Uniform(minimum=1, maximum=10, latex_label="$\\sigma_{m}$")
priors["gaussian_mass_maximum"] = 100

# spin
priors["amax"] = 1
priors["alpha_chi"] = 3 #Uniform(minimum=1, maximum=6, latex_label="$\\alpha_{\\chi}$")
priors["beta_chi"] = 4 # Uniform(minimum=1, maximum=6, latex_label="$\\beta_{\\chi}$")
priors["xi_spin"] = 0.7# Uniform(minimum=0, maximum=1, latex_label="$\\xi$")
priors["sigma_spin"] = 3# Uniform(minimum=0.3, maximum=4, latex_label="$\\sigma$")


priors["lamb"] = 7 #niform(minimum=-1, maximum=10, latex_label="$\\lambda_{z}$")

## Just-in-time compile using JAX

JIT compile the likelihood object before starting the sampler, using `gwpopulation.experimental.jax.JittedLikelihood` class.

In [None]:
parameters = priors.sample()

In [None]:
parameters

In [None]:
red_pars = {'alpha': 0.3949072658140942,
             'beta': -0.9347469576470306,
             'mmin': 2.314923073404898,
             'mmax': 99.30823964844406,
             'lam': 0.551015290765012,
             'mpp': 34.07791899483966,
             'sigpp': 3.5147064761229663,
             'gaussian_mass_maximum': 100.0,}


dataset = {"mass_1" : np.arange(10, 100, 1/10),
           "mass_ratio" : np.arange(0.1, 1, 1/10)}

In [None]:
#pdist = gwpop.models.mass.two_component_primary_mass_ratio(dataset, **red_pars)
Marr = np.linspace(10, 100, 100)

pdist = gwpop.models.mass.two_component_single(mass=Marr,
                                               alpha=0.29,
                                               mmin=2.31,
                                               mmax=99.3,
                                               lam=0.55,
                                               mpp=34,
                                               sigpp=3.5)

In [None]:
plt.plot(Marr, pdist)

In [None]:

likelihood.parameters.update(parameters)


likelihood.log_likelihood_ratio()
print("Usual evaluation")
%time print(likelihood.log_likelihood_ratio())


jit_likelihood = JittedLikelihood(likelihood)
jit_likelihood.parameters.update(parameters)

print("JAX implementation")
print("1st eval")
%time print(jit_likelihood.log_likelihood_ratio())
print(" Subsequent evals")
%time print(jit_likelihood.log_likelihood_ratio())

## Run the sampler

We'll use the sampler `dynesty` and use a small number of live points to reduce the runtime (total runtime should be approximately 5 minutes on T4 GPUs via Google colab).
The settings here may not give publication quality results, a convergence test should be performed before making strong quantitative statements.

`bilby` times a single likelihood evaluation before beginning the run, however, this isn't well defined with JAX.

**Note:** sometimes this finds a high likelihood mode, likely due to [breakdowns in the approximation](https://arxiv.org/abs/2304.06138) used to estimate the likelihood. If you see `dlogz > -190`, you should interrupt the execution and restart.

In [None]:
result = bb.run_sampler(
    likelihood=jit_likelihood,
    priors=priors,
    sampler="dynesty",
    nlive=500,
    sample="acceptance-walk",
    naccept=5,
    save="hdf5",
    outdir="hierarchical",
    label="mass_500"
)

## Plot some posteriors

We can look at the posteriors on some of the parameters, here the cosmology parameters and the location of the mass peak and the redshift evolution.

We see that the value of the Hubble constant is strongly correlated with the location of the peak in the mass distribution as has been noted elsewhere.

We also include the values of the cosmology parameters reported in the `Planck15` cosmology for reference.

In [None]:
result.posterior

In [None]:
result.plot_marginals()

In [None]:
_ = result.plot_corner(save=False, parameters=["alpha", "beta", "mmin", "mmax", "lam", "mpp", "sigpp"])