Skip to content

Commit

Permalink
Seems to be working
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmey committed Feb 13, 2021
1 parent 6f79580 commit f08bbba
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 88 deletions.
18 changes: 14 additions & 4 deletions tensorly/backend/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,7 @@ def sqrt(tensor):
"""
raise NotImplementedError

@staticmethod
def norm(tensor, order=2, axis=None):
def norm(self, tensor, order=2, axis=None):
"""Computes the l-`order` norm of a tensor.
Parameters
Expand All @@ -493,7 +492,18 @@ def norm(tensor, order=2, axis=None):
float or tensor
If `axis` is provided returns a tensor.
"""
raise NotImplementedError
# handle difference in default axis notation
if axis == ():
axis = None

if order == 'inf':
return self.max(self.abs(tensor), axis=axis)
if order == 1:
return self.sum(self.abs(tensor), axis=axis)
elif order == 2:
return self.sqrt(self.sum(tensor**2, axis=axis))
else:
return self.sum(self.abs(tensor)**order, axis=axis)**(1 / order)

@staticmethod
def dot(a, b):
Expand Down Expand Up @@ -919,7 +929,7 @@ def symeig_svd(self, matrix, n_eigenvecs=None, **kwargs):
S = self.sqrt(S)
U = self.dot(matrix, V) / self.reshape(S, (1, -1))

U, S, V = U[:, ::-1], S[::-1], self.transpose(V)[::-1, :]
U, S, V = self.flip(U, axis=1), self.flip(S), self.flip(self.transpose(V), axis=0)
return U[:, :n_eigenvecs], S[:n_eigenvecs], V[:n_eigenvecs, :]

index = Index()
Expand Down
18 changes: 0 additions & 18 deletions tensorly/backend/cupy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,6 @@ def ndim(tensor):
def clip(tensor, a_min=None, a_max=None):
return cp.clip(tensor, a_min, a_max)

def norm(self, tensor, order=2, axis=None):
# handle difference in default axis notation
if axis == ():
axis = None

if order == 'inf':
res = cp.max(cp.abs(tensor), axis=axis)
elif order == 1:
res = cp.sum(cp.abs(tensor), axis=axis)
elif order == 2:
res = cp.sqrt(cp.sum(tensor**2, axis=axis))
else:
res = cp.sum(cp.abs(tensor)**order, axis=axis)**(1 / order)

if res.shape == ():
return self.to_numpy(res)
return res

def solve(self, matrix1, matrix2):
try:
cp.linalg.solve(matrix1, matrix2)
Expand Down
17 changes: 1 addition & 16 deletions tensorly/backend/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,6 @@ def ndim(tensor):
def dot(a, b):
return a.dot(b)

@staticmethod
def norm(tensor, order=2, axis=None):
# handle difference in default axis notation
if axis == ():
axis = None

if order == 'inf':
return np.max(np.abs(tensor), axis=axis)
if order == 1:
return np.sum(np.abs(tensor), axis=axis)
elif order == 2:
return np.sqrt(np.sum(tensor**2, axis=axis))
else:
return np.sum(np.abs(tensor)**order, axis=axis)**(1 / order)

def kr(self, matrices, weights=None, mask=None):
n_columns = matrices[0].shape[1]
n_factors = len(matrices)
Expand All @@ -93,7 +78,7 @@ def sort(tensor, axis, descending = False):
return np.sort(tensor, axis=axis)

for name in ['int64', 'int32', 'float64', 'float32', 'reshape', 'moveaxis',
'where', 'transpose', 'arange', 'ones', 'zeros',
'where', 'transpose', 'arange', 'ones', 'zeros', 'flip',
'zeros_like', 'eye', 'kron', 'concatenate', 'max', 'min',
'all', 'mean', 'sum', 'prod', 'sign', 'abs', 'sqrt', 'argmin',
'argmax', 'stack', 'conj', 'diag', 'clip', 'einsum']:
Expand Down
17 changes: 1 addition & 16 deletions tensorly/backend/mxnet_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,6 @@ def ndim(tensor):
def dot(a, b):
return np.dot(a, b)

@staticmethod
def norm(tensor, order=2, axis=None):
# handle difference in default axis notation
if axis == ():
axis = None

if order == 'inf':
return np.max(np.abs(tensor), axis=axis)
if order == 1:
return np.sum(np.abs(tensor), axis=axis)
if order == 2:
return np.sqrt(np.sum(tensor**2, axis=axis))

return np.sum(np.abs(tensor)**order, axis=axis)**(1 / order)

@staticmethod
def clip(tensor, a_min=None, a_max=None):
return np.clip(tensor, a_min, a_max)
Expand Down Expand Up @@ -117,7 +102,7 @@ def sort(tensor, axis, descending = False):

for name in ['int64', 'int32', 'float64', 'float32', 'reshape', 'moveaxis',
'where', 'copy', 'transpose', 'arange', 'ones', 'zeros',
'zeros_like', 'eye', 'concatenate', 'max', 'min',
'zeros_like', 'eye', 'concatenate', 'max', 'min', 'flip',
'all', 'mean', 'sum', 'prod', 'sign', 'abs', 'sqrt', 'argmin',
'argmax', 'stack', 'diag', 'einsum']:
MxnetBackend.register_method(name, getattr(np, name))
Expand Down
17 changes: 1 addition & 16 deletions tensorly/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,6 @@ def clip(tensor, a_min=None, a_max=None):
def dot(a, b):
return a.dot(b)

@staticmethod
def norm(tensor, order=2, axis=None):
# handle difference in default axis notation
if axis == ():
axis = None

if order == 'inf':
return np.max(np.abs(tensor), axis=axis)
if order == 1:
return np.sum(np.abs(tensor), axis=axis)
elif order == 2:
return np.sqrt(np.sum(tensor**2, axis=axis))
else:
return np.sum(np.abs(tensor)**order, axis=axis)**(1 / order)

def kr(self, matrices, weights=None, mask=None):
n_columns = matrices[0].shape[1]
n_factors = len(matrices)
Expand All @@ -76,7 +61,7 @@ def sort(tensor, axis, descending = False):
return np.sort(tensor, axis=axis)

for name in ['int64', 'int32', 'float64', 'float32', 'reshape', 'moveaxis',
'where', 'copy', 'transpose', 'arange', 'ones', 'zeros',
'where', 'copy', 'transpose', 'arange', 'ones', 'zeros', 'flip',
'zeros_like', 'eye', 'kron', 'concatenate', 'max', 'min',
'all', 'mean', 'sum', 'prod', 'sign', 'abs', 'sqrt', 'argmin',
'argmax', 'stack', 'conj', 'diag', 'einsum']:
Expand Down
40 changes: 22 additions & 18 deletions tensorly/backend/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ def sum(tensor, axis=None):
return torch.sum(tensor)
else:
return torch.sum(tensor, dim=axis)

@staticmethod
def flip(tensor, axis=None):
if isinstance(axis, int):
axis = [axis]

if axis is None:
return torch.flip(tensor, dims=[i for i in range(tensor.ndim)])
else:
return torch.flip(tensor, dims=axis)

@staticmethod
def concatenate(tensors, axis=0):
Expand All @@ -174,28 +184,22 @@ def argmax(input, axis=None):
def stack(arrays, axis=0):
return torch.stack(arrays, dim=axis)

@staticmethod
def _reverse(tensor, axis=0):
"""Reverses the elements along the specified dimension
def svd(self, X, full_matrices=True):
# The torch SVD has accuracy issues. Try again when torch.linalg is stable.
ctx = self.context(X)
X = self.to_numpy(X)

Parameters
----------
tensor : tl.tensor
axis : int, default is 0
axis along which to reverse the ordering of the elements
U, S, V = np.linalg.svd(X, full_matrices=full_matrices)

Returns
-------
reversed_tensor : for a 1-D tensor, returns the equivalent of
tensor[::-1] in NumPy
"""
indices = torch.arange(tensor.shape[axis] - 1, -1, -1, dtype=torch.int64)
return tensor.index_select(axis, indices)
U = self.tensor(U, **ctx)
S = self.tensor(S, **ctx)
V = self.tensor(V, **ctx)

return U, S, V

@staticmethod
def svd(matrix, full_matrices=True):
"""Computes the standard SVD."""
return torch.svd(matrix, some=full_matrices)
def eigh(tensor):
return torch.symeig(tensor, eigenvectors=True)

@staticmethod
def sort(tensor, axis, descending = False):
Expand Down
9 changes: 9 additions & 0 deletions tensorly/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ def sort(tensor, axis, descending = False):
axis = -1

return tf.sort(tensor, axis=axis, direction = direction)

def flip(self, tensor, axis=None):
if isinstance(axis, int):
axis = [axis]

if axis is None:
return tf.reverse(tensor, axis=[i for i in range(self.ndim(tensor))])
else:
return tf.reverse(tensor, axis=axis)

def svd(self, matrix, full_matrices):
""" Correct for the atypical return order of tf.linalg.svd. """
Expand Down

0 comments on commit f08bbba

Please sign in to comment.