Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions livelossplot/outputs/matplotlib_subplots.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
from typing import Literal, Optional

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap


class BaseSubplot:
def __init__(self):
pass
self.output_mode: Literal['notebook', 'script'] = 'notebook'

def draw(self, *args, **kwargs):
raise Exception("Not implemented")
raise NotImplementedError

def __call__(self, *args, **kwargs):
self.draw(*args, **kwargs)

def set_output_mode(self, mode: Literal['notebook', 'script']):
self.output_mode = mode

def _present(self, fig: plt.Figure):
"""Render fig appropriately for the current output mode."""
if self.output_mode == 'notebook':
try:
from IPython.display import clear_output
clear_output(wait=True)
except ImportError:
pass
plt.show()
else:
plt.draw()
plt.pause(0.05)


class LossSubplot(BaseSubplot):
"""To rewrire, this one now won't work"""
Expand All @@ -22,7 +40,7 @@ def __init__(
'validation': 'val_{}'
}, skip_first=2, max_epoch=None
):
super().__init__(self)
super().__init__()
self.metric = metric
self.title = title
self.series_fmt = series_fmt
Expand Down Expand Up @@ -63,7 +81,7 @@ def draw(self, logs):

class Plot1D(BaseSubplot):
def __init__(self, model, X, Y):
super().__init__(self)
super().__init__()
self.model = model
self.X = X
self.Y = Y
Expand All @@ -88,8 +106,6 @@ def __init__(self, model, X, Y, valiation_data=(None, None), h=0.02, margin=0.25
self.Y = Y
self.X_test, self.Y_test = valiation_data

# add size assertions

self.cm_bg = plt.cm.RdBu
self.cm_points = ListedColormap(['#FF0000', '#0000FF'])

Expand All @@ -103,6 +119,9 @@ def __init__(self, model, X, Y, valiation_data=(None, None), h=0.02, margin=0.25

self.torch_device = device

self._fig: Optional[plt.Figure] = None
self._ax: Optional[plt.Axes] = None

def _predict_pytorch(self, model, x_numpy):
import torch
x = torch.from_numpy(x_numpy).to(self.torch_device).float()
Expand All @@ -113,9 +132,15 @@ def predict(self, model, X):
return model.predict(X)

def send(self, logger):
Z = self._predict_pytorch(self.model, np.c_[self.xx.ravel(), self.yy.ravel()])[:, 1]
if self._fig is None or not plt.fignum_exists(self._fig.number):
self._fig, self._ax = plt.subplots()
else:
self._ax.clear()

Z = self.predict(self.model, np.c_[self.xx.ravel(), self.yy.ravel()])[:, 1]
Z = Z.reshape(self.xx.shape)
plt.contourf(self.xx, self.yy, Z, cmap=self.cm_bg, alpha=.8)
plt.scatter(self.X[:, 0], self.X[:, 1], c=self.Y, cmap=self.cm_points)
self._ax.contourf(self.xx, self.yy, Z, cmap=self.cm_bg, alpha=.8)
self._ax.scatter(self.X[:, 0], self.X[:, 1], c=self.Y, cmap=self.cm_points)
if self.X_test is not None:
plt.scatter(self.X_test[:, 0], self.X_test[:, 1], c=self.Y_test, cmap=self.cm_points, alpha=0.3)
self._ax.scatter(self.X_test[:, 0], self.X_test[:, 1], c=self.Y_test, cmap=self.cm_points, alpha=0.3)
self._present(self._fig)
Loading