# Model Comparison Tutorial

Trey V. Wenger (c) March 2024

Here we demonstrate how to optimize the number of cloud components in a `Caribou` model. We start by importing the package and loading the example data.

In [1]:
from IPython.display import SVG, display
    
import os
import pickle

import matplotlib.pyplot as plt
import arviz as az
import pandas as pd

import caribou
from caribou import (
    SimpleModel,
    Optimize,
)

pd.options.display.max_rows = None

print("Caribou:", caribou.__version__)

# plot directory and extension
figdir = "figures"
ext = "svg"
if not os.path.isdir(figdir):
    os.mkdir(figdir)

Caribou: 0.0.1b


In [2]:
# Load example data
with open("example_data.pkl", "rb") as f:
    data = pickle.load(f)

## `Optimize`

We use the `Optimize` class for optimization.

In [3]:
opt = Optimize(
    SimpleModel,  # model type
    data,  # data dictionary
    max_n_clouds=5,  # maximum number of clouds
    baseline_degree=2,  # polynomial baseline degree
    seed=1234,  # random seed
    verbose=True,  # verbosity
)
opt.set_priors() # use defaults

`Optimize` has created `max_n_clouds` models, where `opt.models[1]` has `n_clouds=1`, `opt.models[2]` has `n_clouds=2`, etc.

In [4]:
print(opt.models[4])
print(opt.models[4].n_clouds)

<caribou.simple_model.SimpleModel object at 0x7f2be3574590>
4


The optimization algorithm first loops over every model and approximates the posterior distribution using variational inference. Whichever model has the lowest BIC is then sampled with MCMC. We can supply arguments to `fit` and `sample` via dictionaries.

In [5]:
fit_kwargs = {
    "learning_rate": 1e-3,
    "abs_tolerance": 0.01,
    "rel_tolerance": 0.01,
}
sample_kwargs = {
    "chains": 4,
    "cores": 4,
    "init_kwargs": fit_kwargs,
    "nuts_kwargs": {"target_accept": 0.8},
}
opt.optimize(fit_kwargs=fit_kwargs, sample_kwargs=sample_kwargs)

Null hypothesis BIC = 1.544e+06
Approximating n_cloud = 1 posterior...


Convergence achieved at 38700
Interrupted at 38,699 [7%]: Average Loss = 2.6381e+06


GMM converged to unique solution
n_cloud = 1 BIC = 1.317e+05

Approximating n_cloud = 2 posterior...


Convergence achieved at 62200
Interrupted at 62,199 [12%]: Average Loss = 1.4092e+06


GMM converged to unique solution
n_cloud = 2 BIC = 3.942e+03

Approximating n_cloud = 3 posterior...


Convergence achieved at 68200
Interrupted at 68,199 [13%]: Average Loss = 6.8186e+05


GMM converged to unique solution
n_cloud = 3 BIC = 1.667e+03

Approximating n_cloud = 4 posterior...


Convergence achieved at 65900
Interrupted at 65,899 [13%]: Average Loss = 1.2237e+06


GMM converged to unique solution
n_cloud = 4 BIC = 1.490e+03

Approximating n_cloud = 5 posterior...


Convergence achieved at 69800
Interrupted at 69,799 [13%]: Average Loss = 1.502e+06


GMM converged to unique solution
n_cloud = 5 BIC = 1.534e+03

Sampling best model (n_cloud = 4)...
Initializing NUTS using custom advi+adapt_diag strategy


Convergence achieved at 65900
Interrupted at 65,899 [13%]: Average Loss = 1.2237e+06
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [emission_coeffs, absorption_coeffs, log10_NHI, log10_kinetic_temp, log10_density, log10_n_alpha, log10_larson_linewidth, larson_power, velocity]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 244 seconds.


Only 3 chains appear converged.
There were 16 divergences in converged chains.
GMM converged to unique solution


The "best" model is saved in `opt.best_model`.

In [6]:
print(f"Best model has n_clouds = {opt.best_model.n_clouds}")
display(az.summary(opt.best_model.trace.solution_0))

Best model has n_clouds = 4


Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
emission_coeffs[0],0.616,0.25,0.161,1.079,0.006,0.004,1886.0,2107.0,1.0
emission_coeffs[1],0.209,0.081,0.055,0.354,0.001,0.001,3255.0,1882.0,1.0
emission_coeffs[2],-0.241,0.133,-0.477,0.015,0.003,0.002,1905.0,2287.0,1.0
absorption_coeffs[0],-0.054,0.163,-0.344,0.264,0.003,0.003,2384.0,2208.0,1.0
absorption_coeffs[1],0.052,0.073,-0.092,0.181,0.001,0.001,3406.0,2105.0,1.0
absorption_coeffs[2],0.043,0.101,-0.14,0.239,0.002,0.002,2494.0,1898.0,1.0
log10_NHI[0],20.283,0.124,20.084,20.533,0.006,0.004,454.0,550.0,1.01
log10_NHI[1],20.693,0.054,20.584,20.762,0.003,0.002,474.0,597.0,1.01
log10_NHI[2],20.687,0.034,20.624,20.75,0.001,0.001,793.0,702.0,1.0
log10_NHI[3],20.822,0.033,20.762,20.877,0.001,0.001,787.0,690.0,1.0
