In [1]:
using Random, LinearAlgebra, TensorToolbox, Combinatorics, TensorOperations

In [62]:
function makeRankedTensor(L::Vector, A::Array, d::Int)
    A_ = kronMat(A, d)
    return reshape(sum(A_ .* L', dims=2), tuple(repeat([n], d)...))
end;

function randomTensor(n::Int, d::Int; real::Bool=false)
    T_ = randn(tuple(repeat([n], d)...)...)
    if !real
        Q = im*randn(tuple(repeat([n], d)...)...)
        T_ = T_+Q
    end;
    T = copy(T_)
    perms = permutations(1:d)
    for perm in perms
        if perm != 1:d 
            T = T + permutedims(T_, perm)
        end;
    end;
    return T/factorial(d)
end;

function complexGaussian(m, n)
    A = randn(m, n)
    B = randn(m, n)
    return A + im*B
end;

function randomRankedTensor(n, d, r; real=false)
    if !real
        A = complexGaussian(n, r)
    else 
        A = randn(n, r)
    end;
    L = ones(r)
    A_ = copy(A)
    A1 = A_[1, :]
    L_ = zeros(eltype(A_), r)
    for i=1:r
        A_[:, i] ./= A1[i]
        L_[i] = A1[i]^d
    end;
    return makeRankedTensor(L, A, d), A_, L_
end;

function contract(T::Array, V::Array)
    d1 = length(size(T))
    d2 = length(size(V))
    return ncon((T, V), (vcat(collect(1:d2), -collect(1:d1-d2)), collect(1:d2)))
end;

function jennrich(T::Array; tol=1e-10)
    Tsize = size(T)
    n = Tsize[1]
    d = length(Tsize)

    # draw random vectors or matrices for contraction
    B = randn(repeat([n], Int(iseven(d))+1)...);
    C = randn(repeat([n], Int(iseven(d))+1)...);

    # contract to tensor with largest even order strictly less than the order of d
    # d odd -> order n^{d-1} tensor
    # d even -> order n^{d-2} tensor 
    T1 = contract(T, B);
    T2 = contract(T, C);

    delt = Int(floor((d-1)/2))

    # flatten to matrices
    T1flat = reshape(T1, (n^delt, n^delt));
    T2flat = reshape(T2, (n^delt, n^delt));

    # obtain the factors
    # T real -> true factors are obtained, possibly up to sign
    # T complex -> true factors are obtained, but up to a complex scalar factor of norm 1 
    w, Ahat_ = eigen(T1flat*pinv(T2flat))
    Ahat_ = Ahat_[:, abs.(w).>tol]
    r = last(size(Ahat_))
    Ahat = zeros(eltype(Ahat_), (n, r))
    for i=1:r
        a = Ahat_[:, i]
        U, s, _ = svd(reshape(a, (n, n^(delt-1))))
        Ahat[:, i] = U[:, abs.(s).>tol]
    end;
    Ahat = dehomogenize!(Ahat)

    # obtain L and deflate
    low = Int(floor(d/2))
    high = Int(ceil(d/2))
    Tflat = reshape(T, (n^low, n^high))
    Alow = kronMat(Ahat, low)
    Ahigh = kronMat(Ahat, high)

    Lhat = diag(pinv(Alow)*Tflat*pinv(transpose(Ahigh)))

    return Ahat, Lhat
end;

function dehomogenize!(A)
    for col in eachcol(A)
        col ./= col[1]
    end;
    return A
end;
dehomogenize(A) = dehomogenize!(copy(A));

function normcol!(A)
    for col in eachcol(A)
        col ./= norm(col)
    end;
    return A
end;
normcol(A) = normcol!(copy(A));

function kronMat(A::Matrix, d)
    if d == 1
        B = A
    else
        n, r = size(A)
        B = zeros(eltype(A), (n^d, r))
        for i=1:r 
            B[:, i] = kron(ntuple(x->A[:, i], d)...)
        end;
    end;
    return B
end;


In [63]:
n = 5;
d = 3;

In [67]:
r = 3;
T, A, L = randomRankedTensor(n, d, r; real=false);
Ahat, Lhat = jennrich(T);

In [68]:
A

5×3 Matrix{ComplexF64}:
       1.0-5.30388e-17im       1.0+0.0im                1.0-6.64589e-17im
 -0.563082+1.326im        0.292087+0.5028im       -0.148834-0.326267im
  0.684971-0.0811024im    0.454977-0.0414948im   -0.0294396-0.182617im
  0.109598+1.33521im      0.250857+1.27498im       -0.17348-0.467461im
   1.36524+0.648216im     0.260177-0.00382983im   -0.589395-0.18207im

In [69]:
Ahat

5×3 Matrix{ComplexF64}:
      1.0+0.0im               1.0+0.0im               1.0+0.0im
 0.292087+0.5028im      -0.563082+1.326im       -0.148834-0.326267im
 0.454977-0.0414948im    0.684971-0.0811024im  -0.0294396-0.182617im
 0.250857+1.27498im      0.109598+1.33521im      -0.17348-0.467461im
 0.260177-0.00382983im    1.36524+0.648216im    -0.589395-0.18207im

In [70]:
L

3-element Vector{ComplexF64}:
 -0.16213302913308117 + 0.6098398583343104im
  -2.8417767451062517 + 2.908695200921926im
   2.3946027424387464 + 0.7278595144668726im

In [71]:
Lhat

3-element Vector{ComplexF64}:
  -2.841776745106246 + 2.9086952009219216im
 -0.1621330291330844 + 0.6098398583343101im
  2.3946027424387424 + 0.7278595144668736im