In [1]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax_sbgeom as jsb
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import numpy as onp
import sys 
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

import StellBlanket.SBGeom as SBGeom
from functools import partial
import jax_sbgeom.flux_surfaces.flux_surface_meshing as fsm
import jax_sbgeom.flux_surfaces.flux_surfaces_base as fsb
import jax_sbgeom.flux_surfaces.flux_surfaces_extended as fse
import pyvista as pv
from dataclasses import dataclass


vmec_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"]
vmec_file = vmec_files[1]

Triangle Elements not compiled


In [2]:
def _get_flux_surfaces(vmec_file):
    fs_jax    = jsb.flux_surfaces.FluxSurface.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file)
    return fs_jax, fs_sbgeom

def _get_extended_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtended.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

def _get_extended_no_phi_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtendedNoPhi.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended_No_Phi(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

def _get_extended_constant_phi_flux_surfaces(vmec_file):    
    fs_jax    = jsb.flux_surfaces.FluxSurfaceNormalExtendedConstantPhi.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.Flux_Surfaces_Normal_Extended_Constant_Phi(SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file))
    return fs_jax, fs_sbgeom

fs_jax, fs_sbgeom = _get_extended_constant_phi_flux_surfaces(vmec_file)  # just to compile
fs_jax, fs_sbgeom = _get_extended_no_phi_flux_surfaces(vmec_file)  # just to compile


In [3]:
d_total = 0.0
n_theta = 50
n_phi   = 60
symm = fs_jax.settings.nfp

theta = jnp.linspace(0, 2*jnp.pi, n_theta, endpoint=False)
phi   = jnp.linspace(0, 2*jnp.pi / symm, n_phi, endpoint=False)

theta_mg, phi_mg = jnp.meshgrid(theta, phi, indexing='ij')

In [4]:
positions_jax = fs_jax.cylindrical_position(1.0 + d_total, theta_mg, phi_mg)

In [5]:
positions_cart_sbg = fs_sbgeom.Return_Position(onp.ones_like(theta_mg).ravel(), onp.ones_like(theta_mg).ravel() * d_total, onp.array(theta_mg.ravel()), onp.array(phi_mg.ravel())).reshape(positions_jax.shape)


Rsbg = onp.sqrt(positions_cart_sbg[...,0]**2 + positions_cart_sbg[...,1]**2)
Zsbg = positions_cart_sbg[...,2]
phisb = onp.arctan2(positions_cart_sbg[...,1], positions_cart_sbg[...,0])


In [6]:
RSampling, ZSampling, phisampling, uu, vv = SBGeom.VMEC.Get_Sampling_Curve(fs_sbgeom, n_theta, n_phi, d_total)

In [None]:
def convert_flux_surface_to_different_ntor_mpol(fs : jsb.flux_surfaces.FluxSurface, mpol_new : int, ntor_new : int) -> jsb.flux_surfaces.FluxSurface:
    Rmnc_new2 = _convert_to_different_ntor_mpol(jnp.stack([fs.data.Rmnc, fs.data.Zmns], axis=0), mpol_new, ntor_new, fs.settings.mpol, fs.settings.ntor)
    #Zmns_new = _convert_to_different_ntor_mpol(fs.data.Zmns, mpol_new, ntor_new, fs.settings.mpol, fs.settings.ntor)
    Rmnc_new = Rmnc_new2[0]
    Zmns_new = Rmnc_new2[1]

    ntor_vector_new = _create_ntor_vector(ntor_new, mpol_new, fs.settings.nfp)    
    mpol_vector_new = _create_mpol_vector(ntor_new, mpol_new)
    data_new     = fsb.FluxSurfaceData(Rmnc=Rmnc_new, Zmns=Zmns_new, ntor_vector=ntor_vector_new, mpol_vector=mpol_vector_new)
    settings_new = fsb.FluxSurfaceSettings(mpol=mpol_new, ntor=ntor_new, nfp=fs.settings.nfp, nsurf=1)
    
    fs_new = jsb.flux_surfaces.FluxSurface(data=data_new, settings=settings_new)
    return fs_new

fs_new = convert_flux_surface_to_different_ntor_mpol(fs_jax, mpol_new=25, ntor_new=25)

In [200]:
print(fs_new.cartesian_position(1.0, 0.25, 0.24))
print(fs_jax.cartesian_position(1.0, 0.25, 0.24))

[23.36889775  5.7187596   0.96105676]
[23.36889775  5.7187596   0.96105676]
