# Example: multilayered blanket

In [None]:
import os 
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax 
jax.config.update("jax_enable_x64", True)
import jax_sbgeom
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
import numpy as onp
import matplotlib.pyplot as plt

##### Setting up geometry

Selecting a particular coil and plasma set:

In [None]:
stell_i = 2
vmec_file = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4",     "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias4_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4",     "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"][stell_i]
coil_file = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS3_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS4_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS5_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_coilset.h5"][stell_i]

We use a FluxSurfaceNormalExtendedNoPhi, as this has the property that $\phi_{in} = \phi_{out}$ even beyond the LCFS (as required by FFTs). Furthermore, the lines are straight, allowing for simple raytracing. 

In [None]:
from jax_sbgeom.flux_surfaces import FluxSurfaceNormalExtendedNoPhi
from jax_sbgeom.coils         import CoilSet, DiscreteCoil, convert_to_fourier_coilset
flux_surface = FluxSurfaceNormalExtendedNoPhi.from_hdf5(vmec_file)
def _get_discrete_coils(coil_file):
    import h5py
    with h5py.File(coil_file, 'r') as f:
        coil_data = jnp.array(f['Dataset1'])
    return CoilSet.from_list([DiscreteCoil.from_positions(coil_data[i]) for i in range(coil_data.shape[0])])
coilset      = _get_discrete_coils(coil_file)

fourier_coilset = convert_to_fourier_coilset(coilset)

We first optimize the coil winding surface:

In [None]:
from jax_sbgeom.coils import create_optimized_coil_winding_surface
n_points_per_coil = 100
n_points_toroidal = 500
cws_optimized = create_optimized_coil_winding_surface(coilset, n_points_per_coil, n_points_toroidal, surface_type = "spline")

We then interpolate using the closest mesh point. The built in raytracing works sufficiently fast here.

In [None]:
from jax_sbgeom.coils.coil_winding_surface import calculate_normals_from_closest_point_on_mesh
from jax_sbgeom.coils import FiniteSizeCoilSet, FiniteSizeCoil, RadialVectorFrame
n_coil_samples = 100
positions, normals =calculate_normals_from_closest_point_on_mesh(fourier_coilset, cws_optimized, n_coil_samples)
finitesize_coilset = FiniteSizeCoilSet(FiniteSizeCoil(fourier_coilset.coils, RadialVectorFrame(normals)))

In [None]:
coil_i = 4

In [None]:
from jax_sbgeom.coils import FiniteSizeCoilSet, RadialVectorFrame, DiscreteCoil, FiniteSizeCoil, CentroidFrame, FrenetSerretFrame, mesh_coil_surface, RotationMinimizedFrame
import pyvista as pv

closed_coil_points = finitesize_coilset[coil_i].position(jnp.linspace(0,1,200, endpoint=True))
spline = pv.Spline(closed_coil_points, n_points=200)
tube = spline.tube(radius=0.05)
#tube.plot()

m1 = mesh_coil_surface(FiniteSizeCoil.from_coil(fourier_coilset[coil_i], CentroidFrame),              100, 0.2, 0.2)
m2 = mesh_coil_surface(FiniteSizeCoil.from_coil(fourier_coilset[coil_i], RotationMinimizedFrame, 100),100,  0.2, 0.2)
m3 = mesh_coil_surface(FiniteSizeCoil.from_coil(fourier_coilset[coil_i], FrenetSerretFrame),          100, 0.2, 0.2)
m4 = mesh_coil_surface(finitesize_coilset[coil_i], 100, 0.2, 0.2)

In [None]:
def update_centre(mesh, dphi):
    import numpy as np 
    old_centre = np.average(mesh.points, axis=0)
    phi_old = np.arctan2(old_centre[1], old_centre[0])
    phi_new = phi_old + dphi

    r_old = np.linalg.norm(old_centre[0:2])
    new_centre = np.array([r_old * np.cos(phi_new), r_old * np.sin(phi_new), old_centre[2]]) 
    mesh.points = mesh.points - old_centre + new_centre
    old_centre = np.average(mesh.points, axis=0)
    return mesh

import pyvista as pv

In [None]:
from jax_sbgeom.jax_utils.utils import _mesh_to_pyvista_mesh
plotter = pv.Plotter(window_size=[1600,1080])

colors = [
    "#6CA4DD",  # FW    – deep blue
    "#EBC77F",  # BZ    – warm golden
    "#E28893",  # SS    – rich red
    "#5AB4AC",  # VVF   – teal
    "#B8E186",  # SH    – light green ✅
    "#D6604D",  # VVB   – coral
]

# Define labels for the legend
labels = ["Centroid", "RMF", "Frenet", "Surface"]

# Add meshes with different colors and labels
symm = flux_surface.settings.nfp
dphi = 0.2 * onp.pi / symm
plotter.add_mesh(tube, label = "Filament")
plotter.add_mesh(update_centre(_mesh_to_pyvista_mesh(*m1), dphi), show_edges=True, color=colors[0], label=labels[0])
plotter.add_mesh(update_centre(_mesh_to_pyvista_mesh(*m2), 2 * dphi), show_edges=True, color=colors[1], label=labels[1])
plotter.add_mesh(update_centre(_mesh_to_pyvista_mesh(*m3), 3 * dphi), show_edges=True, color=colors[2], label=labels[2])
plotter.add_mesh(update_centre(_mesh_to_pyvista_mesh(*m4), 4 * dphi), show_edges=True, color=colors[3], label=labels[3])

# Add legend with custom formatting
plotter.add_legend(bcolor='white', face='rectangle', size=(0.4, 0.2))
plotter.camera.position = (16.80119584826707, 41.94186272376006, 15.204265522182741)
plotter.camera.focal_point = (1.8428396279537544, 24.173815586754184, 0.3639140087839392)
plotter.camera.up =  (-0.38198234295398886, -0.38148223104015794, 0.8417605342803884)
plotter.show()  # Render and keep the window open
#plotter.screenshot(filename=f"fig_{fig_id}_coil_different_finite_size.png", window_size=[1600, 1080])
#plotter.close()  # Close t