In [5]:
import numpy as np
import time
import matplotlib.pyplot as plt
import ot 

import sys
sys.path.append("../src")
from utility import HCP,IPRHCP,rand_projections

#### Firstly, we prove: if we replace orthogonal matrix with an all-ones matrix, we have $IPRHCP_{2,2} = \sqrt{2}SW_2$ 

In [None]:
d = 5
n = 1000

In [9]:
## equal weights IPRHCP
def IPRHCP_J(X, Y, q=2, nslice=500):

    d = X.shape[1]
    res = 0

        
    ## random directions may be faster
    proj = rand_projections(d, nslice)
    Xp = X@proj.T
    Yp = Y@proj.T

    for i in range(nslice):
        Xi = np.zeros((n,q))
        Yi = np.zeros((n,q))
        for j in range(q):
            Xi[:,j] = Xp[:,i]
            Yi[:,j] = Yp[:,i]
        res += HCP(Xi, Yi)**2

    return np.sqrt(res/nslice)

In [16]:
d = 5
n = 1000
np.random.seed(42)

    
for i in range(5):
    # guassian and uniform
    x = np.random.randn(n,d)
    y = np.random.rand(n,d) +10

    # SW
    t1 = ot.sliced.sliced_wasserstein_distance(x, y, seed=2022, n_projections=500) 

    # IPRHCP
    t2 = IPRHCP_J(x,y)

    print(abs(np.sqrt(2)*t1/t2-1))


0.01605136491211767
0.035479340889759126
4.19545929195575e-05
0.0190594800770707
0.017204443367046185


#### Secondly, we prove $IPRHCP_{2,q} \geq \sqrt{q}SW_2$, and this equality can be achieved.

In [30]:
## q = 2
d = 20
n = 1000
np.random.seed(42)

    
for i in range(5):
    # guassian and uniform
    x = np.random.randn(n,d)
    y = np.random.rand(n,d)*10

    # SW
    t1 = ot.sliced.sliced_wasserstein_distance(x, y, seed=2022, n_projections=1000) 

    # IPRHCP
    t2 = IPRHCP(x,y,nslice=1000)

    print(t2/(np.sqrt(2)*t1))


1.0012672176867758
1.0007622581190594
1.0004999605723135
1.0009290858583548
1.0015233203843477


In [29]:
# q = 10
d = 20
n = 1000
np.random.seed(42)

    
for i in range(5):
    # guassian and uniform
    x = np.random.randn(n,d)
    y = np.random.rand(n,d)*10

    # SW
    t1 = ot.sliced.sliced_wasserstein_distance(x, y, seed=2022, n_projections=1000) 

    # IPRHCP
    t2 = IPRHCP(x,y,q=10,nslice=1000)

    print(t2/(np.sqrt(10)*t1))


1.0425707265033552
1.0411260088653582
1.0412894332682217
1.0428737381899336
1.0419773827620398
