In [80]:
import pytest
import numpy as np
from scipy.optimize import linear_sum_assignment

from ParTIpy.arch import AA
from ParTIpy.generate_test_data import simulate

def compute_dist_mtx(mtx_1, mtx_2):
    AB = np.dot(mtx_1, mtx_2.T)
    AA = np.sum(np.square(mtx_1), axis=1)
    BB = np.sum(np.square(mtx_2), axis=1)
    dist_mtx = (BB - 2 * AB).T + AA
    dist_mtx[np.isclose(dist_mtx, 0)] = (
        0  # avoid problems if we get small negative numbers due to numerical inaccuracies
    )
    dist_mtx = np.sqrt(dist_mtx)
    return dist_mtx

def align_archetypes(ref_arch, query_arch):
    # not sure if copy here is needed, compute_dist_mtx should not modify the matrices
    euclidean_d = compute_dist_mtx(ref_arch, query_arch.copy()).T
    ref_idx, query_idx = linear_sum_assignment(euclidean_d)
    return query_arch[query_idx, :]

N_SAMPLES = 1_000
N_ARCHETYPES = 4
N_DIMENSIONS = 10
MIN_CORR = 0.95

X, A, Z = simulate(n_samples=N_SAMPLES, 
                   n_archetypes=N_ARCHETYPES, 
                   n_dimensions=10, 
                   noise_std=0.0)

In [81]:
A_hat, B_hat, Z_hat, RSS, varexpl = \
    AA(n_archetypes=N_ARCHETYPES, optim="projected_gradients").fit(X).return_all()
Z_hat = align_archetypes(Z, Z_hat)

In [82]:
def compute_rowwise_correlation(mtx_1, mtx_2):
    assert np.all(mtx_1.shape == mtx_2.shape)
    mtx_1 = mtx_1 - mtx_1.mean(axis=1, keepdims=True)
    mtx_1 /= mtx_1.std(axis=1, keepdims=True)
    mtx_2 = mtx_2 - mtx_2.mean(axis=1, keepdims=True)
    mtx_2 /= mtx_2.std(axis=1, keepdims=True)
    corr_vec = np.mean(mtx_1 * mtx_2, axis=1)
    return corr_vec

In [84]:
assert np.all(compute_rowwise_correlation(Z, Z_hat) > MIN_CORR)