Skip to content

Commit

Permalink
tensorflow multilayer-perceptron
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Mar 22, 2016
1 parent d00888c commit 0c71a09
Show file tree
Hide file tree
Showing 25 changed files with 963 additions and 57 deletions.
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Expand Up @@ -83,6 +83,7 @@ pages:
- user_guide/general_concepts/linear-gradient-derivative.md
- user_guide/general_concepts/regularization-linear.md
- Upcoming Features / 0.3.1dev:
- user_guide/tf_classifier/TfMultiLayerPerceptron.md
- user_guide/tf_classifier/TfSoftmaxRegression.md
- user_guide/classifier/SoftmaxRegression.md
- user_guide/regressor/StackingRegressor.md
Expand Down
1 change: 1 addition & 0 deletions docs/sources/USER_GUIDE_INDEX.md
Expand Up @@ -12,6 +12,7 @@

## `tf_classifier` (TensorFlow Classifier)
- [`TfSoftmaxRegression`](user_guide/tf_classifier/TfSoftmaxRegression.md) (new in 0.3.1dev)
- [`TfMultiLayerPerceptron`](user_guide/tf_classifier/TfMultiLayerPerceptron.md) (new in 0.3.1dev)

## `regressor`

Expand Down
739 changes: 739 additions & 0 deletions docs/sources/user_guide/tf_classifier/TfMultiLayerPerceptron.ipynb

Large diffs are not rendered by default.

Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 13 additions & 13 deletions docs/sources/user_guide/tf_classifier/TfSoftmaxRegression.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mlxtend/tf_classifier/__init__.py
Expand Up @@ -5,6 +5,6 @@
# License: BSD 3 clause

from .tf_softmax import TfSoftmaxRegression
from .TfMultiLayerPerceptron import TfMultiLayerPerceptron
from .tf_multilayerperceptron import TfMultiLayerPerceptron

__all__ = ["TfSoftmaxRegression", "TfMultiLayerPerceptron"]
130 changes: 130 additions & 0 deletions mlxtend/tf_classifier/tests/tests_tf_multilayerperceptron.py
@@ -0,0 +1,130 @@
# Sebastian Raschka 2014-2016
# mlxtend Machine Learning Library Extensions
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause

from mlxtend.tf_classifier import TfMultiLayerPerceptron as MLP
from mlxtend.data import iris_data
import numpy as np
from nose.tools import raises


X, y = iris_data()
X = X[:, [0, 3]] # sepal length and petal width
X_bin = X[0:100] # class 0 and class 1
y_bin = y[0:100] # class 0 and class 1

# standardize
X_bin[:, 0] = (X_bin[:, 0] - X_bin[:, 0].mean()) / X_bin[:, 0].std()
X_bin[:, 1] = (X_bin[:, 1] - X_bin[:, 1].mean()) / X_bin[:, 1].std()
X[:, 0] = (X[:, 0] - X[:, 0].mean()) / X[:, 0].std()
X[:, 1] = (X[:, 1] - X[:, 1].mean()) / X[:, 1].std()


@raises(AttributeError)
def test_optimizer_init():
MLP(optimizer='no-optimizer')


@raises(AttributeError)
def test_activations_init_typo():
MLP(hidden_layers=[1, 2], activations=['logistic', 'invalid'])


@raises(AttributeError)
def test_activations_invalid_ele_1():
MLP(hidden_layers=[1], activations=['logistic', 'logistic'])


@raises(AttributeError)
def test_activations_invalid_ele_2():
MLP(hidden_layers=[10, 10], activations=['logistic'])


def test_mapping():
mlp = MLP()
w, b = mlp._layermapping(n_features=10,
n_classes=11,
hidden_layers=[8, 7, 6])

expect_b = {1: [[8], 'n_hidden_1'],
2: [[7], 'n_hidden_2'],
3: [[6], 'n_hidden_3'],
'out': [[11], 'n_classes']}

expect_w = {1: [[10, 8], 'n_features, n_hidden_1'],
2: [[8, 7], 'n_hidden_1, n_hidden_2'],
3: [[7, 6], 'n_hidden_2, n_hidden_3'],
'out': [[6, 11], 'n_hidden_3, n_classes']}

assert expect_b == b, b
assert expect_w == w, w


def test_binary_gd():
mlp = MLP(epochs=100,
eta=0.5,
hidden_layers=[5],
optimizer='gradientdescent',
activations=['logistic'],
minibatches=1,
random_seed=1)

mlp.fit(X_bin, y_bin)
assert((y_bin == mlp.predict(X_bin)).all())


def test_binary_sgd():
mlp = MLP(epochs=10,
eta=0.5,
hidden_layers=[5],
optimizer='gradientdescent',
activations=['logistic'],
minibatches=len(y_bin),
random_seed=1)

mlp.fit(X_bin, y_bin)
assert((y_bin == mlp.predict(X_bin)).all())


def test_multiclass_probas():
mlp = MLP(epochs=100,
eta=0.5,
hidden_layers=[5],
optimizer='gradientdescent',
activations=['logistic'],
minibatches=1,
random_seed=1)
mlp.fit(X, y)
idx = [0, 50, 149] # sample labels: 0, 1, 2
y_pred = mlp.predict_proba(X[idx])
exp = np.array([[0.9, 0.1, 0.0],
[0.0, 0.6, 0.4],
[0.0, 0.1, 0.9]])
np.testing.assert_almost_equal(y_pred, exp, 1)


def test_multiclass_gd_acc():
mlp = MLP(epochs=100,
eta=0.5,
hidden_layers=[5],
optimizer='gradientdescent',
activations=['logistic'],
minibatches=1,
random_seed=1)
mlp.fit(X, y)
assert((y == mlp.predict(X)).all())


@raises(AttributeError)
def test_fail_minibatches():
mlp = MLP(epochs=100,
eta=0.5,
hidden_layers=[5],
optimizer='gradientdescent',
activations=['logistic'],
minibatches=13,
random_seed=1)
mlp.fit(X, y)
assert((y == mlp.predict(X)).all())
8 changes: 4 additions & 4 deletions mlxtend/tf_classifier/tests/tests_tf_softmax.py
Expand Up @@ -63,11 +63,11 @@ def test_multi_logistic_regression_gd_weights():

def test_multi_logistic_probas():
lr = TfSoftmaxRegression(epochs=200,
eta=0.75,
minibatches=1,
random_seed=1)
eta=0.75,
minibatches=1,
random_seed=1)
lr.fit(X, y)
idx = [0, 50, 149] # sample labels: 0, 1, 2
idx = [0, 50, 149] # sample labels: 0, 1, 2
y_pred = lr.predict_proba(X[idx])
exp = np.array([[0.99, 0.01, 0.0],
[0.01, 0.89, 0.1],
Expand Down
109 changes: 72 additions & 37 deletions mlxtend/tf_classifier/tf_multilayerperceptron.py
Expand Up @@ -21,13 +21,23 @@ class TfMultiLayerPerceptron(_TfBaseClassifier):
Learning rate (between 0.0 and 1.0)
epochs : int (default: 50)
Passes over the training dataset.
n_hidden : list (default: [50, 10])
hidden_layers : list (default: [50, 10])
Number of units per hidden layer. By default 50 units in the
first hidden layer, and 10 hidden units in the second hidden layer.
activations : list (default: ['softmax', 'softmax'])
activations : list (default: ['logistic', 'logistic'])
Activation functions for each layer.
Available actiavtion functions:
"softmax", "relu", "tanh", "elu", "softplus", "softsign"
"logistic", "relu", "tanh", "relu6", "elu", "softplus", "softsign"
optimizer : str (default: "gradientdescent")
Optimizer to minimize the cost function:
"gradientdescent", "momentum", "adam", "ftrl", "adagrad"
momentum : float (default: 0.0)
Momentum constant for momentum learning; only applies if
optimizer='momentum'
l1 : float (default: 0.0)
L1 regularization strength; only applies if optimizer='ftrl'
l2 : float (default: 0.0)
regularization strength; only applies if optimizer='ftrl'
minibatches : int (default: 1)
Divide the training data into *k* minibatches
for accelerated stochastic gradient descent learning.
Expand Down Expand Up @@ -56,21 +66,26 @@ class TfMultiLayerPerceptron(_TfBaseClassifier):
"""
def __init__(self, eta=0.5, epochs=50,
n_hidden=[50, 10],
activations=['softmax', 'softmax'],
hidden_layers=[50, 10],
activations=['logistic', 'logistic'],
optimizer='gradientdescent',
momentum=0.0, l1=0.0, l2=0.0,
minibatches=1, random_seed=None,
print_progress=0, dtype=None):
self.eta = eta
if len(n_hidden) != len(activations):
raise AttributeError('Number n_hidden and'
if len(hidden_layers) != len(activations):
raise AttributeError('Number of hidden_layers and'
' n_activations must be equal.')
self.n_hidden = n_hidden
self.hidden_layers = hidden_layers
self.activations = self._get_activations(activations)

self.optimizer = self._init_optimizer(optimizer)
self.epochs = epochs
self.minibatches = minibatches
self.random_seed = random_seed
self.print_progress = print_progress
self.l1 = l1
self.l2 = l2
self.momentum = momentum

if dtype is None:
self.dtype = tf.float32
Expand All @@ -79,9 +94,30 @@ def __init__(self, eta=0.5, epochs=50,

return

def _init_optimizer(self, optimizer):
if optimizer == 'gradientdescent':
opt = tf.train.GradientDescentOptimizer(learning_rate=self.eta)
elif optimizer == 'momentum':
opt = tf.train.MomentumOptimzer(learning_rate=self.eta,
momentum=self.momentum)
elif optimizer == 'adam':
opt = tf.train.AdamOptimizer(learning_rate=self.eta)
elif optimizer == 'ftrl':
opt = tf.train.GradientDescentOptimizer(
learning_rate=self.eta,
l1_regularization_strength=self.l1,
l2_regularization_strength=self.l2)
elif optimizer == 'adagrad':
opt = tf.train.AdaGradOptimizer(learning_rate=self.eta)
else:
raise AttributeError('optimizer must be "gradientdescent",'
' "momentum", "adam", "ftrl", or "adagrad"')
return opt

def _get_activations(self, activations):
adict = {'softmax': tf.nn.sigmoid,
adict = {'logistic': tf.nn.sigmoid,
'relu': tf.nn.relu,
'relu6': tf.nn.relu6,
'tanh': tf.nn.tanh,
'elu': tf.nn.elu,
'softplus': tf.nn.softplus,
Expand Down Expand Up @@ -134,15 +170,15 @@ def fit(self, X, y, init_weights=True, override_minibatches=None):
self._weight_maps, self._bias_maps = self._layermapping(
n_features=self._n_features,
n_classes=self._n_classes,
n_hidden=self.n_hidden)
hidden_layers=self.hidden_layers)
tf_weights, tf_biases = self._initialize_weights(
weight_maps=self._weight_maps,
bias_maps=self._bias_maps)
self.cost_ = []
else:
tf_weights, tf_biases = self._reuse_weights(
weights=self.weights_,
baises=self.biases_)
biases=self.biases_)

# Prepare the training data
y_enc = self._one_hot(y, self._n_classes)
Expand All @@ -156,18 +192,16 @@ def fit(self, X, y, init_weights=True, override_minibatches=None):
y_batch = tf.gather(params=tf_y, indices=tf_idx)

# Setup the graph for minimizing cross entropy cost
logits = self._predict(tf_X=tf_X,
tf_weights=tf_weights,
tf_biases=tf_biases,
activations=self.activations)
net = self._predict(tf_X=tf_X,
tf_weights=tf_weights,
tf_biases=tf_biases,
activations=self.activations)

# Define loss and optimizer
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(net,
tf_y)
cost = tf.reduce_mean(cross_entropy)
optimizer = tf.train.GradientDescentOptimizer(
learning_rate=self.eta)
train = optimizer.minimize(cost)
train = self.optimizer.minimize(cost)

# Initializing the variables
init = tf.initialize_all_variables()
Expand Down Expand Up @@ -228,24 +262,25 @@ def predict_proba(self, X):
Class probabilties : array-like, shape= [n_samples, n_classes]
"""
self._check_arrays(X, y)
self._check_arrays(X)
if not hasattr(self, 'weights_'):
raise AttributeError('The model has not been fitted, yet.')

with tf.Session():
tf.initialize_all_variables().run()
tf_X = tf.convert_to_tensor(value=X, dtype=self.dtype)
logits = self._predict(tf_X=tf_X,
tf_weights=self.weights_,
tf_biases=self.biases_,
activations=self.activations)
net = self._predict(tf_X=tf_X,
tf_weights=self.weights_,
tf_biases=self.biases_,
activations=self.activations)
logits = tf.nn.softmax(net)
return logits.eval()

def _layermapping(self, n_features, n_classes, n_hidden):
def _layermapping(self, n_features, n_classes, hidden_layers):
"""Creates a dictionaries of layer dimensions for weights and biases.
For example, given
`n_features=10`, `n_classes=10`, and `n_hidden=[8, 7, 6]`:
`n_features=10`, `n_classes=10`, and `hidden_layers=[8, 7, 6]`:
biases =
{1: [[8], 'n_hidden_1'],
Expand All @@ -262,15 +297,15 @@ def _layermapping(self, n_features, n_classes, n_hidden):
}
"""
weights = {1: [[n_features, n_hidden[0]],
weights = {1: [[n_features, hidden_layers[0]],
'n_features, n_hidden_1'],
'out': [[n_hidden[-1], n_classes],
'n_hidden_%d, n_classes' % len(n_hidden)]}
biases = {1: [[n_hidden[0]], 'n_hidden_1'],
'out': [[hidden_layers[-1], n_classes],
'n_hidden_%d, n_classes' % len(hidden_layers)]}
biases = {1: [[hidden_layers[0]], 'n_hidden_1'],
'out': [[n_classes], 'n_classes']}

if len(n_hidden) > 1:
for i, h in enumerate(n_hidden[1:]):
if len(hidden_layers) > 1:
for i, h in enumerate(hidden_layers[1:]):
layer = i + 2
weights[layer] = [[weights[layer - 1][0][1], h],
'n_hidden_%d, n_hidden_%d' % (layer -
Expand All @@ -287,10 +322,10 @@ def _predict(self, tf_X, tf_weights, tf_biases, activations):
for layer in range(2, len(tf_weights)):
prev_layer = self.activations[layer](tf.add(tf.matmul(
prev_layer, tf_weights[layer]), tf_biases[layer]))
logits = tf.matmul(prev_layer, tf_weights['out']) + tf_biases['out']
return logits
net = tf.matmul(prev_layer, tf_weights['out']) + tf_biases['out']
return net

def _resuse_weights(self, weights, biases):
def _reuse_weights(self, weights, biases):
w = {k: tf.Variable(self.weights_[k]) for k in self.weights_}
b = {k: tf.Variable(self.biases_[k]) for k in self.biases_}
return w, b
Expand All @@ -302,7 +337,7 @@ def _initialize_weights(self, weight_maps, bias_maps):
seed = self.random_seed + i
else:
seed = None
tf_weights[k[0]] = tf.Variable(tf.truncated_normal(
tf_weights[k[0]] = tf.Variable(tf.random_normal(
weight_maps[k[0]][0], seed=seed))
tf_biases[k[1]] = tf.zeros(bias_maps[k[1]][0])
return tf_weights, tf_biases
4 changes: 2 additions & 2 deletions mlxtend/tf_classifier/tf_softmax.py
Expand Up @@ -121,8 +121,8 @@ def fit(self, X, y,
y_batch = tf.gather(params=tf_y, indices=tf_idx)

# Setup the graph for minimizing cross entropy cost
logits = tf.matmul(X_batch, tf_weights_) + tf_biases_
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits,
net = tf.matmul(X_batch, tf_weights_) + tf_biases_
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(net,
y_batch)
cost = tf.reduce_mean(cross_entropy)
optimizer = tf.train.GradientDescentOptimizer(
Expand Down

0 comments on commit 0c71a09

Please sign in to comment.