In [1]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax_sbgeom as jsb
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import numpy as onp
import sys 
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

import StellBlanket.SBGeom as SBGeom
from functools import partial
import jax_sbgeom.flux_surfaces.flux_surface_meshing as fsm
import jax_sbgeom.flux_surfaces.flux_surfaces_base as fsb
import jax_sbgeom.flux_surfaces.flux_surfaces_extended as fse
import pyvista as pv
from dataclasses import dataclass


vmec_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"]
vmec_file = vmec_files[1]

Triangle Elements not compiled


In [2]:
def _get_flux_surfaces(vmec_file):
    fs_jax    = jsb.flux_surfaces.FluxSurface.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file)
    return fs_jax, fs_sbgeom

def _get_extended_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtended.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

def _get_extended_no_phi_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtendedNoPhi.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended_No_Phi(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

def _get_extended_constant_phi_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtendedConstantPhi.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended_Constant_Phi(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

fs_jax, fs_sbgeom = _get_extended_constant_phi_flux_surfaces(vmec_file)  # just to compile
fs_jax, fs_sbgeom = _get_extended_no_phi_flux_surfaces(vmec_file)  # just to compile


In [3]:
### We are going to fit RZ points to VMEC

In [4]:
d_total = 0.0
n_theta = 50
n_phi   = 60
symm = fs_jax.settings.nfp

theta = jnp.linspace(0, 2*jnp.pi, n_theta, endpoint=False)
phi   = jnp.linspace(0, 2*jnp.pi / symm, n_phi, endpoint=False)

theta_mg, phi_mg = jnp.meshgrid(theta, phi, indexing='ij')

In [5]:
positions_jax = fs_jax.cylindrical_position(1.0 + d_total, theta_mg, phi_mg)

In [6]:
positions_cart_sbg = fs_sbgeom.Return_Position(onp.ones_like(theta_mg).ravel(), onp.ones_like(theta_mg).ravel() * d_total, onp.array(theta_mg.ravel()), onp.array(phi_mg.ravel())).reshape(positions_jax.shape)


Rsbg = onp.sqrt(positions_cart_sbg[...,0]**2 + positions_cart_sbg[...,1]**2)
Zsbg = positions_cart_sbg[...,2]
phisb = onp.arctan2(positions_cart_sbg[...,1], positions_cart_sbg[...,0])


In [7]:
RSampling, ZSampling, phisampling, uu, vv = SBGeom.VMEC.Get_Sampling_Curve(fs_sbgeom, n_theta, n_phi, d_total)

In [10]:

def test_RZ_to_VMEC_lcfs(_get_flux_surfaces):
    fs_jax, fs_sbgeom = _get_flux_surfaces

    mpol_max = int(fs_jax.settings.mpol)
    print(mpol_max)
    
    ntor_max = int(fs_jax.settings.ntor)
    print(ntor_max)
    mpol_max = int(jnp.max(jnp.abs(fs_jax.data.mpol_vector)))
    print(jnp.max(jnp.abs(fs_jax.data.ntor_vector)) / fs_jax.settings.nfp)

    n_theta = mpol_max * 2 +1# just above nyquist
    n_phi   = ntor_max * 2  + 1 # just above nyquist

    sampling_r, sampling_z = jsb.flux_surfaces.convert_to_VMEC._create_sampling_rz(fs_jax, 1.0, n_theta, n_phi)
    print(sampling_r.shape, sampling_z.shape)

    Rmnc, Zmns             = jsb.flux_surfaces.convert_to_VMEC._rz_to_vmec_representation(sampling_r, sampling_z)
    print(Rmnc.shape)

    onp.testing.assert_allclose(Rmnc, fs_jax.data.Rmnc[-1,:], rtol=1e-12, atol=1e-12)
    onp.testing.assert_allclose(Zmns, fs_jax.data.Zmns[-1,:], rtol=1e-12, atol=1e-12)

test_RZ_to_VMEC_lcfs((fs_jax, fs_sbgeom))

12
12
12.0
(23, 25) (23, 25)
(288,)
