In [None]:
from scipy import io
import numpy as np
import ot
from gsw.gsw import GSW
import torch
import matplotlib.pyplot as plt
import time

In [2]:
# X,Y in R^{n x d}
def proj_wp(X, Y, theta, p=2):
    N, d = X.shape
    theta = theta.flatten()
    xproj = np.matmul(X, theta)
    yproj = np.matmul(Y, theta)
    return np.mean(np.abs((np.sort(xproj) - np.sort(yproj)))**p)**(1/p)

def norm(x):
    return np.sqrt(sum(x**2))

def samp_sph(d):
    x = np.random.normal(size = d)
    return x/norm(x)

In [14]:
# lower bound W1 via the coupling that leaves points at 0 unmoved when possible
def W1_lower_bounds(clean_data, filtered_data):
    clean_norms = np.linalg.norm(clean_data, axis=1)
    filtered_norms = np.linalg.norm(filtered_data, axis=1)
    clean_nz = np.sum(clean_norms > 0)
    clean_N = clean_data.shape[0]
    clean_r = clean_nz/clean_N
    filtered_nz = np.sum(filtered_norms > 0)
    filtered_N = clean_data.shape[0]
    filtered_r = filtered_nz/filtered_N

    return (clean_r - filtered_r)*clean_norms[-1]

In [3]:
def subg_step(X, Y, theta, alpha):
    N, d = X.shape
    theta_X = np.matmul(X, theta)
    theta_Y = np.matmul(Y, theta)

    X_ind = np.argsort(theta_X)
    Y_ind = np.argsort(theta_Y)
    grad = 2*np.dot(theta_X[X_ind] - theta_Y[Y_ind], X[X_ind,:] - Y[Y_ind,:])/N
    newtheta = (theta + alpha*grad)
    return newtheta/norm(newtheta)
             
def msw2_distance_subg(X, Y, n_step, theta0):
    N, d = X.shape
    alpha = np.ones(n_step) # constant step size, can also try np.sqrt(range(1,n_step + 1))
    theta = theta0
    wp_dist = np.zeros(n_step)

    time_iter = np.zeros(n_step+1)
    U_iter = np.zeros((n_step+1,theta.shape[0],1))
    U_iter[0,:,:] = theta0[np.newaxis].T

    for i in range(n_step):
        tic = time.perf_counter()

        theta = subg_step(X, Y, theta, alpha[i])
        wp_dist[i] = proj_wp(X, Y, theta)

        toc = time.perf_counter()
        time_iter[i + 1] = time_iter[i] + toc - tic
        U_iter[i+1,:,:] = theta[np.newaxis].T
    return proj_wp(X, Y, theta)

In [None]:
its = 50
lr = 1e-2
p = 1.0

dims = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200]
MSW_list = {}
W_list = {}

for d in dims:
    print(f'd={d}')
    # load data from MATLAB
    clean_data = io.loadmat('data2/clean' + str(d) + '.mat')['X']
    filtered_data = io.loadmat('data2/filtered' + str(d) + '.mat')['filteredData']
    
    # match array sizes, introduces small bias controlled by empirical approximation error
    # 1D optimal couplings can still by computed efficiently when array sizes don't match, but implementation is less clean - we omit this at the present
    clean_data_set = set([tuple(l) for l in clean_data.tolist()])
    clean_filter_cmp_data = np.copy(filtered_data)
    for i in range(filtered_data.shape[0]):
        x = filtered_data[i,:]
        if tuple(x.tolist()) not in clean_data_set:
            j = np.random.choice(clean_data.shape[0])
            clean_filter_cmp_data[i,:] = clean_data[j,:]
    print('prepared clean data')

    print('estimating sliced distances')
    msw_filtered = msw2_distance_subg(filtered_data, clean_filter_cmp_data, its, samp_sph(d))
    MSW_list[d] = msw_filtered    
    print('computing W1 lower bound')
    W_list[d] = W1_lower_bounds(clean_data, filtered_data)

In [None]:
dims = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200]

plt.plot(dims, [MSW_list[d] for d in dims], label='$\overline{W}_2 error')
plt.plot(dims, [W_list[d] for d in dims], label='W1 error')
plt.legend()
plt.xlabel('dimension')
plt.ylabel('excess error ($\ell_2$ distance)')