diff --git a/examples/glm/plot_ard.py b/examples/glm/plot_ard.py index 87313e33a703a..52cd4f28b8dbb 100644 --- a/examples/glm/plot_ard.py +++ b/examples/glm/plot_ard.py @@ -1,50 +1,47 @@ """ ================================================== -Automatic Relevance Determination Regression +Automatic Relevance Determination Regression (ARD) ================================================== + +Fit regression model with ARD """ +print __doc__ -from scikits.learn.glm import ARDRegression import numpy as np import pylab as pl -import numpy.random as nr -import scipy.stats as st +from scipy import stats +from scikits.learn.glm import ARDRegression ################################################################################ # Generating simulated data with Gaussian weigthts ### Parameters of the example -nr.seed(0) -n_samples = 50 -n_features = 100 +np.random.seed(0) +n_samples, n_features = 50, 100 ### Create gaussian data -X = nr.randn(n_samples, n_features) +X = np.random.randn(n_samples, n_features) ### Create weigts with a precision lambda_ of 4. lambda_ = 4. w = np.zeros(n_features) ### Only keep 10 weights of interest -relevant_features = nr.randint(0,n_features,10) +relevant_features = np.random.randint(0, n_features, 10) for i in relevant_features: - w[i] = st.norm.rvs(loc = 0, scale = 1./np.sqrt(lambda_)) + w[i] = stats.norm.rvs(loc=0, scale=1. / np.sqrt(lambda_)) ### Create noite with a precision alpha of 50. alpha_ = 50. -noise = st.norm.rvs(loc = 0, scale = 1./np.sqrt(alpha_), size = n_samples) +noise = stats.norm.rvs(loc=0, scale=1. / np.sqrt(alpha_), size=n_samples) ### Create the target -Y = np.dot(X, w) + noise - +y = np.dot(X, w) + noise ################################################################################ ### Fit the ARD Regression clf = ARDRegression(compute_score = True) -clf.fit(X, Y) - - +clf.fit(X, y) ################################################################################ ### Plot the true weights, the estimated weights and the histogram of the ### weights - pl.figure() axe = pl.axes([0.1,0.6,0.8,0.325]) axe.set_title("ARD - Weights of the model") @@ -65,7 +62,7 @@ axe = pl.axes([0.65,0.1,0.3,0.325]) axe.set_title("Objective function") -axe.plot(clf.all_score_) +axe.plot(clf.scores_) axe.set_ylabel("Score") axe.set_xlabel("Iterations") pl.show() diff --git a/examples/glm/plot_bayesian_ridge.py b/examples/glm/plot_bayesian_ridge.py index edcad10971130..87028ccde7afc 100644 --- a/examples/glm/plot_bayesian_ridge.py +++ b/examples/glm/plot_bayesian_ridge.py @@ -29,12 +29,12 @@ alpha_ = 50. noise = stats.norm.rvs(loc = 0, scale = 1./np.sqrt(alpha_), size = n_samples) # Create the target -Y = np.dot(X, w) + noise +y = np.dot(X, w) + noise ################################################################################ # Fit the Bayesian Ridge Regression clf = BayesianRidge(compute_score=True) -clf.fit(X, Y) +clf.fit(X, y) ################################################################################ # Plot true weights, estimated weights and histogram of the weights @@ -45,7 +45,7 @@ axe.plot(w, 'g-', label="Ground truth") axe.set_xlabel("Features") axe.set_ylabel("Values of the weights") -axe.legend(loc=1) +axe.legend(loc="upper right") axe = pl.axes([0.1,0.1,0.45,0.325]) axe.set_title("Histogram of the weights") @@ -54,11 +54,11 @@ label="Relevant features") axe.set_ylabel("Features") axe.set_xlabel("Values of the weights") -axe.legend(loc=1) +axe.legend(loc="lower left") axe = pl.axes([0.65,0.1,0.3,0.325]) axe.set_title("Objective function") -axe.plot(clf.all_score_) +axe.plot(clf.scores_) axe.set_ylabel("Score") axe.set_xlabel("Iterations") pl.show() diff --git a/scikits/learn/glm/bayes.py b/scikits/learn/glm/bayes.py index 7855ae1566e47..f7dae1ba9f4ce 100644 --- a/scikits/learn/glm/bayes.py +++ b/scikits/learn/glm/bayes.py @@ -2,10 +2,10 @@ Various bayesian regression """ -# Authors: V. Michel, F. Pedregosa +# Authors: V. Michel, F. Pedregosa, A. Gramfort # License: BSD 3 clause - +from math import log import numpy as np from scipy import linalg @@ -23,55 +23,58 @@ class BayesianRidge(LinearModel): Parameters ---------- - X : numpy array of shape (length,features) + X : array, shape = (n_samples, n_features) Training vectors. - Y : numpy array of shape (length) + y : array, shape = (length) Target values for training vectors - n_iter : int (default is 300) - Maximum number of interations. + n_iter : int, optional + Maximum number of interations. Default is 300. - eps : float (default is 1.e-3) - Stop the algorithm if w has converged. + eps : float, optional + Stop the algorithm if w has converged. Default is 1.e-3. - alpha_1 : float (default is 1.e-6) - Hyper-parameter : shape parameter for the Gamma distribution prior over - the alpha parameter. + alpha_1 : float, optional + Hyper-parameter : shape parameter for the Gamma distribution prior + over the alpha parameter. Default is 1.e-6 - alpha_2 : float (default is 1.e-6) - Hyper-parameter : inverse scale parameter (rate parameter) for the Gamma - distribution prior over the alpha parameter. + alpha_2 : float, optional + Hyper-parameter : inverse scale parameter (rate parameter) for the + Gamma distribution prior over the alpha parameter. + Default is 1.e-6. - lambda_1 : float (default is 1.e-6) - Hyper-parameter : shape parameter for the Gamma distribution prior over - the lambda parameter. + lambda_1 : float, optional + Hyper-parameter : shape parameter for the Gamma distribution prior + over the lambda parameter. Default is 1.e-6. - lambda_2 : float (default is 1.e-6) - Hyper-parameter : inverse scale parameter (rate parameter) for the Gamma - distribution prior over the lambda parameter. + lambda_2 : float, optional + Hyper-parameter : inverse scale parameter (rate parameter) for the + Gamma distribution prior over the lambda parameter. + Default is 1.e-6 - compute_score : boolean (default is False) + compute_score : boolean, optional If True, compute the objective function at each step of the model. + Default is False - fit_intercept : boolean (default is True) + fit_intercept : boolean, optional wether to calculate the intercept for this model. If set to false, no intercept will be used in calculations (e.g. data is expected to be already centered). + Default is True. Attributes ---------- - coef_ : numpy array of shape (nb_features) - Coefficients of the regression model (mean of the weights - distribution.) + `coef_` : array, shape = (n_features) + Coefficients of the regression model (mean of distribution) - alpha_ : float + `alpha_` : float estimated precision of the noise. - lambda_ : numpy array of shape (nb_features) + `lambda_` : array, shape = (n_features) estimated precisions of the weights. - score_ : float + `scores_` : float if computed, value of the objective function (to be maximized) Methods @@ -82,50 +85,27 @@ class BayesianRidge(LinearModel): predict(X) : array Predict using the model. - Examples -------- - + >>> from scikits.learn import glm + >>> clf = glm.BayesianRidge() + >>> clf.fit([[0,0], [1, 1], [2, 2]], [0, 1, 2]) + BayesianRidge(n_iter=300, verbose=False, lambda_1=1e-06, lambda_2=1e-06, + fit_intercept=True, eps=0.001, alpha_2=1e-06, alpha_1=1e-06, + compute_score=False) + >>> print clf.coef_ + [ 0.49999975 0.49999975] + >>> print clf.intercept_ + 5.02041075379e-07 + + Notes + ----- + See examples/glm/plot_bayesian_ridge.py for an example. """ - def __init__(self, n_iter=300, eps=1.e-3, alpha_1 = 1.e-6, alpha_2 = 1.e-6, lambda_1=1.e-6, lambda_2=1.e-6, compute_score=False, - fit_intercept=True): - """ - Parameters - ---------- - n_iter : int (default is 300) - Maximum number of interations. - - eps : float (default is 1.e-3) - Stop the algorithm if w has converged. - - alpha_1 : float (default is 1.e-6) - Hyper-parameter : shape parameter for the Gamma distribution prior - over the alpha parameter. - - alpha_2 : float (default is 1.e-6) - Hyper-parameter : inverse scale parameter (rate parameter) for the - Gamma distribution prior over the alpha parameter. - - lambda_1 : float (default is 1.e-6) - Hyper-parameter : shape parameter for the Gamma distribution prior - over the lambda parameter. - - lambda_2 : float (default is 1.e-6) - Hyper-parameter : inverse scale parameter (rate parameter) for the - Gamma distribution prior over the lambda parameter. - - compute_score : boolean (default is False) - If True, compute the objective function at each step of the model. - - fit_intercept : boolean (default is True) - wether to calculate the intercept for this model. If set - to false, no intercept will be used in calculations - (e.g. data is expected to be already centered). - - """ + fit_intercept=True, verbose=False): self.n_iter = n_iter self.eps = eps self.alpha_1 = alpha_1 @@ -134,14 +114,16 @@ def __init__(self, n_iter=300, eps=1.e-3, alpha_1 = 1.e-6, alpha_2 = 1.e-6, self.lambda_2 = lambda_2 self.compute_score = compute_score self.fit_intercept = fit_intercept + self.verbose = verbose - def fit(self, X, Y, verbose=False, **params): - """ + def fit(self, X, y, **params): + """Fit the model + Parameters ---------- X : numpy array of shape [n_samples,n_features] Training data - Y : numpy array of shape [n_samples] + y : numpy array of shape [n_samples] Target values Returns @@ -150,17 +132,26 @@ def fit(self, X, Y, verbose=False, **params): """ self._set_params(**params) X = np.asanyarray(X, dtype=np.float) - Y = np.asanyarray(Y, dtype=np.float) - X, Y, Xmean, Ymean = self._center_data (X, Y) + y = np.asanyarray(y, dtype=np.float) + X, y, Xmean, ymean = self._center_data(X, y) n_samples, n_features = X.shape ### Initialization of the values of the parameters - self.alpha_ = 1./np.var(Y) - self.lambda_ = 1. - XT_Y = np.dot(X.T, Y) + alpha_ = 1. / np.var(y) + lambda_ = 1. + + verbose = self.verbose + lambda_1 = self.lambda_1 + lambda_2 = self.lambda_2 + alpha_1 = self.alpha_1 + alpha_2 = self.alpha_2 + + self.scores_ = list() + coef_old_ = None + + XT_y = np.dot(X.T, y) U, S, Vh = linalg.svd(X, full_matrices=False) eigen_vals_ = S**2 - self.all_score_ = [] ### Convergence loop of the bayesian ridge regression for iter_ in range(self.n_iter): @@ -170,79 +161,58 @@ def fit(self, X, Y, verbose=False, **params): # coef_ = sigma_^-1 * XT * y if n_samples > n_features: coef_ = np.dot(Vh.T, - Vh / (eigen_vals_ + self.lambda_ / self.alpha_)[:,None]) - coef_ = np.dot(coef_, XT_Y) + Vh / (eigen_vals_ + lambda_ / alpha_)[:,None]) + coef_ = np.dot(coef_, XT_y) if self.compute_score: logdet_sigma_ = - np.sum( - np.log(self.lambda_ + self.alpha_* eigen_vals_)) + np.log(lambda_ + alpha_* eigen_vals_)) else: - coef_ = np.dot(X.T, np.dot(U / - (eigen_vals_ + self.lambda_ / self.alpha_)[None,:], - U.T)) - coef_ = np.dot(coef_, Y) + coef_ = np.dot(X.T, np.dot( + U / (eigen_vals_ + lambda_ / alpha_)[None,:], U.T)) + coef_ = np.dot(coef_, y) if self.compute_score: - logdet_sigma_ = self.lambda_ * np.ones(n_features) - logdet_sigma_[:n_samples] += self.alpha_ * eigen_vals_ + logdet_sigma_ = lambda_ * np.ones(n_features) + logdet_sigma_[:n_samples] += alpha_ * eigen_vals_ logdet_sigma_ = - np.sum(np.log(logdet_sigma_)) - if self.compute_score: - self.logdet_sigma_ = logdet_sigma_ - ### Update alpha and lambda - self.rmse_ = np.sum((Y - np.dot(X, coef_))**2) - self.gamma_ = np.sum((self.alpha_ * eigen_vals_) \ - / (self.lambda_ + self.alpha_ * eigen_vals_)) - self.lambda_ = (self.gamma_ + 2*self.lambda_1) \ - / (np.sum(coef_**2) + 2*self.lambda_2) - self.alpha_ = (n_samples - self.gamma_ + 2*self.alpha_1) \ - / (self.rmse_ + 2*self.alpha_2) - - self.coef_ = coef_ + rmse_ = np.sum((y - np.dot(X, coef_))**2) + gamma_ = np.sum((alpha_ * eigen_vals_) \ + / (lambda_ + alpha_ * eigen_vals_)) + lambda_ = (gamma_ + 2 * lambda_1) \ + / (np.sum(coef_**2) + 2 * lambda_2) + alpha_ = (n_samples - gamma_ + 2 * alpha_1) \ + / (rmse_ + 2 * alpha_2) ### Compute the objective function if self.compute_score: - self.all_score_.append(self.objective_function(X)) + s = lambda_1 * log(lambda_) - lambda_2 * lambda_ + s += alpha_1 * log(alpha_) - alpha_2 * alpha_ + s += 0.5 * n_features * log(lambda_) \ + + 0.5 * n_samples * log(alpha_) \ + - 0.5 * alpha_ * rmse_ \ + - 0.5 * (lambda_ * np.sum(coef_**2)) \ + - 0.5 * logdet_sigma_ \ + - 0.5 * n_samples * log(2 * np.pi) + self.scores_.append(s) ### Check for convergence - if (iter_ != 0 and np.sum(np.abs(coef_old_ - coef_)) < self.eps): + if iter_ != 0 and np.sum(np.abs(coef_old_ - coef_)) < self.eps: if verbose: print "Convergence after ", str(iter_), " iterations" break coef_old_ = np.copy(coef_) - self._set_intercept(Xmean, Ymean) + self.alpha_ = alpha_ + self.lambda_ = lambda_ + self.coef_ = coef_ + + self._set_intercept(Xmean, ymean) # Store explained variance for __str__ - self.explained_variance_ = self._explained_variance(X, Y) + self.explained_variance_ = self._explained_variance(X, y) return self - def objective_function(self, X): - """ - Compute the objective function. - - Parameters - ---------- - X : array-like, shape = [n_samples, n_features] - Training vector, where n_samples in the number of samples and - n_features is the number of features. - - Returns - ------- - score_ : value of the objective function (to be maximized) - """ - score_ = self.lambda_1 * np.log(self.lambda_) - self.lambda_2 \ - * self.lambda_ - score_ += self.alpha_1 * np.log(self.alpha_) - self.alpha_2 \ - * self.alpha_ - score_ += 0.5 * X.shape[1] * np.log(self.lambda_) \ - + 0.5 * X.shape[0] * np.log(self.alpha_) \ - - 0.5 * self.alpha_ * self.rmse_ \ - - 0.5 * (self.lambda_ * np.sum(self.coef_**2)) \ - - 0.5 * self.logdet_sigma_ \ - - 0.5 * X.shape[0] * np.log(2*np.pi) - return score_ - - ############################################################################### # ARD (Automatic Relevance Determination) regression @@ -258,62 +228,66 @@ class ARDRegression(LinearModel): Parameters ---------- - X : numpy array of shape (length,features) + X : array, shape = (n_samples, n_features) Training vectors. - Y : numpy array of shape (length) + y : array, shape = (n_samples) Target values for training vectors - n_iter : int (default is 300) - Maximum number of interations. + n_iter : int, optional + Maximum number of interations. Default is 300 - eps : float (default is 1.e-3) - Stop the algorithm if w has converged. + eps : float, optional + Stop the algorithm if w has converged. Default is 1.e-3. - alpha_1 : float (default is 1.e-6) - Hyper-parameter : shape parameter for the Gamma distribution prior over - the alpha parameter. + alpha_1 : float, optional + Hyper-parameter : shape parameter for the Gamma distribution prior + over the alpha parameter. Default is 1.e-6. - alpha_2 : float (default is 1.e-6) - Hyper-parameter : inverse scale parameter (rate parameter) for the Gamma - distribution prior over the alpha parameter. + alpha_2 : float, optional + Hyper-parameter : inverse scale parameter (rate parameter) for the + Gamma distribution prior over the alpha parameter. Default is 1.e-6. - lambda_1 : float (default is 1.e-6) - Hyper-parameter : shape parameter for the Gamma distribution prior over - the lambda parameter. + lambda_1 : float, optional + Hyper-parameter : shape parameter for the Gamma distribution prior + over the lambda parameter. Default is 1.e-6. - lambda_2 : float (default is 1.e-6) - Hyper-parameter : inverse scale parameter (rate parameter) for the Gamma - distribution prior over the lambda parameter. + lambda_2 : float, optional + Hyper-parameter : inverse scale parameter (rate parameter) for the + Gamma distribution prior over the lambda parameter. Default is 1.e-6. - compute_score : boolean (default is False) + compute_score : boolean, optional If True, compute the objective function at each step of the model. + Default is False. - threshold_lambda : float (default is 1.e+4) + threshold_lambda : float, optional threshold for removing (pruning) weights with high precision from - the computation. + the computation. Default is 1.e+4. - fit_intercept : boolean (default is True) + fit_intercept : boolean, optional wether to calculate the intercept for this model. If set to false, no intercept will be used in calculations (e.g. data is expected to be already centered). + Default is True. + + verbose : boolean, optional + Verbose mode when fitting the model. Default is False. Attributes ---------- - coef_ : numpy array of shape (nb_features) - Coefficients of the regression model (mean of the weights - distribution.) + `coef_` : array, shape = (n_features) + Coefficients of the regression model (mean of distribution) - alpha_ : float + `alpha_` : float estimated precision of the noise. - lambda_ : numpy array of shape (nb_features) + `lambda_` : array, shape = (n_features) estimated precisions of the weights. - sigma_ : numpy array of shape (nb_features,nb_features) + `sigma_` : array, shape = (n_features, n_features) estimated variance-covariance matrix of the weights - score_ : float + `scores_` : float if computed, value of the objective function (to be maximized) Methods @@ -324,51 +298,27 @@ class ARDRegression(LinearModel): predict(X) : array Predict using the model. - Examples -------- + >>> from scikits.learn import glm + >>> clf = glm.ARDRegression() + >>> clf.fit([[0,0], [1, 1], [2, 2]], [0, 1, 2]) + ARDRegression(n_iter=300, verbose=False, lambda_1=1e-06, lambda_2=1e-06, + fit_intercept=True, eps=0.001, threshold_lambda=10000.0, + alpha_2=1e-06, alpha_1=1e-06, compute_score=False) + >>> print clf.coef_ + [ 0.49986097 0.49986097] + >>> print clf.intercept_ + 0.000278067476205 + + Notes + -------- + See examples/glm/plot_ard.py for an example. """ - def __init__(self, n_iter=300, eps=1.e-3, alpha_1 = 1.e-6, alpha_2 = 1.e-6, - lambda_1 = 1.e-6, lambda_2 = 1.e-6, compute_score = False, - threshold_lambda = 1.e+4, fit_intercept = True): - """ - Parameters - ---------- - n_iter : int (default is 300) - Maximum number of interations. - - eps : float (default is 1.e-3) - Stop the algorithm if w has converged. - - alpha_1 : float (default is 1.e-6) - Hyper-parameter : shape parameter for the Gamma distribution prior - over the alpha parameter. - - alpha_2 : float (default is 1.e-6) - Hyper-parameter : inverse scale parameter (rate parameter) for the - Gamma distribution prior over the alpha parameter. - - lambda_1 : float (default is 1.e-6) - Hyper-parameter : shape parameter for the Gamma distribution prior - over the lambda parameter. - - lambda_2 : float (default is 1.e-6) - Hyper-parameter : inverse scale parameter (rate parameter) for the - Gamma distribution prior over the lambda parameter. - - compute_score : boolean (default is False) - If True, compute the objective function at each step of the model. - - threshold_lambda : float (default is 1.e+4) - threshold for removing (pruning) weights with high precision from - the computation. - - fit_intercept : boolean (default is True) - wether to calculate the intercept for this model. If set - to false, no intercept will be used in calculations - (e.g. data is expected to be already centered). - """ + def __init__(self, n_iter=300, eps=1.e-3, alpha_1=1.e-6, alpha_2=1.e-6, + lambda_1=1.e-6, lambda_2 = 1.e-6, compute_score=False, + threshold_lambda=1.e+4, fit_intercept=True, verbose=False): self.n_iter = n_iter self.eps = eps self.fit_intercept = fit_intercept @@ -378,19 +328,20 @@ def __init__(self, n_iter=300, eps=1.e-3, alpha_1 = 1.e-6, alpha_2 = 1.e-6, self.lambda_2 = lambda_2 self.compute_score = compute_score self.threshold_lambda = threshold_lambda + self.verbose = verbose + def fit(self, X, y, **params): + """Fit the ARDRegression model according to the given training data + and parameters. - def fit(self, X, Y, **params): - """ - Fit the ARDRegression model according to the given training data and - parameters. + Iterative procedure to maximize the evidence Parameters ---------- X : array-like, shape = [n_samples, n_features] Training vector, where n_samples in the number of samples and n_features is the number of features. - Y : array, shape = [n_samples] + y : array, shape = [n_samples] Target values (integers) Returns @@ -400,119 +351,78 @@ def fit(self, X, Y, **params): self._set_params(**params) X = np.asanyarray(X, dtype=np.float) - Y = np.asanyarray(Y, dtype=np.float) + y = np.asanyarray(y, dtype=np.float) n_samples, n_features = X.shape + coef_ = np.zeros(n_features) - X, Y, Xmean, Ymean = self._center_data (X, Y) - - ### Initialization of the values of the parameters - self.alpha_ = 1./np.var(Y) - self.lambda_ = np.ones(n_features) - self.all_score_ = [] + X, y, Xmean, ymean = self._center_data(X, y) ### Launch the convergence loop - self.evidence_maximization(X, Y) - - self._set_intercept(Xmean, Ymean) - # Store explained variance for __str__ - self.explained_variance_ = self._explained_variance(X, Y) - return self - - def evidence_maximization(self, X, Y, verbose=False): - """ - Iterative procedure for estimating the ARDRegression model according to - the given training data and parameters. + keep_lambda = np.ones(n_features, dtype=bool) - Parameters - ---------- - X : array-like, shape = [n_samples, n_features] - Training vector, where n_samples in the number of samples and - n_features is the number of features. - Y : array, shape = [n_samples] - Target values (integers) + lambda_1 = self.lambda_1 + lambda_2 = self.lambda_2 + alpha_1 = self.alpha_1 + alpha_2 = self.alpha_2 + verbose = self.verbose - Attributes - ---------- - keep_lambda : boolean numpy array of shape (nb_features) - Lambda under a given threshold, to be keep for the computation. - Avoid divergence when lambda is to high - """ + ### Initialization of the values of the parameters + alpha_ = 1. / np.var(y) + lambda_ = np.ones(n_features) - n_samples, n_features = X.shape - coef_ = np.zeros(n_features) - keep_lambda = np.ones(n_features,dtype=bool) + self.scores_ = list() + coef_old_ = None ### Iterative procedure of ARDRegression for iter_ in range(self.n_iter): - ### Compute mu and sigma (using Woodbury matrix identity) - self.sigma_ = linalg.pinv(np.eye(n_samples)/self.alpha_ + + sigma_ = linalg.pinv(np.eye(n_samples) / alpha_ + np.dot(X[:,keep_lambda] * - np.reshape(1./self.lambda_[keep_lambda],[1,-1]), + np.reshape(1. / lambda_[keep_lambda], [1, -1]), X[:,keep_lambda].T)) - self.sigma_ = np.dot(self.sigma_,X[:,keep_lambda] - * np.reshape(1./self.lambda_[keep_lambda], - [1,-1])) - self.sigma_ = - np.dot(np.reshape(1./self.lambda_[keep_lambda], - [-1,1]) * X[:,keep_lambda].T ,self.sigma_) - self.sigma_.flat[::(self.sigma_.shape[1]+1)] += \ - 1./self.lambda_[keep_lambda] - coef_[keep_lambda] = self.alpha_ \ - * np.dot(self.sigma_,np.dot(X[:,keep_lambda].T, - Y)) + sigma_ = np.dot(sigma_, X[:,keep_lambda] + * np.reshape(1. / lambda_[keep_lambda], [1, -1])) + sigma_ = - np.dot(np.reshape( 1. / lambda_[keep_lambda], [-1, 1]) + * X[:,keep_lambda].T, sigma_) + sigma_.flat[::(sigma_.shape[1] + 1)] += \ + 1. / lambda_[keep_lambda] + coef_[keep_lambda] = alpha_ * np.dot( + sigma_,np.dot(X[:,keep_lambda].T, y)) ### Update alpha and lambda - self.rmse_ = np.sum((Y - np.dot(X, coef_))**2) - self.gamma_ = 1. - self.lambda_[keep_lambda]\ - *np.diag(self.sigma_) - self.lambda_[keep_lambda] = (self.gamma_ + 2*self.lambda_1)\ - /((coef_[keep_lambda])**2 + 2*self.lambda_2) - self.alpha_ = (n_samples - self.gamma_.sum() + 2*self.alpha_1)\ - /(self.rmse_ + 2*self.alpha_2) + rmse_ = np.sum((y - np.dot(X, coef_))**2) + gamma_ = 1. - lambda_[keep_lambda] * np.diag(sigma_) + lambda_[keep_lambda] = (gamma_ + 2. * lambda_1) \ + / ((coef_[keep_lambda])**2 + 2. * lambda_2) + alpha_ = (n_samples - gamma_.sum() + 2. * alpha_1) \ + / (rmse_ + 2. * alpha_2) ### Prune the weights with a precision over a threshold - keep_lambda = self.lambda_ < self.threshold_lambda + keep_lambda = lambda_ < self.threshold_lambda coef_[keep_lambda == False] = 0 - self.coef_ = coef_ - ### Compute the objective function if self.compute_score: - self.all_score_.append(self.objective_function(X)) + s = (lambda_1 * np.log(lambda_) - lambda_2 * lambda_).sum() + s += alpha_1 * log(alpha_) - alpha_2 * alpha_ + s += 0.5 * (fast_logdet(sigma_) + n_samples * log(alpha_) + + np.sum(np.log(lambda_))) + s -= 0.5 * (alpha_ * rmse_ + (lambda_ * coef_**2).sum()) + self.scores_.append(s) ### Check for convergence - if iter_ != 0 and np.sum(np.abs(coef_old_ - coef_)) < self.eps: + if iter_ > 0 and np.sum(np.abs(coef_old_ - coef_)) < self.eps: if verbose: - print "Convergence after %s iterations" % iter_ + print "Converged after %s iterations" % iter_ break coef_old_ = np.copy(coef_) + self.coef_ = coef_ + self.alpha_ = alpha_ + self.sigma_ = sigma_ - def objective_function(self, X): - """ - Compute the objective function. - - Parameters - ---------- - X : array-like, shape = [n_samples, n_features] - Training vector, where n_samples in the number of samples and - n_features is the number of features. - - Returns - ------- - score_ : value of the objective function (to be maximized) - """ - - score_ = (self.lambda_1 * np.log(self.lambda_) - self.lambda_2\ - * self.lambda_).sum() - score_ += self.alpha_1 * np.log(self.alpha_) - self.alpha_2\ - * self.alpha_ - score_ += 0.5 * (fast_logdet(self.sigma_) + X.shape[0]\ - * np.log(self.alpha_) + np.sum(np.log(self.lambda_))) - score_ -= 0.5 * (self.alpha_ * self.rmse_\ - + (self.lambda_ * self.coef_**2).sum()) - return score_ - - - + self._set_intercept(Xmean, ymean) + # Store explained variance for __str__ + self.explained_variance_ = self._explained_variance(X, y) + return self diff --git a/scikits/learn/glm/benchmarks/bench_bayes.py b/scikits/learn/glm/benchmarks/bench_bayes.py index 4af34aa16ce1a..8a7bd40424175 100644 --- a/scikits/learn/glm/benchmarks/bench_bayes.py +++ b/scikits/learn/glm/benchmarks/bench_bayes.py @@ -16,17 +16,17 @@ n_iter = 20 - time_ridge = np.empty (n_iter) - time_ols = np.empty (n_iter) - time_lasso = np.empty (n_iter) + time_ridge = np.empty(n_iter) + time_ols = np.empty(n_iter) + time_lasso = np.empty(n_iter) dimensions = 10 * np.arange(n_iter) - n, m = 100, 100 + n_samples, n_features = 100, 100 - X = np.random.randn (n, m) - Y = np.random.randn (n) + X = np.random.randn(n_samples, n_features) + y = np.random.randn(n_samples) start = datetime.now() ridge = glm.BayesianRidge() - ridge.fit (X, Y) + ridge.fit(X, y) diff --git a/scikits/learn/glm/tests/test_bayes.py b/scikits/learn/glm/tests/test_bayes.py index 9ebbf23bd5d5f..11a3e2c8394c8 100644 --- a/scikits/learn/glm/tests/test_bayes.py +++ b/scikits/learn/glm/tests/test_bayes.py @@ -25,14 +25,14 @@ def test_bayesian_on_diabetes(): # Test with more samples than features clf.fit(X, y) # Test that scores are increasing at each iteration - assert_array_equal(np.diff(clf.all_score_) > 0, True) + assert_array_equal(np.diff(clf.scores_) > 0, True) # Test with more features than samples X = X[:5,:] y = y[:5] clf.fit(X, y) # Test that scores are increasing at each iteration - assert_array_equal(np.diff(clf.all_score_) > 0, True) + assert_array_equal(np.diff(clf.scores_) > 0, True) def test_toy_bayesian_ridge_object():