Skip to content

Commit

Permalink
add 'LearnerND.plot_3D' and add an example to the docs
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Oct 20, 2018
1 parent df84178 commit 1119d95
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 1 deletion.
65 changes: 64 additions & 1 deletion adaptive/learner/learnerND.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from .base_learner import BaseLearner

from ..notebook_integration import ensure_holoviews
from ..notebook_integration import ensure_holoviews, ensure_plotly
from .triangulation import (Triangulation, point_in_simplex,
circumsphere, simplex_volume_in_embedding)
from ..utils import restore, cache_latest
Expand Down Expand Up @@ -585,6 +585,69 @@ def plot_slice(self, cut_mapping, n=None):
else:
raise ValueError("Only 1 or 2-dimensional plots can be generated.")

def plot_3D(self, with_triangulation=False):
"""Plot the learner's data in 3D using plotly.
Parameters
----------
with_triangulation : bool, default: False
Add the verticices to the plot.
Returns
-------
plot : plotly.offline.iplot object
The 3D plot of ``learner.data``.
"""
plotly = ensure_plotly()

plots = []

vertices = self.tri.vertices
if with_triangulation:
Xe, Ye, Ze = [], [], []
for simplex in self.tri.simplices:
for s in itertools.combinations(simplex, 2):
Xe += [vertices[i][0] for i in s] + [None]
Ye += [vertices[i][1] for i in s] + [None]
Ze += [vertices[i][2] for i in s] + [None]

plots.append(plotly.graph_objs.Scatter3d(
x=Xe, y=Ye, z=Ze, mode='lines',
line=dict(color='rgb(125,125,125)', width=1),
hoverinfo='none'
))

Xn, Yn, Zn = zip(*vertices)
colors = [self.data[p] for p in self.tri.vertices]
marker = dict(symbol='circle', size=3, color=colors,
colorscale='Viridis',
line=dict(color='rgb(50,50,50)', width=0.5))

plots.append(plotly.graph_objs.Scatter3d(
x=Xn, y=Yn, z=Zn, mode='markers',
name='actors', marker=marker,
hoverinfo='text'
))

axis = dict(
showbackground=False,
showline=False,
zeroline=False,
showgrid=False,
showticklabels=False,
title='',
)

layout = plotly.graph_objs.Layout(
showlegend=False,
scene=dict(xaxis=axis, yaxis=axis, zaxis=axis),
margin=dict(t=100),
hovermode='closest')

fig = plotly.graph_objs.Figure(data=plots, layout=layout)

return plotly.offline.iplot(fig)

def _get_data(self):
return self.data

Expand Down
17 changes: 17 additions & 0 deletions docs/source/docs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,23 @@ on the *Play* :fa:`play` button or move the sliders.
plots = {n: plot(learner, n) for n in range(10, 10000, 200)}
hv.HoloMap(plots, kdims=['npoints'])

`adaptive.LearnerND`
~~~~~~~~~~~~~~~~~~~~

.. jupyter-execute::
:hide-code:

def sphere(xyz):
import numpy as np
x, y, z = xyz
a = 0.4
return np.exp(-(x**2 + y**2 + z**2 - 0.75**2)**2/a**4)

learner = adaptive.LearnerND(sphere, bounds=[(-1, 1), (-1, 1), (-1, 1)])
adaptive.runner.simple(learner, lambda l: l.npoints == 3000)

learner.plot_3D()

see more in the :ref:`Tutorial Adaptive`.

.. include:: ../../README.rst
Expand Down

0 comments on commit 1119d95

Please sign in to comment.