diff --git a/livelossplot/outputs/matplotlib_subplots.py b/livelossplot/outputs/matplotlib_subplots.py index 6057650..b11adc8 100644 --- a/livelossplot/outputs/matplotlib_subplots.py +++ b/livelossplot/outputs/matplotlib_subplots.py @@ -1,3 +1,5 @@ +from typing import Literal, Optional + import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import ListedColormap @@ -5,14 +7,30 @@ class BaseSubplot: def __init__(self): - pass + self.output_mode: Literal['notebook', 'script'] = 'notebook' def draw(self, *args, **kwargs): - raise Exception("Not implemented") + raise NotImplementedError def __call__(self, *args, **kwargs): self.draw(*args, **kwargs) + def set_output_mode(self, mode: Literal['notebook', 'script']): + self.output_mode = mode + + def _present(self, fig: plt.Figure): + """Render fig appropriately for the current output mode.""" + if self.output_mode == 'notebook': + try: + from IPython.display import clear_output + clear_output(wait=True) + except ImportError: + pass + plt.show() + else: + plt.draw() + plt.pause(0.05) + class LossSubplot(BaseSubplot): """To rewrire, this one now won't work""" @@ -22,7 +40,7 @@ def __init__( 'validation': 'val_{}' }, skip_first=2, max_epoch=None ): - super().__init__(self) + super().__init__() self.metric = metric self.title = title self.series_fmt = series_fmt @@ -63,7 +81,7 @@ def draw(self, logs): class Plot1D(BaseSubplot): def __init__(self, model, X, Y): - super().__init__(self) + super().__init__() self.model = model self.X = X self.Y = Y @@ -88,8 +106,6 @@ def __init__(self, model, X, Y, valiation_data=(None, None), h=0.02, margin=0.25 self.Y = Y self.X_test, self.Y_test = valiation_data - # add size assertions - self.cm_bg = plt.cm.RdBu self.cm_points = ListedColormap(['#FF0000', '#0000FF']) @@ -103,6 +119,9 @@ def __init__(self, model, X, Y, valiation_data=(None, None), h=0.02, margin=0.25 self.torch_device = device + self._fig: Optional[plt.Figure] = None + self._ax: Optional[plt.Axes] = None + def _predict_pytorch(self, model, x_numpy): import torch x = torch.from_numpy(x_numpy).to(self.torch_device).float() @@ -113,9 +132,15 @@ def predict(self, model, X): return model.predict(X) def send(self, logger): - Z = self._predict_pytorch(self.model, np.c_[self.xx.ravel(), self.yy.ravel()])[:, 1] + if self._fig is None or not plt.fignum_exists(self._fig.number): + self._fig, self._ax = plt.subplots() + else: + self._ax.clear() + + Z = self.predict(self.model, np.c_[self.xx.ravel(), self.yy.ravel()])[:, 1] Z = Z.reshape(self.xx.shape) - plt.contourf(self.xx, self.yy, Z, cmap=self.cm_bg, alpha=.8) - plt.scatter(self.X[:, 0], self.X[:, 1], c=self.Y, cmap=self.cm_points) + self._ax.contourf(self.xx, self.yy, Z, cmap=self.cm_bg, alpha=.8) + self._ax.scatter(self.X[:, 0], self.X[:, 1], c=self.Y, cmap=self.cm_points) if self.X_test is not None: - plt.scatter(self.X_test[:, 0], self.X_test[:, 1], c=self.Y_test, cmap=self.cm_points, alpha=0.3) + self._ax.scatter(self.X_test[:, 0], self.X_test[:, 1], c=self.Y_test, cmap=self.cm_points, alpha=0.3) + self._present(self._fig)