# Samplers
The samplers module provides various sampling algorithms to generate samples from a given probability distribution. These algorithms can be used to sample a explicit (graph) or implicit ($f(z)=0$) distribution. The samplers can be used for various tasks such as:
* build a dataset or mesh for a manifold
* visualize a distribution
* build a dataset of geodesics using one of the methods in `geodesics.generate`

## Sample from an explicit manifold
The following code samples from an explicit manifold defined by a potential function $\phi(x):\mathbb{R}^i\rightarrow \mathbb{R}^m$, where $i$ is the number of (independent) inputs and $m$ the number of constraints. The dimensionality of ambient space is $n = i+m$.
* Using `randinput_expl` we can obtain a random sample of the manifold by sampling uniformly via latin-hypercube in the input space and projecting the samples onto the manifold using $\phi$. This method can be used to obtain a **uniform sample of the input space**, but it does not guarantee a uniform sample of the manifold itself, but it's fast.
* Using `volume_expl` we can obtain a non-repeated **uniform sample of the manifold**. The function perform an over-sampling of $phi$, which gets then re-sampled $w_i \sim \sqrt{\text{det} [G(x_i)]}$ where $G = J^T J$ is the metric tensor induced by $\phi$.

In [1]:
from jnlr.utils.samplers import volume_expl, randinput_expl
from jnlr.utils.manifolds import f_ackley as phi
from jnlr.utils.plot_utils import plot_3d_projection

Y_vol = volume_expl(phi, n_samples=5000, oversample=5, roi_R=5)
fig = plot_3d_projection(Y_vol, phi)

# set view from the top
fig.update_layout(scene_camera=dict(eye=dict(x=0., y=0., z=2.5)))


In [2]:

Y_vol = randinput_expl(phi, n_samples=5000)
fig = plot_3d_projection(Y_vol, phi)
# set view from the top
fig.update_layout(scene_camera=dict(eye=dict(x=0., y=0., z=2.5)))

In [3]:
from jnlr.utils.samplers import langevin_implicit

def f(xyz):
    x, y, z = xyz
    return x**4 - x**2 + y**2 + z**2 - 2*x*z
Y_lang = langevin_implicit(f, n_samples=5000, burn=100, thin=2, sigma=0.1, lam=0.01, kappa=0.02, R=3.0)
fig = plot_3d_projection(Y_lang)
fig.update_layout(scene_camera=dict(eye=dict(x=0., y=0., z=2.5)))


In [None]:
from jnlr.utils.implicit_hypersurfaces import f_cube

Y_lang = langevin_implicit(f_cube, n_samples=1000, burn=100, thin=2, sigma=0.1, lam=0.1, kappa=0.01, R=3.0, tol=1)
fig = plot_3d_projection(Y_lang)
fig.update_layout(scene_camera=dict(eye=dict(x=0., y=0., z=2.5)))

In [None]:
from jnlr.utils.implicit_hypersurfaces import f_dodecahedron
Y_lang = langevin_implicit(f_dodecahedron, n_samples=5000, burn=100, thin=2, sigma=0.1, lam=0.01, kappa=0.02, R=3.0)
fig = plot_3d_projection(Y_lang)
fig.update_layout(scene_camera=dict(eye=dict(x=0., y=0., z=2.5)))

In [None]:
from jnlr.utils.plot_utils import plot_3d_projection
from jnlr.utils.implicit_hypersurfaces import surface7, surface_b, surface_c

Y_lang = langevin_implicit(surface7, n_samples=10000, burn=100, thin=1, sigma=0.3, lam=0, kappa=0, R=5.0, tol=1e-2)
fig = plot_3d_projection(Y_lang)
fig.update_layout(scene_camera=dict(eye=dict(x=0., y=0., z=2.5)))

In [None]:
Y_lang = langevin_implicit(surface_c, n_samples=10000, burn=100, thin=1, sigma=0.5, lam=0.1, kappa=0, R=5.0, tol=1e-3)
fig = plot_3d_projection(Y_lang)
fig.update_layout(scene_camera=dict(eye=dict(x=0., y=0., z=2.5)))

In [None]:
Y_lang = langevin_implicit(surface_b, n_samples=10000, burn=100, thin=1, sigma=0.3, lam=0, kappa=0, R=5.0, tol=1e-2)
fig = plot_3d_projection(Y_lang)
fig.update_layout(scene_camera=dict(eye=dict(x=0., y=0., z=2.5)))

In [10]:
import jax.numpy as jnp



# Parameters as before
R         = 1.0
t         = 0.12
stroke    = 0.12
epsilon   = 0.06
extrude_x = 0.40
extrude_y = 0.40
extrude_z = 0.40
z_center_R = -0.40

# Updated segments for N, L and rho (same as before)
N_segments = [
    (jnp.array([-0.4, -0.4]), jnp.array([-0.4,  0.4])),
    (jnp.array([ 0.4, -0.4]), jnp.array([ 0.4,  0.4])),
    (jnp.array([-0.4,  0.4]), jnp.array([ 0.4, -0.4]))
]
L_segments = [
    (jnp.array([-0.4, -0.4]), jnp.array([-0.4,  0.4])),
    (jnp.array([-0.4, -0.4]), jnp.array([ 0.4, -0.4]))
]
R_segments = [
    (jnp.array([-0.4, -0.4]), jnp.array([-0.4,  0.4])),
    (jnp.array([-0.4,  0.4]), jnp.array([ 0.0,  0.4])),
    (jnp.array([ 0.0,  0.4]), jnp.array([ 0.0,  0.2]))
]

def sd_segment(p, a, b):
    pa = p - a
    ba = b - a
    ba_dot = jnp.dot(ba, ba)
    t_clamp = jnp.clip(jnp.dot(pa, ba) / (ba_dot + 1e-8), 0.0, 1.0)
    proj = a + t_clamp * ba
    return jnp.linalg.norm(p - proj, axis=-1)

def smooth_indicator(p, segments):
    d = 1e6
    for a, b in segments:
        d = jnp.minimum(d, sd_segment(p, a, b))
    d0 = d - stroke
    return jnp.clip(1.0 - d0 / epsilon, 0.0, 1.0)

def gate(t, width):
    return jnp.clip(1.0 - jnp.abs(t) / width, 0.0, 1.0)

def s_N(y, z):
    uv = jnp.stack([y, z], axis=-1)
    return smooth_indicator(uv, N_segments) * gate(0.0, extrude_x)

def s_L(x, z):
    uv = jnp.stack([x, z], axis=-1)
    return smooth_indicator(uv, L_segments) * gate(0.0, extrude_y)

def s_R(x, y, z):
    uv = jnp.stack([x, y], axis=-1)
    return smooth_indicator(uv, R_segments) * gate(z - z_center_R, extrude_z)

def phi(xyz):
    x, y, z = xyz
    r = jnp.sqrt(x*x + y*y + z*z)
    smax = jnp.maximum(
        s_N(y, z),
        jnp.maximum(s_L(x, z), s_R(x, y, z))
    )
    return r - (R - t * smax)
Y_lang = langevin_implicit(phi, n_samples=10000, burn=100, thin=2, sigma=0.1, lam=0.01, kappa=0.02, R=3.0)
fig = plot_3d_projection(Y_lang)
fig.update_layout(scene_camera=dict(eye=dict(x=0., y=0., z=2.5)))

In [11]:
import numpy as np
import matplotlib.pyplot as plt

smooth_width = 0.06

def sd_segment_vector(p,a,b):
    pa = p - a
    ba = b - a
    ba_dot = np.dot(ba,ba)
    if ba_dot < 1e-8:
        return np.linalg.norm(pa,axis=1)
    t = np.clip((pa*ba).sum(axis=1)/ba_dot,0,1)
    projection = a + t[:,None]*ba
    return np.sqrt(((p-projection)**2).sum(axis=1))


def s_indicator(p,segments):
    dmin = np.full(p.shape[0], np.inf)
    for a,b in segments:
        d = sd_segment_vector(p,a,b)
        dmin = np.minimum(dmin,d)
    d0 = dmin - stroke
    return np.clip(1 - d0/smooth_width, 0,1)

def gate(t,width):
    return np.clip(1 - np.abs(t)/width, 0,1)

n = 200
xs = np.linspace(-1,1,n)
ys = xs.copy()
# Evaluate at z=z_center_R to show R
Z_plane = z_center_R
xg,yg = np.meshgrid(xs,ys,indexing='ij')
# yz-plane contributions: s_N uses (y,z), gating along x; s_L uses (x,z), gating along y; s_R uses (x,y), gating along z.
# compute s indicators
p_N = np.stack([ys, np.full(n,Z_plane)],axis=1)
sN_row = s_indicator(p_N,N_segments)
sN_xy = np.tile(sN_row,(n,1)).T * gate(xg, extrude_x)

p_L = np.stack([xs, np.full(n,Z_plane)],axis=1)
sL_row = s_indicator(p_L,L_segments)
sL_xy = np.tile(sL_row,(n,1)) * gate(yg, extrude_y)

p_R = np.stack([xg.flatten(), yg.flatten()],axis=1)
sR_full = s_indicator(p_R,R_segments).reshape(n,n) * gate(Z_plane - z_center_R, extrude_z)

s_total_xy_plane = sN_xy + sL_xy + sR_full

plt.figure(figsize=(4,4))
plt.imshow(s_total_xy_plane.T, extent=[xs.min(),xs.max(),ys.min(),ys.max()], origin='lower', aspect='auto', cmap='magma')
plt.title('xy cross-section at z=z_center_R'); plt.xlabel('x'); plt.ylabel('y'); plt.colorbar();
plt.tight_layout()
plt.close()

ModuleNotFoundError: No module named 'matplotlib'