Skip to content

Commit

Permalink
followed Jean's recommendation
Browse files Browse the repository at this point in the history
  • Loading branch information
amanj120 committed Jul 14, 2020
1 parent 92931b7 commit 99c6df2
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 71 deletions.
13 changes: 2 additions & 11 deletions tensorly/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
'pytorch':'PyTorchBackend',
'tensorflow':'TensorflowBackend',
'cupy':'CupyBackend',
'jax': 'JaxBackend',
'tensorflow.sparse': 'TensorflowSparseBackend'}
'jax': 'JaxBackend'}

_LOADED_BACKENDS = {}
_LOCAL_STATE = threading.local()
Expand Down Expand Up @@ -51,16 +50,8 @@ def register_backend(backend_name):
If `backend_name` does not correspond to one listed
in `_KNOWN_BACKEND`
"""
module_list = {'numpy': 'tensorly.backend.numpy_backend',
'mxnet': 'tensorly.backend.mxnet_backend',
'pytorch': 'tensorly.backend.pytorch_backend',
'tensorflow': 'tensorly.backend.tensorflow_backend',
'cupy': 'tensorly.backend.cupy_backend',
'jax': 'tensorly.backend.jax_backend',
'tensorflow.sparse': 'tensorly.contrib.sparse.backend.tensorflow_backend'}

if backend_name in _KNOWN_BACKENDS:
module = importlib.import_module(module_list[backend_name])
module = importlib.import_module('tensorly.backend.{0}_backend'.format(backend_name))
backend = getattr(module, _KNOWN_BACKENDS[backend_name])()
_LOADED_BACKENDS[backend_name] = backend
else:
Expand Down
11 changes: 11 additions & 0 deletions tensorly/backend/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,17 @@ def conj(x, *args, **kwargs):
"""
raise NotImplementedError

@staticmethod
def values(tensor):
""" Returns all the non zero values of the tensor in COO
"""
raise NotImplementedError

@staticmethod
def indices(tensor):
""" Returns all the indices of non zero values of the tensor in COO
"""
raise NotImplementedError

@staticmethod
def sort(tensor, axis, descending = False):
Expand Down
6 changes: 3 additions & 3 deletions tensorly/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
raise ImportError(message) from error

import numpy as np

import tensorly as tl
from . import Backend


Expand Down Expand Up @@ -43,11 +43,11 @@ def to_numpy(tensor):

@staticmethod
def ndim(tensor):
return len(tensor.get_shape()._dims)
return len(tl.to_numpy(tensor).shape)

@staticmethod
def shape(tensor):
return tuple(tensor.shape.as_list())
return tl.to_numpy(tensor).shape

@staticmethod
def arange(start, stop=None, step=1, dtype=np.float32):
Expand Down
2 changes: 1 addition & 1 deletion tensorly/contrib/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
arange, ones, zeros, zeros_like, eye,
clip, where, max, min, all, mean, sum,
prod, sign, abs, sqrt, norm, dot, kron,
kr, solve, qr, partial_svd)
kr, solve, qr, partial_svd, values, indices)

from .core import wrap

Expand Down
3 changes: 3 additions & 0 deletions tensorly/contrib/sparse/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def inner(*args, **kwargs):

return inner


values = dispatch_sparse(backend.values)
indices = dispatch_sparse(backend.indices)
tensor = dispatch_sparse(backend.tensor)
is_tensor = dispatch_sparse(backend.is_tensor)
context = dispatch_sparse(backend.context)
Expand Down
8 changes: 8 additions & 0 deletions tensorly/contrib/sparse/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def norm(tensor, order=2, axis=None):
else:
return np.sum(np.abs(tensor)**order, axis=axis)**(1 / order)

@staticmethod
def values(tensor):
return tensor.coords

@staticmethod
def indices(tensor):
return tensor.data

def dot(self, x, y):
if is_sparse(x) or is_sparse(y):
return sparse.dot(x, y)
Expand Down
9 changes: 8 additions & 1 deletion tensorly/contrib/sparse/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@ def is_sparse(x):


class TensorflowSparseBackend(Backend):
backend_name = 'tensorflow.sparse'
backend_name = 'tensorflow'

@staticmethod
def tensor(data, dtype=np.float32, device=None, device_id=None):
if isinstance(data, tf.sparse.SparseTensor):
return data
elif isinstance(data, tuple):
if len(data) == 3:
if isinstance(data[0], np.ndarray):
if isinstance(data[1], np.ndarray):
if len(data[0]) == len(data[1]):
return tf.sparse.SparseTensor(indices=data[0], values=data[1], dense_shape=data[2])


@staticmethod
def context(tensor):
Expand Down
2 changes: 1 addition & 1 deletion tensorly/contrib/sparse/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@ def parafac(tensor, rank, n_iter_max=100, init='svd', svd='numpy_svd',\

factors[n] = An

return KruskalTensor((weights[0], factors))
return KruskalTensor((weights[0], factors))
52 changes: 28 additions & 24 deletions tensorly/contrib/sparse/kruskal_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ...kruskal_tensor import kruskal_to_tensor, unfolding_dot_khatri_rao, kruskal_norm
from .core import wrap
from .backend import sparse_context

from numpy import zeros, multiply
import tensorly as tl
Expand All @@ -10,40 +11,43 @@


def sparse_mttkrp(tensor, factors, n, rank, dims):
values = tl.values(tensor)
indices = tl.indices(tensor)
output = zeros((dims[n], rank))
with sparse_context():
values = tl.values(tensor)
indices = tl.indices(tensor)
output = zeros((dims[n], rank))

for l in range(len(values)):
cur_index = indices[l]
prod = [values[l]] * rank # makes the value into a row
for l in range(len(values)):
cur_index = indices[l]
prod = [values[l]] * rank # makes the value into a row

for mode, cv in enumerate(cur_index): # does elementwise row multiplications
if mode != n:
for r in range(rank):
prod[r] *= factors[mode][cv][r]
for mode, cv in enumerate(cur_index): # does elementwise row multiplications
if mode != n:
for r in range(rank):
prod[r] *= factors[mode][cv][r]

for r in range(rank):
output[cur_index[n]][r] += prod[r]
for r in range(rank):
output[cur_index[n]][r] += prod[r]

return output
return output


def kruskal_sparse_inner_product(kt, st):
s = 0.0
weights, factors = kt
idxs = st.indices.numpy()
vals = st.values.numpy()
for i, index in enumerate(idxs):
st_val = vals[i]
kt_val = weights
for fac_no, dim in enumerate(index):
kt_val = multiply(factors[fac_no][dim], kt_val)
s += (sum(kt_val) * st_val)
return s
with sparse_context():
s = 0.0
weights, factors = kt
idxs = st.indices.numpy()
vals = st.values.numpy()
for i, index in enumerate(idxs):
st_val = vals[i]
kt_val = weights
for fac_no, dim in enumerate(index):
kt_val = multiply(factors[fac_no][dim], kt_val)
s += (sum(kt_val) * st_val)
return s


def kruskal_sparse_fit(kt, st):
with sparse_context():
normX = tl.norm(st)
normP = kruskal_norm(kt)
ip = kruskal_sparse_inner_product(kt, st)
Expand Down
61 changes: 31 additions & 30 deletions tensorly/contrib/sparse/tests/test_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from .... import backend as tl
from scipy import sparse
from ..kruskal_tensor import kruskal_sparse_fit
from ....backend import set_backend
from ....backend import set_backend, get_backend
from ..backend import sparse_context

import numpy as np
import tensorflow as tf

import pytest
if not tl.get_backend() == "numpy":
Expand Down Expand Up @@ -41,37 +41,38 @@ def generate_random_sp_tensor(dimensions, d=0.0001):
indices = list(set(idxs4))
values = np.random.rand(len(indices))
indices.sort()
st = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=dimensions)
indices = np.asarray(indices)
st = tl.tensor((indices, values, dimensions))
return st


def test_tf_sparse_cpd():
set_backend('tensorflow.sparse')

print("generating tensor")
shape = (100, 100, 1000)
density = 0.001
rank = 20

st = generate_random_sp_tensor(shape, d=density)
print("performing decomposition")
cpd = parafac(st, rank, n_iter_max=50, verbose=True)
print("testing fit")
fit_st_rebuilt = kruskal_sparse_fit(cpd, st)

result_text = '''
+--------------------------------------------
| shape: {}
| rank: {}
| iterations: 100
| density: {}
| actual density: {}
| number of non-zeros: {}
|--------------------------------------------
| fit: {}
+--------------------------------------------
'''.format(shape, rank, density, (st.values.shape[0] / np.prod(shape)), st.values.shape[0], fit_st_rebuilt)

print(result_text)
with sparse_context():
if get_backend() == 'tensorflow':
print("generating tensor")
shape = (100, 100, 100)
density = 0.001
rank = 20

st = generate_random_sp_tensor(shape, d=density)
print("performing decomposition")
cpd = parafac(st, rank, n_iter_max=50, verbose=True)
print("testing fit")
fit_st_rebuilt = kruskal_sparse_fit(cpd, st)

result_text = '''
+--------------------------------------------
| shape: {}
| rank: {}
| iterations: 100
| density: {}
| actual density: {}
| number of non-zeros: {}
|--------------------------------------------
| fit: {}
+--------------------------------------------
'''.format(shape, rank, density, (st.values.shape[0] / np.prod(shape)), st.values.shape[0], fit_st_rebuilt)

print(result_text)


0 comments on commit 99c6df2

Please sign in to comment.