diff --git a/livelossplot/core.py b/livelossplot/core.py index 6b8ab76..a9f7896 100644 --- a/livelossplot/core.py +++ b/livelossplot/core.py @@ -1,14 +1,16 @@ from __future__ import division +import warnings import matplotlib import matplotlib.pyplot as plt from IPython.display import clear_output -def check_inline(): - return "backend_inline" in matplotlib.get_backend() +def not_inline_warning(): + backend = matplotlib.get_backend() + if "backend_inline" not in backend: + warnings.warn("livelossplot requires inline plots.\nYour current backend is: {}\nRun in a Jupyter environment and execute '%matplotlib inline'.".format(backend)) # TODO -# * check backend # * object-oriented API # * only integer ticks diff --git a/livelossplot/generic_plot.py b/livelossplot/generic_plot.py index ff4221c..1f71d25 100644 --- a/livelossplot/generic_plot.py +++ b/livelossplot/generic_plot.py @@ -1,6 +1,6 @@ from __future__ import division -from .core import draw_plot +from .core import draw_plot, not_inline_warning class PlotLosses(): def __init__(self, figsize=None, cell_size=(6, 4), dynamic_x_axis=False, max_cols=2, max_epoch=None, metric2title={}, @@ -14,6 +14,8 @@ def __init__(self, figsize=None, cell_size=(6, 4), dynamic_x_axis=False, max_col self.validation_fmt = validation_fmt self.logs = None + not_inline_warning() + def set_metrics(self, metrics): self.base_metrics = metrics if self.figsize is None: diff --git a/livelossplot/keras_plot.py b/livelossplot/keras_plot.py index bd4ec70..a0b1c64 100644 --- a/livelossplot/keras_plot.py +++ b/livelossplot/keras_plot.py @@ -1,7 +1,7 @@ from __future__ import division from keras.callbacks import Callback -from .core import draw_plot +from .core import draw_plot, not_inline_warning metric2printable = { "acc": "Accuracy", @@ -31,6 +31,8 @@ def __init__(self, figsize=None, cell_size=(6, 4), dynamic_x_axis=False, max_col self.max_cols = max_cols self.metric2printable = metric2printable.copy() + not_inline_warning() + def on_train_begin(self, logs={}): self.base_metrics = [metric for metric in self.params['metrics'] if not metric.startswith('val_')] if self.figsize is None: