# Example: optimizing coil winding surface and Fourier transformations

In [None]:
import os 
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax 
jax.config.update("jax_enable_x64", True)
import jax_sbgeom
import jax.numpy as jnp
from jax_sbgeom.coils import CoilSet, DiscreteCoil, FiniteSizeCoilSet, RotationMinimizedFrame, mesh_coilset_surface
from jax_sbgeom.flux_surfaces import FluxSurfaceNormalExtendedNoPhi, mesh_surface, ToroidalExtent
from jax_sbgeom.flux_surfaces.flux_surfaces_utilities import generate_thickness_matrix
from jax_sbgeom.coils.coil_winding_surface import create_optimized_coil_winding_surface
%load_ext autoreload
%autoreload 2
from jax_sbgeom.jax_utils.utils import _mesh_to_pyvista_mesh

import matplotlib.pyplot as plt

##### Setting up geometry

Selecting a particular coil and plasma set:

In [None]:
stell_i = 1
vmec_file = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4",     "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4",     "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"][stell_i]
coil_file = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS3_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS5_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_coilset.h5"][stell_i]

We use a FluxSurfaceNormalExtendedNoPhi, as this has the property that $\phi_{in} = \phi_{out}$ even beyond the LCFS (as required by FFTs). Furthermore, the lines are straight, allowing for simple raytracing. 

In [None]:
from jax_sbgeom.flux_surfaces import FluxSurfaceNormalExtendedNoPhi
flux_surface = FluxSurfaceNormalExtendedNoPhi.from_hdf5(vmec_file)


This flux surface can immediately be meshed to obtain the LCFS:

In [None]:
lcfs_mesh                     = mesh_surface(flux_surface, 1.0, ToroidalExtent(0, 2 * jnp.pi), 500, 500)

Using the convenience function _mesh_to_pyvista_mesh, all meshes can be directly plot using PyVista (set the optional plotting flag)

In [None]:
plot_pyvista = True
import pyvista as pv
_mesh_to_pyvista_mesh(*lcfs_mesh).plot() if plot_pyvista else None

We load in some discrete coils, using the data in the HDF5.

In [None]:
def _get_discrete_coils(coil_file):
    import h5py
    with h5py.File(coil_file, 'r') as f:
        coil_data = jnp.array(f['Dataset1'])
    return CoilSet.from_list([DiscreteCoil.from_positions(coil_data[i]) for i in range(coil_data.shape[0])])
coilset      = _get_discrete_coils(coil_file)

Since these coils are filaments, for visualization, we use the finite size coilset using a RotationMinimizedFrame. This frame requires an extra argument denoting the number of points to sample the RMF at. We use 100 in this case.

In [None]:
finite_size_coilset           = FiniteSizeCoilSet.from_coilset(coilset, RotationMinimizedFrame, 100)

We then filter this coilset to a full module and mesh it, together with the LCFS:

In [None]:
full_module_finitesizecoilset = jax_sbgeom.coils.coilset.filter_coilset_phi(finite_size_coilset, 0.0, 2 * jnp.pi / flux_surface.settings.nfp)
coilset_rmf_mesh              = mesh_coilset_surface(full_module_finitesizecoilset, 300, 0.2 ,0.2 )

In [None]:
if plot_pyvista:    
    plotter = pv.Plotter()
    plotter.add_mesh(_mesh_to_pyvista_mesh(*coilset_rmf_mesh),color = 'cyan')
    plotter.add_mesh(_mesh_to_pyvista_mesh(*lcfs_mesh), color = 'lightgrey')
    plotter.show()

### Creating Coil Winding Surface


Given a coilset, we can draw a toroidal line connecting all the coils. To ensure a smooth resulting surface, we can use a periodic B-Spline to interpolate the toroidal line. If we do this for n_points_per_coil, we can mesh the resulting surface:

In [None]:
from jax_sbgeom.coils.coil_winding_surface import create_coil_winding_surface

n_points_per_coil = 100
n_points_toroidal = 400
cws_non_optimized = create_coil_winding_surface(coilset, n_points_per_coil, n_points_toroidal, 'spline')

We can plot this spline surface together with part of the coils and some toroidal lines (red):

In [None]:
from jax_sbgeom.jax_utils.utils import _vertices_to_pyvista_polyline
if plot_pyvista:
    plotter = pv.Plotter()
    plotter.add_mesh(_mesh_to_pyvista_mesh(*coilset_rmf_mesh),color = 'cyan')
    plotter.add_mesh(_mesh_to_pyvista_mesh(*cws_non_optimized), color='lightgreen', show_edges=True)
    n_lines=  10
    for i in range(n_lines):        
        plotter.add_mesh(_vertices_to_pyvista_polyline(cws_non_optimized[0].reshape(n_points_per_coil, n_points_toroidal, 3)[int(i / n_lines * n_points_per_coil),:,:]), color='red', line_width=5)
    plotter.show()

The red curve is extremely curvy, leading to a very 'wobbly' coil winding surface. Next, we try to optimize the sampling on each coil to ensure both coverage of all coils and a less curvy line.

Technically, we optimize the sampling point $s$ for each coil independently for minimum distance to the next coil. We do have to apply some penalties to ensure that the points do not just cluster on the inboard side. 

Since we use JAX, everything is differentiable, leading to a simple optimization problem, with Jacobians easily calculatable.

In [None]:
cws_optimized = create_optimized_coil_winding_surface(coilset, n_points_per_coil, n_points_toroidal, surface_type = "spline")

In [None]:
from jax_sbgeom.jax_utils.utils import _vertices_to_pyvista_polyline
if plot_pyvista:
    plotter = pv.Plotter()
    plotter.add_mesh(_mesh_to_pyvista_mesh(*coilset_rmf_mesh),color = 'cyan')
    plotter.add_mesh(_mesh_to_pyvista_mesh(*cws_optimized), color='lightgreen', show_edges=True)
    n_lines=  10
    for i in range(n_lines):        
        plotter.add_mesh(_vertices_to_pyvista_polyline(cws_optimized[0].reshape(n_points_per_coil, n_points_toroidal, 3)[int(i / n_lines * n_points_per_coil),:,:]), color='red', line_width=5)
    plotter.show()

This results in a much more smooth surface. 


### Generating Thickness matrix

We wish to create a VMEC representation of this surface. It is not directly possible to FFT the surface, since the toroidal parameter does not correspond to the standard $\phi$ coordinate.

Instead, we generate a thickness matrix that tells us at a particular $\theta,\phi$ point, which distance beyond the LCFS should be used to obtain a point on the coil winding surface.

We do this by ray-tracing points from the LCFS at a grid of $\theta,\phi$ and interpolating this grid later.

This procedure was shown in figure 3a in the SBGeom paper. To reproduce it, we have to truncate the coil-winding-surface, create the arrows from the LCFS and pick a particular viewpoint:

In [None]:
def plot_figure_3a(show):

    cws_mesh_truncated = (cws_optimized[0], cws_optimized[1].reshape(n_points_per_coil, n_points_toroidal, 2, 3)[:, n_points_toroidal//2:, :].reshape(-1,3))

    n_poloidal_rays = 30
    n_toroidal_rays = 30

    theta_s = jnp.linspace(0, 2 * jnp.pi, n_poloidal_rays, endpoint=False)
    phi_s   = jnp.linspace(0, 2 * jnp.pi / flux_surface.settings.nfp, n_toroidal_rays)
    theta_mg, phi_mg = jnp.meshgrid(theta_s, phi_s, indexing='ij')

    LCFS = flux_surface.cartesian_position(1.0, theta_mg, phi_mg)
    d_LCFS = flux_surface.cartesian_position(2.0, theta_mg, phi_mg) - LCFS
    plotter = pv.Plotter(window_size=[900, int(900 * 0.75)])

    plotter.add_mesh(_mesh_to_pyvista_mesh(*cws_mesh_truncated), color='lightgrey', opacity=0.4, show_edges=False)
    plotter.add_mesh(_mesh_to_pyvista_mesh(*coilset_rmf_mesh), color='cyan', opacity=1.0, show_edges=False)
    plotter.add_mesh(_mesh_to_pyvista_mesh(*lcfs_mesh), color='lightblue', opacity=1.0, show_edges=False)


    plotter.add_arrows(LCFS, d_LCFS, mag = 2.0 , color = 'red')
    plotter.camera.position    = (7.920269193531761, -29.364895547441858, 13.260398034562193)
    plotter.camera.focal_point = (17.252318958192408, 7.859902146409471, 1.9152626501079728)
    plotter.camera.up          =  (0.015661256564529463, 0.28790403430695255, 0.9575311963965463)
    # Render and get screenshot as a NumPy array
    img = plotter.screenshot(return_img=True)
    if show:
        plotter.show()
    else:
        plotter.close()
    return img
img = plot_figure_3a(show = True)

In [None]:
plt.imshow(img); 
_ = plt.gca().axis('off')

Then, using the built-in ray-tracing in SBGeom (although not as efficient as dedicated raytracing, it is sufficiently fast for our purposes and keeps dependencies light) we can calculate the distance matrices for both the non-optimized and optimized surfaces:

In [None]:
theta, phi, d_grid_opt     = generate_thickness_matrix(flux_surface, cws_optimized, 200, 205)
theta, phi, d_grid_non_opt = generate_thickness_matrix(flux_surface, cws_non_optimized, 200, 205)

This reproduces figure 2b in the SBGeom paper

In [None]:
import numpy as onp

def plot_distance_matrix(d_grid):
    size = 3.2
    fig1, ax1 = plt.subplots(figsize=(size , 0.75 *size), dpi = 300)    
    plt.pcolormesh(phi, theta, d_grid, shading='auto')
    symm = flux_surface.settings.nfp
    plt.yticks([0, jnp.pi/2, jnp.pi, 3*jnp.pi/2, 2*jnp.pi],   [r'$0$', r'$\frac{\pi}{2}$', r'$\pi$', r'$\frac{3\pi}{2}$', r'$2\pi$'])
    plt.xticks([0, 2 * jnp.pi / symm / 2,  2*jnp.pi / symm ], [r'$0$', r'$\frac{2 \pi}{%d}$' % (2 * symm), r'$\frac{2 \pi}{%d}$' % symm])
    plt.xlabel(r'$\phi$ [rad]')
    plt.ylabel(r'$\theta$ [rad]')
    plt.colorbar(label=r'Distance [m]')
    plt.tight_layout()
    # Make sure it's rendered to the Agg canvas
    fig1.canvas.draw()

    # Convert the rendered buffer to a NumPy array
    img_dist_matrix = onp.array(fig1.canvas.renderer.buffer_rgba())
    plt.close(fig1)
    return img_dist_matrix

img_non_opt = plot_distance_matrix(d_grid_non_opt)
img_opt     = plot_distance_matrix(d_grid_opt)

fig, axs = plt.subplots(1,2, figsize=(10,5))

plt.sca(axs[0])
plt.title("Non-optimized d($\\theta,\\phi$)")
plt.imshow(img_non_opt)
axs[0].axis('off')

plt.sca(axs[1])
plt.title("Optimized d($\\theta,\\phi$)")
plt.imshow(img_opt)
_ = axs[1].axis('off')

The optimization clearly produces a much smoother distance matrix.

### Converting Distance Matrix to Fourier representation

This distance matrix can be used to convert the coil-winding surface to a VMEC representation. Optionally, use the equal arclength parametrisation to ensure balanced mesh sizes.

In [None]:
from jax_sbgeom.flux_surfaces.convert_to_vmec import  create_fourier_surface_extension_interp_equal_arclength, _create_fluxsurface_from_rmnc_zmns, create_fourier_surface_extension_interp
from jax_sbgeom.flux_surfaces import FluxSurfaceFourierExtended
flux_surface_extension = _create_fluxsurface_from_rmnc_zmns(*create_fourier_surface_extension_interp_equal_arclength(flux_surface, d_grid_opt, n_theta = 100, n_phi = 200, n_theta_s_arclength = 400))
flux_surface_extension_non_equal_arclength = _create_fluxsurface_from_rmnc_zmns(*create_fourier_surface_extension_interp(flux_surface, d_grid_opt, n_theta = 100, n_phi = 200))
flux_surface_total     = FluxSurfaceFourierExtended.from_flux_surface_and_extension(flux_surface, flux_surface_extension)
flux_surface_total_non_equal_arclength     = FluxSurfaceFourierExtended.from_flux_surface_and_extension(flux_surface, flux_surface_extension_non_equal_arclength)

To reproduce figure 3c, we can use the following code:

In [None]:
from jax_sbgeom.flux_surfaces import mesh_surfaces_closed
mesh_blanket_cws                     = mesh_surfaces_closed(flux_surface_total, 1.1, 2.0, ToroidalExtent(0,2*jnp.pi/flux_surface.settings.nfp),60,120, 10)
mesh_blanket_cws_non_equal_arclength = mesh_surfaces_closed(flux_surface_total_non_equal_arclength, 1.1, 2.0, ToroidalExtent(0,2*jnp.pi/flux_surface.settings.nfp),60,120, 10)

In [None]:
def plot_3c(show, mesh_blanket_cws):
    plotter = pv.Plotter(window_size=[900, int(900 * 0.75)])
    plotter.add_mesh(_mesh_to_pyvista_mesh(*mesh_blanket_cws), color="lightgreen", opacity=1.0, show_edges=  True)
    plotter.add_mesh(_mesh_to_pyvista_mesh(*coilset_rmf_mesh), color='cyan', opacity=1.0, show_edges=False)
    plotter.add_mesh(_mesh_to_pyvista_mesh(*lcfs_mesh), color='lightblue', opacity=1.0, show_edges=False)
    plotter.camera.position    = (7.920269193531761,    -29.364895547441858,  13.260398034562193)
    plotter.camera.focal_point = (17.252318958192408,    7.859902146409471,   1.9152626501079728)
    plotter.camera.up          = (0.015661256564529463,  0.28790403430695255, 0.9575311963965463)
    img_cws = plotter.screenshot(return_img=True)
    if show:
        plotter.show()
    else:
        plotter.close()
    return img_cws

img_cws_ea     = plot_3c(plot_pyvista, mesh_blanket_cws)
img_cws_non_ea = plot_3c(plot_pyvista, mesh_blanket_cws_non_equal_arclength)

In [None]:
if plot_pyvista:
    fig, axs = plt.subplots(1,2, figsize=(10, 5))
    plt.sca(axs[0])
    plt.imshow(img_cws_ea); 
    axs[0].axis('off')
    plt.title("Equal arclength")
    plt.sca(axs[1])
    plt.imshow(img_cws_non_ea); 
    axs[1].axis('off')
    plt.title("Non equal arclength")

Equal arclength produces significantly better meshes.

To reproduce the full figures:

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

size =3.2
fig, axes = plt.subplots(1, 3, figsize=(size * 3, size * 0.75))

labels = ['(a)', '(b)', '(c)']

axes[0].imshow(img)
axes[0].axis('off')
axes[0].text(0.01, 0.01, labels[0], transform=axes[0].transAxes,
        fontsize=12, va='bottom', ha='left',
        color='black')

# plot_distance_matrix(axes[1])
# axes[1].text(0.01, 0.01, labels[1], transform=axes[1].transAxes,
#         fontsize=12, va='bottom', ha='left',
#         color='black')  

axes[1].imshow(img_opt)
axes[1].axis('off')
axes[1].text(0.01, 0.01, labels[1], transform=axes[1].transAxes,
        fontsize=12, va='bottom', ha='left',
        color='black')

axes[2].imshow(img_cws_ea)
axes[2].axis('off')
axes[2].text(0.01, 0.01, labels[2], transform=axes[2].transAxes,
        fontsize=12, va='bottom', ha='left',
        color='black')
plt.tight_layout()

plt.show()

We can again plot the resulting arc lenghts of the surface as given by just the d_grid and the arc lengths 

In [None]:

def plot_arclengths_ad(ax, flux_surfaces, n_theta, phi_values, s):
    from jax_sbgeom.flux_surfaces.flux_surfaces_base import _arc_length_theta_interpolating_s_grid_full_mod
    s = jnp.atleast_2d(s)
    theta_sample = jnp.linspace(0, 2*jnp.pi, n_theta, endpoint=False)           #[ n_theta ]
    theta_mg, phi_mg = jnp.meshgrid(theta_sample, jnp.array(phi_values), indexing='ij')                     #[n_theta, n_phi]
    
    al = _arc_length_theta_interpolating_s_grid_full_mod(flux_surfaces, s, theta_mg, phi_mg)
    ax.plot(theta_sample, al)
fig, ax = plt.subplots(1,2, figsize=(10,5))


plot_arclengths_ad(ax[0], flux_surface, 200, jnp.linspace(0,2 * jnp.pi / flux_surface.settings.nfp, 10), d_grid_opt)
ax[0].set_xlabel("$\\theta$ [rad]")
ax[0].set_ylabel("$dr/d\\theta$")
ax[0].set_title("Base")
plot_arclengths_ad(ax[1], flux_surface_total, 200, jnp.linspace(0,2 * jnp.pi / flux_surface.settings.nfp, 10), 2.0)
#ax[1].set_ylim(0.0,12)
plt.xlabel("$\\theta$ [rad]")

plt.title("Equal Arc Length")
 
