In [None]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax_sbgeom as jsb
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import numpy as onp
import sys 
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

import StellBlanket.SBGeom as SBGeom
from functools import partial
import jax_sbgeom.flux_surfaces.flux_surface_meshing as fsm
import jax_sbgeom.flux_surfaces.flux_surfaces_base as fsb
import jax_sbgeom.flux_surfaces.flux_surfaces_extended as fse
from jax_sbgeom.jax_utils.utils import _mesh_to_pyvista_mesh
import pyvista as pv
from dataclasses import dataclass
import h5py


vmec_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"]


coil_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS3_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS5_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_coilset.h5"]
stell_i = 2
vmec_file = vmec_files[stell_i]
coil_file = coil_files[stell_i]

In [None]:
def _get_flux_surfaces(vmec_file):
    fs_jax    = jsb.flux_surfaces.FluxSurface.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file)
    return fs_jax, fs_sbgeom

def _get_extended_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtended.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

def _get_extended_no_phi_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtendedNoPhi.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended_No_Phi(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

def _get_extended_constant_phi_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtendedConstantPhi.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended_Constant_Phi(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

def _get_discrete_coils(coil_file):
    with h5py.File(coil_file, 'r') as f:
        coil_data = jnp.array(f['Dataset1'])

    return jsb.coils.CoilSet.from_list([jsb.coils.DiscreteCoil.from_positions(coil_data[i]) for i in range(coil_data.shape[0])])

fs_jax, fs_sbgeom = _get_extended_constant_phi_flux_surfaces(vmec_file)  # just to compile
fs_jax, fs_sbgeom = _get_extended_no_phi_flux_surfaces(vmec_file)  # just to compile
coilset_jax       = _get_discrete_coils(coil_file)

In [None]:
d_total = 1.0
n_theta = 53
n_phi   = 60

sarr, coilset = jsb.coils.coil_winding_surface.optimize_coil_surface(coilset_jax)

In [None]:
positions_coilset = coilset.position_different_s(jsb.coils.coil_winding_surface._create_total_s(sarr[0], coilset.n_coils))
vertices = jsb.flux_surfaces.flux_surface_meshing._mesh_surface_connectivity(positions_coilset.shape[1], positions_coilset.shape[0], True, True)
positions_standard_ordering = jnp.moveaxis(positions_coilset, 0, 1) # ntheta, nphi [number of coils], 3
mesh_pv_cws = _mesh_to_pyvista_mesh(positions_standard_ordering.reshape(-1,3), vertices)

mesh_cws_def = (positions_standard_ordering.reshape(-1,3), vertices)

In [None]:
@jax.jit
def ray_triangle_intersect_single(origin, direction, triangles, eps=1e-8):
    """
    Compute intersections of one ray with many triangles (JAX version).

    Parameters
    ----------
    origin:    (3,) float32 - ray origin
    direction: (3,) float32 - ray direction
    triangles: (T, 3, 3) float32 - triangles [v0,v1,v2]
    eps: float - numerical tolerance

    Returns
    -------
    t:        (T,) distance along ray (jnp.inf if no hit)
    u, v:     (T,) barycentric coordinates
    mask:     (T,) boolean array, True if hit
    """
    v0 = triangles[:, 0, :]
    v1 = triangles[:, 1, :]
    v2 = triangles[:, 2, :]

    e1 = v1 - v0
    e2 = v2 - v0

    pvec = jnp.cross(direction, e2)
    det  = jnp.einsum('ij,ij->i', e1, pvec)

    valid_det = jnp.abs(det) > eps
    inv_det = jnp.where(valid_det, 1.0 / det, 0.0)

    tvec = origin - v0
    u = jnp.einsum('ij,ij->i', tvec, pvec) * inv_det

    qvec = jnp.cross(tvec, e1)
    v = jnp.einsum('j,ij->i', direction, qvec) * inv_det

    t = jnp.einsum('ij,ij->i', e2, qvec) * inv_det

    mask = (valid_det &
            (u >= 0.0) &
            (v >= 0.0) &
            ((u + v) <= 1.0) &
            (t > eps))

    t = jnp.where(mask, t, jnp.inf)
    return t, u, v, mask

@jax.jit
def ray_triangle_intersect_single_fori(origin, direction, triangles, eps=1e-8):
    """
    Compute intersections of one ray with many triangles (JAX version).

    Parameters
    ----------
    origin:    (3,) float32 - ray origin
    direction: (3,) float32 - ray direction
    triangles: (T, 3, 3) float32 - triangles [v0,v1,v2]
    eps: float - numerical tolerance

    Returns
    -------
    t:        (T,) distance along ray (jnp.inf if no hit)
    u, v:     (T,) barycentric coordinates
    mask:     (T,) boolean array, True if hit
    """
    v0 = triangles[:, 0, :]
    v1 = triangles[:, 1, :]
    v2 = triangles[:, 2, :]

    e1 = v1 - v0
    e2 = v2 - v0

    pvec = jnp.cross(direction, e2)
    det  = jnp.einsum('ij,ij->i', e1, pvec)

    valid_det = jnp.abs(det) > eps
    inv_det = jnp.where(valid_det, 1.0 / det, 0.0)

    tvec = origin - v0
    u = jnp.einsum('ij,ij->i', tvec, pvec) * inv_det

    qvec = jnp.cross(tvec, e1)
    v = jnp.einsum('j,ij->i', direction, qvec) * inv_det

    t = jnp.einsum('ij,ij->i', e2, qvec) * inv_det

    mask = (valid_det &
            (u >= 0.0) &
            (v >= 0.0) &
            ((u + v) <= 1.0) &
            (t > eps))

    t = jnp.where(mask, t, jnp.inf)
    return t, u, v, mask

n_theta_rt = 100
n_phi_rt  = 200
theta_rt = jnp.linspace(0, 2*jnp.pi, n_theta_rt, endpoint=False)
phi_rt   = jnp.linspace(0, 2*jnp.pi / fs_jax.settings.nfp, n_phi_rt, endpoint=False)
theta_mg, phi_mg = jnp.meshgrid(theta_rt, phi_rt, indexing='ij')

positions_origins = fs_jax.cartesian_position(1.0, theta_mg, phi_mg)  # just to compile
directions        = fs_jax.cartesian_position(2.0, theta_mg, phi_mg) - positions_origins  # just to compile
total_triangles = positions_standard_ordering.reshape(-1,3)[vertices]
t,u,v , mask = ray_triangle_intersect_single(positions_origins[0,0], directions[0,0], total_triangles)  # just to compile

In [None]:
from jax_sbgeom.jax_utils import raytracing as RT

In [None]:
jnp.all(jnp.array([False, True]))

In [None]:
RT.build_lbvh(*mesh_cws_def)  # just to compile

In [None]:

def aabb_to_lines(aabbs):
    """
    Convert (N,2,3) AABBs to a single PolyData with all box edges.
    Vectorized and fast.
    """
    aabbs = onp.asarray(aabbs)
    N = aabbs.shape[0]
    mn = aabbs[:,0,:]  # shape (N,3)
    mx = aabbs[:,1,:]  # shape (N,3)

    # 8 corners per box in normalized [0,1] space
    corners = onp.array([
        [0,0,0],[1,0,0],[1,1,0],[0,1,0],
        [0,0,1],[1,0,1],[1,1,1],[0,1,1]
    ])  # (8,3)

    # Compute all points: (N,8,3)
    points = mn[:,None,:] + (mx - mn)[:,None,:] * corners[None,:,:]
    points = points.reshape(-1,3)  # (8*N,3)

    # Edges of one box
    edges = onp.array([
        [0,1],[1,2],[2,3],[3,0],
        [4,5],[5,6],[6,7],[7,4],
        [0,4],[1,5],[2,6],[3,7]
    ])  # (12,2)

    # Repeat edges for all boxes with proper offset
    offsets = onp.arange(N) * 8  # (N,)
    all_edges = edges[None,:,:] + offsets[:,None,None]  # (N,12,2)
    all_edges = all_edges.reshape(-1,2)  # (12*N,2)

    # Build connectivity array for PolyData
    connectivity = onp.hstack([onp.full((all_edges.shape[0],1),2), all_edges])  # (12*N,3)
    connectivity = connectivity.flatten()

    return pv.PolyData(points, lines=connectivity)

In [None]:
plotter = pv.Plotter()
plotter.add_mesh(mesh_pv_cws, cmap='turbo', opacity = 0.5)
plotter.add_mesh(aabb_to_lines(aabb[:]), color = 'red', line_width = 1)
plotter.show()  

In [None]:
def build_bvh(positions : jnp.ndarray, connectivity : jnp.ndarray):
    

In [None]:
def ray_box_intersect(ray_o, ray_d, box_min, box_max):
    inv_d = 1.0 / ray_d
    t1 = (box_min - ray_o) * inv_d
    t2 = (box_max - ray_o) * inv_d
    tmin = jnp.maximum(jnp.minimum(t1, t2).max(axis=-1), 0.0)
    tmax = jnp.minimum(jnp.maximum(t1, t2).min(axis=-1), jnp.inf)
    hit = tmax >= tmin
    return hit, tmin, tmax

def traverse_bvh(bvh, ray_o, ray_d, max_hits=64, max_stack=64):
    """Single-ray BVH traversal (functional, JIT-able)."""
    nodes_min = bvh["nodes_min"]
    nodes_max = bvh["nodes_max"]
    left = bvh["left"]
    right = bvh["right"]
    is_leaf = bvh["is_leaf"]
    leaf_id = bvh["leaf_id"]
    root = bvh["root_idx"]

    # initialize state
    stack = jnp.full((max_stack,), -1, dtype=jnp.int32)
    stack = stack.at[0].set(root)
    sp = jnp.array(1, dtype=jnp.int32)  # stack pointer
    hits = jnp.full((max_hits,), -1, dtype=jnp.int32)
    hit_count = jnp.array(0, dtype=jnp.int32)

    def cond_fun(state):
        sp, hit_count, stack, hits = state
        return (sp > 0) & (hit_count < max_hits)

    def body_fun(state):
        sp, hit_count, stack, hits = state
        sp = sp - 1
        node_idx = stack[sp]

        node_min = nodes_min[node_idx]
        node_max = nodes_max[node_idx]
        hit, tmin, tmax = ray_box_intersect(ray_o, ray_d, node_min, node_max)

        def hit_branch(args):
            sp, hit_count, stack, hits = args
            # if leaf, record hit
            def leaf_branch(_):
                hits = hits.at[hit_count].set(leaf_id[node_idx])
                hit_count2 = hit_count + 1
                return (sp, hit_count2, stack, hits)

            # else, push children
            def internal_branch(_):
                stack = stack.at[sp].set(left[node_idx])
                stack = stack.at[sp + 1].set(right[node_idx])
                sp2 = sp + 2
                return (sp2, hit_count, stack, hits)

            return lax.cond(is_leaf[node_idx], leaf_branch, internal_branch, None)

        def miss_branch(args):
            return args

        return lax.cond(hit, hit_branch, miss_branch, (sp, hit_count, stack, hits))

    sp, hit_count, stack, hits = lax.while_loop(cond_fun, body_fun, (sp, hit_count, stack, hits))
    return hits[:hit_count]

In [None]:
@jax.jit
def ray_tracing_min(origin, direction, triangles):
    t, u, v, mask = ray_triangle_intersect_single(origin, direction, triangles)    
    return jnp.min(t)

In [None]:
ray_tracing_vectorized = jax.jit(jnp.vectorize(ray_tracing_min, excluded=(2,), signature="(3),(3)->()"))

ray_tracing_vectorized(positions_origins,  directions, total_triangles)

In [None]:
%timeit ray_tracing_vectorized(positions_origins,  directions, total_triangles).block_until_ready()

In [None]:
vmap_rt = jax.jit(jax.vmap(ray_tracing_min, in_axes=(0,0,None)))


distance_function = vmap_rt(positions_origins.reshape(-1,3), directions.reshape(-1,3), total_triangles).reshape(positions_origins.shape[:-1])  # just to compile

In [None]:
import matplotlib.pyplot as plt
plt.pcolormesh(distance_function)
plt.colorbar()

In [None]:
positions_new = fs_jax.cartesian_position(1.0 + distance_function, theta_mg, phi_mg)

connectivity_rt = jsb.flux_surfaces.flux_surface_meshing._mesh_surface_connectivity(n_theta_rt, n_phi_rt, False, True)

In [None]:
plotter = pv.Plotter()
plotter.add_mesh(pv.PolyData(onp.array(positions_coilset.reshape(-1, 3))), color='red', point_size=5, render_points_as_spheres=True)
#plotter.add_mesh(mesh_pv_cws, show_edges=True)

plotter.add_mesh(_mesh_to_pyvista_mesh(onp.array(positions_new.reshape(-1,3)), connectivity_rt), color='lightblue', opacity=0.5, show_edges=True)
plotter.show()