In [None]:
# import jax
# import jax.numpy as jnp

# jax.config.update("jax_compilation_cache_dir", "../jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [None]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../"))

from desc import set_device
set_device("gpu")

In [None]:
import numpy as np
np.set_printoptions(linewidth=np.inf, precision=4, suppress=True, threshold=sys.maxsize)
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.graph_objects as go
import functools
import scipy

In [None]:
import desc

from desc.basis import *
from desc.backend import *
from desc.compute import *
from desc.coils import *
from desc.equilibrium import *
from desc.examples import *
from desc.grid import *
from desc.geometry import *

from desc.objectives import *
from desc.objectives.objective_funs import *
from desc.objectives.getters import *
from desc.objectives.normalization import compute_scaling_factors
from desc.objectives.utils import *
from desc.optimize._constraint_wrappers import *

from desc.transform import Transform
from desc.plotting import *
from desc.optimize import *
from desc.perturbations import *
from desc.profiles import *
from desc.compat import *
from desc.utils import *
from desc.magnetic_fields import *

from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.input_reader import InputReader
from desc.continuation import solve_continuation_automatic

print_backend_info()

In [None]:
import numpy as np
import pyvista as pv
from pyvista import CellType

def export_coil_to_paraview(x, y, z, current, filename="coil.vtp"):
    # Stack coordinates into shape (N, 3)
    points = np.column_stack((x, y, z))
    n_points = len(points)

    # Define connectivity: a single line through all points
    # The format is: [n_points, p0, p1, p2, ..., pN]
    # For a polyline with N segments, there are N+1 points
    lines = np.hstack(([n_points], np.arange(n_points)))

    # Create PolyData object
    poly = pv.PolyData()
    poly.points = points
    poly.lines = lines

    # Add current as a scalar field
    current_array = np.full(n_points, current)
    poly["current"] = current_array

    # Save to VTP
    poly.save(filename)

In [None]:
def export_field_to_paraview(nodes, B, filename="magnetic_field.vtu"):
    """
    Export magnetic field data to ParaView using an unstructured grid.

    Parameters
    ----------
    nodes : ndarray, shape (N, 3)
        Grid node positions.
    B : ndarray, shape (N, 3)
        Magnetic field vectors at each node.
    filename : str
        Output filename (.vtu).
    """
    # Create PyVista point cloud
    point_cloud = pv.PolyData(nodes)

    # Attach vector field
    point_cloud["B"] = B

    # Save as a .vtu file
    point_cloud.save(filename)

In [None]:
def export_surface_to_vtu(nodes, data, grid_shape=None, filename="surface.vtu"):
    Nt, Np = grid_shape
    assert nodes.shape == (Nt * Np, 3), "Shape mismatch between nodes and grid_shape"

    points = nodes  # Already in (N, 3)

    # Build connectivity: each quad is made of 4 points
    cells = []
    celltypes = []

    for i in range(Nt):
        for j in range(Np):
            # Current point index
            p0 = i * Np + j
            # Neighbor indices (with wrap-around)
            p1 = i * Np + (j + 1) % Np
            p2 = ((i + 1) % Nt) * Np + (j + 1) % Np
            p3 = ((i + 1) % Nt) * Np + j

            # Append one quad: format = [4, pt0, pt1, pt2, pt3]
            cells.extend([4, p0, p1, p2, p3])
            celltypes.append(CellType.QUAD)

    # Convert to numpy arrays
    cells = np.array(cells)
    celltypes = np.array(celltypes, dtype=np.uint8)

    # Create the unstructured grid
    grid = pv.UnstructuredGrid(cells, celltypes, points)

    # Optionally add data
    for name, values in data.items():
        if len(values) == len(points):
            grid[name] = values
        else:
            raise ValueError(f"Length of {name} does not match number of points.")

    grid.save(filename)


def export_volume_to_vtu(nodes, data, grid_shape=None, filename="volume.vtu"):
    """
    Export volume grid to VTU with (Np, Nr, Nt) ordering:
    - p: poloidal (fastest)
    - r: radial (middle)
    - t: toroidal (slowest)
    """
    Np, Nr, Nt = grid_shape
    assert nodes.shape == (Np * Nr * Nt, 3), "Mismatch in node count and grid shape"

    def idx(p, r, t):
        return t * (Nr * Np) + r * Np + p

    cells = []
    celltypes = []

    for t in range(Nt):
        t1 = (t + 1) % Nt  # wrap around toroidally
        for r in range(Nr - 1):
            for p in range(Np):
                p1 = (p + 1) % Np  # wrap around poloidally

                # 8 nodes of the hexahedron
                # Bottom face
                pt0 = idx(p, r, t)
                pt1 = idx(p1, r, t)
                pt2 = idx(p1, r + 1, t)
                pt3 = idx(p, r + 1, t)

                # Top face (toroidal +1)
                pt4 = idx(p, r, t1)
                pt5 = idx(p1, r, t1)
                pt6 = idx(p1, r + 1, t1)
                pt7 = idx(p, r + 1, t1)

                cells.extend([8, pt0, pt1, pt2, pt3, pt4, pt5, pt6, pt7])
                celltypes.append(CellType.HEXAHEDRON)

    # Convert to VTK format
    cells = np.array(cells)
    celltypes = np.array(celltypes, dtype=np.uint8)

    grid = pv.UnstructuredGrid(cells, celltypes, nodes)

    # Optionally add data
    for name, values in data.items():
        if len(values) == len(nodes):
            grid[name] = values
        else:
            raise ValueError(f"Length of {name} does not match number of points.")

    grid.save(filename)
    return grid

In [None]:
# Example shape
Nr = 20  # radial resolution
Nt = 100  # toroidal resolution
Np = 100  # poloidal resolution

grid = LinearGrid(rho=np.linspace(0.2, 1, Nr), theta=Np, zeta=Nt, NFP=1)
eq = get("precise_QA")
dataB = eq.compute(["B", "X", "Y", "Z", "|F|"], grid=grid)
B = dataB["B"]
X = dataB["X"]
Y = dataB["Y"]
Z = dataB["Z"]
F = dataB["|F|"]
nodes = np.column_stack((X.flatten(), Y.flatten(), Z.flatten()))
data = {"B": B, "|F|": F}
# export_surface_to_vtu(nodes, data, grid_shape=(Nt, Np), filename="toroidal_surface.vtu")
mesh = export_volume_to_vtu(nodes, data, grid_shape=(Nr, Nt, Np), filename="toroidal_volume3.vtu")