Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds BlockCP Tensorized Matrix and tests #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 319 additions & 0 deletions tltorch/factorized_tensors/tensorized_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,322 @@ def init_from_tensor(self, tensor, **kwargs):
self.rank = tuple([f.shape[0] for f in factors] + [1])

return self




def validate_block_cp_rank(tensor_shape, rank='same', rounding='round'):
"""Returns the rank of a BlockCP Decomposition

Parameters
----------
tensor_shape : tupe
shape of the tensor to decompose
rank : {'same', float, int}, default is same
way to determine the rank, by default 'same'
if 'same': rank is computed to keep the number of parameters (at most) the same
if float, computes a rank so as to keep rank percent of the original number of parameters
if int, just returns rank
rounding = {'round', 'floor', 'ceil'}

Returns
-------
rank : int
rank of the decomposition
"""
# print("TENSOR SHAPE:", tensor_shape)

if is_tensorized_shape(tensor_shape):
tensor_shape = tensorized_shape_to_shape(tensor_shape)
if rounding == 'ceil':
rounding_fun = np.ceil
elif rounding == 'floor':
rounding_fun = np.floor
elif rounding == 'round':
rounding_fun = np.round
else:
raise ValueError(f'Rounding should be of round, floor or ceil, but got {rounding}')

if rank == 'same':
rank = float(1)

if isinstance(rank, float):
rank = int(rounding_fun(np.prod(tensor_shape)*rank/np.sum(tensor_shape)))
return rank

class BlockCP(TensorizedTensor, name='BlockCP'):
"""BlockCP Factorization

Parameters
----------
weights
factors
shape
rank
"""

def __init__(self, weights, factors, tensorized_shape=None, rank=None):
super().__init__()
self.shape = tensorized_shape_to_shape(tensorized_shape)
self.tensorized_shape = tensorized_shape
self.rank = rank
self.order = len(self.shape)
self.weights = weights
self.factors = FactorList(factors)

@classmethod
def new(cls, tensorized_shape, rank, device=None, dtype=None, **kwargs):

if all(isinstance(s,int) for s in tensorized_shape):
warnings.warn(f'Given a "flat" shape {tensorized_shape}, '
' This will be considered as the shape of a tensorized vector. '
' If you just want a 1D tensor, used CP')
ndim = 1
factor_shapes = [tensorized_shape]
tensorized_shape = (tensorized_shape,)

else:
#*
ndim = max([1 if isinstance(s, int) else len(s) for s in tensorized_shape])

#*
factor_shapes = [(s, )*ndim if isinstance(s, int) else s for s in tensorized_shape]

# changed from shape to tensorized_shape

rank = validate_block_cp_rank(tensorized_shape, rank)
# # Register the parameters
weights = nn.Parameter(torch.empty(rank, device = device, dtype = dtype))

ranks = [rank] * len(tensorized_shape[1])
factor_shapes = factor_shapes + [ranks]
factor_shapes = list(zip(*factor_shapes))
factors = [nn.Parameter(torch.empty(s)) for s in factor_shapes]

return cls(weights, factors, tensorized_shape, rank=rank)


@classmethod
def from_tensor(cls, tensor, tensorized_shape, rank='same', **kwargs):

"""
Note: Not Implemented because we need a version of parafac
for order-n factors, currently have for order 2 (shape, rank)
"""

raise NotImplementedError("Not Implemented because we need a version of " +
"parafac for order-n factors, currently have for order 2 (shape, rank) ")

shape = tensor.shape
rank = bct.validate_block_cp_rank(shape, rank)
dtype = tensor.dtype

with torch.no_grad():
weights, factors = parafac(tensor.to(torch.float64), rank, **kwargs)

return cls(nn.Parameter(weights.to(dtype)), [nn.Parameter(f.to(dtype)) for f in factors])


def init_from_tensor(self, tensor, l2_reg=1e-5, **kwargs):

"""
Note: Not Implemented because we need a version of parafac
for order-n factors, currently have for order 2 (shape, rank)
"""

raise NotImplementedError("Not Implemented because we need a version of " +
"parafac for order-n factors, currently have for order 2 (shape, rank) ")

with torch.no_grad():
weights, factors = parafac(tensor, self.rank, l2_reg=l2_reg, **kwargs)

self.weights = nn.Parameter(weights)
self.factors = FactorList([nn.Parameter(f) for f in factors])
return self

@property
def decomposition(self):
return self.weights, self.factors

def to_tensor(self):
factors = self.factors
weights = self.weights
ndim = len(factors)

rank_ind = ord('a')

if isinstance(self.tensorized_shape[0], int):
batched = True
batch_size = self.tensorized_shape[0]
else:
batched = False

if batched:
batch_ind = ord('b')
idx = ord('c') # Current character we can use for contraction
eq_terms = {}
for i in range(ndim):
eq_terms[i] = chr(rank_ind)
for i in range(ndim):
chr_idx = chr(idx)
idx +=1
chr_idx2 = chr(idx)
eq_terms[i] = chr_idx + chr_idx2 + chr(rank_ind)
if batched:
eq_terms[i] = chr(batch_ind) + eq_terms[i]
idx +=1
eq_out = ''
if batched:
eq_out += chr(batch_ind)
for j in range(1, 3):
if batched:
eq_out += ''.join([eq_terms[i][j] for i in range(len(eq_terms.keys()))])
else:
eq_out += ''.join([eq_terms[i][j-1] for i in range(len(eq_terms.keys()))])
eq_in = ','.join([*eq_terms.values()])
eq = eq_in + ',' + chr(rank_ind) + '->' + eq_out

res = tl.einsum(eq, *self.factors, weights)
return tl.reshape(res, self.tensor_shape)



def normal_(self, mean=0, std=1):
super().normal_(mean, std)
std_factors = (std/math.sqrt(self.rank))**(1/self.order)

with torch.no_grad():
self.weights.fill_(1)
for factor in self.factors:
factor.data.normal_(0, std_factors)
return self


def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

args = [t.to_matrix() if hasattr(t, 'to_matrix') else t for t in args]
return func(*args, **kwargs)


def __getitem__(self, indices):
factors = self.factors
weights = self.weights
if not isinstance(indices, Iterable):
indices = [indices]

if len(indices) < self.ndim:
indices = list(indices)
(indices.extend([slice(None)]*(self.ndim - len(indices)))) ### e.g. if you index [0], goes to [0, slice(None, None, None)]

elif len(indices) > self.ndim:
indices = [indices] # We're only indexing the first dimension

if isinstance(self.tensorized_shape[0], int):
batched = True
batch_size = self.tensorized_shape[0]
else:
batched = False

output_shape = []
ndim = len(self.factors)

contract_factors = False # If True, the result is dense, we need to form the full result
rank_ind = ord('a') # BlockCP rank character
if batched:
batch_ind = ord('b') # batch character

idx = ord('c') # Current character we can use for contraction
eq_terms = {}
for i in range(ndim):
eq_terms[i] = chr(rank_ind)
eq_out = ''

pad = ()
add_pad = False # whether to increment the padding post indexing
rebatched = False # Add a dim for batch size if only one item has been selected

m = 0
for (index, shape) in zip(indices, self.tensorized_shape):
if isinstance(shape, int):
# We are indexing a "batched" mode, not a tensorized one
if not isinstance(index, (np.integer, int)):
if isinstance(index, slice):
index = list(range(*index.indices(shape)))
batch_size = len(index)

output_shape.append(len(index))
add_pad = True

# else: we've essentially removed a mode of each factor
else:
batch_size = 1
index = [index]*ndim
else:
# We are indexing a tensorized mode

if index == slice(None) or index == ():
# Keeping all indices (:)
output_shape.append(shape)
for i in range(ndim):
chr_idx = chr(idx)
eq_terms[i] = chr_idx + eq_terms[i] #+ chr(rank_ind)
idx +=1
eq_out += ''.join([eq_terms[i][0] for i in range(len(eq_terms.keys()))])
add_pad = True
index = [index]*ndim
else:

## index is an integer
contract_factors = True

if isinstance(index, slice):
# Since we've already filtered out :, this is a partial slice
# Convert into list
max_index = math.prod(shape)
index = list(range(*index.indices(max_index)))
add_pad = True

if isinstance(index, Iterable):
output_shape.append(len(index))
for i in range(ndim):
chr_idx = chr(idx)
eq_terms[i] = chr_idx + eq_terms[i]
eq_out += chr(idx)
idx += 1
add_pad = True

index = np.unravel_index(index, shape)

factors = [ff[pad + (idx,)] for (ff, idx) in zip(factors, index)]# + factors[indexed_ndim:]
if add_pad:
pad += (slice(None), )
add_pad = False

if contract_factors:
if batched and batch_size != 1: # only append a batch dimension if that batch_size > 1, batch_size = 1 has no batch dimension
for i in range(ndim):
eq_terms[i] = chr(batch_ind) + eq_terms[i]
eq_out = chr(batch_ind) + eq_out
eq_in = ','.join([*eq_terms.values()])
eq = eq_in + ',' + chr(rank_ind) + '->' + eq_out
for i in range(len(factors)):
if all( isinstance(ind, slice) for ind in indices):
factors[i] = factors[i].transpose(1,0)
res = tl.einsum(eq, *factors, weights)

if not batched:
if any(isinstance(x, int) for x in indices ) and not all(isinstance(x, int) for x in indices ):
res = res.flatten()
else:
if any(isinstance(x, int) for x in indices[1:] ) and not all(isinstance(x, int) for x in indices[1:] ):
res = res.reshape(batch_size, -1).squeeze(0)

# ensure correct shape for when a single nontrivial slices is taken
output_shape = [s if isinstance(s, int) else np.prod(s) for s in output_shape]
res= res.reshape(output_shape)

return res
else:
return self.__class__(self.weights, factors, output_shape, self.rank)
6 changes: 3 additions & 3 deletions tltorch/factorized_tensors/tests/test_factorizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch import testing

from tltorch.factorized_tensors.tensorized_matrices import CPTensorized, TuckerTensorized, BlockTT
from tltorch.factorized_tensors.tensorized_matrices import CPTensorized, TuckerTensorized, BlockTT, BlockCP
from tltorch.factorized_tensors.core import TensorizedTensor

from ..factorized_tensors import FactorizedTensor, CPTensor, TuckerTensor, TTTensor
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_FactorizedTensor(factorization):
testing.assert_allclose(reconstruction[idx], res)


@pytest.mark.parametrize('factorization', ['BlockTT', 'CP']) #['CP', 'Tucker', 'BlockTT'])
@pytest.mark.parametrize('factorization', ['BlockTT', 'CP', 'BlockCP']) #['CP', 'Tucker', 'BlockTT'])
@pytest.mark.parametrize('batch_size', [(), (4,)])
def test_TensorizedMatrix(factorization, batch_size):
"""Test for TensorizedMatrix"""
Expand All @@ -57,7 +57,7 @@ def test_TensorizedMatrix(factorization, batch_size):

# Check that the correct type of factorized tensor is created
assert fact_tensor._name.lower() == factorization.lower()
mapping = dict(CP=CPTensorized, Tucker=TuckerTensorized, BlockTT=BlockTT)
mapping = dict(CP=CPTensorized, Tucker=TuckerTensorized, BlockTT=BlockTT, BlockCP=BlockCP)
assert isinstance(fact_tensor, mapping[factorization])

# Check that the matrix has the right shape
Expand Down