Skip to content

Commit

Permalink
Feature/svm (#280)
Browse files Browse the repository at this point in the history
* SVM initial commit

* Finalize SVM example
  • Loading branch information
ethanwharris authored and MattPainter01 committed Aug 3, 2018
1 parent b0fea52 commit c3074a3
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 2 deletions.
100 changes: 100 additions & 0 deletions docs/_static/examples/svm_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Based on svm-pytorch (https://github.com/kazuto1011/svm-pytorch)

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets.samples_generator import make_blobs

import torchbearer
import torchbearer.callbacks as callbacks
from torchbearer import Model
from torchbearer.callbacks import L2WeightDecay, ExponentialLR


class LinearSVM(nn.Module):
"""Support Vector Machine"""

def __init__(self):
super(LinearSVM, self).__init__()
self.w = nn.Parameter(torch.randn(1, 2), requires_grad=True)
self.b = nn.Parameter(torch.randn(1), requires_grad=True)

def forward(self, x):
h = x.matmul(self.w.t()) + self.b
return h


def hinge_loss(y_pred, y_true):
return torch.mean(torch.clamp(1 - y_pred.t() * y_true, min=0))


X, Y = make_blobs(n_samples=1024, centers=2, cluster_std=1.2, random_state=1)
X = (X - X.mean()) / X.std()
Y[np.where(Y == 0)] = -1
X, Y = torch.FloatTensor(X), torch.FloatTensor(Y)


delta = 0.01
x = np.arange(X[:, 0].min(), X[:, 0].max(), delta)
y = np.arange(X[:, 1].min(), X[:, 1].max(), delta)
x, y = np.meshgrid(x, y)
xy = list(map(np.ravel, [x, y]))


def mypause(interval):
backend = plt.rcParams['backend']
if backend in matplotlib.rcsetup.interactive_bk:
figManager = matplotlib._pylab_helpers.Gcf.get_active()
if figManager is not None:
canvas = figManager.canvas
if canvas.figure.stale:
canvas.draw_idle()
canvas.start_event_loop(interval)
return


@callbacks.on_start
def scatter(_):
plt.figure(figsize=(5, 5))
plt.ion()
plt.scatter(x=X[:, 0], y=X[:, 1], c="black", s=10)


@callbacks.on_step_training
def draw_margin(state):
if state[torchbearer.BATCH] % 10 == 0:
w = state[torchbearer.MODEL].w[0].detach().to('cpu').numpy()
b = state[torchbearer.MODEL].b[0].detach().to('cpu').numpy()

z = (w.dot(xy) + b).reshape(x.shape)
z[np.where(z > 1.)] = 4
z[np.where((z > 0.) & (z <= 1.))] = 3
z[np.where((z > -1.) & (z <= 0.))] = 2
z[np.where(z <= -1.)] = 1

if 'contour' in state:
for coll in state['contour'].collections:
coll.remove()
state['contour'] = plt.contourf(x, y, z, cmap=plt.cm.jet, alpha=0.5)
else:
state['contour'] = plt.contourf(x, y, z, cmap=plt.cm.jet, alpha=0.5)
plt.tight_layout()
plt.show()

mypause(0.001)


svm = LinearSVM()
model = Model(svm, optim.SGD(svm.parameters(), 0.1), hinge_loss, ['loss']).to('cuda')

model.fit(X, Y, batch_size=32, epochs=50, verbose=1,
callbacks=[scatter,
draw_margin,
ExponentialLR(0.999, step_on_batch=True),
L2WeightDecay(0.01, params=[svm.w])])

plt.ioff()
plt.show()
Binary file added docs/_static/img/svm_fit.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __getattr__(cls, name):
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx']
extensions = ['sphinx.ext.mathjax', 'sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinx.ext.intersphinx']

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
Expand Down
128 changes: 128 additions & 0 deletions docs/examples/svm_linear.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
Linear Support Vector Machine (SVM)
===================================

We've seen how to frame a problem as a differentiable program in the `Optimising Functions example <./basic_opt.html>`_.
Now we can take a look a more usable example; a linear Support Vector Machine (SVM). Note that the model and loss used
in this guide are based on the code found `here <https://github.com/kazuto1011/svm-pytorch>`_.

SVM Recap
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Recall that an SVM tries to find the maximum margin hyperplane which separates the data classes. For a soft margin SVM
where :math:`\textbf{x}` is our data, we minimize:

:math:`\left[\frac 1 n \sum_{i=1}^n \max\left(0, 1 - y_i(\textbf{w}\cdot \textbf{x}_i - b)\right) \right] + \lambda\lVert \textbf{w} \rVert^2`


We can formulate this as an optimization over our weights :math:`\textbf{w}` and bias :math:`b`, where we minimize the
hinge loss subject to a level 2 weight decay term. The hinge loss for some model outputs
:math:`z = \textbf{w}\textbf{x} + b` with targets :math:`y` is given by:

:math:`\ell(y,z) = \max\left(0, 1 - yz \right)`

Defining the Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Let's put this into code. First we can define our module which will project the data through our weights and offset by
a bias. Note that this is identical to the function of a linear layer.

.. literalinclude:: /_static/examples/svm_linear.py
:language: python
:lines: 17-27

Next, we define the hinge loss function:

.. literalinclude:: /_static/examples/svm_linear.py
:language: python
:lines: 30-31

Creating Synthetic Data
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Now for some data, 1024 samples should do the trick. We normalise here so that our random init is in the same space as
the data:

.. literalinclude:: /_static/examples/svm_linear.py
:language: python
:lines: 34-37

Subgradient Descent
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Since we don't know that our data is linearly separable, we would like to use a soft-margin SVM. That is, an SVM for
which the data does not all have to be outside of the margin. This takes the form of a weight decay term,
:math:`\lambda\lVert \textbf{w} \rVert^2` in the above equation. This term is called weight decay because the gradient
corresponds to subtracting some amount (:math:`2\lambda\textbf{w}`) from our weights at each step. With torchbearer we
can use the :class:`.L2WeightDecay` callback to do this. This whole process is known as subgradient descent because we
only use a mini-batch (of size 32 in our example) at each step to approximate the gradient over all of the data. This is
proven to converge to the minimum for convex functions such as our SVM. At this point we are ready to create and train
our model:

.. literalinclude:: /_static/examples/svm_linear.py
:language: python
:lines: 90-97

Visualizing the Training
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

You might have noticed some strange things in that call to :meth:`.Model.fit`. Specifically, we use the
:class:`.ExponentialLR` callback to anneal the convergence a little and we have a couple of other callbacks:
:code:`scatter` and :code:`draw_margin`. These callbacks produce the following live visualisation (note, doesn't work in
PyCharm, best run from terminal):

.. figure:: /_static/img/svm_fit.gif
:scale: 100 %
:alt: Convergence of the SVM decision boundary

The code for the visualisation (using `pyplot <https://matplotlib.org/api/pyplot_api.html>`_) is a bit ugly but we'll
try to explain it to some degree. First, we need a mesh grid :code:`xy` over the range of our data:

.. literalinclude:: /_static/examples/svm_linear.py
:language: python
:lines: 40-44

Next, we have the scatter callback. This happens once at the start of our fit call and draws the figure with a scatter
plot of our data:

.. literalinclude:: /_static/examples/svm_linear.py
:language: python
:lines: 59-63

Now things get a little strange. We start by evaluating our model over the mesh grid from earlier:

.. literalinclude:: /_static/examples/svm_linear.py
:language: python
:lines: 66-70

For our outputs :math:`z \in \textbf{Z}`, we can make some observations about the decision boundary. First, that we are
outside the margin if :math:`z \lt -1` or :math:`z \gt 1`. Conversely, we are inside the margine where :math:`z \gt -1`
or :math:`z \lt 1`. This gives us some rules for colouring, which we use here:

.. literalinclude:: /_static/examples/svm_linear.py
:language: python
:lines: 72-76

So far it's been relatively straight forward. The next bit is a bit of a hack to get the update of the contour plot
working. If a reference to the plot is already in state we just remove the old one and add a new one, otherwise we add
it and show the plot. Finally, we call :code:`mypause` to trigger an update. You could just use :code:`plt.pause`,
however, it grabs the mouse focus each time it is called which can be annoying. Instead, :code:`mypause` is taken from
`stackoverflow <https://stackoverflow.com/questions/45729092/make-interactive-matplotlib-window-not-pop-to-front-on-each-update-windows-7>`_.

.. literalinclude:: /_static/examples/svm_linear.py
:language: python
:lines: 78-87

Final Comments
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

So, there you have it, a fun differentiable programming example with a live visualisation in under 100 lines of code
with torchbearer. It's easy to see how this could become more useful, perhaps finding a way to use the kernel trick with
the standard form of an SVM (essentially an RBF network). You could also attempt to write some code that saves the gif
from earlier. We had some but it was beyond a hack, can you do better?

Source Code
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

The source code for the example is given below:

:download:`Download Python source code: svm_linear.py </_static/examples/svm_linear.py>`
9 changes: 8 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@ Welcome to torchbearer's documentation!
.. toctree::
:glob:
:maxdepth: 1
:caption: Examples
:caption: Deep Learning

examples/quickstart
examples/vae
examples/gan

.. toctree::
:glob:
:maxdepth: 1
:caption: Differentiable Programming

examples/basic_opt
examples/svm_linear

.. toctree::
:glob:
Expand Down

0 comments on commit c3074a3

Please sign in to comment.