# Quickstart Tutorial

`kdescent` provides a general framework for comparing an N-dimensional distribution of a model population to that of a training dataset. In short, it draws a mini-batched sample from the training data, computes a kernel density estimate (KDE) of the training distribution, and computes counts weighted by each mini-batched kernel (similar to computing the number density within a randomly drawn bin). These weighted counts can be directly compared to those of the model to compute a loss or likelihood, and even perform gradient descent using Jax's autodiff functionality. To improve the power of gradient descent even further, `kdescent` also provides an analogous metric for comparison of weighted counts in Fourier space. Combining both the KDE and Fourier metrics into a loss term for stochastic gradient descent has shown to be a very powerful method of parameter optimization.

In [None]:
import functools
import numpy as np
import jax.numpy as jnp
import jax.random
import matplotlib.pyplot as plt

from diffopt import kdescent

## Example model

A population with variables $x_1$ and $x_2$ distributed as a 2D multivariate normal. We will simply parameterize this distribution with 5 parameters: [$\mu_1$, $\mu_2$, $\sigma_1^2$, $\sigma_2^2$, $c_{12}$] where $\mu_i$ is the mean of each variable, $\sigma_i^2$ is the variance of each variable, and $c_{12}$ is the correlation coefficient between the two variables, which can be translated directly to the off-diagonal element of the covariance matrix.

In [None]:
data_nsample = 10_000
model_nsample = 5_000

# Generate data from a 2D multivariate normal distribution given a
# 5-param model [mean1, mean2, sigma1**2, sigma2**2, correlation_coef]
@functools.partial(jax.jit, static_argnames=["nsample"])
def generate_data(params, randkey, nsample=model_nsample):
    mean = params[:2]
    cov11, cov22 = jnp.abs(params[2:4])
    cov12 = params[4] * jnp.sqrt(cov11 * cov22)
    cov = jnp.array([[cov11, cov12],
                     [cov12, cov22]])
    return jax.random.multivariate_normal(
        randkey, mean, cov, shape=(nsample,))

### Define "true" parameters to generate training data

In [None]:
truth_params = jnp.array([1.6, 2.9, 0.8, 1.25, -0.2])
truth_randkey = jax.random.key(42)

training_x = generate_data(truth_params, truth_randkey, nsample=data_nsample)

plt.hexbin(*training_x.T, mincnt=1, norm=plt.matplotlib.colors.LogNorm(),
           linewidth=0.3)
plt.text(0.95, 0.95, "Training data", fontsize=14,
         transform=plt.gca().transAxes, ha="right", va="top")
plt.xlabel("$x_1$", fontsize=14)
plt.ylabel("$x_2$", fontsize=14)
plt.show()

## Define loss function comparing ${\rm PDF}(x_1, x_2)$

- Characterize the loss as the difference between our training and model distributions
- We will evaluate these distributions around randomized kernel centers using the `compare_kde_counts` method (20 kernels by default)
- To perform this gradient descent stochastically, the random seed for drawing kernel centers must be updated at each step. To do this, we will have our loss function accept the `randkey` argument, which it will split and pass to all sources of stochasticity

In [None]:
kde = kdescent.KCalc(training_x, num_kernels=20)


def lossfunc(params, randkey):
    # Split random key for (1) multivariate draws and (2) kernel mini-batching
    key1, key2 = jax.random.split(randkey)
    model_x = generate_data(params, randkey=key1)
    model_kde_counts, truth_kde_counts = kde.compare_kde_counts(key2, model_x)
    
    # Must divide by total number in sample since the training dataset
    # is not the same size as the population generated by the model
    model_kde_density = model_kde_counts / model_nsample
    truth_kde_density = truth_kde_counts / data_nsample

    # Return the mean-squared error of our metrics
    return jnp.mean((model_kde_density - truth_kde_density)**2)

## Run gradient descent

In [None]:
# Define initial guess and bounds for our parameters
guess = jnp.array([0., 0., 1., 1., 0.])
bounds = jnp.array([[-jnp.inf, jnp.inf], [-jnp.inf, jnp.inf],
                    [0.001, jnp.inf], [0.001, jnp.inf], [-0.999, 0.999]])

# Run gradient descent to approximately recover the truth
adam_params, adam_losses = kdescent.adam(
    lossfunc, guess, nsteps=100, param_bounds=bounds,
    learning_rate=1.0, randkey=12345)
print("Final params =", adam_params[-1])
print("True params =", truth_params)

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(8, 3.5))

axes[0].hexbin(*training_x.T, mincnt=1, gridsize=100)
axes[0].scatter(*generate_data(guess, truth_randkey).T,
                s=1, alpha=0.3, color="C1")
axes[0].text(0.02, 0.02, "Initial guess", fontsize=13,
             color="C1", transform=axes[0].transAxes)
axes[0].set_xlim(-2.5, 4.5)
axes[0].set_ylim(-3, 7)

axes[1].hexbin(*training_x.T, mincnt=1, gridsize=100)
axes[1].scatter(*generate_data(adam_params[-1], truth_randkey).T,
                s=1, alpha=0.3, color="C1")
axes[1].text(0.02, 0.02, f"Solution", fontsize=13,
             color="C1", transform=axes[1].transAxes)
axes[1].set_xlim(-2.5, 4.5)
axes[1].set_ylim(-3, 7)
plt.show()

# Advanced Usage

## More complex example model

- 20-parameter model that generates a non-trivial bimodal 3-dimensional distribution (variables: $\log M_\star, g-r, r-z$)
    - To aid our gradient descent maneuver such a tricky parameter space, we will introduce Fourier-space terms into our loss
- To add even more complexity all at once: the training dataset is undersampled below $\log M_\star < 10.5$
    - We must therefore rely on *conditional* probability distributions $P(g-r, r-z | \log M_\star)$, with a separate `KCalc` object handling each bin of our conditional variable, $\log M_\star$

In [None]:
import seaborn as sns

In [None]:
model_nsample = 40_000
data_nsample = 20_000  # same volume, but undersampled below logM* < 10.5

# Generate data weighted from two mass-dependent multivariate normals
@functools.partial(jax.jit, static_argnames=["undersample", "nsample"])
def generate_model(params, randkey, undersample=False, nsample=model_nsample):
    # Parse all 20 parameters
    # =======================
    # Lower and upper bounds on log stellar mass
    logmlim = params[:2]
    logmlim = logmlim.at[1].add(logmlim[0] + 0.001)

    # Distribution parameters at lower mass bound
    mean_mmin = params[2:4]
    sigma11, sigma22 = params[4:6]
    maxcov = jnp.sqrt(sigma11 * sigma22)
    sigma12 = params[6] * maxcov
    cov_mmin = jnp.array([[sigma11, sigma12],
                          [sigma12, sigma22]])
    qfrac_mmin = params[7]
    qmean_mmin = mean_mmin + params[8:10]
    qscale_mmin = params[10]

    # Distribution parameters at upper mass bound
    mean_mmax = params[11:13]
    sigma11, sigma22 = params[13:15]
    maxcov = jnp.sqrt(sigma11 * sigma22)
    sigma12 = params[15] * maxcov
    cov_mmax = jnp.array([[sigma11, sigma12],
                          [sigma12, sigma22]])
    qfrac_mmax = params[16]
    qmean_mmax = mean_mmax + params[17:19]
    qscale_mmax = params[19]

    # Generate distribution from parameters
    # =====================================
    key1, key2 = jax.random.split(randkey, num=2)
    triangle_vals = (0, 0.5, 1) if undersample else (0, 0, 1)
    logm = jax.random.triangular(key1, *triangle_vals, shape=(nsample,))
    logm = logmlim[0] + logm * (logmlim[1] - logmlim[0])
    # Calculate slope of mass dependence
    dlogm = logmlim[1] - logmlim[0]
    dmean = (mean_mmax - mean_mmin) / dlogm
    dcov = (cov_mmax - cov_mmin) / dlogm
    dqfrac = (qfrac_mmax - qfrac_mmin) / dlogm
    dqmean = (qmean_mmax - qmean_mmin) / dlogm
    dqscale = (qscale_mmax - qscale_mmin) / dlogm
    # Apply mass dependence
    mean_sf = mean_mmin + dmean * (logm[:, None] - logmlim[0])
    cov_sf = cov_mmin + dcov * (logm[:, None, None] - logmlim[0])
    mean_q = qmean_mmin + dqmean * (logm[:, None] - logmlim[0])
    qscale = qscale_mmin + dqscale * (logm - logmlim[0])
    cov_q = cov_sf * qscale[:, None, None] ** 2
    qfrac = qfrac_mmin + dqfrac * (logm - logmlim[0])

    # Generate colors from two separate multivariate normals
    rz_sf, gr_sf = jax.random.multivariate_normal(key2, mean_sf, cov_sf).T
    rz_q, gr_q = jax.random.multivariate_normal(key2, mean_q, cov_q).T
    # Concatenate the quenched + star-forming values and assign weights
    data_sf = jnp.array([rz_sf, gr_sf, logm]).T
    data_q = jnp.array([rz_q, gr_q, logm]).T
    data = jnp.concatenate([data_sf, data_q])
    weights = jnp.concatenate([1 - qfrac, qfrac])
    return data, weights

### Define "true" parameters to generate training data

In [None]:
truth_logmmin = 9.0
truth_logmrange = 3.0

truth_mean_mmin = jnp.array([1.4, 1.1])
truth_var_mmin = jnp.array([0.7, 0.4])
truth_corr_mmin = 0.3
truth_qfrac_mmin = 0.2
truth_qmean_mmin = jnp.array([-0.1, 1.6])
truth_qscale_mmin = 0.3

truth_mean_mmax = jnp.array([2.0, 1.6])
truth_var_mmax = jnp.array([0.5, 0.5])
truth_corr_mmax = 0.75
truth_qfrac_mmax = 0.95
truth_qmean_mmax = jnp.array([-0.6, 1.2])
truth_qscale_mmax = 1.1

bounds_truth_logmrange = [0.001, jnp.inf]
bounds_var = ([0.001, jnp.inf], [0.001, jnp.inf])
bounds_corr = [-0.999, 0.999]
bounds_qfrac = [0.0, 1.0]
bounds_qmean_gr = [0.001, jnp.inf]
bounds_qscale = [0.001, jnp.inf]

truth = jnp.array([
    truth_logmmin, truth_logmrange,
    *truth_mean_mmin, *truth_var_mmin, truth_corr_mmin, truth_qfrac_mmin,
    *truth_qmean_mmin, truth_qscale_mmin,
    *truth_mean_mmax, *truth_var_mmax, truth_corr_mmax, truth_qfrac_mmax,
    *truth_qmean_mmax, truth_qscale_mmax
])
guess = jnp.array([
    9.25, 2.5, *[0.0, 0.0, 1.0, 1.0, 0.0, 0.5, 0.0, 1.0, 1.0]*2
])
bounds = [
    None, bounds_truth_logmrange,
    *[None, None, *bounds_var, bounds_corr, bounds_qfrac,
      None, bounds_qmean_gr, bounds_qscale]*2
]

# Generate training data from the truth parameters we just defined
truth_randkey = jax.random.key(43)
training_x_weighted, training_w = generate_model(
    truth, truth_randkey, undersample=True, nsample=data_nsample)

# KDescent allows weighted training data, but to make this more realistic,
# let's generate our actual training data by randomized weighted sampling
training_selection = jax.random.uniform(
    jax.random.split(truth_randkey)[0], training_w.shape) < training_w
training_x = training_x_weighted[training_selection]

### Define plotting function

- Plot the mass distribution + the color-color distribution in three separate mass bins

In [None]:
lowmass_cut = [9.0, 9.5]
midmass_cut = [10.25, 10.75]
highmass_cut = [11.5, 12.0]
is_lowmass = ((lowmass_cut[0] < training_x_weighted[:, 2])
              & (training_x_weighted[:, 2] < lowmass_cut[1]))
is_midmass = ((midmass_cut[0] < training_x_weighted[:, 2])
              & (training_x_weighted[:, 2] < midmass_cut[1]))
is_highmass = ((highmass_cut[0] < training_x_weighted[:, 2])
               & (training_x_weighted[:, 2] < highmass_cut[1]))
training_w_lowmass = training_w * is_lowmass
training_w_midmass = training_w * is_midmass
training_w_highmass = training_w * is_highmass
is_noweight_lowmass = (
    (lowmass_cut[0] < training_x[:, 2])
    & (training_x[:, 2] < lowmass_cut[1]))
is_noweight_midmass = (
    (midmass_cut[0] < training_x[:, 2])
    & (training_x[:, 2] < midmass_cut[1]))
is_noweight_highmass = (
    (highmass_cut[0] < training_x[:, 2])
    & (training_x[:, 2] < highmass_cut[1]))


def generate_model_into_mass_bins(params, randkey):
    key1 = jax.random.split(randkey)[0]
    model_x, model_w = generate_model(params, randkey=key1)
    is_low = ((lowmass_cut[0] < model_x[:, 2])
              & (model_x[:, 2] < lowmass_cut[1]))
    is_mid = ((midmass_cut[0] < model_x[:, 2])
              & (model_x[:, 2] < midmass_cut[1]))
    is_high = ((highmass_cut[0] < model_x[:, 2])
               & (model_x[:, 2] < highmass_cut[1]))
    return (model_x, model_x[is_low], model_x[is_mid], model_x[is_high],
            model_w, model_w[is_low], model_w[is_mid], model_w[is_high])


def make_sumstat_plot(params, txt="", fig=None, prev_layers=None):
    (modall, modlow, modmid, modhigh,
     w_all, w_low, w_mid, w_high) = generate_model_into_mass_bins(
         params, jax.random.key(13))
    if prev_layers is not None:
        for layer in prev_layers:
            layer.remove()

    fig = plt.figure(figsize=(10, 9)) if fig is None else fig
    ax = fig.add_subplot(221) if len(fig.axes) < 4 else fig.axes[0]
    ax.hist(training_x_weighted[:, 2], bins=50, color="red",
            weights=training_w)
    _, bins, hist1 = ax.hist(
        modall[:, 2], color="grey", bins=50, alpha=0.9, weights=w_all)
    hist2 = ax.hist(modlow[:, 2], bins=list(bins), color="C0",
                    alpha=0.9, weights=w_low)[-1]
    hist3 = ax.hist(modmid[:, 2], bins=list(bins), color="C0",
                    alpha=0.9, weights=w_mid)[-1]
    hist4 = ax.hist(modhigh[:, 2], bins=list(bins), color="C0",
                    alpha=0.9, weights=w_high)[-1]
    ax.set_xlabel("$\\log M_\\ast$", fontsize=14)
    text1 = ax.text(
        0.98, 0.98, "Training data", color="red", va="top", ha="right",
        fontsize=14, transform=ax.transAxes)
    text2 = ax.text(
        0.98, 0.91, txt, color="blue", va="top", ha="right",
        fontsize=14, transform=ax.transAxes)

    ax = fig.add_subplot(222) if len(fig.axes) < 4 else fig.axes[1]
    hex1 = ax.hexbin(*modlow[:, :2].T, mincnt=1,
                     C=w_low, reduce_C_function=np.sum,
                     norm=plt.matplotlib.colors.LogNorm())
    if prev_layers is None:
        sns.kdeplot(
            {"$r - z$": training_x_weighted[is_lowmass][:, 0],
             "$g - r$": training_x_weighted[is_lowmass][:, 1]},
            weights=training_w[is_lowmass],
            x="$r - z$", y="$g - r$", color="red", levels=7, ax=ax)
    ax.set_xlabel("$r - z$", fontsize=14)
    ax.set_ylabel("$g - r$", fontsize=14)
    text3 = ax.text(
        0.02, 0.02, f"${lowmass_cut[0]} < \\log M_\\ast < {lowmass_cut[1]}$",
        fontsize=14, transform=ax.transAxes)

    ax = fig.add_subplot(223, sharex=ax, sharey=ax) if len(
        fig.axes) < 4 else fig.axes[2]
    hex2 = ax.hexbin(*modmid[:, :2].T, mincnt=1,
                     C=w_mid, reduce_C_function=np.sum,
                     norm=plt.matplotlib.colors.LogNorm())
    if prev_layers is None:
        sns.kdeplot(
            {"$r - z$": training_x_weighted[is_midmass][:, 0],
             "$g - r$": training_x_weighted[is_midmass][:, 1]},
            weights=training_w[is_midmass],
            x="$r - z$", y="$g - r$", color="red", levels=7, ax=ax)
    ax.set_xlabel("$r - z$", fontsize=14)
    ax.set_ylabel("$g - r$", fontsize=14)
    text4 = ax.text(
        0.02, 0.02, f"${midmass_cut[0]} < \\log M_\\ast < {midmass_cut[1]}$",
        fontsize=14, transform=ax.transAxes)

    ax = fig.add_subplot(224, sharex=ax, sharey=ax) if len(
        fig.axes) < 4 else fig.axes[3]
    hex3 = ax.hexbin(*modhigh[:, :2].T, mincnt=1,
                     C=w_high, reduce_C_function=np.sum,
                     norm=plt.matplotlib.colors.LogNorm())
    if prev_layers is None:
        sns.kdeplot(
            {"$r - z$": training_x_weighted[is_highmass][:, 0],
             "$g - r$": training_x_weighted[is_highmass][:, 1]},
            weights=training_w[is_highmass],
            x="$r - z$", y="$g - r$", color="red", levels=7, ax=ax)
    ax.set_xlabel("$r - z$", fontsize=14)
    ax.set_ylabel("$g - r$", fontsize=14)
    text5 = ax.text(
        0.02, 0.02, f"${highmass_cut[0]} < \\log M_\\ast < {highmass_cut[1]}$",
        fontsize=14, transform=ax.transAxes)
    ax.set_xlim(-4, 7)
    ax.set_ylim(-4, 7)
    return [hex1, hex2, hex3, hist1, hist2, hist3, hist4,
            text1, text2, text3, text4, text5]

In [None]:
make_sumstat_plot(truth, txt="Truth")
plt.show()

## Define loss function comparing ${\rm PDF}(g-r, r-z | M_\ast)$ *and* its Fourier pair

In [None]:
kcalc_lowmass = kdescent.KCalc(training_x[is_noweight_lowmass, :2])
kcalc_midmass = kdescent.KCalc(training_x[is_noweight_midmass, :2])
kcalc_highmass = kdescent.KCalc(training_x[is_noweight_highmass, :2])


# Differentiable alternative hard binning in the loss function:
@jax.jit
def soft_tophat(x, low, high, squish=25.0):
    """Approximately return 1 when `low < x < high`, else return 0"""
    width = (high - low) / squish
    left = jax.nn.sigmoid((x - low) / width)
    right = jax.nn.sigmoid((high - x) / width)
    return left * right

@jax.jit
def lossfunc(params, randkey):
    key1, *keys = jax.random.split(randkey, 7)
    model_x, model_w = generate_model(params, randkey=key1)
    weight_low = soft_tophat(model_x[:, 2], *lowmass_cut) * model_w
    weight_mid = soft_tophat(model_x[:, 2], *midmass_cut) * model_w
    weight_high = soft_tophat(model_x[:, 2], *highmass_cut) * model_w

    model_low_counts, truth_low_counts = kcalc_lowmass.compare_kde_counts(
        keys[0], model_x[:, :2], weight_low)
    model_mid_counts, truth_mid_counts = kcalc_midmass.compare_kde_counts(
        keys[1], model_x[:, :2], weight_mid)
    model_high_counts, truth_high_counts = kcalc_highmass.compare_kde_counts(
        keys[2], model_x[:, :2], weight_high)

    model_low_fcounts, truth_low_fcounts = kcalc_lowmass.compare_fourier_counts(
        keys[3], model_x[:, :2], weight_low)
    model_mid_fcounts, truth_mid_fcounts = kcalc_midmass.compare_fourier_counts(
        keys[4], model_x[:, :2], weight_mid)
    model_high_fcounts, truth_high_fcounts = kcalc_highmass.compare_fourier_counts(
        keys[5], model_x[:, :2], weight_high)

    # Convert counts to conditional prob: P(krnl | M*) = N(krnl & M*) / N(M*)
    model_low_condprob = model_low_counts / (weight_low.sum() + 1e-10)
    model_mid_condprob = model_mid_counts / (weight_mid.sum() + 1e-10)
    model_high_condprob = model_high_counts / (weight_high.sum() + 1e-10)
    truth_low_condprob = truth_low_counts / (training_w_lowmass.sum() + 1e-10)
    truth_mid_condprob = truth_mid_counts / (training_w_midmass.sum() + 1e-10)
    truth_high_condprob = truth_high_counts / (
        training_w_highmass.sum() + 1e-10)
    # Convert Fourier counts to "conditional" ECF analogously
    model_low_ecf = model_low_fcounts / (weight_low.sum() + 1e-10)
    model_mid_ecf = model_mid_fcounts / (weight_mid.sum() + 1e-10)
    model_high_ecf = model_high_fcounts / (weight_high.sum() + 1e-10)
    truth_low_ecf = truth_low_fcounts / (training_w_lowmass.sum() + 1e-10)
    truth_mid_ecf = truth_mid_fcounts / (training_w_midmass.sum() + 1e-10)
    truth_high_ecf = truth_high_fcounts / (training_w_highmass.sum() + 1e-10)

    # One constraint on number density at the highest stellar mass bin
    volume = 100.0
    model_massfunc = jnp.array([weight_high.sum(),]) / volume
    truth_massfunc = jnp.array([training_w_highmass.sum(),]) / volume

    # Must abs() the Fourier residuals so that the loss is real
    sqerrs = jnp.concatenate([
        (model_low_condprob - truth_low_condprob)**2,
        (model_mid_condprob - truth_mid_condprob)**2,
        (model_high_condprob - truth_high_condprob)**2,
        jnp.abs(model_low_ecf - truth_low_ecf)**2,
        jnp.abs(model_mid_ecf - truth_mid_ecf)**2,
        jnp.abs(model_high_ecf - truth_high_ecf)**2,
        (model_massfunc - truth_massfunc)**2,
    ])

    return jnp.mean(sqerrs)

## Run gradient descent

In [None]:
adam_params, adam_losses = kdescent.adam(
    lossfunc, guess, nsteps=500, param_bounds=bounds,
    learning_rate=0.05, randkey=12345)
print("Best fit params =", adam_params[-1])

In [None]:
fig = plt.figure(figsize=(20, 9), layout="constrained")
fig.set_facecolor("0.05")
figs = fig.subfigures(1, 2, wspace=0.004)
figs[0].set_facecolor("white")
figs[1].set_facecolor("white")
make_sumstat_plot(
    adam_params[0], txt="Initial guess", fig=figs[0])
make_sumstat_plot(
    adam_params[-1],
    txt=f"Solution after {len(adam_params)-1} evaluations", fig=figs[1])
plt.show()

This is not a perfect fit (although it *would* continue to improve a little if we ran gradient descent even longer), but our goal to qualitatively recover the target distribution was a success! This shows how powerful kdescent's Fourier counts can be as a complementary summary statistic to the PDF counts.