In [None]:
import numpy as np
import ipywidgets as widgets
from ipywidgets import Widget, DOMWidget, register, widget_serialization
from traitlets import Unicode, CInt, CFloat, CBool, Enum, Instance, List, Dict
from traittypes import Array

# Should be somewhere else
from unray.traits_numpy import array_serialization, shape_constraints

In [None]:
# Setup some test data
coordinates = [
    [0,0,0],
    [1,0,0],
    [0,1,0],
    [.5,.5,1],
    ]

cells = [
    [0, 1, 2, 3],
    ]

indicators = [7]

values = [0.0, 0.2, 0.5, 1.0]

lut = [0.0, 1.0]

color_lot = [(1,0,0), (0,0,1)]


In [None]:
# Generic traits

# TODO: Make proper traits with validation for these or use traits defined somewhere else?


def Color(default_value=(0.0, 0.0, 0.0)):
    return List(trait=CFloat(), default_value=default_value, minlen=3, maxlen=3)


def Range(default_value=(0.0, 1.0)):
    return List(trait=CFloat(), default_value=default_value, minlen=2, maxlen=2)


def CellType(default_value="tetrahedron"):
    # TODO: Make this a trait with validation
    # TODO: Use Enum
    valid_celltypes = [
        "point",
        "line",
        "triangle", "quadrilateral",
        "tetrahedron", "hexahedron", "wedge"
    ]
    return Enum(valid_celltypes, default_value=default_value)


In [None]:
# Generic visualization data widgets

module_name = "datawidgets"
module_version = "0.1.0"

class DataWidget(Widget):
    _view_module = Unicode(module_name).tag(sync=True)
    _model_module = Unicode(module_name).tag(sync=True)
    _view_module_version = Unicode(module_version).tag(sync=True)
    _model_module_version = Unicode(module_version).tag(sync=True)

    def _force_update_array(self, name, array):
        old = getattr(self, name)
        force = old is array
        setattr(name, array)
        if force:
            self.send_state(name)


@register
class LUT(DataWidget):
    values = Array().tag(sync=True, **array_serialization)


@register
class ScalarLUT(LUT):
    _model_name = Unicode('ScalarLUTModel').tag(sync=True)

    interpolation = Enum(["nearest", "linear"],
                         default_value="linear").tag(sync=True)


@register
class ColorLUT(LUT):
    _model_name = Unicode('ColorLUTModel').tag(sync=True)

    interpolation = Enum(["nearest", "linear"],
                         default_value="linear").tag(sync=True)

@register
class NominalScalarLUT(LUT):
    _model_name = Unicode('NominalScalarLUTModel').tag(sync=True)


@register
class NominalColorLUT(LUT):
    _model_name = Unicode('NominalColorLUTModel').tag(sync=True)



@register
class Points(DataWidget):
    _model_name = Unicode('PointsModel').tag(sync=True)

    coordinates = Array().tag(sync=True, **array_serialization)

    def update(self, coordinates):
        self._force_update_array("coordinates", coordinates)


@register
class MeshCells(DataWidget):
    _model_name = Unicode('MeshCellsModel').tag(sync=True)

    celltype = CellType().tag(sync=True)
    cells = Array().tag(sync=True, **array_serialization)

    def update(self, cells):
        self._force_update_array("cells", cells)


@register
class Mesh(DataWidget):
    _model_name = Unicode('MeshModel').tag(sync=True)

    cells = Instance(MeshCells).tag(sync=True, **widget_serialization)
    vertices = Instance(Points).tag(sync=True, **widget_serialization)


@register
class Field(DataWidget):
    _model_name = Unicode('FieldModel').tag(sync=True)

    mesh = Instance(Mesh).tag(sync=True, **widget_serialization)
    space = Unicode("P1").tag(sync=True)
    values = Array().tag(sync=True, **array_serialization)

    def update(self, values):
        self._force_update_array("values", values)


@register
class NominalField(DataWidget):
    _model_name = Unicode('NominalFieldModel').tag(sync=True)

    mesh = Instance(Mesh).tag(sync=True, **widget_serialization)
    dim = CInt().tag(sync=True)
    values = Array().tag(sync=True, **array_serialization)

    def update(self, values):
        self._force_update_array("values", values)


In [None]:
# Somewhat generic plot widgets similar to bqplot classes here:  http://bqplot.readthedocs.io/en/stable/

@register
class Scale(Widget):
    pass

@register
class LinearScale(Scale):
    _model_name = Unicode('LinearScaleModel').tag(sync=True)

@register
class LogScale(Scale):
    _model_name = Unicode('LogScaleModel').tag(sync=True)


@register
class Axis(Widget):
    _model_name = Unicode('AxisModel').tag(sync=True)

    scale = Instance(Scale).tag(sync=True, **widget_serialization)

@register
class Grid(Widget):
    _model_name = Unicode('GridModel').tag(sync=True)

    axes = List(Instance(Axis)).tag(sync=True, **widget_serialization)

#class Mark(pythreejs.Group):
@register
class Mark(Widget):
    pass

#class Figure(pythreejs.Group):
@register
class Figure(DOMWidget):
    _model_name = Unicode('FigureModel').tag(sync=True)
    _view_name = Unicode('FigureView').tag(sync=True)

    grid = Instance(Grid).tag(sync=True, **widget_serialization)
    marks = List(Instance(Mark)).tag(sync=True, **widget_serialization)
    title = Unicode().tag(sync=True)

    #def _ipython_display_(self):
    #    # TODO: Create renderer and display that?
    #    return self


In [None]:
# Unray specific plot channels

module_name = "unray"
module_version = "0.1.0"

class UnrayChannel(Widget):
    _view_module = Unicode(module_name).tag(sync=True)
    _model_module = Unicode(module_name).tag(sync=True)
    _view_module_version = Unicode(module_version).tag(sync=True)
    _model_module_version = Unicode(module_version).tag(sync=True)


class MarkersChannel(UnrayChannel):
    _model_name = Unicode('MarkersChannel').tag(sync=True)

    field = Instance(NominalField, allow_none=True).tag(sync=True, **widget_serialization)
    enable_values = List(CInt()).tag(sync=True)
    disable_values = List(CInt()).tag(sync=True)


class ScalarChannel(UnrayChannel):
    _model_name = Unicode('ScalarChannel').tag(sync=True)

    constant = CFloat(0.0).tag(sync=True)
    field = Instance(Field, allow_none=True).tag(sync=True, **widget_serialization)
    lut = Instance(ScalarLUT, allow_none=True).tag(sync=True, **widget_serialization)


class ColorChannel(UnrayChannel):
    _model_name = Unicode('ColorChannel').tag(sync=True)

    # TODO: Add validation to check that data for only one model is provided
    #model = Unicode("rgb").tag(sync=True)

    # Alt 1: RGB model: fixed color or scalar field mapped through color lookup table
    constant = Color().tag(sync=True)
    field = Instance(Field, allow_none=True).tag(sync=True, **widget_serialization)
    lut = Instance(ColorLUT, allow_none=True).tag(sync=True, **widget_serialization)
    range = Range().tag(sync=True)

    # Alt 2: HSL model: separate subchannels for hue, saturation, and luminance,
    # each which can be fixed or provided by fields and lookup tables
    hue_constant = CFloat(0.0).tag(sync=True)
    hue_field = Instance(NominalField, allow_none=True).tag(sync=True, **widget_serialization)
    hue_lut = Instance(NominalScalarLUT, allow_none=True).tag(sync=True, **widget_serialization)

    sat_constant = CFloat(0.0).tag(sync=True)
    sat_field = Instance(Field, allow_none=True).tag(sync=True, **widget_serialization)
    sat_lut = Instance(ScalarLUT, allow_none=True).tag(sync=True, **widget_serialization)
    sat_range = Range().tag(sync=True)

    lum_constant = CFloat(0.0).tag(sync=True)
    lum_field = Instance(Field, allow_none=True).tag(sync=True, **widget_serialization)
    lum_lut = Instance(ScalarLUT, allow_none=True).tag(sync=True, **widget_serialization)
    lum_range = Range().tag(sync=True)


class WireframeChannel(UnrayChannel):
    _model_name = Unicode('WireframeChannel').tag(sync=True)

    enable = CBool(True).tag(sync=True)
    width = CFloat(0.01).tag(sync=True)
    color = Color().tag(sync=True)


In [None]:
# Unray specific plot marks

module_name = "unray"
module_version = "0.1.0"

class UnrayMark(Mark):
    _view_module = Unicode(module_name).tag(sync=True)
    _model_module = Unicode(module_name).tag(sync=True)
    _view_module_version = Unicode(module_version).tag(sync=True)
    _model_module_version = Unicode(module_version).tag(sync=True)

    mesh = Instance(Mesh).tag(sync=True, **widget_serialization)
    markers = Instance(MarkersChannel, allow_none=True).tag(sync=True, **widget_serialization)


@register
class Surface(UnrayMark):
    _model_name = Unicode('SurfaceModel').tag(sync=True)

    color = Instance(ColorChannel).tag(sync=True, **widget_serialization)
    wireframe = Instance(WireframeChannel, allow_none=True).tag(sync=True, **widget_serialization)

    # TODO: Which shading models? E.g. flat, phong, depth, gradient, noise
    shading = Unicode("flat").tag(sync=True)


@register
class Scatter(UnrayMark):
    _model_name = Unicode('ScatterModel').tag(sync=True)

    size = Instance(ScalarChannel).tag(sync=True, **widget_serialization)
    color = Instance(ColorChannel).tag(sync=True, **widget_serialization)


@register
class Wireframe(UnrayMark):
    _model_name = Unicode('WireframeModel').tag(sync=True)

    width = CFloat(0.01).tag(sync=True)
    color = Color().tag(sync=True)


@register
class XRay(UnrayMark):
    _model_name = Unicode('XRayModel').tag(sync=True)

    density = Instance(ScalarChannel).tag(sync=True, **widget_serialization)


@register
class MaxProjection(UnrayMark):
    _model_name = Unicode('MaxProjectionModel').tag(sync=True)

    emission = Instance(ColorChannel).tag(sync=True, **widget_serialization)


@register
class MinProjection(UnrayMark):
    _model_name = Unicode('MaxProjectionModel').tag(sync=True)

    emission = Instance(ColorChannel).tag(sync=True, **widget_serialization)


@register
class Splat(UnrayMark):
    _model_name = Unicode('SplatModel').tag(sync=True)

    emission = Instance(ColorChannel).tag(sync=True, **widget_serialization)


@register
class Volume(UnrayMark):
    _model_name = Unicode('VolumeModel').tag(sync=True)

    density = Instance(ScalarChannel).tag(sync=True, **widget_serialization)
    emission = Instance(ColorChannel).tag(sync=True, **widget_serialization)

    # There are multiple ways to approximate the emission-absorption integral
    #method = Unicode("auto").tag(sync=True)


@register
class Isorange(UnrayMark):
    _model_name = Unicode('IsorangeModel').tag(sync=True)

    field = Instance(ScalarChannel).tag(sync=True, **widget_serialization)
    color = Instance(ColorChannel).tag(sync=True, **widget_serialization)

    # TODO: Maybe extend to multiple intervals? (x-offset) in [n*dist, n*dist + width]
    range = Range().tag(sync=True)


@register
class Isosurface(UnrayMark):
    _model_name = Unicode('IsosurfaceModel').tag(sync=True)

    field = Instance(ScalarChannel).tag(sync=True, **widget_serialization)
    color = Instance(ColorChannel).tag(sync=True, **widget_serialization)

    # TODO: List of values, or range with strides, linear or log
    values = List(trait=CFloat()).tag(sync=True)


In [None]:
mesh = Mesh(
    cells = MeshCells(cells=cells),
    vertices = Points(coordinates=coordinates)
    )

field = Field(
    mesh = mesh,
    values = values
    )

labels = NominalField(
    mesh = mesh,
    values = values
    )

lut = ScalarLUT(values=[0.0, 1.0])

color_lut = ColorLUT(values=[(1,0,0), (0,0,1)])

labels_lut = NominalScalarLUT(values=[0,1])

#labels_lut = NominalScalarLUT(values={ 3: "red", 2: "blue" })  # TODO: Make this possible

In [None]:
color = ColorChannel()
color.constant = (1,0,0)

In [None]:
color = ColorChannel()
color.field = field
color.lut = color_lut

In [None]:
color = ColorChannel()
color.hue_constant = 0.0
color.sat_field = field
color.sat_lut = lut
color.lum_constant = 0.6

In [None]:
color = ColorChannel()
color.hue_field = labels
color.hue_lut = labels_lut
color.sat_field = field
color.sat_lut = lut
color.sat_range = (0.4, 0.8)
color.lum_field = field
color.lum_lut = lut
color.lum_range = (0.1, 0.9)

In [None]:
density = ScalarChannel()
density.constant = 0.1

In [None]:
density = ScalarChannel()
density.field = field
density.lut = lut

In [None]:
surf = Surface(
    mesh = mesh,
    color = color
    )

xray = XRay(
    mesh = mesh,
    density = density
    )

mip = MaxProjection(
    mesh = mesh,
    emission = color
    )

In [None]:
# Setup grid
axes = []
grid = Grid(*axes)
# three obj above subclasses of pythreejs.Object3D


# Setup figure
fig = Figure(
    grid=grid,
    marks=[surf, xray],
    title="Hello world",
    # Allow passing renderer and camera but create these by default if not provided:
    #scene=pythreejs.Scene(...),
    #renderer=pythreejs.Renderer(...),
    #camera=pythreejs.Camera(...)
    )
#display(fig)

In [None]:
# This is not up to date!


# Example configuration of data widgets

# Unstructured mesh
mesh = Mesh(cells=cells, coordinates=coordinates)
#mesh.update(new_coordinates)

# Cellwise constant field over mesh
field0 = Field(mesh=mesh, values=values, space="P0")
#field0.update(new_values)

# Cellwise continuous linear field over mesh
field1 = Field(mesh=mesh, values=values, space="P1")
#field1.update(new_values)

# Cellwise discontinuous linear field over mesh
field2 = Field(mesh=mesh, values=values, space="D1")
#field2.update(new_values)

# Cell indicator values over mesh
ind3 = NominalField(mesh=mesh, values=indicators, dim=3)
#ind3.update(new_values)

# Facet indicator values over mesh
ind2 = NominalField(mesh=mesh, values=values, dim=2)
#ind2.update(new_values)

# Lookup tables
lut0 = ScalarLUT(values=[0.2, 0.4, 0.7, 1.0], interpolation="nearest")
lut1 = ScalarLUT(values=[0, 1], interpolation="linear")
lut2 = ColorLUT(values=[[1,0,0], [0,0,1]], interpolation="linear")