In [1]:
import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax_sbgeom as jsb
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import numpy as onp
import sys 
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)


from functools import partial
import jax_sbgeom.coils as jsc
import jax_sbgeom.flux_surfaces as jsf

from dataclasses import dataclass
import functools
from jax_sbgeom.jax_utils.utils import _mesh_to_pyvista_mesh, cumulative_trapezoid_uniform_periodic, interp1d_jax
#import StellBlanket.SBGeom as sbg
#import StellBlanket

import matplotlib.pyplot as plt
vmec_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"]
vmec_file = vmec_files[2]

coil_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS3_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS5_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_coilset.h5"]

In [2]:
fs_base = jsb.flux_surfaces.FluxSurface.from_hdf5(vmec_file)

In [3]:

coilsd = onp.loadtxt("filament_s4uu_5ci_23")[:,:]
amax = 101
i = 0
coils = []
currents = []
for i in range(40):
#    if i not in [0, 4, 8, 12, 16, 20,24, 28, 32,35, 36,39] :
    coils.append(coilsd[i * amax : i * amax + amax -1,:-1])
    currents.append(coilsd[i * amax, -1])



jax_coils = [jsc.DiscreteCoil.from_positions(jnp.array(coil)) for coil in coils]
jax_coilset = jsc.CoilSet.from_list(jax_coils)
jax_currents = jnp.array(currents)

#sbg_coilset = sbg.Coil_Set([sbg.Discrete_Coil(coils[i]) for i in range(len(coils))])
coilset = jsc.fourier_coil.convert_to_fourier_coilset(jax_coilset)
coilset_resampled = jsc.fourier_coil.convert_fourier_coilset_to_equal_arclength(coilset)

fs_coilset = jsc.FiniteSizeCoilSet.from_coilset(coilset_resampled, jsc.CentroidFrame)




In [109]:
from typing import Type, List
import pytest
classes          = [jsb.coils.CentroidFrame, jsb.coils.RotationMinimizedFrame, jsb.coils.RadialVectorFrame, jsb.coils.FrenetSerretFrame]
classes_discrete = classes[:-1]  # Frenet-Serret frame not implemented for discrete coils

def _radial_vector(coil_i, n_coils):
    x = jnp.cos(jnp.linspace(0, 2*jnp.pi, 100) + coil_i  / n_coils * 2 * jnp.pi)
    y = jnp.sin(jnp.linspace(0, 2*jnp.pi, 100) + coil_i  / n_coils * 2 * jnp.pi)
    z = jnp.zeros_like(x)
    return jnp.stack([x, y, z], axis=-1)

_radial_vectors = jax.vmap(_radial_vector, in_axes=(0, None))

def additional_arguments_per_coil(frame_class : Type[jsb.coils.base_coil.FiniteSizeMethod], coil_i : int, ncoils : int):
    if frame_class == jsb.coils.RotationMinimizedFrame:
        return (10,)
    elif frame_class == jsb.coils.RadialVectorFrame:
        return (_radial_vector(coil_i, ncoils), )
    else:
        return ()
def additional_arguments(frame_class : Type[jsb.coils.base_coil.FiniteSizeMethod], coils_list : List):
    if frame_class == jsb.coils.RotationMinimizedFrame:
        return (10,)
    elif frame_class == jsb.coils.RadialVectorFrame:
        ncoils = len(coils_list)
        radial_vectors = _radial_vectors(jnp.arange(ncoils), ncoils)
        return (radial_vectors, )
    else:
        return ()        


def check_reverse(coil):
    if isinstance(coil, jsb.coils.DiscreteCoil):
        # reversed discretecoils do not have the same tangent on *exactly* the discrete points of the coil: 
        #  the forward derivative is different on those points.
        # so we use the parametrisation without begin and endpoint and without the discrete points itself
        # if 211 (prime) > number of points in coil, this works fine
        assert coil.Number_of_Points() < 211, "Test needs to be adjusted for coils with more than 211 points"
        s         = jnp.linspace(0, 1.0,211, endpoint=False)[1:]
        s_reverse = jnp.linspace(0,-1.0,211, endpoint=False)[1:]
    else:
        s         = jnp.linspace(0, 1.0,100)
        s_reverse = jnp.linspace(0,-1.0,100)

    rev_coil = coil.reverse_parametrisation()
    position_original = coil.position(s)
    position_reversed = rev_coil.position(s_reverse)
    onp.testing.assert_allclose(position_original, position_reversed)

    tangent_original = coil.tangent(s)
    tangent_reversed = rev_coil.tangent(s_reverse)    
    onp.testing.assert_allclose(tangent_original, -tangent_reversed)

def check_reverse_finite_size(finitesize_coil):
    if isinstance(finitesize_coil.coil, jsb.coils.DiscreteCoil):
        # discretecoils do not have the same tangent on *exactly* the discrete points of the coil..
        #  the forward derivative is different on those points.
        # so we use the parametrisation without being and endpoint
        s         = jnp.linspace(0, 1.0,211, endpoint=False)[1:]
        s_reverse = jnp.linspace(0,-1.0,211, endpoint=False)[1:]
    else:
        s = jnp.linspace(0,1,100)
        s_reverse = jnp.linspace(0,-1,100)
    

    reversed_finitesize_coil = finitesize_coil.reverse_parametrisation()
    position_original = finitesize_coil.position(s)
    position_reversed = reversed_finitesize_coil.position(s_reverse)

    print(finitesize_coil.finite_size_method)
    radial_vector_original = finitesize_coil.radial_vector(s)
    radial_vector_reversed = reversed_finitesize_coil.radial_vector(s_reverse)

    tangent_original = finitesize_coil.tangent(s)
    tangent_reversed = reversed_finitesize_coil.tangent(s_reverse)
    
    
    onp.testing.assert_allclose(position_original,      position_reversed)        
    onp.testing.assert_allclose(tangent_original,      -tangent_reversed)    
    onp.testing.assert_allclose(radial_vector_original, radial_vector_reversed)

    finitesize_frame_original = finitesize_coil.finite_size_frame(s)
    finitesize_frame_reversed = reversed_finitesize_coil.finite_size_frame(s_reverse)
    # the tangent vector will be reversed, but the radial vector the same, so
    # we need to reverse the tangent vector in the reversed frame to compare
    finitesize_frame_reversed_corrected = finitesize_frame_reversed.at[...,1,:].set(-finitesize_frame_reversed[...,1,:])

    onp.testing.assert_allclose(finitesize_frame_original, finitesize_frame_reversed_corrected)

    width_phi = 0.245
    width_radial = 0.123
    finitesize_original = finitesize_coil.finite_size(s, width_phi, width_radial)
    finitesize_reversed = reversed_finitesize_coil.finite_size(s_reverse, width_phi, width_radial)

    # finite size has    
    #  The finite size is in the following order:
    # v_0 : + radial, + phi
    # v_1 : - radial, + phi
    # v_2 : - radial, - phi
    # v_3 : + radial, - phi
    # since phi -> -phi in the reversed coil
    # we need to swap v_0 with v_3 and v_1 with v_2 to compare
    # i can just flip the arrays around the first axis.    
    finitesize_reversed_corrected = finitesize_reversed[:, ::-1, :]
    onp.testing.assert_allclose(finitesize_original, finitesize_reversed_corrected)

def _get_finitesize_coilset(coils_jax, frame_class : Type[jsb.coils.base_coil.FiniteSizeMethod]):
    #additional_args = additional_arguments_per_coil(frame_class, coils_jax, len(coils_jax))
    finitesize_coilset = [jsb.coils.base_coil.FiniteSizeCoil(coil_jax, frame_class.from_coil(coil_jax, *additional_arguments_per_coil(frame_class, i, len(coils_jax)))) for i, coil_jax in enumerate(coils_jax)]
    return finitesize_coilset

@pytest.mark.parametrize("frame_class", classes_discrete)
def test_finitesize_coilset_reverse_discrete(_get_all_discrete_coils, frame_class):
    coils_jaxsbgeom, coilset_sbgeom = _get_all_discrete_coils
    coils_finitesize = _get_finitesize_coilset(coils_jaxsbgeom, frame_class)
    [check_reverse_finite_size(coils_finitesize[i]) for i in range(len(coils_jaxsbgeom))]

@pytest.mark.parametrize("frame_class", classes)
def test_finitesize_coilset_reverse_fourier(_get_all_fourier_coils, frame_class):
    coils_jaxsbgeom, coilset_sbgeom = _get_all_fourier_coils
    coils_finitesize = _get_finitesize_coilset(coils_jaxsbgeom, frame_class)
    [check_reverse_finite_size(coils_finitesize[i]) for i in range(len(coils_jaxsbgeom))]

import h5py
def get_disc_coils(hdf5_file):
    with h5py.File(hdf5_file) as f:
        coils = jnp.array(f['Dataset1'])

    return [jsc.DiscreteCoil.from_positions(coils[i]) for i in range(coils.shape[0])]

coilset_jax = get_disc_coils(coil_files[0])

test_finitesize_coilset_reverse_discrete((coilset_jax,None ), jsb.coils.RadialVectorFrame)

RadialVectorFrame(radial_vectors_i=Array([[ 1.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 9.97986676e-01,  6.34239197e-02,  0.00000000e+00],
       [ 9.91954813e-01,  1.26592454e-01,  0.00000000e+00],
       [ 9.81928697e-01,  1.89251244e-01,  0.00000000e+00],
       [ 9.67948701e-01,  2.51147987e-01,  0.00000000e+00],
       [ 9.50071118e-01,  3.12033446e-01,  0.00000000e+00],
       [ 9.28367933e-01,  3.71662456e-01,  0.00000000e+00],
       [ 9.02926538e-01,  4.29794912e-01,  0.00000000e+00],
       [ 8.73849377e-01,  4.86196736e-01,  0.00000000e+00],
       [ 8.41253533e-01,  5.40640817e-01,  0.00000000e+00],
       [ 8.05270258e-01,  5.92907929e-01,  0.00000000e+00],
       [ 7.66044443e-01,  6.42787610e-01,  0.00000000e+00],
       [ 7.23734038e-01,  6.90079011e-01,  0.00000000e+00],
       [ 6.78509412e-01,  7.34591709e-01,  0.00000000e+00],
       [ 6.30552667e-01,  7.76146464e-01,  0.00000000e+00],
       [ 5.80056910e-01,  8.14575952e-01,  0.00000000e+00],
     

In [98]:
print(jnp.linspace(0.0, 1.0, 213, endpoint=False)[1:][0:2])
print(1/213)


[0.00469484 0.00938967]
0.004694835680751174


In [49]:
from jax_sbgeom.coils.discrete_coil import  _discrete_coil_discrete_position, DiscreteCoil
from jax_sbgeom.jax_utils.utils import interpolate_fractions_modulo
def _discrete_coil_tangent(discrete_coil : DiscreteCoil, s):
    i0, i1, ds = interpolate_fractions_modulo(s, discrete_coil.positions.shape[0])

    pos_i0 = _discrete_coil_discrete_position(discrete_coil, i0)
    print(pos_i0)
    print(discrete_coil.positions[i0])
    print(discrete_coil.positions[i1])
    pos_i1 = _discrete_coil_discrete_position(discrete_coil, i1)
    print(pos_i1)
    print(i0, i1, ds)
    tangent = pos_i1 - pos_i0
    
    tangent = tangent / jnp.linalg.norm(tangent, axis=-1, keepdims=True)
    return tangent

rever_coil = jax_coilset[0].reverse_parametrisation()
coil = jax_coilset[0]

print(_discrete_coil_tangent(coil, 0.911))
print(_discrete_coil_tangent(rever_coil, 0.089))

[23.59251   2.461262 -2.576186]
[23.59251   2.461262 -2.576186]
[23.70941   2.366493 -2.32873 ]
[23.70941   2.366493 -2.32873 ]
91 92 0.10000000000000853
[ 0.40362873 -0.3272155   0.85440849]
[23.70941   2.366493 -2.32873 ]
[23.70941   2.366493 -2.32873 ]
[23.59251   2.461262 -2.576186]
[23.59251   2.461262 -2.576186]
8 9 0.9000000000000004
[-0.40362873  0.3272155  -0.85440849]


In [None]:
rever_coil = jax_coilset[0].reverse_parametrisation()
jax_coilset[0].reverse_parametrisation().tangent(0.9)

Array([ 0.49655428,  0.39494535, -0.7729502 ], dtype=float64)

In [None]:
#from StellBlanket.SBGeom import Coils_jax as CJ

#cj_set = CJ.CoilSet_JAX(coilset_resampled.coils.centre_i, coilset_resampled.coils.fourier_cos, coilset_resampled.coils.fourier_sin)

In [None]:
jsc.base_coil._coil_rotation_positive(coilset_resampled[0])

Array(True, dtype=bool)

In [None]:
import pyvista as pv
plotter = pv.Plotter()

    
plotter.add_mesh(pv.Spline(positions_0), scalars = jnp.linspace(0,1,positions_0.shape[0]), line_width=5, cmap="viridis")

plotter.show()

NameError: name 'positions_0' is not defined

In [None]:
coilset_flipped = jsc.coilset.ensure_coilset_rotation(coilset_resampled, False)

In [None]:
positions = coilset_flipped.position(jnp.linspace(0,1,100))

In [None]:
import pyvista as pv
plotter = pv.Plotter()
for i in range(positions.shape[0]):
    coil = positions[i]    
    plotter.add_mesh(pv.Spline(coil), scalars = jnp.linspace(0,1,coil.shape[0]), line_width=5, cmap="viridis")

plotter.show()

Widget(value='<iframe src="http://localhost:36099/index.html?ui=P_0x7fa997687590_1&reconnect=auto" class="pyvi…

In [None]:
%tb

No traceback available to show.
