In [14]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax_sbgeom as jsb
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import numpy as onp
import sys 
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

import StellBlanket.SBGeom as SBGeom
from functools import partial
import jax_sbgeom.coils as jsc

from tests.coils.test_coils import _check_single_vectorized
import pyvista as pv
from dataclasses import dataclass

from jax_sbgeom.jax_utils.utils import _mesh_to_pyvista_mesh


vmec_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"]
vmec_file = vmec_files[2]

coil_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS3_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/HELIAS5_coils_all.h5", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_coilset.h5"]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
class bb:
    param : str
    def __init__(self, param):
        self.param = param

def _get_all_discrete_coils(request):    
    coilset_sbgeom = SBGeom.Coils.Discrete_Coil_Set_From_HDF5(request.param)    
    coilset_jaxsbgeom = jsb.coils.CoilSet([jsb.coils.DiscreteCoil.from_positions(coilset_sbgeom[i].Get_Vertices()) for i in range(coilset_sbgeom.Number_of_Coils())])
    
    return coilset_jaxsbgeom, coilset_sbgeom


jsbc, sbgc = _get_all_discrete_coils(bb(coil_files[1]))

In [18]:
def _sampling_s_finite_size(n_s : int = 111):
    return jnp.linspace(0.0, 1.0, n_s, endpoint=False)

def _switch_finite_size(coil_sbgeom, width_0, width_1, method, ns, **kwargs):
    if method == "centroid":
        return coil_sbgeom.Finite_Size_Lines_Centroid(width_1, width_0, ns)
    elif method == "frenet_serret":
        return coil_sbgeom.Finite_Size_Lines_Frenet(width_1, width_0, ns)
    elif method == "rmf":
        return coil_sbgeom.Finite_Size_Lines_RMF(width_1, width_0, ns)
    else:
        raise NotImplementedError(f"Finite size method '{method}' not implemented for SBGeom discrete coils.")

def _switch_finite_size_cjax(coil_jax, width_0, width_1, method, ns, **kwargs):
    if method == "centroid":
        finitesize_coil = jsb.coils.base_coil.FiniteSizeCoil(coil_jax, jsb.coils.base_coil.CentroidFrame())
        return finitesize_coil.finite_size(_sampling_s_finite_size(ns), width_0, width_1)
    elif method == "frenet_serret":
        finitesize_coil = jsb.coils.base_coil.FiniteSizeCoil(coil_jax, jsb.coils.base_coil.FrenetSerretFrame())
        return finitesize_coil.finite_size(_sampling_s_finite_size(ns), width_0, width_1)
    elif method == "rmf":
        number_of_rmf_samples = kwargs.get("number_of_rmf_samples", 1000)
        finitesize_coil = jsb.coils.base_coil.FiniteSizeCoil(coil_jax, jsb.coils.base_coil.RotationMinimizedFrame.from_coil(coil_jax, number_of_rmf_samples))
        return finitesize_coil.finite_size(_sampling_s_finite_size(ns), width_0, width_1)
    else:
        raise NotImplementedError(f"Finite size method '{method}' not implemented for JAX-SBGeom discrete coils.")
def _check_finite_size(coilset_jsb, coilset_sbgeom, method, rtol = 1e-12, atol = 1e-12, **kwargs):
    for i in range(coilset_sbgeom.Number_of_Coils()):
        coil_jsb = coilset_jsb[i]
        coil_sbgeom = coilset_sbgeom[i]
        s_samples = _sampling_s_finite_size()
        width_0 = 0.3
        width_1 = 0.5
        jsb_lines    = _switch_finite_size_cjax(coil_jsb, width_0, width_1, method, ns = s_samples.shape[0], **kwargs)
        sbgeom_lines = _switch_finite_size(coil_sbgeom, width_0, width_1, method, ns = s_samples.shape[0])

        jsb_comparison = jnp.moveaxis(jsb_lines, 1,0).reshape(-1,3)
        onp.testing.assert_allclose(sbgeom_lines, jsb_comparison, rtol = rtol, atol=atol)


def test_discrete_coil_finite_size_rmf(_get_all_discrete_coils):
    coilset_jaxsbgeom, coilset_sbgeom = _get_all_discrete_coils
    nrmf = _sampling_s_finite_size().shape[0]
    _check_finite_size(coilset_jaxsbgeom, coilset_sbgeom, method="rmf", rtol =1e-7,  atol=1e-7, number_of_rmf_samples = nrmf)
test_discrete_coil_finite_size_rmf((jsbc, sbgc))

In [3]:
from jax_sbgeom.coils.base_coil import CentroidFrame, FiniteSizeCoil, RotationMinimizedFrame


finite_size_method = RotationMinimizedFrame.from_coil(jsbc[0], 40)

fsc = FiniteSizeCoil(coil=jsbc[0], finite_size_method=finite_size_method)

In [4]:
coilset = jax.tree.map(lambda *xs : jnp.stack(xs), *jsbc)

In [5]:
jax.vmap(coilset.position)(jnp.linspace(1,1,50)).shape  # (ncoil, ns, 3)

(50, 50, 3)

In [6]:
from jax_sbgeom.coils.discrete_coil import _discrete_coil_position

In [7]:
coilset

DiscreteCoil(positions=Array([[[ 20.16335247,  17.31449564,   0.7402    ],
        [ 20.30560909,  17.11427282,   0.8618    ],
        [ 20.42608102,  16.8980989 ,   0.9811    ],
        ...,
        [ 19.7297655 ,  17.77313121,   0.3       ],
        [ 19.86177965,  17.64791254,   0.4655    ],
        [ 20.00910203,  17.49413817,   0.6112    ]],

       [[ 17.4685707 ,  20.39372269,   1.4064    ],
        [ 17.43065533,  20.39299824,   1.754     ],
        [ 17.42477949,  20.31295822,   2.0561    ],
        ...,
        [ 17.63121244,  20.09619917,   0.2817    ],
        [ 17.58826135,  20.21962548,   0.6472    ],
        [ 17.52957542,  20.3250686 ,   1.0273    ]],

       [[ 14.7209944 ,  21.36726858,   4.1227    ],
        [ 14.69580635,  21.12190271,   4.3333    ],
        [ 14.68831129,  20.85354557,   4.5       ],
        ...,
        [ 14.89209778,  21.88391044,   3.2982    ],
        [ 14.82335936,  21.75411907,   3.595     ],
        [ 14.76498898,  21.58130533,   3.8725    ]

In [8]:
def coilset_position(coilset, s):
    return coilset.position(s)


jax.vmap(coilset_position, in_axes=(0,0))(coilset, jnp.linspace(1,1,50)).shape

(50, 3)

In [9]:
coilset_lines = coilset.position(jnp.linspace(0,1,200))

pv.Spline(coilset_lines.reshape(-1,3)).plot()

Widget(value='<iframe src="http://localhost:42653/index.html?ui=P_0x79abe816dc10_0&reconnect=auto" class="pyvi…

In [10]:
plotter = pv.Plotter()



coil_mesh = _mesh_to_pyvista_mesh(fsc.get_mesh())
plotter.add_mesh(coil_mesh, color="red", opacity=0.5)
plotter.show()

AttributeError: 'FiniteSizeCoil' object has no attribute 'get_mesh'

In [None]:
from jax_sbgeom.jax_utils.utils import interpolate_array_modulo_broadcasted

In [11]:
mesh_sb = sbgc[0].Mesh_Triangles_RMF(0.2,0.2, 40)
mesh_cj = jsb.coils.coil_meshing.mesh_coil_surface(fsc, 40, 0.2, 0.2)

onp.testing.assert_allclose(mesh_sb.vertices, mesh_cj[0], rtol=1e-12, atol=1e-12)

In [None]:
plotter =  pv.Plotter()


pvmesh_sb = sbgc[0].Mesh_Triangles_RMF(0.6,0.2, 100).to_pyvista()
pvmesh = _mesh_to_pyvista_mesh(*jsb.coils.coil_meshing.mesh_coil_surface(fsc, 100, 0.2, 0.6))
plotter.add_mesh(pvmesh_sb, color="blue", opacity=1.0, show_edges=True)
plotter.add_mesh(pvmesh, color="red", opacity=1.0, show_edges=True)
plotter.show()