# Epistemic Nearest Neighbors (ENN)

ENN is a non-parametric surrogate with $O(N)$ computation-time scaling, where $N$ is the number of observations in the data set. ENN can be used in Bayesian optimization as a scalable alternative to a GP (which scales as $O(N^2)$.)

**Sweet, D., & Jadhav, S. A. (2025).** Taking the GP Out of the Loop. *arXiv preprint arXiv:2506.12818*.  
   https://arxiv.org/abs/2506.12818

   ---

In [None]:
import numpy as np

from enn import EpistemicNearestNeighbors, enn_fit


def plot_enn_demo(ax, num_samples: int, k: int, noise: float, m: int = 1) -> None:
    x = np.sort(np.random.rand(num_samples + 4))
    x[-3] = x[-4]
    x[-2] = x[-4]
    x[-1] = x[-4]
    x[1] = x[0] + 0.03
    eps = np.random.randn(num_samples + 4)
    y = np.sin(2 * m * np.pi * x) + noise * eps
    yvar = (noise**2) * np.ones_like(y)
    train_x = x[:, None]
    train_y = y[:, None]
    train_yvar = yvar[:, None]
    model = EpistemicNearestNeighbors(
        train_x,
        train_y,
        train_yvar,
    )
    rng = np.random.default_rng(0)
    result = enn_fit(
        model,
        k=k,
        num_fit_candidates=100,
        num_fit_samples=min(10, num_samples),
        rng=rng,
    )
    print(k, noise, result)
    params = result
    x_hat = np.linspace(0.0, 1.0, 30)
    x_hat_2d = x_hat[:, None]
    posterior = model.posterior(x_hat_2d, params=params, exclude_nearest=False)
    mu = posterior.mu[:, 0]
    se = posterior.se[:, 0]
    marker_size = 3 if num_samples >= 100 else 15
    ax.scatter(x, y, s=marker_size, color="black", alpha=0.5)
    ax.plot(x_hat, mu, linestyle="--", color="tab:blue", alpha=0.7)
    ax.fill_between(x_hat, mu - 2 * se, mu + 2 * se, color="tab:blue", alpha=0.2)
    ax.set_ylim(-5, 5)
    ax.set_title(f"n={num_samples}, noise={noise}")

In [None]:
import matplotlib.pyplot as plt

# Blue area is the epistemic uncertainty only

k = 5
fig, axes = plt.subplots(2, 3, figsize=(9, 6), sharex=True, sharey=True)
num_samples_list = [5, 10]
noise_list = [0.0, 0.1, 0.3]
for row_idx, num_samples in enumerate(num_samples_list):
    for col_idx, noise in enumerate(noise_list):
        ax = axes[row_idx, col_idx]
        np.random.seed(4)
        plot_enn_demo(ax, num_samples=num_samples, k=k, noise=noise)
for ax in axes[-1, :]:
    ax.set_xlabel("x")
for ax in axes[:, 0]:
    ax.set_ylabel("y")
fig.tight_layout()

In [None]:
import time
import matplotlib.pyplot as plt

np.random.seed(1)
fig, ax = plt.subplots(figsize=(5, 3))
t_0 = time.time()
plot_enn_demo(ax, num_samples=1_000_000, k=5, noise=0.3, m=3)
t_1 = time.time()
print(f"Time taken: {t_1 - t_0:.2f} seconds")
ax.set_xlabel("x")
ax.set_ylabel("y")
fig.tight_layout()