# Experimenting with visdom callbacks

## imports

In [1]:
from collections import ChainMap
from fnmatch import fnmatch
from functools import partial
import pickle
import time

In [2]:
from skorch.callbacks import Callback
from skorch.toy import make_classifier
from skorch import NeuralNetClassifier
from sklearn.datasets import make_classification
import numpy as np
import torch
import visdom

## data

In [3]:
torch.manual_seed(42)

<torch._C.Generator at 0x7ffad59a8670>

In [4]:
X, y = make_classification(1000, n_features=20, n_informative=10, n_classes=5, random_state=42)

In [5]:
X = X.astype(np.float32)
y[::3] = 0  # heavy bias for 0

## definitions

In [6]:
class FunctionCallback(Callback):
    def __init__(self, method_name, func):
        self.method_name = method_name
        self.func = func

    def __getattribute__(self, name):
        try:
            method_name = object.__getattribute__(self, 'method_name')
        except AttributeError:
            # This may happen when object state is currently being restored
            return object.__getattribute__(self, name)

        if method_name != name:
            return object.__getattribute__(self, name)
        return self.func

In [7]:
class _VisdomBase(Callback):
    def initialize(self):
        self._vis = None
        self.windows_ = {}
        return self

    def _check_visdom(self, vis):
        # TODO: friendly error message should visdom not be running
        return vis

    def vis_(self):
        # all this hassle is needed to so that we can delete all references to visdom
        # during __getstate__, since they are not pickleable
        if self._vis is not None:
            return self._vis

        import visdom
        kwargs = getattr(self, 'visdom_kwargs', {})
        kwargs = kwargs or {}
        vis = visdom.Visdom(**kwargs)
        self._vis = self._check_visdom(vis)
        return vis

    def __getstate__(self):
        state = self.__dict__.copy()
        state['_vis'] = None  # cannot be pickled
        return state

In [8]:
class _VisdomHistoryLinePlotMixin:
    def __init__(
            self,
            keys=None,
            opts=None,
            xlabel='epoch',
            ylabel='loss',
            visdom_kwargs=None,
    ):
        self.keys = keys
        self.opts = opts
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.visdom_kwargs = visdom_kwargs

    def _format_keys(self, keys):
        if self.keys is None:
            return ('train_loss', 'valid_loss')
        if isinstance(self.keys, str):
            return (self.keys,)
        return tuple(self.keys)

    @property
    def opts_(self):
        defaults = {
            'legend': list(self.keys_),  # visdom doesn't accept tuples
            'xlabel': self.xlabel,
            'ylabel': self.ylabel,
        }
        return dict(ChainMap(self.opts or {}, defaults))

    def initialize(self):
        super().initialize()
        self.first_iteration_ = True
        self.keys_ = self._format_keys(self.keys)
        return self

    def plotfn(self):
        return self.vis_().line

    def plot(self, name, plot_params):
        plot_params['opts']['title'] = "train progress"
        plot_params['win'] = self.windows_.get(name)
        self.windows_[name] = self.plotfn()(**plot_params)
    
    def make_plot(self, net):
        X, Y = self._get_data(net.history)
        plot_params = {
            'X': X,
            'Y': Y,
            'opts': self.opts_,
            'update': 'append',
        }
        if self.first_iteration_:
            del plot_params['update']
            self.first_iteration_ = False
        self.plot(self.keys_, plot_params)
        
    def _get_data(self, history):
        raise NotImplementedError

In [9]:
class VisdomHistoryLinePlotter(_VisdomHistoryLinePlotMixin, _VisdomBase):
    def _get_data(self, history):
        X = np.atleast_1d(history[-1:, 'epoch'])
        Y = np.atleast_1d(history[-1:, self.keys_])
        return X, Y

    def on_epoch_end(self, net, **kwargs):
        self.make_plot(net)

In [10]:
def match(pattern):
    return pattern if callable(pattern) else partial(fnmatch, pat=pattern)

In [11]:
class VisdomParamPlotter(_VisdomBase):
    def __init__(
            self,
            pattern,
            opts=None,
            visdom_kwargs=None,
    ):
        self.pattern = pattern
        self.opts = opts
        self.visdom_kwargs = visdom_kwargs

    def initialize(self):
        super().initialize()
        self.match_ = match(self.pattern)
        self.defaults_ = {}
        self.opts_ = dict(ChainMap(self.opts or {}, self.defaults_))
        return self

    def plotfn(self):
        return self.vis_().histogram

    def plot(self, name, val, plot_params):
        plot_params['opts']['title'] = self._format_name(name)
        plot_params['win'] = self.windows_.get(name)
        self.windows_[name] = self._plot(val, plot_params)

    def _format_name(self, name):
        return name

    def _plot(self, val, plot_params):
        plotfn = self.plotfn()
        return plotfn(val.view(-1), **plot_params)

    def on_grad_computed(self, net, named_parameters, **kwargs):
        plot_params = {'opts': self.opts_.copy()}
        for name, param in named_parameters:
            if self.match_(name):
                self.plot(name, param, plot_params)

In [12]:
class VisdomGradPlotter(VisdomParamPlotter):
    def _format_name(self, name):
        return "grad of '{}'".format(name)

    def _plot(self, val, plot_params):
        plotfn = self.plotfn()
        return plotfn(val.grad.view(-1), **plot_params)

In [13]:
class VisdomHeatmapPlotter(VisdomParamPlotter):
    def plotfn(self):
        return self.vis_().heatmap

    def _plot(self, val, plot_params):
        plotfn = self.plotfn()
        return plotfn(val, **plot_params)

In [14]:
class VisdomBarPlotter(VisdomParamPlotter):
    def plotfn(self):
        return self.vis_().bar

    def _plot(self, val, plot_params):
        plotfn = self.plotfn()
        return plotfn(val.view(-1), **plot_params)

In [15]:
def sleep(*args, **kwargs):
    time.sleep(0.25)

In [16]:
def name_output_bias(name):
    return name.endswith('3.bias')

## train model

It appears that visdom ignores 'ytickmin' when 'ytickmax' is not specified as well.

In [17]:
net = NeuralNetClassifier(
    make_classifier(input_units=20, output_units=5),
    lr=0.1,
    callbacks=[
        FunctionCallback('on_epoch_end', sleep),  # introduce delay for better observability
        VisdomHistoryLinePlotter(
            opts={'xtickmin': 0, 'fillarea': True}),
        VisdomHistoryLinePlotter(
            keys='valid_acc', ylabel='accuracy', opts={'ytickmin': 0, 'ytickmax': 1}),
        VisdomHistoryLinePlotter(
            keys='dur', ylabel='duration', opts={'markers': True}),
        VisdomGradPlotter(
            '*0.weight', visdom_kwargs={'port': 8097}),
        VisdomGradPlotter(
            '*3.weight', opts={'numbins': 10}),
        VisdomBarPlotter(
            name_output_bias),
        VisdomHeatmapPlotter(
            '*3.weight'),
    ],
)

In [18]:
net.fit(X, y)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m1.5100[0m       [32m0.4877[0m        [35m1.4329[0m  0.6410
      2        [36m1.4014[0m       0.4680        [35m1.3756[0m  0.1945
      3        [36m1.3515[0m       0.4729        [35m1.3416[0m  0.1862
      4        [36m1.3199[0m       0.4729        [35m1.3136[0m  0.1753
      5        [36m1.2945[0m       0.4778        [35m1.2899[0m  0.1828
      6        [36m1.2725[0m       0.4680        [35m1.2686[0m  0.1765
      7        [36m1.2513[0m       0.4828        [35m1.2485[0m  0.1723
      8        [36m1.2319[0m       [32m0.4926[0m        [35m1.2319[0m  0.1813
      9        [36m1.2150[0m       0.4926        [35m1.2183[0m  0.1758
     10        [36m1.1999[0m       [32m0.5025[0m        [35m1.2065[0m  0.1680


<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=MLPModule(
    (nonlin): ReLU()
    (output_nonlin): Softmax()
    (sequential): Sequential(
      (0): Linear(in_features=20, out_features=10, bias=True)
      (1): ReLU()
      (2): Dropout(p=0)
      (3): Linear(in_features=10, out_features=5, bias=True)
      (4): Softmax()
    )
  ),
)

#### test pickle

In [19]:
dump = pickle.dumps(net)
net = pickle.loads(dump)

In [20]:
net.partial_fit(X, y)

     11        [36m1.1851[0m       0.5025        [35m1.1949[0m  0.5974
     12        [36m1.1702[0m       [32m0.5123[0m        [35m1.1837[0m  0.1801
     13        [36m1.1550[0m       [32m0.5172[0m        [35m1.1721[0m  0.1808
     14        [36m1.1411[0m       [32m0.5222[0m        [35m1.1626[0m  0.1852
     15        [36m1.1282[0m       [32m0.5271[0m        [35m1.1544[0m  0.1782
     16        [36m1.1170[0m       [32m0.5369[0m        [35m1.1483[0m  0.1786
     17        [36m1.1075[0m       0.5320        [35m1.1432[0m  0.1637
     18        [36m1.0984[0m       0.5271        [35m1.1381[0m  0.1790
     19        [36m1.0896[0m       0.5222        [35m1.1336[0m  0.1731
     20        [36m1.0816[0m       0.5222        [35m1.1297[0m  0.1777


<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=MLPModule(
    (nonlin): ReLU()
    (output_nonlin): Softmax()
    (sequential): Sequential(
      (0): Linear(in_features=20, out_features=10, bias=True)
      (1): ReLU()
      (2): Dropout(p=0)
      (3): Linear(in_features=10, out_features=5, bias=True)
      (4): Softmax()
    )
  ),
)

In [21]:
from IPython.display import Image
Image(url='visdom.png')