In [1]:
import discretisedfield as df  # df is here chosen to be an alias for discretisedfield

p1 = (0, 0, 0)
p2 = (100e-9, 50e-9, 20e-9)

region = df.Region(p1=p1, p2=p2)

In [2]:
import pyvista as pv

pv.set_jupyter_backend("trame")

In [3]:
region.pyvista()

Widget(value="<iframe src='http://localhost:38453/index.html?ui=P_0x7f3f3841e970_0&reconnect=auto' style='widt…

In [4]:
lx, ly, lz = 100e-9, 50e-9, 20e-9

subregions = {
    "bottom_subregion": df.Region(p1=(20e-9, 0, 0), p2=(40e-9, 50e-9, 10e-9)),
    "top_subregion": df.Region(p1=(80e-9, 40e-9, lz / 2), p2=(lx, ly, lz)),
}

cell = (5e-9, 5e-9, 5e-9)

region = df.Region(p1=(0, 0, 0), p2=(lx, ly, lz))
mesh = df.Mesh(region=region, cell=cell, subregions=subregions)

In [5]:
mesh.pyvista()

Widget(value="<iframe src='http://localhost:38453/index.html?ui=P_0x7f3f383a85e0_1&reconnect=auto' style='widt…

In [6]:
mesh.pyvista.subregions()

Widget(value="<iframe src='http://localhost:38453/index.html?ui=P_0x7f3f2bbd87c0_2&reconnect=auto' style='widt…

In [7]:
import numpy as np
import ubermagutil.units as uu

In [133]:
class PyVistaField:
    def __init__(self, field):
        if field.mesh.region.ndim != 3:
            raise RuntimeError("Only 3d meshes can be plotted.")
        self.field = field * 1

    def __call__(self, *, plot=None, multiplier=None, **kwargs):
        if self.field.nvdim != 3:
            raise RuntimeError(
                "Only meshes with 3 vector dimensions can be plotted not"
                f" {self.field.mesh.region.ndim=}."
            )

        if plot is None:
            plotter = pv.Plotter()
        else:
            plotter = plot

        multiplier = self._setup_multiplier(multiplier)

        self.field.mesh.scale(1 / multiplier, reference_point=(0, 0, 0), inplace=True)

        field_pv = pv.wrap(self.field.to_vtk())
        # field_pv = self._vtk_add_point_data(field_pv)

        scale = np.min(self.field.mesh.cell) / np.max(self.field.norm.array)

        vector = pv.Arrow(
            tip_radius=0.18,
            tip_length=0.4,
            scale=scale,
            tip_resolution=80,
            shaft_resolution=80,
            shaft_radius=0.05,
            start=(-0.5 * scale, 0, 0),
        )
        plotter.add_mesh(
            field_pv.glyph(orient="field", scale="norm", geom=vector),
            scalars="z",
            cmap="coolwarm",
        )

        label = self._axis_labels(multiplier)
        # Bounds only needed due to axis bug
        bounds = tuple(
            val
            for pair in zip(self.field.mesh.region.pmin, self.field.mesh.region.pmax)
            for val in pair
        )
        box = pv.Box(bounds)
        plotter.add_mesh(box, opacity=0.0)
        plotter.show_grid(xtitle=label[0], ytitle=label[1], ztitle=label[2])
        if plot is None:
            plotter.show()

    def valid(self, *, plot=None, multiplier=None, **kwargs):
        if self.field.nvdim != 3:
            raise RuntimeError(
                "Only meshes with 3 vector dimensions can be plotted not"
                f" {self.field.mesh.region.ndim=}."
            )

        if plot is None:
            plotter = pv.Plotter()

        multiplier = self._setup_multiplier(multiplier)

        rescaled_mesh = self.field.mesh.scale(1 / multiplier, reference_point=(0, 0, 0))

        values = self.field.valid.astype(int)

        grid = pv.RectilinearGrid(*rescaled_mesh.vertices)
        grid.cell_data["values"] = values.flatten(order="F")

        # plotter.add_mesh(grid, scalars='values', opacity='values', **kwargs)
        plotter.add_volume(grid, scalars="values", flip_scalars=True, **kwargs)
        plotter.remove_scalar_bar()
        label = self._axis_labels(multiplier)
        # Bounds only needed due to axis bug
        bounds = tuple(
            val
            for pair in zip(rescaled_mesh.region.pmin, rescaled_mesh.region.pmax)
            for val in pair
        )
        box = pv.Box(bounds)
        plotter.add_mesh(box, opacity=0.0)
        plotter.show_grid(xtitle=label[0], ytitle=label[1], ztitle=label[2])
        if plot is None:
            plotter.show()

    def contour(self, *, plot=None, multiplier=None, **kwargs):
        if self.field.nvdim != 3:
            raise RuntimeError(
                "Only meshes with 3 vector dimensions can be plotted not"
                f" {self.field.mesh.region.ndim=}."
            )

        if plot is None:
            plotter = pv.Plotter()
        else:
            plotter = plot

        multiplier = self._setup_multiplier(multiplier)

        self.field.mesh.scale(1 / multiplier, reference_point=(0, 0, 0), inplace=True)

        field_pv = pv.wrap(self.field.to_vtk()).cell_data_to_point_data()
        # field_pv = self._vtk_add_point_data(field_pv)

        plotter.add_mesh(
            field_pv.contour(scalars="z", isosurfaces=[0]),
            cmap="coolwarm",
            opacity=0.5,
        )

        label = self._axis_labels(multiplier)
        # Bounds only needed due to axis bug
        bounds = tuple(
            val
            for pair in zip(self.field.mesh.region.pmin, self.field.mesh.region.pmax)
            for val in pair
        )
        box = pv.Box(bounds)
        plotter.add_mesh(box, opacity=0.0)
        plotter.show_grid(xtitle=label[0], ytitle=label[1], ztitle=label[2])
        if plot is None:
            plotter.show()

    def _setup_multiplier(self, multiplier):
        return self.field.mesh.region.multiplier if multiplier is None else multiplier

    def _axis_labels(self, multiplier):
        return [
            rf"{dim} ({uu.rsi_prefixes[multiplier]}{unit})"
            for dim, unit in zip(
                self.field.mesh.region.dims, self.field.mesh.region.units
            )
        ]

    def _vtk_add_point_data(self, mesh):
        pmesh = mesh.cell_data_to_point_data()
        for name in mesh.array_names:
            mesh.point_data[f"{name}-points"] = pmesh.point_data[name]
        return mesh

In [134]:
a, b, c = 5e-9, 3e-9, 2e-9
cell = (0.5e-9, 0.5e-9, 0.5e-9)

mesh = df.Mesh(p1=(-a, -b, -c), p2=(a, b, c), cell=cell)


def norm_fun(pos):
    x, y, z = pos
    if (x / a) ** 2 + (y / b) ** 2 + (z / c) ** 2 <= 1:
        return 1e6
    else:
        return 0


def value_fun(pos):
    x, y, z = pos
    c = 1e9
    return (-c * y, c * x, c * z)


field = df.Field(mesh, nvdim=3, value=value_fun, norm=norm_fun, valid="norm")

In [135]:
PyVistaField(field).contour()

Widget(value="<iframe src='http://localhost:38453/index.html?ui=P_0x7f3c4faa74f0_45&reconnect=auto' style='wid…

In [22]:
plot = pv.Plotter()
field.mesh.pyvista.subregions(plot=plot)
PyVistaField(field)(plot=plot)
plot.show()

Widget(value="<iframe src='http://localhost:38453/index.html?ui=P_0x7f3eea48cdf0_8&reconnect=auto' style='widt…