diff --git a/tltorch/factorized_tensors/tensorized_matrices.py b/tltorch/factorized_tensors/tensorized_matrices.py index 56f95e6..81336df 100644 --- a/tltorch/factorized_tensors/tensorized_matrices.py +++ b/tltorch/factorized_tensors/tensorized_matrices.py @@ -468,3 +468,322 @@ def init_from_tensor(self, tensor, **kwargs): self.rank = tuple([f.shape[0] for f in factors] + [1]) return self + + + + +def validate_block_cp_rank(tensor_shape, rank='same', rounding='round'): + """Returns the rank of a BlockCP Decomposition + + Parameters + ---------- + tensor_shape : tupe + shape of the tensor to decompose + rank : {'same', float, int}, default is same + way to determine the rank, by default 'same' + if 'same': rank is computed to keep the number of parameters (at most) the same + if float, computes a rank so as to keep rank percent of the original number of parameters + if int, just returns rank + rounding = {'round', 'floor', 'ceil'} + + Returns + ------- + rank : int + rank of the decomposition + """ + # print("TENSOR SHAPE:", tensor_shape) + + if is_tensorized_shape(tensor_shape): + tensor_shape = tensorized_shape_to_shape(tensor_shape) + if rounding == 'ceil': + rounding_fun = np.ceil + elif rounding == 'floor': + rounding_fun = np.floor + elif rounding == 'round': + rounding_fun = np.round + else: + raise ValueError(f'Rounding should be of round, floor or ceil, but got {rounding}') + + if rank == 'same': + rank = float(1) + + if isinstance(rank, float): + rank = int(rounding_fun(np.prod(tensor_shape)*rank/np.sum(tensor_shape))) + return rank + +class BlockCP(TensorizedTensor, name='BlockCP'): + """BlockCP Factorization + + Parameters + ---------- + weights + factors + shape + rank + """ + + def __init__(self, weights, factors, tensorized_shape=None, rank=None): + super().__init__() + self.shape = tensorized_shape_to_shape(tensorized_shape) + self.tensorized_shape = tensorized_shape + self.rank = rank + self.order = len(self.shape) + self.weights = weights + self.factors = FactorList(factors) + + @classmethod + def new(cls, tensorized_shape, rank, device=None, dtype=None, **kwargs): + + if all(isinstance(s,int) for s in tensorized_shape): + warnings.warn(f'Given a "flat" shape {tensorized_shape}, ' + ' This will be considered as the shape of a tensorized vector. ' + ' If you just want a 1D tensor, used CP') + ndim = 1 + factor_shapes = [tensorized_shape] + tensorized_shape = (tensorized_shape,) + + else: + #* + ndim = max([1 if isinstance(s, int) else len(s) for s in tensorized_shape]) + + #* + factor_shapes = [(s, )*ndim if isinstance(s, int) else s for s in tensorized_shape] + + # changed from shape to tensorized_shape + + rank = validate_block_cp_rank(tensorized_shape, rank) + # # Register the parameters + weights = nn.Parameter(torch.empty(rank, device = device, dtype = dtype)) + + ranks = [rank] * len(tensorized_shape[1]) + factor_shapes = factor_shapes + [ranks] + factor_shapes = list(zip(*factor_shapes)) + factors = [nn.Parameter(torch.empty(s)) for s in factor_shapes] + + return cls(weights, factors, tensorized_shape, rank=rank) + + + @classmethod + def from_tensor(cls, tensor, tensorized_shape, rank='same', **kwargs): + + """ + Note: Not Implemented because we need a version of parafac + for order-n factors, currently have for order 2 (shape, rank) + """ + + raise NotImplementedError("Not Implemented because we need a version of " + + "parafac for order-n factors, currently have for order 2 (shape, rank) ") + + shape = tensor.shape + rank = bct.validate_block_cp_rank(shape, rank) + dtype = tensor.dtype + + with torch.no_grad(): + weights, factors = parafac(tensor.to(torch.float64), rank, **kwargs) + + return cls(nn.Parameter(weights.to(dtype)), [nn.Parameter(f.to(dtype)) for f in factors]) + + + def init_from_tensor(self, tensor, l2_reg=1e-5, **kwargs): + + """ + Note: Not Implemented because we need a version of parafac + for order-n factors, currently have for order 2 (shape, rank) + """ + + raise NotImplementedError("Not Implemented because we need a version of " + + "parafac for order-n factors, currently have for order 2 (shape, rank) ") + + with torch.no_grad(): + weights, factors = parafac(tensor, self.rank, l2_reg=l2_reg, **kwargs) + + self.weights = nn.Parameter(weights) + self.factors = FactorList([nn.Parameter(f) for f in factors]) + return self + + @property + def decomposition(self): + return self.weights, self.factors + + def to_tensor(self): + factors = self.factors + weights = self.weights + ndim = len(factors) + + rank_ind = ord('a') + + if isinstance(self.tensorized_shape[0], int): + batched = True + batch_size = self.tensorized_shape[0] + else: + batched = False + + if batched: + batch_ind = ord('b') + idx = ord('c') # Current character we can use for contraction + eq_terms = {} + for i in range(ndim): + eq_terms[i] = chr(rank_ind) + for i in range(ndim): + chr_idx = chr(idx) + idx +=1 + chr_idx2 = chr(idx) + eq_terms[i] = chr_idx + chr_idx2 + chr(rank_ind) + if batched: + eq_terms[i] = chr(batch_ind) + eq_terms[i] + idx +=1 + eq_out = '' + if batched: + eq_out += chr(batch_ind) + for j in range(1, 3): + if batched: + eq_out += ''.join([eq_terms[i][j] for i in range(len(eq_terms.keys()))]) + else: + eq_out += ''.join([eq_terms[i][j-1] for i in range(len(eq_terms.keys()))]) + eq_in = ','.join([*eq_terms.values()]) + eq = eq_in + ',' + chr(rank_ind) + '->' + eq_out + + res = tl.einsum(eq, *self.factors, weights) + return tl.reshape(res, self.tensor_shape) + + + + def normal_(self, mean=0, std=1): + super().normal_(mean, std) + std_factors = (std/math.sqrt(self.rank))**(1/self.order) + + with torch.no_grad(): + self.weights.fill_(1) + for factor in self.factors: + factor.data.normal_(0, std_factors) + return self + + + def __torch_function__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + args = [t.to_matrix() if hasattr(t, 'to_matrix') else t for t in args] + return func(*args, **kwargs) + + + def __getitem__(self, indices): + factors = self.factors + weights = self.weights + if not isinstance(indices, Iterable): + indices = [indices] + + if len(indices) < self.ndim: + indices = list(indices) + (indices.extend([slice(None)]*(self.ndim - len(indices)))) ### e.g. if you index [0], goes to [0, slice(None, None, None)] + + elif len(indices) > self.ndim: + indices = [indices] # We're only indexing the first dimension + + if isinstance(self.tensorized_shape[0], int): + batched = True + batch_size = self.tensorized_shape[0] + else: + batched = False + + output_shape = [] + ndim = len(self.factors) + + contract_factors = False # If True, the result is dense, we need to form the full result + rank_ind = ord('a') # BlockCP rank character + if batched: + batch_ind = ord('b') # batch character + + idx = ord('c') # Current character we can use for contraction + eq_terms = {} + for i in range(ndim): + eq_terms[i] = chr(rank_ind) + eq_out = '' + + pad = () + add_pad = False # whether to increment the padding post indexing + rebatched = False # Add a dim for batch size if only one item has been selected + + m = 0 + for (index, shape) in zip(indices, self.tensorized_shape): + if isinstance(shape, int): + # We are indexing a "batched" mode, not a tensorized one + if not isinstance(index, (np.integer, int)): + if isinstance(index, slice): + index = list(range(*index.indices(shape))) + batch_size = len(index) + + output_shape.append(len(index)) + add_pad = True + + # else: we've essentially removed a mode of each factor + else: + batch_size = 1 + index = [index]*ndim + else: + # We are indexing a tensorized mode + + if index == slice(None) or index == (): + # Keeping all indices (:) + output_shape.append(shape) + for i in range(ndim): + chr_idx = chr(idx) + eq_terms[i] = chr_idx + eq_terms[i] #+ chr(rank_ind) + idx +=1 + eq_out += ''.join([eq_terms[i][0] for i in range(len(eq_terms.keys()))]) + add_pad = True + index = [index]*ndim + else: + + ## index is an integer + contract_factors = True + + if isinstance(index, slice): + # Since we've already filtered out :, this is a partial slice + # Convert into list + max_index = math.prod(shape) + index = list(range(*index.indices(max_index))) + add_pad = True + + if isinstance(index, Iterable): + output_shape.append(len(index)) + for i in range(ndim): + chr_idx = chr(idx) + eq_terms[i] = chr_idx + eq_terms[i] + eq_out += chr(idx) + idx += 1 + add_pad = True + + index = np.unravel_index(index, shape) + + factors = [ff[pad + (idx,)] for (ff, idx) in zip(factors, index)]# + factors[indexed_ndim:] + if add_pad: + pad += (slice(None), ) + add_pad = False + + if contract_factors: + if batched and batch_size != 1: # only append a batch dimension if that batch_size > 1, batch_size = 1 has no batch dimension + for i in range(ndim): + eq_terms[i] = chr(batch_ind) + eq_terms[i] + eq_out = chr(batch_ind) + eq_out + eq_in = ','.join([*eq_terms.values()]) + eq = eq_in + ',' + chr(rank_ind) + '->' + eq_out + for i in range(len(factors)): + if all( isinstance(ind, slice) for ind in indices): + factors[i] = factors[i].transpose(1,0) + res = tl.einsum(eq, *factors, weights) + + if not batched: + if any(isinstance(x, int) for x in indices ) and not all(isinstance(x, int) for x in indices ): + res = res.flatten() + else: + if any(isinstance(x, int) for x in indices[1:] ) and not all(isinstance(x, int) for x in indices[1:] ): + res = res.reshape(batch_size, -1).squeeze(0) + + # ensure correct shape for when a single nontrivial slices is taken + output_shape = [s if isinstance(s, int) else np.prod(s) for s in output_shape] + res= res.reshape(output_shape) + + return res + else: + return self.__class__(self.weights, factors, output_shape, self.rank) diff --git a/tltorch/factorized_tensors/tests/test_factorizations.py b/tltorch/factorized_tensors/tests/test_factorizations.py index 16e5c39..b2619e7 100644 --- a/tltorch/factorized_tensors/tests/test_factorizations.py +++ b/tltorch/factorized_tensors/tests/test_factorizations.py @@ -4,7 +4,7 @@ import torch from torch import testing -from tltorch.factorized_tensors.tensorized_matrices import CPTensorized, TuckerTensorized, BlockTT +from tltorch.factorized_tensors.tensorized_matrices import CPTensorized, TuckerTensorized, BlockTT, BlockCP from tltorch.factorized_tensors.core import TensorizedTensor from ..factorized_tensors import FactorizedTensor, CPTensor, TuckerTensor, TTTensor @@ -41,7 +41,7 @@ def test_FactorizedTensor(factorization): testing.assert_allclose(reconstruction[idx], res) -@pytest.mark.parametrize('factorization', ['BlockTT', 'CP']) #['CP', 'Tucker', 'BlockTT']) +@pytest.mark.parametrize('factorization', ['BlockTT', 'CP', 'BlockCP']) #['CP', 'Tucker', 'BlockTT']) @pytest.mark.parametrize('batch_size', [(), (4,)]) def test_TensorizedMatrix(factorization, batch_size): """Test for TensorizedMatrix""" @@ -57,7 +57,7 @@ def test_TensorizedMatrix(factorization, batch_size): # Check that the correct type of factorized tensor is created assert fact_tensor._name.lower() == factorization.lower() - mapping = dict(CP=CPTensorized, Tucker=TuckerTensorized, BlockTT=BlockTT) + mapping = dict(CP=CPTensorized, Tucker=TuckerTensorized, BlockTT=BlockTT, BlockCP=BlockCP) assert isinstance(fact_tensor, mapping[factorization]) # Check that the matrix has the right shape