In [3]:
import numpy as np
from numpy.random import randn, permutation, seed
from numpy.linalg import norm
from scipy.spatial.distance import pdist, squareform
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt

In [4]:
import sys
sys.path.append("..")
from pp5.stats import tw_test

In [5]:
def W2(mux,muy, Cx, Cy):
    """
    Wasserstein distance between two normal distributions
    """
    return np.sqrt(
        max(0, norm(mux-muy)**2 + np.trace( Cx + Cy - 2*sqrtm(sqrtm(Cx) @ Cy @ sqrtm(Cx)) ))
    )

In [6]:
def gen_random_data(n, nx, ny, shift, covdiff):
    # first-order statistics
    mu_x = randn(n)
    dmu = randn(n)
    mu_y = mu_x + dmu*shift/norm(dmu)

    # second-order statistics
    #A = randn(n, n)
    A = np.abs(np.diag(randn(n)))
    Cx = A @ A.T
    dA = np.abs(randn(n))
    B = A + np.diag(dA/norm(dA)*covdiff)
    Cy = B @ B.T

    # Generate random samples from normal distribution
    X = A @ randn(n, nx) + mu_x[:,None]
    Y = B @ randn(n, ny) + mu_y[:,None]
    
    return X, Y, mu_x, Cx, mu_y, Cy 

In [7]:
print(f"dims\t|mux-muy|\tW2(Cx,Cy)\tNx\tNy\tT2\tp")

RESULTS = []

n  = 5       # number of dimensions
trials = 10
for shift in np.linspace(0,2,11):
    for covdiff in (0,): #(0,0.1,1.,2.):   # uncomment for heteroschedasticity 
        for nx in (10,50,100,):     # number of samples in X
            for ny in (10,50,100,):
                if ny < nx: 
                    continue
                ps = []
                for trial in range(trials):
                    seed(trial)
                    X, Y, mu_x, Cx, mu_y, Cy = gen_random_data(n, nx, ny, shift, covdiff)
                    w2 = W2(mu_x*0,mu_y*0, Cx, Cy)
                    t2, p = tw_test(X,Y,k=1000)
                    ps.append(p)
                p_mean = np.mean(ps)
                p_std  = np.std(ps)
                print(f"{n:<2}\t{shift:<4.2f}\t\t{w2:<4.2f}\t\t{nx:<4}\t{ny:<4}\t{t2:<6.2f}\t{p_mean:<8.6f}±{p_std:<8.6f}")
                RESULTS.append( ((nx,ny), shift, p_mean, p_std) )

dims	|mux-muy|	W2(Cx,Cy)	Nx	Ny	T2	p
5 	0.00		0.00		10  	10  	0.41  	0.615100±0.255964
5 	0.00		0.00		10  	50  	0.37  	0.720100±0.203110
5 	0.00		0.00		10  	100 	0.47  	0.631500±0.211098
5 	0.00		0.00		50  	50  	0.26  	0.632900±0.319203
5 	0.00		0.00		50  	100 	0.22  	0.670200±0.262467
5 	0.00		0.00		100 	100 	0.42  	0.413200±0.252850
5 	0.20		0.00		10  	10  	0.46  	0.590400±0.256463
5 	0.20		0.00		10  	50  	0.43  	0.638100±0.189404
5 	0.20		0.00		10  	100 	0.61  	0.564200±0.201665
5 	0.20		0.00		50  	50  	0.49  	0.521500±0.302366
5 	0.20		0.00		50  	100 	0.59  	0.430100±0.262539
5 	0.20		0.00		100 	100 	0.99  	0.263100±0.209522
5 	0.40		0.00		10  	10  	0.56  	0.514100±0.264871
5 	0.40		0.00		10  	50  	0.54  	0.511900±0.235512
5 	0.40		0.00		10  	100 	0.79  	0.450600±0.228070
5 	0.40		0.00		50  	50  	0.97  	0.290100±0.224022
5 	0.40		0.00		50  	100 	1.32  	0.184500±0.139874
5 	0.40		0.00		100 	100 	2.08  	0.115400±0.124094
5 	0.60		0.00		10  	10  	0.69  	0.424600±0.277186
5 	0.60		0.00	

In [None]:
R = np.array(RESULTS)
plt.figure()
for n in np.unique(R[:,0]):
    idx = [r==n for r in R[:,0]]
    plt.errorbar(R[idx,1],R[idx,2],yerr=R[idx,3],capsize=3, marker='o')
plt.legend(np.unique(R[:,0]))
plt.yscale('log')
plt.ylim([1e-3,1])
plt.xlabel('Mean distance')
plt.ylabel('Significance')