<a href="https://colab.research.google.com/github/ziatdinovmax/gpax/blob/main/examples/gpax_hypo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# For github continuous integration only
# Please ignore if you're running this notebook!
import os
if os.environ.get("CI_SMOKE"):
    NUM_WARMUP = 100
    NUM_SAMPLES = 100
else:
    NUM_WARMUP = 2000
    NUM_SAMPLES = 2000

# Hypothesis learning: toy data example

This notebook demonstrates how to apply the hypothesis learning to toy data. The [hypothesis learning](https://arxiv.org/abs/2112.06649) is based on the idea that in active learning, the correct model of the system’s behavior leads to a faster decrease in the overall Bayesian uncertainty about the system under study. In the hypothesis learning setup, the probabilistic models of the possible system’s behaviors (hypotheses) are wrapped into structured Gaussian processes, and a basic reinforcement learning policy (such as epsilon-greedy or softmax) is used to select a correct model from several competing hypotheses.

*Prepared by Maxim Ziatdinov (2023). Last updated in October 2023.*

## Install & Import

Install the latest GPax package from PyPI (this is best practice, as it installs the latest, deployed and tested version).

In [None]:
!pip install gpax

Import needed packages:

In [None]:
try:
    # For use on Google Colab
    import gpax

except ImportError:
    # For use locally (where you're using the local version of gpax)
    print("Assuming notebook is being run locally, attempting to import local gpax module")
    import sys
    sys.path.append("..")
    import gpax

In [None]:
#@title Imports
from typing import Union, Dict, Type

import gpax

import jax.numpy as jnp
import numpy as onp
import numpyro
import matplotlib.pyplot as plt

gpax.utils.enable_x64()

Enable some pretty plotting.

In [None]:
import matplotlib as mpl

In [None]:
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
mpl.rcParams['text.usetex'] = False
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
plt.rc('axes', labelsize=12)
mpl.rcParams['figure.dpi'] = 200

# Hypothesis learning

In [None]:
#@title Plotting and data utilities { form-width: "20%" }

def update_datapoints(point_idx, point_measured, X_measured, y_measured, X_unmeasured):
    """Updates "measured" and "unmeasured" arrays of (dummy) data points"""
    X_measured = jnp.append(X_measured, X_unmeasured[point_idx][None], 0)
    X_unmeasured = jnp.delete(X_unmeasured, point_idx, 0)
    y_measured = jnp.append(y_measured, point_measured)
    return X_measured, y_measured, X_unmeasured
    

def plot_results(X_measured, y_measured, X_unmeasured, y_pred, y_sampled, obj, model_idx, rewards, **kwargs):
    X = jnp.concatenate([X_measured, X_unmeasured], axis=0).sort()
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    ax1.scatter(X_measured, y_measured, marker='x', s=100, c='k', label="Measured points", zorder=1)
    ax1.plot(X, y_pred, c='red', label='Model reconstruction', zorder=0)
    ax1.fill_between(X, y_pred - y_sampled.std(0), y_pred + y_sampled.std(0),
                     color='r', alpha=0.2, label="Model uncertainty", zorder=0)
    ax1.set_xlabel("$x$", fontsize=18)
    ax1.set_ylabel("$y$", fontsize=18)
    ax2.plot(X_unmeasured, obj, c='k')
    ax2.vlines(X_unmeasured[obj.argmax()], obj.min(), obj.max(), linestyles='dashed', label= "Next point")
    ax2.set_xlabel("$x$", fontsize=18)
    ax2.set_ylabel("Acquisition function", fontsize=18)
    ax1.legend(loc="upper left")
    ax2.legend(loc="upper left")
    step = kwargs.get("e", 0)
    plt.suptitle("Step: {},  Sampled Model: {}, Rewards: {}".format(
        step+1, model_idx, onp.around(rewards, 3).tolist()), fontsize=24)
    # fig.savefig("./{}.png".format(step))
    plt.show() 
    

def plot_acq(x, obj, idx):
    plt.plot(x.squeeze(), obj, c='k')
    plt.vlines(x[idx], obj.min(), obj.max(), linestyles='dashed')
    plt.xlabel("$x$", fontsize=18)
    plt.ylabel("Acquisition function", fontsize=18)
    plt.show()
    

def plot_final_result(X, y, X_unmeasured, y_pred, y_sampled, seed_points):
    plt.figure(figsize=(6, 4))
    plt.scatter(X[seed_points:], y[seed_points:], c=jnp.arange(1, len(X[seed_points:])+1),
                cmap='viridis', label="Sampled points", zorder=2)
    cbar = plt.colorbar(label="Exploration step")
    cbar_ticks = jnp.arange(2, len(X[seed_points:]) + 1, 2)
    cbar.set_ticks(cbar_ticks)
    plt.scatter(X[:seed_points], y[:seed_points], marker='x', s=64,
                c='k', label="Seed points", zorder=1)
    plt.plot(X_unmeasured, y_pred, '--', c='red', label='Model reconstruction', zorder=1)
    plt.plot(X_unmeasured, truefunc, c='k', label="Ground truth", zorder=0)
    plt.fill_between(X_unmeasured, y_pred - y_sampled.std(0), y_pred + y_sampled.std(0),
                            color='r', alpha=0.2, label="Model uncertainty", zorder=0)
    plt.xlabel("$x$", fontsize=12)
    plt.ylabel("$y$", fontsize=12)
    plt.legend(fontsize=9, loc='upper left')
    #plt.ylim(1.8, 6.6)
    plt.show()

First, let's generate some data. As a practical example chosen here, we are interested in the active learning of phase
diagram that has a transition between different phases. The phase transition manifests in discontinuity of a measurable system’s property, such as heat capacity. However, we usually do not know where a phase transition occurs precisely, nor are we aware of the exact behavior of the property of interest in different phases. We note that using a standard Gaussian process-based active learning is not an optimal choice in such a case as simple GP struggles around the discontinuity point.

In [None]:
def function_(x: jnp.ndarray, params: Dict[str, float]) -> jnp.ndarray:
    return jnp.piecewise(
        x,
        [x < params["t"], x >= params["t"]],
        [lambda x: x**params["beta1"], lambda x: x**params["beta2"]]
    )


X = jnp.linspace(0.0, 2.5, 100)
params_i = {"t": 1.6, "beta1": 4, "beta2": 2.5}

truefunc = function_(X, params_i)
Y = truefunc + 0.2 * onp.random.normal(size=len(X))

fig, ax = plt.subplots(1, 1, figsize=(6, 2))
ax.scatter(X, Y, alpha=0.5, c='k', label="Noisy observations")
ax.plot(X, truefunc, lw=2, c='k', label="True function")
ax.legend()
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
plt.show()

Of course, our algorithm is not going to see all these observations. Nor it is going to see the true function. Instead, we are going to start with just 4 measured points.

In [None]:
onp.random.seed(1)

seed_idx = onp.array([0, 33, 66, 99])
X_measured = X[seed_idx]
X_unmeasured = onp.delete(X, seed_idx)
y_measured = function_(X_measured, params_i) + 0.2 * onp.random.normal(size=len(X_measured))
num_seed_points = len(X_measured)

fig, ax = plt.subplots(1, 1, figsize=(6, 2))
ax.scatter(X_measured, y_measured, alpha=1.0, c='k', marker='x', label="Noisy observations")
ax.legend()
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
plt.show()

Next, we define possible models of system's behavior as dereministic functions:

In [None]:
def piecewise1(x: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """Power law behavior before the transition and linear behavior after the transition"""
    return jnp.piecewise(
        x,
        [x < params["t"], x >= params["t"]],
        [lambda x: x**params["beta"], lambda x: params["c"]*x]
    )
    
def piecewise2(x: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """Linear behavior before and after the transition"""
    return jnp.piecewise(
        x,
        [x < params["t"], x >= params["t"]],
        [lambda x: params["b"]*x, lambda x: params["c"]*x]
    )
    
def piecewise3(x: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """Power-law behavior before and after the transition"""
    return jnp.piecewise(
        x,
        [x < params["t"], x >= params["t"]],
        [lambda x: x**params["beta1"], lambda x: x**params["beta2"]]
    )

We put priors over parameters of each model to make them probabilistic:

In [None]:
def piecewise1_priors() -> Dict[str, jnp.ndarray]:
    # Sample model parameters
    t = numpyro.sample("t", numpyro.distributions.Uniform(0.5, 2.0))
    beta = numpyro.sample("beta", numpyro.distributions.Normal(3, 1))
    c = numpyro.sample("c", numpyro.distributions.Normal(3, 1))
    # Return sampled parameters as a dictionary
    return {"t": t, "beta": beta, "c": c}

def piecewise2_priors() -> Dict[str, jnp.ndarray]:
    # Sample model parameters
    t = numpyro.sample("t", numpyro.distributions.Uniform(0.5, 2.0))
    b = numpyro.sample("b", numpyro.distributions.Normal(3, 1))
    c = numpyro.sample("c", numpyro.distributions.Normal(3, 1))
    # Return sampled parameters as a dictionary
    return {"t": t, "b": b, "c": c}

def piecewise3_priors() -> Dict[str, jnp.ndarray]:
    # Sample model parameters
    t = numpyro.sample("t", numpyro.distributions.Uniform(0.5, 2.0))
    beta1 = numpyro.sample("beta1", numpyro.distributions.Normal(3, 1))
    beta2 = numpyro.sample("beta2", numpyro.distributions.Normal(3, 1))
    # Return sampled parameters as a dictionary
    return {"t": t, "beta1": beta1, "beta2": beta2}

Let's also specify custom priors over GP kernel (this step is optional):

In [None]:
def gp_kernel_prior() -> Dict[str, jnp.ndarray]:
    length = numpyro.sample("k_length", numpyro.distributions.Uniform(0, 1))
    scale = numpyro.sample("k_scale", numpyro.distributions.LogNormal(0, 1))
    return {"k_length": length, "k_scale": scale}

Define a simple reward function for hypothesis learning:

In [None]:
def get_reward(obj_history):
    """A reward of +/-1 is given if the median uncertainty at the current step
    is smaller/larger than the median uncertainty at the previous step"""
    r = 1 if obj_history[-1] < obj_history[-2] else -1
    return r

The main part (Algorithm 1 in the paper):

In [None]:
exploration_steps = 15
warmup_steps = 3
plot_reconstruction = True  # available only for exploration phase

# Create lists containing physical models and probabilistic priors over their parameters
models = [piecewise1, piecewise2, piecewise3]
model_priors = [piecewise1_priors, piecewise2_priors, piecewise3_priors]

onp.random.seed(1)  # rng seed for reproducibility

# Initialize the reward, predictive uncertainty and model selection records
record = onp.zeros((len(models), 2))
model_choices = []
obj_history = []

# Warm-up phase
for w in range(warmup_steps):
    print("Warmup step {}/{}".format(w+1, warmup_steps))
    obj_median_all, obj_all = [], []
    
    # Iterate over probabilistic models in the list
    for i, model in enumerate(models):

        # for each model, run BI and store uncertainty values
        obj, _ = gpax.hypo.step(
            model,
            model_priors[i],
            X_measured,
            y_measured,
            X_unmeasured,
            gp_wrap=True,
            gp_kernel='Matern',
            gp_kernel_prior=gp_kernel_prior,  # wrap model into a Gaussian process
            num_warmup=NUM_WARMUP,
            num_samples=NUM_SAMPLES,
        )
        record[i, 0] += 1
        obj_all.append(obj)

        # (one can use integral uncertainty instead of median)
        obj_median_all.append(jnp.nanmedian(obj).item())

    # Reward a model that has the smallest integral/median uncertainty
    idx = onp.argmin(obj_median_all)
    model_choices.append(idx)
    record[idx, 1] += 1

    # Store the integral/median uncertainty
    obj_history.append(obj_median_all[idx])

    # Compute the next measurement point using the predictive uncertainty of rewarded model
    obj = obj_all[idx]
    next_point_idx = obj.argmax()

    # Evaluate the function in the suggested point
    measured_point = function_(X_unmeasured[next_point_idx], params_i) + 0.2*onp.random.normal()

    # Update arrays with measured and unmeasured points
    X_measured, y_measured, X_unmeasured = update_datapoints(
        next_point_idx, measured_point, X_measured, y_measured, X_unmeasured
    )

# Average over the number of warmup steps
record[:, 1] = record[:, 1] / warmup_steps

# Run exploration phase
for e in range(exploration_steps - warmup_steps):
    print("Exploration step {}/{}".format(e+warmup_steps+1, exploration_steps))

    # Choose model according to epsilon-greedy policy
    idx = gpax.hypo.sample_next(record[:, 1], method="eps-greedy", eps=0.4)
    model_choices.append(idx)
    print("Using model {}".format(idx+1))

    # Derive acquisition function with the selected model
    obj, m_post = gpax.hypo.step(
        models[idx],
        model_priors[idx],
        X_measured,
        y_measured,
        X_unmeasured,
        gp_wrap=True,
        gp_kernel='Matern',
        gp_kernel_prior=gp_kernel_prior,  # wrap the sampled model into a Gaussian process
        num_restarts=2,
        print_summary=False,
        num_warmup=NUM_WARMUP,
        num_samples=NUM_SAMPLES,
    )

    # Get reward
    obj_history.append(jnp.nanmedian(obj).item())
    r = get_reward(obj_history)

    # Update records
    record = gpax.hypo.update_record(record, idx, r)

    # Get the next measurement point from the predictive uncertainty of the sampled model
    next_point_idx = obj.argmax()

    # Evaluate function in the suggested point
    measured_point = function_(X_unmeasured[next_point_idx], params_i) + 0.2*onp.random.normal()
    if plot_reconstruction:
        
        # plot current reconstruction and acqusition function
        y_pred, y_sampled = m_post.predict(gpax.utils.get_keys()[1], X)
        plot_results(
            X_measured,
            y_measured,
            X_unmeasured,
            y_pred,
            y_sampled.squeeze(),
            obj,
            idx+1,
            record[:, 1],
            e=e+warmup_steps
        )

    # Update arrays with measured and unmeasured points
    X_measured, y_measured, X_unmeasured = update_datapoints(
        next_point_idx,
        measured_point,
        X_measured,
        y_measured,
        X_unmeasured
    )

Plot integral/median uncerainty as a function of exploration steps:

(note that for the warm-up steps, we plot only model that produced lowest uncertainty)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 2))
cmap = mpl.colormaps["rainbow"].resampled(3)
ax.plot(onp.arange(1, exploration_steps+1), obj_history, c='k')
for model_index in range(3):
    where = onp.where(onp.array(model_choices) == model_index)[0]
    ax.scatter(
        where + 1,
        onp.array(obj_history)[where],
        color=cmap(model_index),
        s=128,
        alpha=1,
        label=f"model {model_index+1}",
    )
ax.set_xlabel("Exploration step", fontsize=14)
ax.set_ylabel("Median uncertainty", fontsize=14)
ax.legend()
plt.show()

View average reward associated with each model:

(note that it counts the warmup steps where all the models were evaluated)

In [None]:
for i, r in enumerate(record):
    print("model {}:  counts {}  reward (avg) {}".format(i+1, (int(r[0])), onp.round(r[1], 3)))

Compute (and plot) each model's prediction over the entire grid using the final set of the discovered point:


In [None]:
for i, model in enumerate(models):
    # use the same parameters as in the main loop
    _, gp_model = gpax.hypo.step(
        model,
        model_priors[i],
        X_measured,
        y_measured,
        gp_wrap=True,
        gp_kernel='Matern',
        gp_kernel_prior=gp_kernel_prior,
        num_restarts=2,
        print_summary=0,
        num_warmup=NUM_WARMUP,
        num_samples=NUM_SAMPLES,
    )
    y_pred, y_sampled = gp_model.predict(gpax.utils.get_keys()[1], X)
    print("\n Model {}, Reward (avg) {}".format(i+1, onp.round(record[i, 1], 3)))
    plot_final_result(
        X_measured,
        y_measured,
        X, y_pred,
        y_sampled.squeeze(),
        seed_points=num_seed_points
    )

Note that because we wrapped our models in GP and because each model had a transition point, even the first two cases showed a satisfactory fit. At the same time, the model that received the highest reward (i.e. was
favored by our algorithm) provided the best fit accompanied by the smallest uncertainty. Hence, we were able to learn a distribution of property of interest with a small number of sparse measurements while also identifying a correct model describing the system’s behavior.