Skip to content

Commit

Permalink
Fixed Prototype issues without input_size, updated README
Browse files Browse the repository at this point in the history
  • Loading branch information
mbeissinger committed May 9, 2015
1 parent 5905b7f commit 6b9beb3
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 216 deletions.
106 changes: 62 additions & 44 deletions README.rst
Expand Up @@ -3,65 +3,67 @@
:alt: OpenDeep
:align: center

=========================================
OpenDeep: A modular deep learning library
=========================================
Documentation: http://www.opendeep.org/
==============================================================
OpenDeep: a fully modular & extensible deep learning framework
==============================================================
Developer hub: http://www.opendeep.org/

OpenDeep_ is a general purpose commercial and research grade deep learning library for Python built from the ground up
OpenDeep_ is a deep learning framework for Python built from the ground up
in Theano_ with a focus on flexibility and ease of use for both industry data scientists and cutting-edge researchers.
OpenDeep is a modular and easily extensible framework for constructing any neural network architecture to
solve your problem.

**This library is currently undergoing rapid development and is in its alpha stages.**
Use OpenDeep to:

You can train and use existing deep learning models as a black box implementation, combine multiple models
to create your own novel research, or write new models from scratch without worrying about the overhead!
* Quickly prototype complex networks through a focus on complete modularity and containers similar to Torch.
* Configure and train existing state-of-the-art models.
* Write your own models from scratch in Theano and plug into OpenDeep for easy training and dataset integration.
* Use visualization and debugging tools to see exactly what is happening with your neural net architecture.
* Run on the CPU or GPU.

.. image:: readme_images/automate!.jpg
:align: center
**This library is currently undergoing rapid development and is in its alpha stages.**

.. _OpenDeep: http://www.opendeep.org/
.. _Theano: http://deeplearning.net/software/theano/


Quick example usage
-------------------
hello world with an MLP on the MNIST handwritten digit dataset::
===================
Train and evaluate a Multilayer Perceptron (MLP - your generic feedforward neural network for classification)
on the MNIST handwritten digit dataset::
from opendeep.models.container import Prototype
from opendeep.models.single_layer.basic import BasicLayer, SoftmaxLayer
from opendeep.optimization.adadelta import AdaDelta
from opendeep.data.standard_datasets.image.mnist import MNIST

mlp = Prototype()
mlp.add(BasicLayer(input_size=28*28, output_size=512, activation='rectifier', noise='dropout'))
mlp.add(BasicLayer(output_size=512, activation='rectifier', noise='dropout'))
mlp.add(SoftmaxLayer(output_size=10))

trainer = AdaDelta(model=mlp, dataset=MNIST())
trainer.train()

from opendeep.models.container import Prototype
from opendeep.models.single_layer.basic import BasicLayer, SoftmaxLayer
from opendeep.optimization.adadelta import AdaDelta
from opendeep.data.standard_datasets.image.mnist import MNIST, datasets

Motivation
----------
mlp = Prototype()
mlp.add(BasicLayer(input_size=28*28, output_size=512, activation='rectifier', noise='dropout'))
mlp.add(BasicLayer(output_size=512, activation='rectifier', noise='dropout'))
mlp.add(SoftmaxLayer(output_size=10))

- **Modularity**. A lot of recent deep learning progress has come from combining multiple models. Existing libraries are either too confusing or not easily extensible enough to perform novel research and also quickly set up existing algorithms at scale. This need for transparency and modularity is the main motivating factor for creating the OpenDeep library, where we hope novel research and industry use can both be easily implemented.
data = MNIST()
trainer = AdaDelta(model=mlp, dataset=data, n_epoch=10)
trainer.train()

- **Ease of use**. Many libraries require a lot of familiarity with deep learning or their specific package structures. OpenDeep's goal is to be the best-documented deep learning library and have smart enough default code that someone without a background can start training models, while experienced practitioners can easily create and customize their own algorithms. OpenDeep is a 'black box' factory - it has all the parts you need to make your own 'black boxes', or you could use existing ones.
test_data, test_labels = data.getSubset(datasets.TEST)
predictions = mlp.run(test_data.eval())

- **State of the art**. A side effect of modularity and ease of use, OpenDeep aims to maintain state-of-the-art performance as new algorithms and papers get published. As a research library, citing and accrediting those authors and code used is very important to the library.
print "Accuracy: ", float(sum(predictions==test_labels.eval())) / len(test_labels.eval())


Installation
------------
============
Because OpenDeep is still in alpha, you have to install via setup.py. Also, please make sure you have these dependencies installed first.

Dependencies
^^^^^^^^^^^^
------------

* Theano_: Theano and its dependencies are required to use OpenDeep. You need to install the bleeding-edge version, which has `installation instructions here`_.
* Theano_: Theano and its dependencies are required to use OpenDeep. You need to install the bleeding-edge version directly from their GitHub, which has `installation instructions here`_.

* For GPU integration with Theano, you also need the latest `CUDA drivers`_. Here are `instructions for setting up Theano for the GPU`_. If you prefer to use a server on Amazon Web Services, here are instructions for setting up an `EC2 server with Theano`_.
* For GPU integration with Theano, you also need the latest `CUDA drivers`_. Here are `instructions for setting up Theano for the GPU`_. If you prefer to use a server on Amazon Web Services, here are instructions for setting up an `EC2 gpu server with Theano`_.

* CuDNN_ (optional): for a fast convolutional net support from Nvidia. You will want to move the files to Theano's directory like the instructions say here: `Theano cuDNN integration`_.
* CuDNN_ (optional but recommended for CNN's): for a fast convolution support from Nvidia. You will want to move the files to Theano's directory like the instructions say here: `Theano cuDNN integration`_.

* `Pillow (PIL)`_: image manipulation functionality.

Expand All @@ -73,7 +75,7 @@ Dependencies

.. _CUDA drivers: https://developer.nvidia.com/cuda-toolkit
.. _instructions for setting up Theano for the GPU: http://deeplearning.net/software/theano/tutorial/using_gpu.html
.. _EC2 server with Theano: http://markus.com/install-theano-on-aws
.. _EC2 gpu server with Theano: http://markus.com/install-theano-on-aws

.. _CuDNN: https://developer.nvidia.com/cuDNN
.. _Theano cuDNN integration: http://deeplearning.net/software/theano/library/sandbox/cuda/dnn.html
Expand All @@ -85,7 +87,7 @@ Dependencies
.. _Bokeh: http://bokeh.pydata.org/en/latest/

Install from source
^^^^^^^^^^^^^^^^^^^
-------------------
1) Navigate to your desired installation directory and download the github repository::

git clone https://github.com/vitruvianscience/opendeep.git
Expand All @@ -96,14 +98,15 @@ Install from source
python setup.py develop

Using :code:`python setup.py develop` instead of the normal :code:`python setup.py install` allows you to update the repository files by pulling
from git and have the whole package update! No need to reinstall.
from git and have the whole package update! No need to reinstall when you get the latest files.

That's it! Now you should be able to import opendeep into python modules.


Quick Start
-----------
===========
To get up to speed on deep learning, check out a blog post here: `Deep Learning 101`_.
You can also go through guides on OpenDeep's documentation site: http://www.opendeep.org/
You can also go through tutorials on OpenDeep's documentation site: http://www.opendeep.org/

Let's say you want to train a Denoising Autoencoder on the MNIST handwritten digit dataset. You can get started
in just a few lines of code::
Expand All @@ -126,13 +129,13 @@ in just a few lines of code::
# create the MNIST dataset
mnist = MNIST()

# define some model configuration parameters
# define some model configuration parameters (this could have come from json!)
config = {
"input_size": 28*28, # dimensions of the MNIST images
"hidden_size": 1500 # number of hidden units - generally bigger than input size
}
# create the denoising autoencoder
dae = DenoisingAutoencoder(config)
dae = DenoisingAutoencoder(**config)

# create the optimizer to train the denoising autoencoder
# AdaDelta is normally a good generic optimizer
Expand All @@ -158,14 +161,19 @@ Congrats, you just:

- and predicted some outputs given inputs (and saved them as an image)!

.. image:: readme_images/gatsby.gif
:scale: 100 %
:alt: Working example!
:align: center

.. _Deep Learning 101: http://markus.com/deep-learning-101/


More Information
----------------
================
Source code: https://github.com/vitruvianscience/opendeep

Documentation: http://www.opendeep.org/
Documentation and tutorials: http://www.opendeep.org/

User group: `opendeep-users`_

Expand All @@ -176,3 +184,13 @@ join the Google groups!

.. _opendeep-users: https://groups.google.com/forum/#!forum/opendeep-users/
.. _opendeep-dev: https://groups.google.com/forum/#!forum/opendeep-dev/


Why OpenDeep?
=============

- **Modularity**. A lot of recent deep learning progress has come from combining multiple models. Existing libraries are either too confusing or not easily extensible enough to perform novel research and also quickly set up existing algorithms at scale. This need for transparency and modularity is the main motivating factor for creating the OpenDeep library, where we hope novel research and industry use can both be easily implemented.

- **Ease of use**. Many libraries require a lot of familiarity with deep learning or their specific package structures. OpenDeep's goal is to be the best-documented deep learning library and have smart enough default code that someone without a background can start training models, while experienced practitioners can easily create and customize their own algorithms.

- **State of the art**. A side effect of modularity and ease of use, OpenDeep aims to maintain state-of-the-art performance as new algorithms and papers get published. As a research library, citing and accrediting those authors and code used is very important to the library.
28 changes: 27 additions & 1 deletion opendeep/models/model.py
Expand Up @@ -70,7 +70,7 @@ class Model(object):
"""

def __init__(self, inputs_hook=None, hiddens_hook=None, params_hook=None,
output_size=None,
input_size=None, output_size=None,
outdir=None,
**kwargs):
"""
Expand Down Expand Up @@ -101,6 +101,9 @@ def __init__(self, inputs_hook=None, hiddens_hook=None, params_hook=None,
this model (instead of initializing your own shared variables). This parameter is useful when you want to
have two versions of the model that use the same parameters - such as a training model with dropout applied
to layers and one without for testing, where the parameters are shared between the two.
input_size : int or shape tuple
The dimensionality of the input for this model. This is required for stacking models
automatically - where the input to one layer is the output of the previous layer.
output_size : int or shape tuple
The dimensionality of the output for this model. This is required for stacking models
automatically - where the input to one layer is the output of the previous layer. Currently, we cannot
Expand All @@ -118,9 +121,32 @@ def __init__(self, inputs_hook=None, hiddens_hook=None, params_hook=None,
self.inputs_hook = inputs_hook
self.hiddens_hook = hiddens_hook
self.params_hook = params_hook
self.input_size = input_size
self.output_size = output_size
self.outdir = outdir

# Combine arguments that could specify input_size -> overwrite input_size with inputs_hook[0] if it exists.
if self.inputs_hook and self.inputs_hook[0] is not None:
self.input_size = self.inputs_hook[0]

# Check if the input_size wasn't provided - if this is the case, it could either be a programmer's error
# or it could be during the automatic stacking in a Container. Since that is a common use case, set
# the input_size to 1 to avoid errors when instantiating the model.
if not self.input_size:
# Could be error, or more commonly, when adding models to a Container
log.warning("No input_size or inputs_hook! Make sure this is done in a Container. Setting input_size"
"=1 for the Container now...")
self.input_size = 1

# Also, check if no output_size was given - this could be the case for generative models. Copy input_size
# in that case.
if not self.output_size:
# Could be an error (hopefully not), so give the warning.
log.warning("No output_size given! Make sure this is from a generative model (where output_size is the"
"same as input_size. Setting output_size=input_size now...")
self.output_size = self.input_size


# copy all of the parameters from the class into an args (configuration) dictionary
self.args = {}
self._add_kwargs_to_dict(kwargs.copy(), self.args)
Expand Down

0 comments on commit 6b9beb3

Please sign in to comment.