In [None]:
import os 
#os.environ["JAX_PLATFORM_NAME"] = "cpu"
import jax 
jax.config.update("jax_enable_x64", True)
import jax_sbgeom
import jax.numpy as jnp
from jax_sbgeom.flux_surfaces import FluxSurfaceNormalExtendedConstantPhi, mesh_surface, ToroidalExtent, FluxSurface

%load_ext autoreload
%autoreload 2
from jax_sbgeom.jax_utils.utils import _mesh_to_pyvista_mesh

import matplotlib.pyplot as plt
from jax_sbgeom.flux_surfaces import mesh_surfaces_closed

In [None]:
stell_i = 1
vmec_file = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4",     "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4",     "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"][stell_i]
coil_file = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS3_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS5_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_coilset.h5"][stell_i]

In [None]:
base_flux_surface = FluxSurface
flux_surfaces_cphi = FluxSurfaceNormalExtendedConstantPhi.from_hdf5(vmec_file)

In [None]:
n_theta = 30 
n_phi   = 40

d_lcfs = 1.4
mesh_non_convert = mesh_surfaces_closed(flux_surfaces_cphi, 1.2, 1 + d_lcfs, ToroidalExtent.full_module(flux_surfaces_cphi), n_theta, n_phi, 3)
lcfs = mesh_surface(flux_surfaces_cphi, 1.0, ToroidalExtent.full(), 500 ,500)

In [None]:
from jax_sbgeom.flux_surfaces.convert_to_vmec import create_extended_flux_surface_d_interp_equal_arclength
from jax_sbgeom.flux_surfaces import FluxSurfaceFourierExtended
flux_surface_total     = create_extended_flux_surface_d_interp_equal_arclength(flux_surfaces_cphi, d_lcfs, n_theta = 200, n_phi = 200, n_theta_s_arclength=300)

In [None]:
mesh_convert = mesh_surfaces_closed(flux_surface_total, 1.2, 2, ToroidalExtent.full_module(flux_surfaces_cphi), n_theta, n_phi, 3)

In [None]:
import pyvista as pv
plotter = pv.Plotter(shape = (1,2))

plotter.subplot(0,0)
plotter.add_mesh(_mesh_to_pyvista_mesh(*lcfs), color='white', opacity = 0.7)
plotter.add_mesh(_mesh_to_pyvista_mesh(*mesh_non_convert), color="#FC8882", show_edges=True, line_width = 1)
plotter.camera.position = (12.117945733232254, -8.607424191419783, -14.037527528499304)
plotter.camera.focal_point = (21.510495462644545, 0.8983836673386539, -1.955428883477867)
plotter.camera.up =  (0.41585256935017817, 0.5291767723432965, -0.7396205677083261)

plotter.subplot(0,1)
plotter.add_mesh(_mesh_to_pyvista_mesh(*lcfs), color='white', opacity = 0.7)
plotter.add_mesh(_mesh_to_pyvista_mesh(*mesh_convert), color="#66DB7C", show_edges=True, line_width = 1)
plotter.camera.position = (12.117945733232254, -8.607424191419783, -14.037527528499304)
plotter.camera.focal_point = (21.510495462644545, 0.8983836673386539, -1.955428883477867)
plotter.camera.up =  (0.41585256935017817, 0.5291767723432965, -0.7396205677083261)

plotter.link_views()
plotter.show()



# Plotting arc length

Below we plot the arc length in both cases. First we just compute differences of adjacent points. Afterwards, we use the autodifferentiable method to compute it. This returns the same up to a scaling factor.

In [None]:
def plot_arclengths_fd(ax, flux_surfaces, n_theta, phi_values, s):
    theta_sample = jnp.linspace(0, 2*jnp.pi, n_theta, endpoint=False)           #[ n_theta ]
    theta_mg, phi_mg = jnp.meshgrid(theta_sample, jnp.array(phi_values), indexing='ij')                     #[n_theta, n_phi]
    positions = flux_surfaces.cartesian_position(s, theta_mg, phi_mg)
    lenghts = jnp.linalg.norm(jnp.roll(positions, shift=1, axis=0) - positions, axis=-1)

    ax.plot(theta_sample, lenghts)

fig, ax = plt.subplots(1,2, figsize=(10,5))

ax[0].set_ylim(0.0, 0.35)
plot_arclengths_fd(ax[0], flux_surfaces_cphi, 200, [0.0, jnp.pi/2, jnp.pi, 3*jnp.pi/2], 2.4)
ax[0].set_xlabel("$\\theta$ [rad]")
ax[0].set_ylabel("$\\Delta h$ [m]")
ax[0].set_title("Base")
plot_arclengths_fd(ax[1], flux_surface_total, 200, [0.0, jnp.pi/2, jnp.pi, 3*jnp.pi/2], 2.0)
ax[1].set_ylim(0.0, 0.35)
plt.xlabel("$\\theta$ [rad]")
plt.ylabel("$\\Delta h$ [m]")
plt.title("Equal Arc Length")
 

In [None]:
d_grid = jnp.ones((2,2)) * d_lcfs + 1.0
def plot_arclengths_ad(ax, flux_surfaces, n_theta, phi_values, s):
    from jax_sbgeom.flux_surfaces.flux_surfaces_base import _arc_length_theta_interpolating_s_grid_full_mod
    s = jnp.atleast_2d(s)
    theta_sample = jnp.linspace(0, 2*jnp.pi, n_theta, endpoint=False)           #[ n_theta ]
    theta_mg, phi_mg = jnp.meshgrid(theta_sample, jnp.array(phi_values), indexing='ij')                     #[n_theta, n_phi]
    
    al = _arc_length_theta_interpolating_s_grid_full_mod(flux_surfaces, s, theta_mg, phi_mg)
    ax.plot(theta_sample, al)
fig, ax = plt.subplots(1,2, figsize=(10,5))

ax[0].set_ylim(0.0, 12)
plot_arclengths_ad(ax[0], flux_surfaces_cphi, 200, [0.0, jnp.pi/2, jnp.pi, 3*jnp.pi/2], 2.4)
ax[0].set_xlabel("$\\theta$ [rad]")
ax[0].set_ylabel("$dr/d\\theta$")
ax[0].set_title("Base")
plot_arclengths_ad(ax[1], flux_surface_total, 200, [0.0, jnp.pi/2, jnp.pi, 3*jnp.pi/2], 2.0)
ax[1].set_ylim(0.0,12)
plt.xlabel("$\\theta$ [rad]")

plt.title("Equal Arc Length")
 
