In [1]:
import jax.numpy as jnp
from jnlr.utils.samplers import langevin_implicit
from jnlr.utils.plot_utils import plot_3d_projection

# Parameters as before
R         = 1.0
stroke    = 0.08
epsilon   = 0.06
extrude_y = 0.20
extrude_z = 0.20
z_center_R = -0.40
high = 0.5
width = 0.4
offset=0.2
# Updated segments for N, L and rho (same as before)
N_segments = [
    (jnp.array([-width, -high]), jnp.array([-width,  high])),
    (jnp.array([ width+offset,  high]), jnp.array([ width+offset, -high])),
    (jnp.array([-width,  high]), jnp.array([ width+offset, -high]))
]
L_segments = [
    (jnp.array([width+offset, -high]), jnp.array([width+offset,  high])),
    (jnp.array([-width, -high]), jnp.array([width+offset, -high]))
]
R_segments = [
    (jnp.array([-width, -high]), jnp.array([-width,  high])),
    (jnp.array([-width,  high]), jnp.array([ 0.0,  high])),
    (jnp.array([ 0.0,  high]), jnp.array([ 0.0,  0.2]))
]

def sd_segment(p, a, b):
    pa = p - a
    ba = b - a
    ba_dot = jnp.dot(ba, ba)
    t_clamp = jnp.clip(jnp.dot(pa, ba) / (ba_dot + 1e-8), 0.0, 1.0)
    proj = a + t_clamp * ba
    return jnp.linalg.norm(p - proj, axis=-1)

def smooth_indicator(p, segments):
    d = 1e6
    for a, b in segments:
        d = jnp.minimum(d, sd_segment(p, a, b))
    d0 = d - stroke
    return jnp.clip(d0 / epsilon, 0.0, 1e6)



def gate(t, width):
    return jnp.clip(1.0 - jnp.abs(t) / width, 0.0, 1.0)

def s_N(y, z, x):
    uv = jnp.stack([y, z], axis=-1)
    return smooth_indicator(uv, N_segments)  + jnp.maximum(-x, 0)

def s_L(x, z, y):
    uv = jnp.stack([x, z], axis=-1)
    return smooth_indicator(uv, L_segments) + jnp.maximum(-y, 0)

def s_R(x, y, z):
    uv = jnp.stack([x, y], axis=-1)
    return smooth_indicator(uv, R_segments) * gate(z - z_center_R, extrude_z)

softmax = lambda x, y: jnp.log(jnp.exp(50*x)+jnp.exp(50*y))/50

def phi(xyz):
    x, y, z = xyz
    r = x*x + y*y + z*z
    smax = jnp.maximum(
        s_N(y, z, x),
        jnp.maximum(s_L(x, z, y), s_R(x, y, z))
    )
    #sphere_shell = jnp.maximum(R-0.2-r , jnp.maximum(0, r-R))**2
    #sphere_shell = softmax(jnp.array([R-0.2-r, jnp.maximum(0, r-R)]))**2
    sphere_shell = softmax(R-0.02-r , softmax(0, r-R))**2
    #return r - (R - t * s_L(x, z))
    return sphere_shell + s_N(y, z, x) * s_L(x, z, y) + (s_N(y, z, x) * s_L(x, z, y) == 0)
Y_lang = langevin_implicit(phi, n_samples=10000, burn=100, thin=1, sigma=0.1, lam=0, kappa=0.3, R=3.0, tol=0.03)
#fig = plot_3d_projection(Y_lang)
#fig.update_layout(scene_camera=dict(eye=dict(x=0., y=0., z=2.5)))

# add mesh of a sphere
from jnlr.utils.meshes import get_mesh
from jnlr.utils.plot_utils import plot_mesh_plotly
import numpy as np
V_sphere, F_sphere = get_mesh(lambda U: jnp.array([jnp.sin(U[0])*jnp.cos(U[1]), jnp.sin(U[0])*jnp.sin(U[1]), jnp.cos(U[0])]), 'explicit', nu=50, nv=50, grid_ranges=((0, np.pi), (0, 2*np.pi)))
fig_sphere = plot_mesh_plotly(V_sphere, F_sphere, title="Sphere", opacity=0.1, points=Y_lang)


In [2]:
fig_sphere.show()

In [4]:

fig_sphere = plot_mesh_plotly(V_sphere, F_sphere, title="Sphere", opacity=0.1, points=Y_lang)
