# Basic projection examples

The following cell defines a hypersurface starting from a graph (f_paraboloid) and then defines an implicit function (f_implicit_paraboloid) that returns how far a point is from the surface. The solver is then used to project points onto the surface defined by the implicit function.
The graph is defined as

$$z = f(x,y): \mathbb{R}^2\rightarrow \mathbb{R}= x^2 + y^2$$

and the implicit function is defined as

$$f_{implicit}(x,y,z): \mathbb{R}^3\rightarrow \mathbb{R} = f(x,y) - z = x^2 + y^2 - z$$

When f_implicit(x, y, z) = 0, the point (x, y, z) lies on the surface of the paraboloid.
The solver uses the f_implicit formulation to project points onto the surface.


In [1]:
# generate gaussian samples in 3d and reproject them onto the paraboloid equations
import numpy as np
import jax
from jnlr.reconcile import make_projector_alm_optax as make_solver
import jax.numpy as jnp

# generate gaussian samples in 3d and reproject them onto the paraboloid equations
n_samples = 100
X = np.random.random((n_samples, 3))*2-1

# define an implicit function. Each component of the function (in this case m=1) returns how far a point is from the surface
def f_paraboloid(v):
    x, y = v
    return x**2 + y**2

def f_implicit_paraboloid(v):
    z = v[2]
    return f_paraboloid(v[:2]) - z

solver = make_solver(f_implicit_paraboloid, jnp.eye(3), n_iterations=30)
X_proj = solver(X)

print("mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_implicit_paraboloid)(X)))))
print("mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_implicit_paraboloid)(X_proj)))))



mean abs f before projection: 7.97e-01
mean abs f after projection: 5.69e-09


In `test.manifolds` there are other examples of implicit functions that can be used to define hypersurfaces. Additionally, the `jnlr.function_utils` module provides a utility function `f_impl` that can convert a graph function into an implicit function. This is particularly useful for standard benchmark functions like Ackley and Rastrigin.

In [2]:
import jnlr.utils.manifolds as mfs
from jnlr.utils.function_utils import f_impl


solver = make_solver(f_impl(mfs.f_paraboloid), jnp.eye(3), n_iterations=10)
X_proj = solver(X)

print("Paraboloid, mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_paraboloid))(X)))))
print("Paraboloid, mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_paraboloid))(X_proj)))))


solver = make_solver(f_impl(mfs.f_rastrigin), jnp.eye(3), n_iterations=10)
X_proj = solver(X)

print("Ackley, mean abs f before projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_rastrigin))(X)))))
print("Ackley, mean abs f after projection: {:0.2e}".format(jnp.mean(jnp.abs(jax.vmap(f_impl(mfs.f_rastrigin))(X_proj)))))


Paraboloid, mean abs f before projection: 7.97e-01
Paraboloid, mean abs f after projection: 6.28e-09
Ackley, mean abs f before projection: 1.93e+01
Ackley, mean abs f after projection: 1.26e-07


In [3]:
from jnlr.utils.plot_utils import plot_3d_projection

plot_3d_projection(X, f_paraboloid, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=4)

In [4]:
plot_3d_projection(X, mfs.f_abs, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20)

In [5]:
plot_3d_projection(X, mfs.f_ackley, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20)


In [6]:
def f_shubert(v):
    """
    Shubert function. Input: array of shape (2,)
    Use vmap externally for batching.
    """
    x1, x2 = v
    total1 = 0.0
    total2 = 0.0
    for j in range(1, 6):
        total1 += j * jnp.cos((j + 1) * x1 + j)
        total2 += j * jnp.cos((j + 1) * x2 + j)
    return total1 * total2 / 100

plot_3d_projection(X, f_shubert, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20)


In [7]:
plot_3d_projection(X, lambda z: mfs.f_rastrigin(z)/100, round_cutoff=None, solver_builder=make_solver, plot_history=True, n_iterations=20)
