# Basic projection examples

The following cell defines a hypersurface starting from a graph (f_paraboloid) and then defines an implicit function (f_implicit_paraboloid) that returns how far a point is from the surface. The solver is then used to project points onto the surface defined by the implicit function.
The graph is defined as

$$z = f(x,y): \mathbb{R}^2\rightarrow \mathbb{R}= x^2 + y^2$$

and the implicit function is defined as

$$f_{implicit}(x,y,z): \mathbb{R}^3\rightarrow \mathbb{R} = f(x,y) - z = x^2 + y^2 - z$$

When f_implicit(x, y, z) = 0, the point (x, y, z) lies on the surface of the paraboloid.
The solver uses the f_implicit formulation to project points onto the surface.


In [None]:
# Set JAX config before importing JAX
import os
os.environ['JAX_PLATFORMS'] = "cpu"  # use "cuda" if you have a GPU
os.environ['JAX_ENABLE_X64'] = "1"

# generate gaussian samples in 3d and reproject them onto the paraboloid equations
import numpy as np
import jax
from jnlr.reconcile import make_solver_alm_optax as make_solver
import jax.numpy as jnp

# generate gaussian samples in 3d and reproject them onto the paraboloid equations
n_samples = 100
X = np.random.random((n_samples, 3))*2-1

# define an implicit function. Each component of the function (in this case m=1) returns how far a point is from the surface
def f_paraboloid(v):
    x, y = v
    return x**2 + y**2

def f_implicit_paraboloid(v):
    z = v[2]
    return f_paraboloid(v[:2]) - z

solver = make_solver(f_implicit_paraboloid, n_iterations=30)
X_proj = solver(X)

print("mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_implicit_paraboloid)(X)))))
print("mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_implicit_paraboloid)(X_proj)))))



mean abs f before projection: 7.37e-01
mean abs f after projection: 7.95e-10


In `test.manifolds` there are other examples of implicit functions that can be used to define hypersurfaces. Additionally, the `jnlr.function_utils` module provides a utility function `f_impl` that can convert a graph function into an implicit function. This is particularly useful for standard benchmark functions like Ackley and Rastrigin.



In [3]:
import jnlr.utils.manifolds as mfs
from jnlr.utils.function_utils import f_impl


solver = make_solver(f_impl(mfs.f_paraboloid), n_iterations=10)
X_proj = solver(X)

print("Paraboloid, mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_paraboloid))(X)))))
print("Paraboloid, mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_paraboloid))(X_proj)))))


solver = make_solver(f_impl(mfs.f_rastrigin), n_iterations=10)
X_proj = solver(X)

print("Ackley, mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_rastrigin))(X)))))
print("Ackley, mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_rastrigin))(X_proj)))))


Paraboloid, mean abs f before projection: 7.37e-01
Paraboloid, mean abs f after projection: 7.95e-10
Ackley, mean abs f before projection: 1.98e+01
Ackley, mean abs f after projection: 3.64e-10


In [4]:
from jnlr.utils.plot_utils import plot_3d_projection
plot_3d_projection(X, f_paraboloid, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=4)


In [5]:
import jax.numpy as jnp

def sdf_box(p: jnp.ndarray, b: jnp.ndarray) -> float:
    """Signed distance to an axis–aligned box with half‑sizes ``b``.

    ``p`` is the query point in 3D space.  ``b`` should be a vector of
    positive half–extents along ``x``, ``y`` and ``z``.  Based on
    ``sdBox`` from ``Cube.glsl``.  For a cube of side length 1 the
    appropriate ``b`` is ``(0.5, 0.5, 0.5)``.
    """
    d = jnp.abs(p) - b
    inside = jnp.minimum(jnp.maximum(d[0], jnp.maximum(d[1], d[2])), 0.0)
    outside = jnp.linalg.norm(jnp.maximum(d, 0.0))
    return jnp.array(inside + outside)

def sdf_cube(p: jnp.ndarray, size: float = 1.0) -> float:
    """Convenience wrapper around :func:`sdf_box` for cubes.

    ``size`` is the full length of each side.  The GLSL version uses a
    cube of size 1, corresponding to ``b = (0.5, 0.5, 0.5)``.
    """
    half = 0.5 * size
    b = jnp.array([half, half, half])
    return sdf_box(p, b)

solver = make_solver(sdf_cube, n_iterations=10)
X_proj = solver(X*10)


In [6]:
plot_3d_projection(X, mfs.f_abs, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=40)

In [7]:
plot_3d_projection(X, mfs.f_ackley, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20)

In [8]:
def f_shubert(v):
    """
    Shubert function. Input: array of shape (2,)
    Use vmap externally for batching.
    """
    x1, x2 = v
    total1 = 0.0
    total2 = 0.0
    for j in range(1, 6):
        total1 += j * jnp.cos((j + 1) * x1 + j)
        total2 += j * jnp.cos((j + 1) * x2 + j)
    return total1 * total2 / 100

plot_3d_projection(X, f_shubert, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20)


In [9]:
plot_3d_projection(X, lambda z: mfs.f_rastrigin(z)/100, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20)


In [10]:
# generate gaussian samples in 3d and reproject them onto the paraboloid equations
n_samples = 100
X = np.random.random((n_samples, 3))*2-1

# define an implicit function for the sphere. Each component of the function (in this case m=1) returns how far a point is from the surface

def f_implicit_sphere(v):
    return v[0]**2 + v[1]**2 + v[2]**2 -1

plot_3d_projection(X, f_implicit=f_implicit_sphere, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20)
