Skip to content

Commit

Permalink
checks if matplotlib backend is inline
Browse files Browse the repository at this point in the history
  • Loading branch information
stared committed Mar 23, 2018
1 parent fb20506 commit 45b152d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
8 changes: 5 additions & 3 deletions livelossplot/core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import division
import warnings

import matplotlib
import matplotlib.pyplot as plt
from IPython.display import clear_output

def check_inline():
return "backend_inline" in matplotlib.get_backend()
def not_inline_warning():
backend = matplotlib.get_backend()
if "backend_inline" not in backend:
warnings.warn("livelossplot requires inline plots.\nYour current backend is: {}\nRun in a Jupyter environment and execute '%matplotlib inline'.".format(backend))

# TODO
# * check backend
# * object-oriented API
# * only integer ticks

Expand Down
4 changes: 3 additions & 1 deletion livelossplot/generic_plot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import division

from .core import draw_plot
from .core import draw_plot, not_inline_warning

class PlotLosses():
def __init__(self, figsize=None, cell_size=(6, 4), dynamic_x_axis=False, max_cols=2, max_epoch=None, metric2title={},
Expand All @@ -14,6 +14,8 @@ def __init__(self, figsize=None, cell_size=(6, 4), dynamic_x_axis=False, max_col
self.validation_fmt = validation_fmt
self.logs = None

not_inline_warning()

def set_metrics(self, metrics):
self.base_metrics = metrics
if self.figsize is None:
Expand Down
4 changes: 3 additions & 1 deletion livelossplot/keras_plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import division

from keras.callbacks import Callback
from .core import draw_plot
from .core import draw_plot, not_inline_warning

metric2printable = {
"acc": "Accuracy",
Expand Down Expand Up @@ -31,6 +31,8 @@ def __init__(self, figsize=None, cell_size=(6, 4), dynamic_x_axis=False, max_col
self.max_cols = max_cols
self.metric2printable = metric2printable.copy()

not_inline_warning()

def on_train_begin(self, logs={}):
self.base_metrics = [metric for metric in self.params['metrics'] if not metric.startswith('val_')]
if self.figsize is None:
Expand Down

0 comments on commit 45b152d

Please sign in to comment.