Skip to content

Commit

Permalink
DAG.plot no longer uses matplotlib, using IPython.display.Image instead
Browse files Browse the repository at this point in the history
  • Loading branch information
Eduardo Blancas Reyes committed Oct 19, 2020
1 parent 6bf702d commit 7d57c21
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 52 deletions.
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ def read(*names, **kwargs):
# don't add it as such because it's gonna break installation for most
# setups, since we don't expect users to have graphviz installed
'pygraphviz',
# matplotlib only needed for dag.plot(output='matplotlib'),
'matplotlib',
# RemoteShellClient
'paramiko',
# Upload to S3
Expand Down
29 changes: 16 additions & 13 deletions src/ploomber/dag/DAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@
from tqdm.auto import tqdm
from jinja2 import Template
import mistune
from IPython.display import Image

from ploomber.Table import Table, TaskReport, BuildReport
from ploomber.products import MetaProduct
from ploomber.util import (image_bytes2html, isiterable, path2fig, requires,
markup)
from ploomber.util import (image_bytes2html, isiterable, requires, markup)
from ploomber import resources
from ploomber import executors
from ploomber.constants import TaskStatus, DAGStatus
Expand Down Expand Up @@ -550,7 +550,7 @@ def to_markup(self, path=None, fmt='html', sections=None):
"""
sections = sections or ['plot', 'status']

if fmt not in ['html', 'md']:
if fmt not in {'html', 'md'}:
raise ValueError('fmt must be html or md, got {}'.format(fmt))

if 'status' in sections:
Expand All @@ -559,8 +559,9 @@ def to_markup(self, path=None, fmt='html', sections=None):
status = False

if 'plot' in sections:
path_to_plot = Path(self.plot())
plot = image_bytes2html(path_to_plot.read_bytes())
_, path_to_plot = tempfile.mkstemp(suffix='.png')
self.plot(output=path_to_plot)
plot = image_bytes2html(Path(path_to_plot).read_bytes())
else:
plot = False

Expand Down Expand Up @@ -591,13 +592,13 @@ def to_markup(self, path=None, fmt='html', sections=None):
'dependency of "pygraphviz", the easiest way to install both is '
'through conda "conda install pygraphviz", for more options see: '
'https://graphviz.org/'))
def plot(self, output='tmp'):
def plot(self, output='embed'):
"""Plot the DAG
"""
if output in {'tmp', 'matplotlib'}:
path = tempfile.mktemp(suffix='.png')
if output == 'embed':
_, path = tempfile.mkstemp(suffix='.png')
else:
path = output
path = str(output)

# attributes docs:
# https://graphviz.gitlab.io/_pages/doc/info/attrs.html
Expand All @@ -617,10 +618,12 @@ def plot(self, output='tmp'):
G_ = nx.nx_agraph.to_agraph(G)
G_.draw(path, prog='dot', args='-Grankdir=LR')

if output == 'matplotlib':
return path2fig(path)
else:
return path
image = Image(filename=path)

if output == 'embed':
Path(path).unlink()

return image

def _add_task(self, task):
"""Adds a task to the DAG
Expand Down
9 changes: 5 additions & 4 deletions src/ploomber/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from ploomber.util.util import (safe_remove, image_bytes2html,
isiterable, path2fig,
from ploomber.util.util import (safe_remove, image_bytes2html, isiterable,
requires)
from ploomber.util import markup
from ploomber.util.param_grid import Interval, ParamGrid

__all__ = ['safe_remove', 'image_bytes2html', 'Interval',
'ParamGrid', 'isiterable', 'path2fig', 'requires', 'markup']
__all__ = [
'safe_remove', 'image_bytes2html', 'Interval', 'ParamGrid', 'isiterable',
'requires', 'markup'
]
20 changes: 0 additions & 20 deletions src/ploomber/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,6 @@ def wrapper(*args, **kwargs):
return decorator


@requires(['matplotlib'])
def path2fig(path_to_image, dpi=50):
# FIXME: having this import at the top causes trouble with the
# multiprocessing library, moving it here solves the problem but we
# have to find a better solution.
# more info: https://stackoverflow.com/q/16254191/709975
import matplotlib.pyplot as plt

data = plt.imread(path_to_image)
height, width, _ = data.shape
fig = plt.figure()
fig.set_size_inches((width / dpi, height / dpi))
ax = plt.Axes(fig, [0, 0, 1, 1])
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(data)

return fig


def safe_remove(path):
if path.exists():
if path.is_file():
Expand Down
63 changes: 50 additions & 13 deletions tests/dag/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@

import pytest
import tqdm.auto
from IPython import display

from tests_util import executors_w_exception_logging
from ploomber import DAG
from ploomber import dag as dag_module
from ploomber.tasks import ShellScript, PythonCallable, SQLDump
from ploomber.products import File
from ploomber.constants import TaskStatus, DAGStatus
Expand All @@ -19,8 +21,6 @@
# TODO: a lot of these tests should be in a test_executor file
# since they test Errored or Executed status and the output errors, which
# is done by the executor
# TODO: test dag.plot(), create a function that returns an object and test
# such function, to avoid comparing images

# parametrize tests over these executors
_executors = [
Expand Down Expand Up @@ -87,20 +87,57 @@ def failing(upstream, product):
raise FailedTask('Bad things happened')


# can test this since this uses dag.plot(), which needs dot for plotting
# def test_to_html():
# def fn1(product):
# pass
@pytest.fixture
def dag():
def fn1(product):
pass

def fn2(upstream, product):
pass

dag = DAG()
t1 = PythonCallable(fn1, File('file1.txt'), dag)
t2 = PythonCallable(fn2, File('file2.txt'), dag)
t1 >> t2

return dag


def test_plot_embed(dag, monkeypatch):

obj = object()
mock = Mock(wraps=display.Image, return_value=obj)
monkeypatch.setattr(dag_module.DAG, 'Image', mock)

img = dag.plot()

kwargs = mock.call_args[1]
mock.assert_called_once()
assert set(kwargs) == {'filename'}
# file should not exist, it's just temporary
assert not Path(kwargs['filename']).exists()
assert img is obj


def test_plot_path(dag, tmp_directory, monkeypatch):

obj = object()
mock = Mock(wraps=display.Image, return_value=obj)
monkeypatch.setattr(dag_module.DAG, 'Image', mock)

img = dag.plot(output='pipeline.png')

# def fn2(product):
# pass
kwargs = mock.call_args[1]
mock.assert_called_once()
assert kwargs == {'filename': 'pipeline.png'}
assert Path('pipeline.png').exists()
assert img is obj

# dag = DAG()
# t1 = PythonCallable(fn1, File('file1.txt'), dag)
# t2 = PythonCallable(fn2, File('file2.txt'), dag)
# t1 >> t2

# dag.to_html('a.html')
@pytest.mark.parametrize('fmt', ['html', 'md'])
@pytest.mark.parametrize('sections', [None, 'plot', 'status', 'source'])
def test_to_markup(fmt, sections, dag):
dag.to_markup(fmt=fmt, sections=sections)


def test_count_in_progress_bar(monkeypatch, tmp_directory):
Expand Down

0 comments on commit 7d57c21

Please sign in to comment.