In [None]:
import jax_sbgeom as jsb
%load_ext autoreload
%autoreload 2
import h5py
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import numpy as onp
import sys 
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)
vmec_files = ["/home/tbogaarts/stellarator_paper/base_data/vmecs/helias3_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/helias5_vmec.nc4", "/home/tbogaarts/stellarator_paper/base_data/vmecs/squid_vmec.nc4"]

vmec_file = vmec_files[1]
import StellBlanket.SBGeom as SBGeom
from jax_sbgeom.flux_surfaces.flux_surfaces_base import _cartesian_position_interpolated_jit, _cylindrical_position_interpolated, _cartesian_position_interpolated_grad, ToroidalExtent, FluxSurface, FluxSurfaceData
from jax_sbgeom.flux_surfaces.flux_surface_meshing import _volume_of_mesh
from tests.flux_surfaces.flux_surface_base import test_position, _get_flux_surfaces, _sampling_grid, _1d_sampling_grid, test_normals, test_meshing_surface, test_principal_curvatures, _get_all_closed_surfaces, test_all_closed_surfaces, _mesh_to_pyvista_mesh, test_volumes
import pyvista as pv

In [None]:
def _get_flux_surfaces(vmec_file):
    fs_jax    = jsb.flux_surfaces.FluxSurface.from_hdf5(vmec_file)
    fs_sbgeom = SBGeom.VMEC.Flux_Surfaces_From_HDF5(vmec_file)
    return fs_jax, fs_sbgeom

fs_jax, fs_sbgeom = _get_flux_surfaces(vmec_file)  # just to compile

In [None]:
mesh = jsb.flux_surfaces.mesh_surface(fs_jax, 1.0, ToroidalExtent.full(), 100, 150)

In [None]:
@jax.jit
def Rmnc_Zmns_wrapper(Rmnc, Zmns, flux_surface):
    data_old = flux_surface.data
    fs_new = FluxSurface(FluxSurfaceData(Rmnc, Zmns, data_old.mpol_vector, data_old.ntor_vector), flux_surface.settings)
    return fs_new


def Volume_Rmnc_Zmns_wrapper(Rmnc, Zmns, flux_surface):
    fs_new = Rmnc_Zmns_wrapper(Rmnc, Zmns, flux_surface)
    points, connectivity = jsb.flux_surfaces.flux_surface_meshing._mesh_surface(fs_new, 1.0, 0.0, 2 * jnp.pi, True, 600, 600, True)
    volume = _volume_of_mesh(points, connectivity)
    return volume

def scaling_factor_wrapper(scaling , flux_surface):
    Rmnc_new = flux_surface.data.Rmnc * scaling
    Zmns_new = flux_surface.data.Zmns * scaling
    return Volume_Rmnc_Zmns_wrapper(Rmnc_new, Zmns_new, flux_surface)


print(scaling_factor_wrapper(1.0, fs_jax))
print(scaling_factor_wrapper(2.0, fs_jax))

In [None]:
test_volumes(vmec_file, n_repetitions=10)

In [None]:
from jax_sbgeom.flux_surfaces.flux_surfaces_base import _cylindrical_position_interpolated, _cylindrical_position_interpolated_grad, FluxSurfaceSettings, _cartesian_position_interpolated_grad, _normal_interpolated_jit, _dx_dphi_cross_dx_dtheta


In [None]:
volumes_f = jax.jit(jax.vmap(scaling_factor_wrapper, in_axes=(0, None)))

In [None]:
scaling_factor  = jnp.linspace(0.1, 5.0, 1000)
volumes_scaling = volumes_f(scaling_factor, fs_jax)

In [None]:
%timeit volumes_f(scaling_factor, fs_jax).block_until_ready()

In [None]:
import matplotlib.pyplot as plt
plt.plot(scaling_factor, volumes_scaling)
plt.plot(scaling_factor, volumes_scaling[0] * (scaling_factor / 0.1)**3 )

print(volumes_scaling, volumes_scaling[0] * (scaling_factor / 0.1)**3 )

plt.xlabel("Scaling Factor")
plt.ylabel("Volume [m^3]")



In [None]:
Rmnc_2 = fs_jax.data.Rmnc 
Rmnc_3 = Rmnc_2.at[-1,1].add(1e-5)

In [None]:
(Volume_Rmnc_wrapper(Rmnc_3, fs_jax) - Volume_Rmnc_wrapper(Rmnc_2, fs_jax)) / 1e-5

In [None]:
volume_grad = jax.grad(Volume_Rmnc_wrapper, argnums=0)

volume_grad(fs_jax.data.Rmnc, fs_jax)

In [None]:
%timeit ree(fs_jax).block_until_ready()

In [None]:
%timeit volume_grad(fs_jax.data.Rmnc, fs_jax, 0.1, 0.2).block_until_ready()

In [None]:
%timeit Volume_Rmnc_wrapper(fs_jax.data.Rmnc, fs_jax).block_until_ready()

In [None]:
points, connectivity = mesh

mesh_pv = _mesh_to_pyvista_mesh(points, connectivity)

print(mesh_pv.volume)
_volume_of_mesh(*mesh)
