In [None]:
from scipy import io
import numpy as np
import ot
from gsw.gsw import GSW
import torch
import matplotlib.pyplot as plt

In [2]:
dist = GSW(ftype='linear')
device = torch.device('cuda')
# we approximate SW1 using subgradient ascent as implemented by Kolouri et al. 2019

In [None]:
its = 50
lr = 1e-2

p = 1.0
dims = [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190,200]
mean_diff_list = {}
MSW_list = {}

for d in dims:
    print(f'd={d}')
    # load data from MATLAB
    clean_data = io.loadmat('/home/ubuntu/sloan/sliced-OT-data/clean' + str(d) + '.mat')['X']
    corrupted_data = io.loadmat('/home/ubuntu/sloan/sliced-OT-data/corrupted' + str(d) + '.mat')['X']
    filtered_data = io.loadmat('/home/ubuntu/sloan/sliced-OT-data/filtered' + str(d) + '.mat')['filteredData']
    
    # match array sizes, introduces small bias controlled by empirical approximation error
    # 1D optimal couplings can still by computed efficiently when array sizes don't match, but implementation is less clean - we omit this at the present
    clean_data_set = set([tuple(l) for l in clean_data.tolist()])
    clean_filter_cmp_data = np.copy(filtered_data)
    for i in range(filtered_data.shape[0]):
        x = filtered_data[i,:]
        if tuple(x.tolist()) not in clean_data_set:
            j = np.random.choice(clean_data.shape[0])
            clean_filter_cmp_data[i,:] = clean_data[j,:]
    clean_corrupted_cmp_data = np.copy(corrupted_data)
    clean_corrupted_cmp_data[0:clean_data.shape[0],:] = clean_data
    for i in range(clean_data.shape[0],corrupted_data.shape[0]):
        j = np.random.choice(clean_data.shape[0])
        clean_corrupted_cmp_data[i,:] = clean_data[j,:]
    print('prepared clean data')

    # convert to torch data type
    clean_data = torch.tensor(clean_data, device=device, dtype=torch.float)
    corrupted_data = torch.tensor(corrupted_data, device=device, dtype=torch.float)
    filtered_data = torch.tensor(filtered_data, device=device, dtype=torch.float)
    clean_filter_cmp_data = torch.tensor(clean_filter_cmp_data, device=device, dtype=torch.float)
    clean_corrupted_cmp_data = torch.tensor(clean_corrupted_cmp_data, device=device, dtype=torch.float)

    # estimate max-sliced distance between filtered data and clean data
    msw_filtered = dist.max_gsw(filtered_data, clean_filter_cmp_data, iterations=its, lr=lr, p=p, rand_init=False)
    MSW_list[(d,True)] = msw_filtered.detach().cpu()
    print(msw_filtered)
    # estimate max-sliced distance between corrupted data and clean data
    sw_corrupted = dist.max_gsw(corrupted_data, clean_corrupted_cmp_data, iterations=iterations, lr=lr, p=p, rand_init=init)
    MSW_list[(d,False)] = sw_corrupted.detach().cpu()
    print(sw_corrupted)

    mean_diff_list[(d, True)] = (filtered_data.mean(dim=0) - clean_filter_cmp_data.mean(dim=0)).norm().detach().cpu()
    mean_diff_list[(d, False)] = (corrupted_data.mean(dim=0) - clean_corrupted_cmp_data.mean(dim=0)).norm().detach().cpu()

In [None]:
plot_dims = [10,20,30,40,50,60,70,80,90,100,110,120,130,140,150,160,170,180,190,200]
plt.plot(plot_dims, [MSW_list[(d,True)].detach().cpu() for d in plot_dims], 'g', label='$\overline{W}_1$ error (filtered)')
plt.plot(plot_dims, [MSW_list[(d,False)].detach().cpu() for d in plot_dims], 'g--', label='$\overline{W}_1$ error (corrupted)')
plt.plot(plot_dims, [mean_diff_list[(d, True)] for d in plot_dims], 'b', label='$\ell_2$ mean error (filtered)')
plt.plot(plot_dims, [mean_diff_list[(d, False)] for d in plot_dims], 'b--', label='$\ell_2$ mean error (corrupted)')
plt.legend()
plt.xlabel('dimension')
plt.ylabel('excess error ($\ell_2$ distance)')
plt.title('Robust Estimation Error: $\overline{W}_1$ vs. Difference b/t Means')