Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-dimensional output for CP regression #252

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 36 additions & 24 deletions tensorly/regression/cp_regression.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np

from .. import backend as T
from ..base import partial_tensor_to_vec, partial_unfold
from ..tenalg import khatri_rao
from ..cp_tensor import cp_to_tensor, cp_to_vec
from .. import backend as T
from ..tenalg import khatri_rao
from ..utils import DefineDeprecated


# Author: Jean Kossaifi

# License: BSD 3 clause
Expand All @@ -21,8 +23,8 @@ class CPRegressor:
rank of the CP decomposition of the regression weights
tol : float
convergence value
reg_W : int, optional, default is 1
regularisation on the weights
reg_W : float, optional, default is 1
l2 regularisation constant for the regression weights (:math:`reg_W * \sum_i ||factors[i]||_F^2`)
n_iter_max : int, optional, default is 100
maximum number of iteration
random_state : None, int or RandomState, optional, default is None
Expand Down Expand Up @@ -69,9 +71,9 @@ def fit(self, X, y):

Parameters
----------
X : ndarray
tensor data of shape (n_samples, N1, ..., NS)
y : 1D-array of shape (n_samples, )
X : tensor of shape (n_samples, I_1, ..., I_p)
tensor data
y : tensor of shape (n_samples, O_1, ..., O_q)
labels associated with each sample

Returns
Expand All @@ -80,12 +82,13 @@ def fit(self, X, y):
"""
rng = T.check_random_state(self.random_state)

# Initialise randomly the weights
# Initialise the weights randomly
W = []

for i in range(
1, T.ndim(X)
): # The first dimension of X is the number of samples
W.append(T.tensor(rng.randn(X.shape[i], self.weight_rank), **T.context(X)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that I resolved the merge conflict correctly here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I definitely think I messed up the merge here.

W.append(T.tensor(rng.randn(y.shape[i], self.weight_rank), **T.context(X)))

# Norm of the weight tensor at each iteration
norm_W = []
Expand All @@ -95,19 +98,23 @@ def fit(self, X, y):

# Optimise each factor of W
for i in range(len(W)):
phi = T.reshape(
T.dot(
partial_unfold(X, i, skip_begin=1), khatri_rao(W, skip_matrix=i)
),
(X.shape[0], -1),
)
inv_term = T.dot(T.transpose(phi), phi) + self.reg_W * T.tensor(
np.eye(phi.shape[1]), **T.context(X)
)
W[i] = T.reshape(
T.solve(inv_term, T.dot(T.transpose(phi), y)),
(X.shape[i + 1], self.weight_rank),
)
if i < T.ndim(X) - 1:
X_unfolded = partial_unfold(X, i, skip_begin=1)
phi = T.dot(X_unfolded, T.reshape(khatri_rao(W, skip_matrix=i), (X_unfolded.shape[-1], -1)))
phi = T.transpose(T.reshape(phi, (X.shape[0], X.shape[i + 1], -1, self.weight_rank)), (0, 2, 1, 3))
phi = T.reshape(phi, (-1, X.shape[i + 1] * self.weight_rank))
y_reshaped = T.reshape(y, (-1,))
inv_term = T.dot(T.transpose(phi), phi) + self.reg_W * T.tensor(np.eye(phi.shape[1]), **T.context(X))
W[i] = T.reshape(
T.solve(inv_term, T.dot(T.transpose(phi), y_reshaped)),
(-1, self.weight_rank))
else:
X_unfolded = partial_tensor_to_vec(X, skip_begin=1)
phi = T.dot(X_unfolded, T.reshape(khatri_rao(W, skip_matrix=i), (X_unfolded.shape[-1], -1)))
phi = T.reshape(phi, (-1, self.weight_rank))
y_reshaped = T.reshape(T.moveaxis(y, i - T.ndim(X) + 2, -1), (-1, y.shape[i - T.ndim(X) + 2]))
inv_term = T.dot(T.transpose(phi), phi) + self.reg_W * T.tensor(np.eye(phi.shape[1]), **T.context(X))
W[i] = T.transpose(T.solve(inv_term, T.dot(T.transpose(phi), y_reshaped)))

weight_tensor_ = cp_to_tensor((weights, W))
norm_W.append(T.norm(weight_tensor_, 2))
Expand Down Expand Up @@ -136,9 +143,14 @@ def predict(self, X):
Parameters
----------
X : ndarray
tensor data of shape (n_samples, N1, ..., NS)
tensor data of shape (n_samples, I_1, ..., I_p)
"""
return T.dot(partial_tensor_to_vec(X), self.vec_W_)
out_shape = (-1, *self.weight_tensor_.shape[T.ndim(X) - 1:])
if T.ndim(self.weight_tensor_) > T.ndim(X) - 1:
weight_shape = (-1, int(np.prod(self.weight_tensor_.shape[T.ndim(X) - 1:])))
else:
weight_shape = (-1,)
return T.reshape(T.dot(partial_tensor_to_vec(X), T.reshape(self.weight_tensor_, weight_shape)), out_shape)


KruskalRegressor = DefineDeprecated("KruskalRegressor", CPRegressor)
20 changes: 20 additions & 0 deletions tensorly/regression/tests/test_cp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ...base import tensor_to_vec, partial_tensor_to_vec
from ...metrics.regression import RMSE
from ... import backend as T
from ...random import random_cp
from ...testing import assert_


Expand Down Expand Up @@ -52,3 +53,22 @@ def test_CPRegressor():
estimator.weight_rank == 5,
msg="set_params did not correctly set the given parameters",
)


def test_multidim_CPRegressor():
tol = 0.005
rng = T.check_random_state(1234)

regression_weights = random_cp(shape=(12, 5, 4, 3, 2), rank=4, full=True, random_state=rng)
X = T.randn((1200, 12, 5, 4), seed=rng)
y = T.reshape(T.dot(partial_tensor_to_vec(X), T.reshape(regression_weights, (-1, 3*2))), (-1, 3, 2))
X_train = X[:1000]
X_test = X[1000:]
y_train = y[:1000]
y_test = y[1000:]

estimator = CPRegressor(weight_rank=20, tol=1e-8, reg_W=0., n_iter_max=200, verbose=True)
estimator.fit(X_train, y_train)
y_pred = estimator.predict(X_test)
error = RMSE(y_test, y_pred)
assert_(error <= tol, msg='CP Regressor : RMSE is too large, {} > {}'.format(error, tol))