### Example: Equidistant geometry


In [None]:
import os 
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax 
jax.config.update("jax_enable_x64", True)
import jax_sbgeom
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
import numpy as onp
import matplotlib.pyplot as plt
import pyvista as pv
from jax_sbgeom.jax_utils.utils import _mesh_to_pyvista_mesh

##### Setting up geometry

Selecting a particular coil and plasma set:

In [None]:
stell_i = 2
vmec_file = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4",     "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias4_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4",     "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"][stell_i]
coil_file = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS3_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS4_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS5_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_coilset.h5"][stell_i]

We use a FluxSurfaceConstantPhi, as this has the property that $\phi_{in} = \phi_{out}$ even beyond the LCFS (as required by FFTs). Furthermore, this creates a surface exactly the same as the original normal vector for constant $d$.

In [None]:
from jax_sbgeom.flux_surfaces import FluxSurfaceNormalExtendedConstantPhi, ToroidalExtent, FluxSurfaceFourierExtended
from jax_sbgeom.coils         import CoilSet, DiscreteCoil, convert_to_fourier_coilset, RotationMinimizedFrame, FiniteSizeCoilSet
flux_surface = FluxSurfaceNormalExtendedConstantPhi.from_hdf5(vmec_file)
def _get_discrete_coils(coil_file):
    import h5py
    with h5py.File(coil_file, 'r') as f:
        coil_data = jnp.array(f['Dataset1'])
    return CoilSet.from_list([DiscreteCoil.from_positions(coil_data[i]) for i in range(coil_data.shape[0])])
coilset      = _get_discrete_coils(coil_file)

fourier_coilset = convert_to_fourier_coilset(coilset)
fs_coilset = FiniteSizeCoilSet.from_coilset(jax_sbgeom.coils.coilset.filter_coilset_phi(fourier_coilset, 0.0, 2 * jnp.pi / 2.0 / flux_surface.settings.nfp), RotationMinimizedFrame, 100)
mesh_coils = jax_sbgeom.coils.mesh_coilset_surface(fs_coilset, 500, 0.2, 0.2)

In [None]:
fw_distance         = 0.2 
n_theta_blanket     = 55
n_phi_blanket       = 65
resolutions_blanket = [10, 1,1,6,4,3,4,3]  # Number of radial elements in each blanket layer
thicknesses         = [0.002, 0.025, 0.5, 0.385, 0.06, 0.2, 0.06]  # Thickness of each blanket layer
layers_jax          = jnp.concatenate([jnp.array([0.0]),jnp.cumsum(jnp.array(thicknesses))]) + fw_distance

In [None]:
import jax_sbgeom.interfaces.blanket_creation as bc

In [None]:
blanket = bc.LayeredBlanket(tuple(layers_jax))


In [None]:
f = jax.jit(jax.vmap(jax_sbgeom.flux_surfaces.convert_to_vmec.create_fourier_surface_extension_interp_equal_arclength, in_axes=(None, 0, None, None, None)), static_argnums = (2,3,4))

In [None]:
Rmncstack, Zmnsstack, mpolstack, ntorstack, nfpstack = f(flux_surface, jnp.array(blanket.d_layers), flux_surface.settings.mpol * 6, flux_surface.settings.ntor * 6, 100)

In [None]:
fourier_surface_stack = jax_sbgeom.flux_surfaces.convert_to_vmec._create_fluxsurface_from_rmnc_zmns(Rmncstack, Zmnsstack, mpolstack[0], ntorstack[0], nfpstack[0])

fs_total = FluxSurfaceFourierExtended.from_flux_surface_and_extension(flux_surface, fourier_surface_stack)

In [None]:
discrete_blanket = bc.LayeredDiscreteBlanket(tuple(layers_jax), n_theta_blanket, n_phi_blanket, tuple(resolutions_blanket), ToroidalExtent.half_module(flux_surface))

In [None]:
tetrahedral_mesh =bc.mesh_tetrahedral_blanket(fs_total, discrete_blanket, 2 )

In [None]:

#n_layers = domain_settings.no_radial_elements 
total_array_layers = jnp.zeros(tetrahedral_mesh[1].shape[0], dtype=int)

actual_layers = [0 for i in range(resolutions_blanket[0])] + sum([[i+1 for b in range(resolutions_blanket[i + 1])] for i in range(len(resolutions_blanket) - 1)], start=[])

for i in range(discrete_blanket.n_discrete_layers):
    total_array_layers = total_array_layers.at[discrete_blanket.layer_slice(i)].set(actual_layers[i])

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import pyvista as pv
import numpy as np
pv_mesh = _mesh_to_pyvista_mesh(*tetrahedral_mesh)
pv_mesh.cell_data["layer"] = total_array_layers

# Example: 7 discrete layers
n_layers = np.max(total_array_layers) + 1 - 0.5   # Adjust based on your actual layers

# Pick a set of visually distinct colors (shuffled rainbow or any you like)
colors = [
    "#D1E5F0",  # PL    – light blue
    "#92C5DE",  # TA    – medium blue
    "#2166AC",  # FW    – deep blue
    "#FBAE17",  # BZ    – warm golden
    "#B2182B",  # SS    – rich red
    "#5AB4AC",  # VVF   – teal
    "#B8E186",  # SH    – light green ✅
    "#D6604D",  # VVB   – coral
]
# Shuffle or reorder as you like


# Create a discrete colormap
cmap = mcolors.ListedColormap(colors)

# Custom labels for each layer
layer_labels = ['PL', 'TA', 'FW', 'BZ', 'SS', 'VVF', 'SH', 'VVB']

# When adding the mesh in PyVista, use the colormap and set discrete color mapping
plotter = pv.Plotter(window_size=[1800, 1080])
plotter.add_mesh(
    pv_mesh,
    scalars='layer',
    show_edges=True,
    edge_color="black",
    ambient=0.0,
    cmap=cmap,
    edge_opacity=0.5,
    categories=True,
    clim=[0, n_layers + 0.5],
    annotations={i + 0.5: label for i, label in enumerate(layer_labels)},
    scalar_bar_args={
        'n_labels': 0,
        'title': "",
        'label_font_size': 36,
        #'draw_ticks': False,  # Remove tick marks
        # 'vertical': True,   # Optional: set to False for horizontal bar
    }
)
plotter.add_mesh(    
    _mesh_to_pyvista_mesh(*mesh_coils),
    color='white',
    show_edges=False
)

plotter.camera.position = (7.001912929110858, -17.898356078421788, -9.091389900931919)
plotter.camera.focal_point = (15.091625430162669, -2.19314226528047, -2.879522983120314)
plotter.camera.up =  (0.20579334869669091, 0.26643695903679715, -0.9416264888433344)

plotter.show(auto_close=False)  # Render and keep the window open

#plotter.close()  # Close t

### Reproducing exact geometry:

The paper uses not 8 distinct layers: instead, it only transforms the outer layer and meshes the inner layer as a function of the distance between LCFS & outer layer.

In [None]:
def spacing_uniform(blanket, s_power_sampling):
    max_distance = blanket.d_layers[-1]    
    inner_blanket_spacing = jnp.linspace(0.0, 1.0, blanket.resolution_layers[0]) ** s_power_sampling
    s_fraction            = jnp.array(blanket.d_layers) / max_distance
    s_layers              = jnp.concatenate([inner_blanket_spacing, jnp.concatenate([jnp.linspace(1.0 + s_fraction[i], 1.0 + s_fraction[i+1], blanket.resolution_layers[i  + 1], endpoint=False) for i in range(blanket.n_layers - 1)], axis=0), jnp.array([2.0 ])])                 

    Rmnc, Zmns, mpol, ntor, nfp = jax_sbgeom.flux_surfaces.convert_to_vmec.create_fourier_surface_extension_interp_equal_arclength(flux_surface,  max_distance, n_theta = 6 * flux_surface.settings.mpol, n_phi = 6 * flux_surface.settings.ntor, n_theta_s_arclength= 100 )

    fourier_surface = jax_sbgeom.flux_surfaces.convert_to_vmec._create_fluxsurface_from_rmnc_zmns(Rmnc, Zmns, mpol, ntor, nfp)
    total_surface   = jax_sbgeom.flux_surfaces.FluxSurfaceFourierExtended.from_flux_surface_and_extension(flux_surface, fourier_surface)
    return jax_sbgeom.flux_surfaces.flux_surface_meshing._mesh_tetrahedra(total_surface, s_layers, True, blanket.toroidal_extent.start, blanket.toroidal_extent.end, bool(blanket.toroidal_extent.full_angle()), blanket.n_theta, blanket.n_phi)
    

mesh_new = spacing_uniform(discrete_blanket, 2.0)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import pyvista as pv
import numpy as np
pv_mesh = _mesh_to_pyvista_mesh(*mesh_new)
pv_mesh.cell_data["layer"] = total_array_layers

# Example: 7 discrete layers
n_layers = np.max(total_array_layers) + 1 - 0.5   # Adjust based on your actual layers

# Pick a set of visually distinct colors (shuffled rainbow or any you like)
colors = [
    "#D1E5F0",  # PL    – light blue
    "#92C5DE",  # TA    – medium blue
    "#2166AC",  # FW    – deep blue
    "#FBAE17",  # BZ    – warm golden
    "#B2182B",  # SS    – rich red
    "#5AB4AC",  # VVF   – teal
    "#B8E186",  # SH    – light green ✅
    "#D6604D",  # VVB   – coral
]
# Shuffle or reorder as you like


# Create a discrete colormap
cmap = mcolors.ListedColormap(colors)

# Custom labels for each layer
layer_labels = ['PL', 'TA', 'FW', 'BZ', 'SS', 'VVF', 'SH', 'VVB']

# When adding the mesh in PyVista, use the colormap and set discrete color mapping
plotter = pv.Plotter(window_size=[1800, 1080])
plotter.add_mesh(
    pv_mesh,
    scalars='layer',
    show_edges=True,
    edge_color="black",
    ambient=0.0,
    cmap=cmap,
    edge_opacity=0.5,
    categories=True,
    clim=[0, n_layers + 0.5],
    annotations={i + 0.5: label for i, label in enumerate(layer_labels)},
    scalar_bar_args={
        'n_labels': 0,
        'title': "",
        'label_font_size': 36,
        #'draw_ticks': False,  # Remove tick marks
        # 'vertical': True,   # Optional: set to False for horizontal bar
    }
)
plotter.add_mesh(    
    _mesh_to_pyvista_mesh(*mesh_coils),
    color='white',
    show_edges=False
)

plotter.camera.position = (7.001912929110858, -17.898356078421788, -9.091389900931919)
plotter.camera.focal_point = (15.091625430162669, -2.19314226528047, -2.879522983120314)
plotter.camera.up =  (0.20579334869669091, 0.26643695903679715, -0.9416264888433344)

plotter.show(auto_close=False)  # Render and keep the window open

#plotter.close()  # Close t