Skip to content

Commit

Permalink
Example/simple optim (#256)
Browse files Browse the repository at this point in the history
* Add simple optimisation example

* Update docs

* Update docs

* Update docs

* Update docs
  • Loading branch information
MattPainter01 authored and ethanwharris committed Jul 26, 2018
1 parent e196bbf commit 4654ee1
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 1 deletion.
51 changes: 51 additions & 0 deletions docs/_static/examples/basic_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from torch.nn import Module

import torchbearer as tb


class Net(Module):
def __init__(self, x):
super().__init__()
self.pars = torch.nn.Parameter(x)

def f(self):
"""
function to be minimised:
f(x) = (x[0]-5)^2 + x[1]^2 + (x[2]-1)^2
Solution:
x = [5,0,1]
"""
out = torch.zeros_like(self.pars)
out[0] = self.pars[0]-5
out[1] = self.pars[1]
out[2] = self.pars[2]-1
return torch.sum(out**2)

def forward(self, _, state):
state['est'] = self.pars
return self.f()


def loss(y_pred, y_true):
return y_pred


@tb.metrics.to_dict
class est(tb.metrics.Metric):
def __init__(self):
super().__init__('est')

def process(self, state):
return state['est'].data


steps = torch.tensor(list(range(50000)))
p = torch.tensor([2.0, 1.0, 10.0])

model = Net(p)
optim = torch.optim.SGD(model.parameters(), lr=0.0001)

tbmodel = tb.Model(model, optim, loss, [est(), 'loss'])
tbmodel.fit(steps, steps, 1, pass_state=True)
print(list(model.parameters())[0].data)
70 changes: 70 additions & 0 deletions docs/examples/basic_opt.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
Optimising functions
====================================

Now for something a bit different.
PyTorch is a tensor processing library and whilst it has a focus on neural networks, it can also be used for more standard funciton optimisation.
In this example we will use torchbearer to minimise a simple function.


The Model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

First we will need to create something that looks very similar to a neural network model - but with the purpose of minimising our function.
We store the current estimates for the minimum as parameters in the model (so PyTorch optimisers can find and optimise them) and we return the function value in the forward method.

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 7-27

The Loss
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

For function minimisation we have an analogue to neural network losses - we minimise the value of the function under the current estimates of the minimum.
Note that as we are using a base loss, torchbearer passes this the network output and the "label" (which is of no use here).

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 30-31


Optimising
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

We need two more things before we can start optimising with torchbearer.
We need our initial guess - which we've set to [2.0, 1.0, 10.0] and we need to tell torchbearer how "long" an epoch is - I.e. how many optimisation steps we want for each epoch.
For our simple function, we can complete the optimisation in a single epoch, but for more complex optimisations we might want to take multiple epochs and include tensorboard logging and perhaps learning rate annealing to find a final solution.
We have set the number of optimisation steps for this example as 50000.

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 43-44

The learning rate chosen for this example is very low and we could get convergence much faster with a larger rate, however this allows us to view convergence in real time.
We define the model and optimiser in the standard way.

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 46-47

Finally we start the optimising (giving as "data" and "targets" the number of steps desired) and print the final minimum estimate.

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 49-51

Note that we could use targets that are meaningful as they are given to the loss function, however this is not done for this example.


Viewing Progress
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

You might have noticed in the previous snippet that the example uses a metric we've not seen before.
This simple metric is used to display the estimate throughout the optimisation process - although this is probably only useful for very small optimisation problems.

.. literalinclude:: /_static/examples/basic_opt.py
:language: python
:lines: 34-40

The final estimate is very close to our desired minimum at [5, 0, 1]:

tensor([ 4.9988e+00, 4.5355e-05, 1.0004e+00])
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Welcome to torchbearer's documentation!
examples/quickstart
examples/vae
examples/gan
examples/basic_opt

.. toctree::
:glob:
Expand Down
2 changes: 1 addition & 1 deletion torchbearer/torchbearer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, model, optimizer, loss_criterion, metrics=[]):
torchbearer.SELF: self
}

def fit(self, x, y, batch_size=None, epochs=1, verbose=1, callbacks=[], validation_split=0.0,
def fit(self, x, y, batch_size=None, epochs=1, verbose=1, callbacks=[], validation_split=None,
validation_data=None, shuffle=True, initial_epoch=0,
steps_per_epoch=None, validation_steps=None, workers=1, pass_state=False):
""" Perform fitting of a model to given data and label tensors
Expand Down

0 comments on commit 4654ee1

Please sign in to comment.