In [None]:
"""
This file contains a draft of a high level API
covering most of the unray functionality.

Note that creating objects directly will be more efficient
and provide more detailed control of some parameters.
"""

import numbers
import numpy as np
import unray as ur
import pythreejs as three
from ipydatawidgets import NDArrayWidget

def as_array_widget(values):
    if values is None:
        return None
    elif isinstance(values, NDArrayWidget):
        return values
    return NDArrayWidget(values=values)

def as_scalar_lut(lut):
    if lut is None:
        values = np.linspace(0.0, 1.0, 256, dtype='float32')
        return ur.ArrayScalarLUT(values=values)
    elif isinstance(lut, (tuple, list, np.ndarray)):
        return ur.ArrayScalarLUT(values=lut)
    return lut

def as_color_lut(lut):
    if lut is None:
        values = np.outer(
            np.linspace(0.0, 1.0, 256, dtype='float32'),
            np.ones(3, dtype="float32")
            )
        return ur.ArrayColorLUT(values=values)
    elif isinstance(lut, (tuple, list, np.ndarray)):
        return ur.ArrayColorLUT(values=lut)
    return lut

def as_field(mesh, values):
    if isinstance(values, ur.Field):
        return values
    return ur.Field(mesh=mesh, values=values)

def as_scalar(mesh, values, scalarmap):
    if isinstance(values, numbers.Real):
        return ur.ScalarConstant(value=values)
    elif values is not None:
        return ur.ScalarField(
            field=as_field(mesh, values),
            lut=as_scalar_lut(scalarmap)
        )
    elif isinstance(scalarmap, ur.ScalarConstant):
        return scalarmap
    return None

def as_color(mesh, values, colormap):
    #if isinstance(values, numbers.Real):
    #    return ur.ColorConstant(value=values, lut=colormap)
    if values is not None:
        return ur.ColorField(
            field=as_field(mesh, values),
            lut=as_color_lut(colormap)
        )
    elif isinstance(colormap, ur.ColorConstant):
        return colormap
    elif isinstance(colormap, str):
        return ur.ColorConstant(color=colormap)
    return None

def as_restrict(mesh, restrict):
    if restrict is not None:
        ifield = ur.IndicatorField(mesh=mesh, values=restrict, space="I3")
        return ur.ScalarIndicators(field=ifield, value=1)
    return None

def as_wireframe_params(wireframe):
    if isinstance(wireframe, ur.WireframeParams):
        return wireframe
    if wireframe is None:
        return ur.WireframeParams(enable=False)
    elif isinstance(wireframe, dict):
        return ur.WireframeParams(**wireframe)
    elif isinstance(wireframe, bool):
        return ur.WireframeParams(enable=wireframe)
    return None

def surf(mesh, values=None, colormap=None, restrict=None, wireframe=None):
    """Create an unray SurfacePlot from numpy arrays.

    Consider this an example, the object based
    API is more flexible than this function.
    """
    assert isinstance(mesh, ur.Mesh)
    attribs = {}
    attribs["mesh"] = mesh
    attribs["restrict"] = as_restrict(mesh, restrict)
    attribs["color"] = as_color(mesh, values, colormap)
    attribs["wireframe"] = as_wireframe_params(wireframe)
    return ur.SurfacePlot(**attribs)

def xray(mesh, density=None, densitymap=None, restrict=None, extinction=None):
    """Create an unray XrayPlot from numpy arrays.

    Consider this an example, the object based
    API is more flexible than this function.
    """
    assert isinstance(mesh, ur.Mesh)
    attribs = {}
    attribs["mesh"] = mesh
    attribs["restrict"] = as_restrict(mesh, restrict)
    attribs["density"] = as_scalar(mesh, density, densitymap)
    if extinction is not None:
        attribs["extinction"] = extinction
    return ur.XrayPlot(**attribs)

def sumproj(mesh, values=None, colormap=None, restrict=None, exposure=None):
    """Create an unray MaxPlot from numpy arrays.

    Consider this an example, the object based
    API is more flexible than this function.
    """
    assert isinstance(mesh, ur.Mesh)
    attribs = {}
    attribs["mesh"] = mesh
    attribs["restrict"] = as_restrict(mesh, restrict)
    color = as_color(mesh, values, colormap)
    if color is not None:
        attribs["color"] = color
    if exposure is not None:
        attribs["exposure"] = exposure
    return ur.SumPlot(**attribs)

def setup_renderer(group, camera_position=(10, 10, 10), light_position=(0, 10, 10), width=800, height=600, background='#eeeeee'):
    "Helper function to setup a basic pythreejs renderer and scene, adding given group to it."
    camera = three.PerspectiveCamera(
        position=camera_position,
        aspect=width/height
    )
    key_light = three.DirectionalLight(position=light_position)
    ambient = three.AmbientLight(intensity=0.5)
    scene = three.Scene(children=[key_light, ambient, camera, group], background=background)
    controls = three.OrbitControls(camera)
    renderer = three.Renderer(scene, camera, [controls],
                              width=width, height=height)
    return renderer


In [None]:
def display_plot(plot, **kwargs):
    group = three.Group()
    group.add(plot)
    renderer = setup_renderer(group, **kwargs)
    return renderer

# Importing dependencies

In [None]:
# We'll use numpy for representing raw arrays
import numpy as np

# ipywidgets is the framework for handling GUI elements
# and communication between the python and browser context
import ipywidgets as widgets

# An NDArrayWidget represents a numpy array mirrored
# to the browser context, improving sharing of data
#from ipydatawidgets import NDArrayWidget

# pythreejs provides a scenegraph by mirroring
# three.js objects as ipywidgets
import pythreejs as three

# Finally the unray library
import unray as ur

# Currently defined in the cell above, either move into unray or just rewrite the demos to use the regular API:
#from unray.lab import surf, xray, sumproj, display_renderer

# Setup some data for testing
All unray plots need a mesh in the form of vertex coordinates in a M x 3 points array and vertex indices for each tetrahedron in a N x 4 cells array. Data for continuous piecewise linear functions is passed as length M arrays. (Discontinuous DP1 or P0 functions are not yet supported.)

In [None]:
def single_tetrahedron():
    cells = np.zeros((1, 4), dtype="uint32")
    coordinates = np.zeros((4, 3), dtype="float32")
    cells[0, :] = [0, 1, 2, 3]
    coordinates[0, :] = [0, 0, 0]
    coordinates[1, :] = [1, 0, 0]
    coordinates[2, :] = [0, 1, 0]
    coordinates[3, :] = [0, 0, 1]
    values = np.zeros(4, dtype="float32")
    values[:] = [1, 3, 2, -1]
    return coordinates, cells, values

def load_data(filename):
    mesh_data = np.load(filename)
    cells_array = mesh_data["cells"].astype(np.int32)
    points_array = mesh_data["points"].astype(np.float32)

    # Coordinates of all vertices in mesh
    x = list(points_array.T)  # x[2] = z coordinate array for all vertices

    # Model center 3d vector
    center = list(map(lambda x: x.mean(), x))

    # Model min/max coordinates
    bbox = list(map(lambda x: (x.min(), x.max()), x))

    # Coordinates with origo shifted to center of model
    xm = list(map(lambda x, mp: x - mp, x, center))

    # Distance from model center
    xd = np.sqrt(sum(map(lambda x: x**2, xm)))
    radius = xd.max()

    # Distance from center, normalized to max 1.0
    func_dist = xd / radius

    # A constant for all vertices
    func_const = np.ones(x[0].shape)

    # x coordinate
    func_x = x[0]

    # A wave pattern from the center of the model
    freq = 4
    func_wave = 2.0 + np.sin((freq * 2 * np.pi / radius) * xd)

    return cells_array, points_array, func_wave

# Example data
filename = "data/heart.npz"
#filename = "data/brain.npz"
#filename = "data/aneurysm.npz"
cells_array, points_array, function_array = load_data(filename)

#cells_array, points_array, function_array = single_tetrahedron()

# Data widgets
See the submodule `unray.datawidgets.*` for a list of all data widgets. These are used to set up the input to unray.

In [None]:
#import unray.datawidgets
#unray.datawidgets??

# Plot widgets
See the submodule `unray.plotwidgets.*` for a list of all plot widgets. Each plot widget class (e.g. `XrayPlot`) corresponds to one type of visualization and has its own set of valid parameters.

In [None]:
#import unray.plotwidgets
#unray.plotwidgets??

# Efficiency and memory usage
The unray API can in many places take pure numpy arrays with data. To save memory and network traffic (copying between the python and browser context), it is highly recommended to create data objects wrapping the numpy arrays
before setting up the plot objects. This will allow sharing data between plot objects on the browser side and simultaneous updating of fields across multiple plots.

In [None]:
# Wrap arrays with cell and point data in a reusable Mesh object
mesh = ur.Mesh(cells=cells_array, points=points_array)
D = 6  # Mesh diameter, used for positioning below

# Surface plot
The surface plot draws the facets of the mesh as solid opaque surfaces.
It can display just the mesh, be configured with wireframe parameters,
or show a scalar field mapped to colors on its surface.
All plot objects support restriction to cells.

In [None]:
plot = surf(mesh)
display_plot(plot, background="white")

In [None]:
plot = surf(mesh, wireframe=True)
display_plot(plot, background="white")

In [None]:
wp = ur.WireframeParams(enable=True, color="#0000ff", opacity=0.3)
plot = surf(mesh, colormap="#ff8888", wireframe=wp)
display_plot(plot, background="white")

In [None]:
plot = surf(mesh, values=function_array)
display_plot(plot, background="white")

In [None]:
plot = surf(mesh, values=function_array)
display_plot(plot, background="white")

# Xray
The xray is a simple direct volume rendering mode with pure absorption of background light at every point in the mesh. A scalar density field is integrated along a view ray behind each pixel to compute the opacity of the mesh.
The image projected to screen is then simply the background image scaled by the transparency of the mesh (transparency = 1 - opacity).

In [None]:
plot = xray(mesh)
display_plot(plot, background="white")

In [None]:
plot = xray(mesh, density=0.15)
display_plot(plot, background="white")

# Integration with ipywidgets
Attributes of the data and plot widgets can be linked with sliders and other GUI elements from ipywidgets for some interactive control.

In [None]:
plot = xray(mesh, density=1.0)
renderer = display_plot(plot, background="white")

density_slider = widgets.FloatSlider(value=1.0, min=0.0, max=2.0, description="Density")
widgets.jslink(
    (density_slider, "value"),
    (plot.density, "value")
)

extinction_slider = widgets.FloatSlider(value=1.0, min=-10.0, max=10.0, description="Extinction")
widgets.jslink(
    (extinction_slider, "value"),
    (plot, "extinction")
)

widgets.VBox([renderer, density_slider, extinction_slider])

# Sum projection
The sum projection is a simple direct volume rendering mode with pure emission of light at every point in the mesh. The image projected to screen is then simply the integral of emitted light along a view ray behind each pixel.

In [None]:
plot = sumproj(mesh, colormap="blue", exposure=10.0)
display_plot(plot, background="black")

In [None]:
plot = sumproj(mesh, function_array, [[1,0,0], [0,0,1]], exposure=-0.0)
display_plot(plot, background="black")

In [None]:
plot = sumproj(mesh, colormap="red", exposure=10.0)
display_plot(plot, background="black")

In [None]:
plot = sumproj(mesh, colormap="green", exposure=-10.0)
display_plot(plot, background="black")

# That's all!