In [None]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax_sbgeom as jsb
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import numpy as onp
import sys 
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

import StellBlanket.SBGeom as SBGeom
from functools import partial
import jax_sbgeom.flux_surfaces.flux_surface_meshing as fsm
import jax_sbgeom.flux_surfaces.flux_surfaces_base as fsb
import jax_sbgeom.flux_surfaces.flux_surfaces_extended as fse
from jax_sbgeom.jax_utils.utils import _mesh_to_pyvista_mesh
import pyvista as pv
from dataclasses import dataclass
import h5py


vmec_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"]


coil_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS3_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS5_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_coilset.h5"]
stell_i = 2
vmec_file = vmec_files[stell_i]
coil_file = coil_files[stell_i]

In [None]:
def _get_flux_surfaces(vmec_file):
    fs_jax    = jsb.flux_surfaces.FluxSurface.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file)
    return fs_jax, fs_sbgeom

def _get_extended_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtended.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

def _get_extended_no_phi_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtendedNoPhi.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended_No_Phi(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

def _get_extended_constant_phi_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtendedConstantPhi.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended_Constant_Phi(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

def _get_discrete_coils(coil_file):
    with h5py.File(coil_file, 'r') as f:
        coil_data = jnp.array(f['Dataset1'])

    return jsb.coils.CoilSet.from_list([jsb.coils.DiscreteCoil.from_positions(coil_data[i]) for i in range(coil_data.shape[0])])

fs_jax, fs_sbgeom = _get_extended_constant_phi_flux_surfaces(vmec_file)  # just to compile
fs_jax, fs_sbgeom = _get_extended_no_phi_flux_surfaces(vmec_file)  # just to compile
coilset_jax       = _get_discrete_coils(coil_file)

In [None]:
d_total = 1.0
n_theta = 53
n_phi   = 60

sarr, coilset = jsb.coils.coil_winding_surface.optimize_coil_surface(coilset_jax)

In [None]:
positions_coilset = coilset.position_different_s(jsb.coils.coil_winding_surface._create_total_s(sarr[0], coilset.n_coils))
vertices = jsb.flux_surfaces.flux_surface_meshing._mesh_surface_connectivity(positions_coilset.shape[1], positions_coilset.shape[0], True, True)
positions_standard_ordering = jnp.moveaxis(positions_coilset, 0, 1) # ntheta, nphi [number of coils], 3
mesh_pv_cws = _mesh_to_pyvista_mesh(positions_standard_ordering.reshape(-1,3), vertices)

In [None]:
@jax.jit
def ray_triangle_intersect_single(origin, direction, triangles, eps=1e-8):
    """
    Compute intersections of one ray with many triangles (JAX version).

    Parameters
    ----------
    origin:    (3,) float32 - ray origin
    direction: (3,) float32 - ray direction
    triangles: (T, 3, 3) float32 - triangles [v0,v1,v2]
    eps: float - numerical tolerance

    Returns
    -------
    t:        (T,) distance along ray (jnp.inf if no hit)
    u, v:     (T,) barycentric coordinates
    mask:     (T,) boolean array, True if hit
    """
    v0 = triangles[:, 0, :]
    v1 = triangles[:, 1, :]
    v2 = triangles[:, 2, :]

    e1 = v1 - v0
    e2 = v2 - v0

    pvec = jnp.cross(direction, e2)
    det  = jnp.einsum('ij,ij->i', e1, pvec)

    valid_det = jnp.abs(det) > eps
    inv_det = jnp.where(valid_det, 1.0 / det, 0.0)

    tvec = origin - v0
    u = jnp.einsum('ij,ij->i', tvec, pvec) * inv_det

    qvec = jnp.cross(tvec, e1)
    v = jnp.einsum('j,ij->i', direction, qvec) * inv_det

    t = jnp.einsum('ij,ij->i', e2, qvec) * inv_det

    mask = (valid_det &
            (u >= 0.0) &
            (v >= 0.0) &
            ((u + v) <= 1.0) &
            (t > eps))

    t = jnp.where(mask, t, jnp.inf)
    return t, u, v, mask

@jax.jit
def ray_triangle_intersect_single_fori(origin, direction, triangles, eps=1e-8):
    """
    Compute intersections of one ray with many triangles (JAX version).

    Parameters
    ----------
    origin:    (3,) float32 - ray origin
    direction: (3,) float32 - ray direction
    triangles: (T, 3, 3) float32 - triangles [v0,v1,v2]
    eps: float - numerical tolerance

    Returns
    -------
    t:        (T,) distance along ray (jnp.inf if no hit)
    u, v:     (T,) barycentric coordinates
    mask:     (T,) boolean array, True if hit
    """
    v0 = triangles[:, 0, :]
    v1 = triangles[:, 1, :]
    v2 = triangles[:, 2, :]

    e1 = v1 - v0
    e2 = v2 - v0

    pvec = jnp.cross(direction, e2)
    det  = jnp.einsum('ij,ij->i', e1, pvec)

    valid_det = jnp.abs(det) > eps
    inv_det = jnp.where(valid_det, 1.0 / det, 0.0)

    tvec = origin - v0
    u = jnp.einsum('ij,ij->i', tvec, pvec) * inv_det

    qvec = jnp.cross(tvec, e1)
    v = jnp.einsum('j,ij->i', direction, qvec) * inv_det

    t = jnp.einsum('ij,ij->i', e2, qvec) * inv_det

    mask = (valid_det &
            (u >= 0.0) &
            (v >= 0.0) &
            ((u + v) <= 1.0) &
            (t > eps))

    t = jnp.where(mask, t, jnp.inf)
    return t, u, v, mask

n_theta_rt = 100
n_phi_rt  = 200
theta_rt = jnp.linspace(0, 2*jnp.pi, n_theta_rt, endpoint=False)
phi_rt   = jnp.linspace(0, 2*jnp.pi / fs_jax.settings.nfp, n_phi_rt, endpoint=False)
theta_mg, phi_mg = jnp.meshgrid(theta_rt, phi_rt, indexing='ij')

positions_origins = fs_jax.cartesian_position(1.0, theta_mg, phi_mg)  # just to compile
directions        = fs_jax.cartesian_position(2.0, theta_mg, phi_mg) - positions_origins  # just to compile
total_triangles = positions_standard_ordering.reshape(-1,3)[vertices]
t,u,v , mask = ray_triangle_intersect_single(positions_origins[0,0], directions[0,0], total_triangles)  # just to compile

In [None]:
@jax.jit
def ray_tracing_min(origin, direction, triangles):
    t, u, v, mask = ray_triangle_intersect_single(origin, direction, triangles)    
    return jnp.min(t)

In [None]:
ray_tracing_vectorized = jax.jit(jnp.vectorize(ray_tracing_min, excluded=(2,), signature="(3),(3)->()"))

ray_tracing_vectorized(positions_origins,  directions, total_triangles)

In [None]:
%timeit ray_tracing_vectorized(positions_origins,  directions, total_triangles).block_until_ready()

In [None]:
vmap_rt = jax.jit(jax.vmap(ray_tracing_min, in_axes=(0,0,None)))


distance_function = vmap_rt(positions_origins.reshape(-1,3), directions.reshape(-1,3), total_triangles).reshape(positions_origins.shape[:-1])  # just to compile

In [None]:
import matplotlib.pyplot as plt
plt.pcolormesh(distance_function)
plt.colorbar()

In [None]:
positions_new = fs_jax.cartesian_position(1.0 + distance_function, theta_mg, phi_mg)

connectivity_rt = jsb.flux_surfaces.flux_surface_meshing._mesh_surface_connectivity(n_theta_rt, n_phi_rt, False, True)

In [None]:
plotter = pv.Plotter()
plotter.add_mesh(pv.PolyData(onp.array(positions_coilset.reshape(-1, 3))), color='red', point_size=5, render_points_as_spheres=True)
#plotter.add_mesh(mesh_pv_cws, show_edges=True)

plotter.add_mesh(_mesh_to_pyvista_mesh(onp.array(positions_new.reshape(-1,3)), connectivity_rt), color='lightblue', opacity=0.5, show_edges=True)
plotter.show()