In [None]:
from plt_utils import *
import numpy as np
import matplotlib.pyplot as plt

import torch
import tensorly as tl
from tensorly import unfold
from tensorly.decomposition import parafac
from tensorly.cp_tensor import cp_to_tensor
from tensorly.tenalg import khatri_rao

tl.set_backend('pytorch')

def generate_probability_matrix(dims, K):
    matrix = torch.rand(dims, K)
    matrix = matrix / matrix.sum(dim=0, keepdim=True)
    return matrix

def generate_probability_vector(D):
    vector = torch.rand(D)
    vector = vector / vector.sum()
    return vector

def generate_tensor(D, K, d):
    factors = [generate_probability_matrix(d, K) for _ in range(D)]
    weights = generate_probability_vector(K)
    return cp_to_tensor((weights, factors)), factors, weights

def generate_tensor_sampled(P, N):
    P_flat = P.flatten()
    indices = torch.multinomial(P_flat, N, replacement=True)
    counts = torch.bincount(indices, minlength=P_flat.size(0))
    P_count = counts.view(P.shape)
    P_hat = P_count / N
    return P_hat, P_count

def exponential_iterator(start, end, factor):
    value = start
    while value <= end:
        yield value
        value *= factor

In [None]:
D = 3
K = 4
d = 10
N = 1000

P, factors, weights = generate_tensor(D,K,d)
P_hat, P_count = generate_tensor_sampled(P, N)
M = torch.where(P_count!=0, torch.tensor(1.0), torch.tensor(0.0))

In [None]:
a = 2
b = a
eps_abs = 1e-6
eps_rel = 0
max_itr = 5000

# Initialize variables
_, S, lmbda = generate_tensor(D,K,d)
u = 0.
w = torch.zeros_like(lmbda)
U = [torch.zeros(K) for mode in range(D)]
W = [torch.zeros_like(S[mode]) for mode in range(D)]

w_prev = w.clone()
W_prev = [W[mode].clone() for mode in range(D)]

err_hat = torch.tensor([]) # also objective
err_tru = torch.tensor([])

res_pri = torch.tensor([])
res_dua = torch.tensor([])
eps_pri = torch.tensor([])
eps_dua = torch.tensor([])

for itr in range(max_itr):
    # lmbda update
    lmbda_prev = lmbda.clone()
    p = P_hat.reshape(-1)
    B = khatri_rao(S)
    M1 = torch.inverse(B.T@B + a*torch.ones((K,K)) + b*torch.eye(K))
    m2 = B.T@p + a*(1-u)*torch.ones(K) + b*(torch.maximum(lmbda_prev+w_prev,torch.zeros(K)) - w)
    lmbda = M1@m2

    u = u + a*(lmbda.sum()-1)
    w_prev = w.clone()
    w = w + b*(lmbda - torch.maximum(lmbda+w,torch.zeros(K)))

    # S update
    S_prev = [S[mode].clone() for mode in range(D)]
    for mode in range(D):
        P_unfold = unfold(P_hat, mode=mode).T
        S_kr = khatri_rao( S, lmbda, skip_matrix=mode )

        p = P_unfold.T.reshape(-1)
        B = torch.kron( torch.eye(d), S_kr )
        E = torch.kron( torch.ones((1,d)), torch.eye(K) )

        M1 = torch.inverse( B.T@B + a*E.T@E + b*torch.eye(d*K) )
        m2 = B.T@p + a*E.T@(1-U[mode]) + b*(torch.maximum(S_prev[mode]+W_prev[mode],torch.zeros_like(S[mode])) - W[mode]).reshape(-1)
        s = M1@m2
        S[mode] = s.reshape(S[mode].shape)

        U[mode] = U[mode] + a*(S[mode].sum(0) - 1)
        W[mode] = W[mode] + b*(S[mode] - torch.maximum(S[mode]+W[mode],torch.zeros_like(S[mode])))

    rp_u = torch.abs( lmbda.sum()-1 )
    rp_w = torch.norm( lmbda - torch.maximum(lmbda+w,torch.zeros(K)) )
    rp_U = sum([torch.norm( S[mode].sum(0)-1, 'fro' ) for mode in range(D)])
    rp_W = sum([torch.norm( S[mode] - torch.maximum(S[mode]+W[mode],torch.zeros_like(S[mode])), 'fro' ) for mode in range(D)])
    rp = torch.tensor([ rp_u + rp_w + rp_U + rp_W ]).float()
    
    rd_u = 0.
    rd_w = b*torch.norm( torch.maximum( lmbda+w,torch.zeros(K) ) - torch.maximum( lmbda_prev+w_prev,torch.zeros(K) ))
    rd_U = 0.
    rd_W = sum([b*torch.norm( torch.maximum(S[mode]+W[mode],torch.zeros_like(S[mode])) - torch.maximum(S_prev[mode]+W_prev[mode],torch.zeros_like(S[mode])) ) for mode in range(D)])
    rd = torch.tensor([ rd_u + rd_w + rd_U + rd_W ]).float()

    ep_u = eps_abs*np.sqrt(lmbda.numel()) + eps_rel*torch.maximum( torch.abs(lmbda.sum()),torch.tensor(1) )
    ep_w = eps_abs*np.sqrt(lmbda.numel()) + eps_rel*torch.maximum( torch.norm(lmbda), torch.norm( torch.maximum(lmbda+w,torch.zeros(K)) ) )
    ep_U = sum([eps_abs*np.sqrt(S[mode].numel()) + eps_rel*torch.maximum( torch.norm(S[mode].sum(0),'fro'), torch.norm(torch.ones(K)) ) for mode in range(D)])
    ep_W = sum([eps_abs*np.sqrt(S[mode].numel()) + eps_rel*torch.maximum( torch.norm(S[mode],'fro'), torch.norm(torch.maximum(S[mode]+W[mode],torch.zeros_like(S[mode])),'fro') ) for mode in range(D)])
    ep = torch.tensor([ep_u + ep_w + ep_U + ep_W]).float()

    ed_u = eps_abs * 1 + eps_rel * torch.norm( a*u*torch.ones(K) )
    ed_w = eps_abs * np.sqrt(lmbda.numel()) + eps_rel * torch.norm( b*w )
    ed_U = sum([eps_abs * np.sqrt(K) + eps_rel * torch.norm( a*torch.ones(K)[:,None]*U[mode][None] ) for mode in range(D)])
    ed_W = sum([eps_abs * np.sqrt(S[mode].numel()) + eps_rel * torch.norm( b*W[mode] ) for mode in range(D)])
    ed = torch.tensor([ed_u + ed_w + ed_U + ed_W]).float()

    res_pri = torch.cat(( res_pri, rp ), 0)
    res_dua = torch.cat(( res_dua, rd ), 0)
    eps_pri = torch.cat(( eps_pri, ep ), 0)
    eps_dua = torch.cat(( eps_dua, ed ), 0)

    P_est = cp_to_tensor((lmbda,S))
    eh = torch.tensor([torch.norm( P_est-P_hat,'fro' )]).float()
    et = torch.tensor([torch.norm( P_est-P,'fro' )]).float()
    err_hat = torch.cat(( err_hat, eh ), 0)
    err_tru = torch.cat(( err_tru, et ), 0)

    if (itr+1) % 300 - 1 == 0:
        print(f'Iter. {itr} | Err: {float(eh):.3f}')

    if (res_pri<=eps_pri).sum()>0 and (res_dua<=eps_dua).sum()>0:
        break

if (itr<max_itr-1):
    print('Terminated early')

In [None]:
P_unfold = P.reshape(25,40)
P_hat_unfold = P_hat.reshape(25,40)
P_est_unfold = P_est.reshape(25,40)

vmin = np.min([P_unfold.min(),P_hat_unfold.min(),P_est_unfold.min()])
vmax = np.max([P_unfold.max(),P_hat_unfold.max(),P_est_unfold.max()])
madimshow(P_unfold[:,:],'magma',vmin=vmin,vmax=vmax,axis=False); plt.title('True')
madimshow(P_hat_unfold[:,:],'magma',vmin=vmin,vmax=vmax,axis=False); plt.title('Sample')
madimshow(P_est_unfold[:,:],'magma',vmin=vmin,vmax=vmax,axis=False); plt.title('Estimated')

fig = plt.figure()
ax = fig.subplots()
ax.plot( err_hat, '-', c=bright_qual['red'], label='Objective' )
ax.plot( err_tru, '--', c=bright_qual['blue'], label='Error' )
ax.set_xlabel('Iterations')
ax.set_ylabel('Error')
ax.legend()
ax.grid(True)

fig = plt.figure()
ax = fig.subplots()
ax.plot( res_pri, c=medcont_qual['dark_red'] )
ax.plot( eps_pri, c=medcont_qual['light_red'] )
ax.set_xlabel('Iterations')
ax.set_ylabel('Primal residual')
ax.grid(True)

fig = plt.figure()
ax = fig.subplots()
ax.plot( res_dua, c=medcont_qual['dark_blue'] )
ax.plot( eps_dua, c=medcont_qual['light_blue'] )
ax.set_xlabel('Iterations')
ax.set_ylabel('Dual residual')
ax.grid(True)

print( torch.norm( P_est-P_hat ) )
print( torch.norm( P_est-P ) )
print( torch.norm( P_hat-P ) )
