In [1]:
import torch
import torch.nn as nn
import string

class CPTensor(nn.Module):
    def __init__(self, shape, rank):
        """
        Args:
            shape (tuple of ints): (d1, d2, ..., dN)
            rank  (int): number of CP components M
        """
        super().__init__()
        self.shape = shape
        self.rank = rank
        # factor matrices U^(k): shape (d_k, rank)
        self.factors = nn.ParameterList([
            nn.Parameter(torch.randn(d, rank))
            for d in shape
        ])
        # component weights λ of shape (rank,)
        self.weights = nn.Parameter(torch.randn(rank))

    def reconstruct(self):
        # Build einsum equation, e.g. "ar,br,cr->abc" for N=3
        letters = string.ascii_lowercase
        in_subs = [f"{letters[i]}r" for i in range(len(self.shape))]
        out_subs = "".join(letters[:len(self.shape)])
        eq = ",".join(in_subs) + "->" + out_subs

        # apply weights to first factor
        F = [self.factors[0] * self.weights.view(1, -1)] + list(self.factors[1:])
        # einsum to reconstruct full tensor
        return torch.einsum(eq, *F)

# Example usage:
shape = (4, 5, 6)   # e.g. a 3-mode tensor
rank  = 10          # desired CP rank M
cp    = CPTensor(shape, rank)

# factors cp.factors, weights cp.weights are your variational parameters
T = cp.reconstruct()   # full tensor of shape (4,5,6)
print(T.shape)         # => torch.Size([4, 5, 6])


torch.Size([4, 5, 6])
