# Using the O'Neil (2019) Observation-based Priors

by Sarah Blunt (2024)

In [3]:
# TODO: write tutorial intro

import numpy as np
import multiprocessing as mp

import orbitize
from orbitize import driver, priors

### Set up MCMC run

In [4]:
filename = "{}/GJ504.csv".format(orbitize.DATADIR)

# system parameters
num_secondary_bodies = 1
system_mass = 1.75  # [Msol]
plx = 51.44  # [mas]

# MCMC parameters
num_temps = 10
num_walkers = 50
num_threads = mp.cpu_count()  # or a different number if you prefer


my_driver = driver.Driver(
    filename,
    "MCMC",
    num_secondary_bodies,
    system_mass,
    plx,
    mass_err=0,
    plx_err=0,
    mcmc_kwargs={
        "num_temps": num_temps,
        "num_walkers": num_walkers,
        "num_threads": num_threads,
    },
)

## Modify Priors

Define the priors on `sma`, `ecc`, `tau`, `plx`, and `mtot` to be the O'Neil observation-based prior.

In [13]:
# convert input sep/PA measurements to RA/decl
sep = np.array(my_driver.system.data_table["quant1"])
sep_err = np.array(my_driver.system.data_table["quant1_err"])
pa = np.radians(np.array(my_driver.system.data_table["quant2"]))
pa_err = np.radians(np.array(my_driver.system.data_table["quant2_err"]))


ra_err = np.sqrt((np.cos(pa) * sep_err) ** 2 + (sep * np.sin(pa) * pa_err) ** 2)
dec_err = np.sqrt((np.sin(pa) * sep_err) ** 2 + (sep * np.cos(pa) * pa_err) ** 2)


epochs = np.array(my_driver.system.data_table["epoch"])

# define the `ObsPrior` object
my_obsprior = priors.ObsPrior(ra_err, dec_err, epochs)

# TODO: change the basis

print(my_driver.system.param_idx)

# set the priors on `sma`, `ecc`, `tau`, `mtot`, and `plx` to point to this object
for i in [0, 1, 5]:
    my_driver.system.sys_priors[i] = my_obsprior

my_driver.system.sys_priors

{'sma1': 0, 'ecc1': 1, 'inc1': 2, 'aop1': 3, 'pan1': 4, 'tau1': 5, 'plx': 6, 'mtot': 7}


[ObsPrior, ObsPrior, Sine, Uniform, Uniform, ObsPrior, ObsPrior, ObsPrior]

### Run MCMC!

In [None]:
my_driver.sampler.run_sampler(100000, burn_steps=1000, examine_chains=True)

In [None]:
my_corner = my_driver.sampler.results.plot_corner()
my_corner

### Examine the Prior Shapes

In [None]:
n_pts = 10

# make an array of sample orbits to compute lnprob: maybe 1000 orbits long, drawn from uniform dists in all 5 params