In [None]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax_sbgeom as jsb
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import numpy as onp
import sys 
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

import StellBlanket.SBGeom as SBGeom
from functools import partial
import jax_sbgeom.flux_surfaces.flux_surface_meshing as fsm
import jax_sbgeom.flux_surfaces.flux_surfaces_base as fsb
import jax_sbgeom.flux_surfaces.flux_surfaces_extended as fse
from jax_sbgeom.jax_utils.utils import _mesh_to_pyvista_mesh, _vertices_to_pyvista_polyline
import pyvista as pv
from dataclasses import dataclass
import h5py
import jax_sbgeom
import matplotlib.pyplot as plt


vmec_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"]
coil_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS3_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS5_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_coilset.h5"]
stell_i = 2
vmec_file = vmec_files[stell_i]
coil_file = coil_files[stell_i]

In [None]:
def _get_extended_no_phi_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtendedNoPhi.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended_No_Phi(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom


def _get_discrete_coils(coil_file):
    with h5py.File(coil_file, 'r') as f:
        coil_data = jnp.array(f['Dataset1'])

    return jsb.coils.CoilSet.from_list([jsb.coils.DiscreteCoil.from_positions(coil_data[i]) for i in range(coil_data.shape[0])])

fs_jax, fs_sbgeom = _get_extended_no_phi_flux_surfaces(vmec_file)  # just to compile
coilset_jax       = _get_discrete_coils(coil_file)

import jax_sbgeom.jax_utils.raytracing as RT

In [None]:
cws_mesh = jax_sbgeom.coils.coil_winding_surface.create_optimized_coil_winding_surface(coilset_jax, 100, True, 150)

In [None]:
plotter = pv.Plotter()
#plotter.add_mesh(line_0, color='blue', line_width=4)

plotter.add_mesh(_mesh_to_pyvista_mesh(*cws_mesh), color='red')
plotter.show()

In [None]:
theta = jnp.linspace(0, 2 * jnp.pi, 210)
phi   = jnp.linspace(0, 2 * jnp.pi / fs_jax.settings.nfp, 100)
theta, phi = jnp.meshgrid(theta, phi, indexing='ij')
positions_lcfs_mg = fs_jax.cartesian_position(1.0,  theta, phi)
directions_lcfs_mg = fs_jax.cartesian_position(2.0, theta, phi) - positions_lcfs_mg

In [None]:
@jax.jit
def find_minimum_distance_from_bvh(points, directions, mesh):
    bvh = RT.build_lbvh(mesh[0], mesh[1])
    hits_possible = RT.ray_traversal_bvh_vectorized(bvh, points, directions)    
    mesh_total = jnp.moveaxis(mesh[0][mesh[1][bvh.order[hits_possible]]], -3, 0)
    return jnp.nanmin(RT.ray_triangle_intersection_vectorized(points, directions, mesh_total), axis=0)

In [None]:
import matplotlib.pyplot as plt
dmesh = find_minimum_distance_from_bvh(positions_lcfs_mg, directions_lcfs_mg, cws_mesh) 
plt.pcolormesh(phi, theta, dmesh)
plt.colorbar()