In [1]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax_sbgeom as jsb
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import numpy as onp
import sys 
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)


from functools import partial
import jax_sbgeom.coils as jsc
import jax_sbgeom.flux_surfaces as jsf

from dataclasses import dataclass
import functools
from jax_sbgeom.jax_utils.utils import _mesh_to_pyvista_mesh


vmec_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"]
vmec_file = vmec_files[2]

coil_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS3_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS5_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_coilset.h5"]

In [2]:

coilsd = onp.loadtxt("filament_s4uu_5ci_23")[:,:]
amax = 101
i = 0
coils = []
currents = []
for i in range(40):
#    if i not in [0, 4, 8, 12, 16, 20,24, 28, 32,35, 36,39] :
    coils.append(coilsd[i * amax : i * amax + amax -2,:-1])
    currents.append(coilsd[i * amax, -1])


jax_coils = [jsc.DiscreteCoil.from_positions(jnp.array(coil)) for coil in coils]

jax_currents = jnp.array(currents)

In [3]:
jax_coilset = jsc.CoilSet.from_list(jax_coils)

In [4]:
flux_surface            =  jsf.FluxSurface.from_hdf5(vmec_file)
mesh_lcfs               = jsf.mesh_surface(flux_surface, 1.0, jsf.ToroidalExtent.full(), 100, 110)

In [5]:
theta = jnp.linspace(0, 2*jnp.pi, 100)
phi   = jnp.linspace(0, 2*jnp.pi, 1100)
theta_mg, phi_mg = jnp.meshgrid(theta, phi, indexing='ij')
positions_lcfs = flux_surface.cartesian_position(1.0, theta_mg, phi_mg)

In [6]:
def create_coilset_total_arrays(jax_coilset, currents, number_of_samples_per_coil):
    coil_samples     = jax_coilset.position(jnp.linspace(0,1, number_of_samples_per_coil, endpoint=False))    
    currents_stacked = jnp.stack([currents] * number_of_samples_per_coil, axis=-1)    
    coil_diff        = jnp.roll(coil_samples, -1, axis=1) - coil_samples
    return currents_stacked.reshape(-1) , coil_samples.reshape(-1,3), coil_diff.reshape(-1,3)
    
currents_total, coil_samples, coil_diff = create_coilset_total_arrays(jax_coilset, jax_currents, 200)


In [7]:
%timeit field_total2 = jsc.biot_savart.biot_savart_batch(currents_total, coil_samples, coil_diff, positions_lcfs.reshape(-1,3)[0:500], batch_size=None).block_until_ready()

3.25 ms ± 153 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
import scipy.integrate as integrate
def B_func(x):
    return jsc.biot_savart.biot_savart_single(currents_total, coil_samples, coil_diff, x)

r0 = positions_lcfs[0,0]

print(r0)

def field_line_rhs(s,r, b_func):
    B = b_func(r)
    B_norm = jnp.linalg.norm(B)
    return B / B_norm

def trace_field_line(r0, b_func, s_span = (0,1), max_step=0.01):
    s0, s1 = s_span

    sol = integrate.solve_ivp(field_line_rhs, 
                              (s0, s1), 
                            r0,
                            args=(b_func,),
                            method = "RK45",
                            max_step=max_step,
                            rtol=1e-6, atol = 1e-9)
    
    return sol

field_line = trace_field_line(r0, B_func, s_span=(0,5000), max_step=1.0)


[21.74336723  0.          0.        ]


In [16]:
field_line

  message: The solver successfully reached the end of the integration interval.
  success: True
   status: 0
        t: [ 0.000e+00  1.000e-03 ...  5.000e+03  5.000e+03]
        y: [[ 2.174e+01  2.174e+01 ... -1.960e+01 -1.962e+01]
            [ 0.000e+00  9.043e-04 ...  4.459e+00  4.412e+00]
            [ 0.000e+00  4.267e-04 ... -2.163e+00 -2.150e+00]]
      sol: None
 t_events: None
 y_events: None
     nfev: 30926
     njev: 0
      nlu: 0

In [17]:
import pyvista as pv    
plotter = pv.Plotter()
plotter.add_mesh(_mesh_to_pyvista_mesh(*mesh_lcfs), color='white', opacity=0.3, show_edges=False)
plotter.add_mesh(pv.Spline(field_line.y.T), color='red', line_width=4)
plotter.show()

Widget(value='<iframe src="http://localhost:39417/index.html?ui=P_0x700614c86030_2&reconnect=auto" class="pyvi…