Skip to content

Commit

Permalink
BUG don't densify sparse matrix in BernoulliRBM.score_samples
Browse files Browse the repository at this point in the history
Also documented the non-determinism of this method.
  • Loading branch information
larsmans committed Dec 6, 2013
1 parent 21f0479 commit b0485c4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 23 deletions.
40 changes: 24 additions & 16 deletions sklearn/neural_network/rbm.py
Expand Up @@ -9,11 +9,12 @@
import time

import numpy as np
import scipy.sparse as sp

from ..base import BaseEstimator
from ..base import TransformerMixin
from ..externals.six.moves import xrange
from ..utils import check_arrays
from ..utils import atleast2d_or_csr, check_arrays
from ..utils import check_random_state
from ..utils import gen_even_slices
from ..utils import issparse
Expand Down Expand Up @@ -189,8 +190,8 @@ def _free_energy(self, v):
The value of the free energy.
"""
return (- safe_sparse_dot(v, self.intercept_visible_)
- np.log(1. + np.exp(safe_sparse_dot(v, self.components_.T)
+ self.intercept_hidden_)).sum(axis=1))
- np.log1p(np.exp(safe_sparse_dot(v, self.components_.T)
+ self.intercept_hidden_)).sum(axis=1))

def gibbs(self, v):
"""Perform one Gibbs sampling step.
Expand Down Expand Up @@ -246,33 +247,40 @@ def _fit(self, v_pos, rng):
h_neg[rng.uniform(size=h_neg.shape) < h_neg] = 1.0 # sample binomial
self.h_samples_ = np.floor(h_neg, h_neg)

if self.verbose:
return self.score_samples(v_pos)

def score_samples(self, v):
"""Compute the pseudo-likelihood of v.
def score_samples(self, X):
"""Compute the pseudo-likelihood of X.
Parameters
----------
v : {array-like, sparse matrix} shape (n_samples, n_features)
Values of the visible layer.
X : {array-like, sparse matrix} shape (n_samples, n_features)
Values of the visible layer. Must be all-boolean (not checked).
Returns
-------
pseudo_likelihood : array-like, shape (n_samples,)
Value of the pseudo-likelihood (proxy to likelihood).
Value of the pseudo-likelihood (proxy for likelihood).
Notes
-----
This method is not deterministic: it computes a quantity called the
free energy on X, then on a randomly corrupted version of X, and
returns the log of the logistic function of the difference.
"""
v = atleast2d_or_csr(X)
rng = check_random_state(self.random_state)
fe = self._free_energy(v)

# Randomly corrupt one feature in each sample in v.
ind = (np.arange(v.shape[0]),
rng.randint(0, v.shape[1], v.shape[0]))
if issparse(v):
v_ = v.toarray()
data = -2 * v[ind] + 1
v_ = v + sp.csr_matrix((data.A.ravel(), ind), shape=v.shape)
else:
v_ = v.copy()
i_ = rng.randint(0, v.shape[1], v.shape[0])
v_[np.arange(v.shape[0]), i_] = 1 - v_[np.arange(v.shape[0]), i_]
fe_ = self._free_energy(v_)
v_[ind] = 1 - v_[ind]

fe = self._free_energy(v)
fe_ = self._free_energy(v_)
return v.shape[1] * logistic_sigmoid(fe_ - fe, log=True)

def fit(self, X, y=None):
Expand Down
22 changes: 15 additions & 7 deletions sklearn/neural_network/tests/test_rbm.py
Expand Up @@ -2,8 +2,9 @@
import re

import numpy as np
from numpy.testing import assert_almost_equal, assert_array_equal
from scipy.sparse import csr_matrix
from scipy.sparse import csr_matrix, lil_matrix
from sklearn.utils.testing import (assert_almost_equal, assert_array_equal,
assert_true)

from sklearn.datasets import load_digits
from sklearn.externals.six.moves import cStringIO as StringIO
Expand Down Expand Up @@ -108,16 +109,23 @@ def test_gibbs_smoke():


def test_score_samples():
"""Check that the pseudo likelihood is computed without clipping.
http://fa.bianp.net/blog/2013/numerical-optimizers-for-logistic-regression/
"""
"""Test score_samples (pseudo-likelihood) method."""
# Assert that pseudo-likelihood is computed without clipping.
# http://fa.bianp.net/blog/2013/numerical-optimizers-for-logistic-regression
rng = np.random.RandomState(42)
X = np.vstack([np.zeros(1000), np.ones(1000)])
rbm1 = BernoulliRBM(n_components=10, batch_size=2,
n_iter=10, random_state=rng)
rbm1.fit(X)
assert((rbm1.score_samples(X) < -300).all())
assert_true((rbm1.score_samples(X) < -300).all())

# Sparse vs. dense should not affect the output. Also test sparse input
# validation.
rbm1.random_state = 42
d_score = rbm1.score_samples(X)
rbm1.random_state = 42
s_score = rbm1.score_samples(lil_matrix(X))
assert_almost_equal(d_score, s_score)


def test_rbm_verbose():
Expand Down

0 comments on commit b0485c4

Please sign in to comment.