In [16]:
from scipy.stats import wasserstein_distance as wsd
from scipy.stats import beta as beta_distribution

from ddsm import noise_factory
import torch

In [17]:
num_cat = 9
num_samples = 500000
device="cpu"

In [18]:
total_time = 4.0
order = 1000
steps_per_tick = 200
speed_balance = True #if True s=2/(a+b) is used; otherwise s=1
logspace = True
mode ='path'

In [19]:
alpha = torch.ones(1)
beta =  (num_cat-1)*torch.ones(1)

In [None]:
iteration = 10
for num_time_steps in [400]:#,800,1200,1600,2000]:#[100,400,800,1200,1600,2000]::
    for boundary_mode in ['clamp', 'reflect_boundaries', 'reflect']:
        distances = torch.zeros(iteration)
        for i in range(iteration):
            v_one, v_zero = noise_factory(num_samples, 
                                            num_time_steps,
                                            alpha,
                                            beta,
                                            total_time=total_time,
                                            order=order,
                                            time_steps=steps_per_tick,
                                            logspace=logspace,
                                            speed_balanced=speed_balance,
                                            mode=mode,
                                            device=device,
                                            boundary_mode=boundary_mode,
                                            noise_only=True)

            vT_approx = torch.cat([v_one[:,-1,0],v_zero[:,-1,0]])
            vT_true = beta_distribution.rvs(alpha, beta, size=vT_approx.shape[0])
            d = wsd(vT_approx,vT_true)
            print(f'{boundary_mode} with steps={num_time_steps}: wsd={d}')
            distances[i] = d

        #print('v shape one', v_one.shape)
        #print('v shape zero', v_zero.shape)

        print(f'\n{boundary_mode} with steps={num_time_steps}: mean={torch.mean(distances)} and speed_balance={speed_balance}')
        print(f'{boundary_mode} with steps={num_time_steps}: std={torch.std(distances)} and speed_balance={speed_balance}\n')

clamp with steps=400: wsd=0.007076658603011063


## Ground truth

In [74]:
iteration = 10
distances = torch.zeros(iteration)
for i in range(iteration):
    test1 = beta_distribution.rvs(alpha, beta, size=1000000)
    test2 = beta_distribution.rvs(alpha, beta, size=1000000)
    d = wsd(test1,test2)
    print('distance',d)
    distances[i] = d

test1 (1000000,)
test2 (1000000,)
distance 0.00015255260316868874
test1 (1000000,)
test2 (1000000,)
distance 0.00021186835176490015
test1 (1000000,)
test2 (1000000,)
distance 0.00011741920541712615
test1 (1000000,)
test2 (1000000,)
distance 0.00019672195663265597
test1 (1000000,)
test2 (1000000,)
distance 0.00014337657853888966
test1 (1000000,)
test2 (1000000,)
distance 0.00022759701250514168
test1 (1000000,)
test2 (1000000,)
distance 0.00029576271299861255
test1 (1000000,)
test2 (1000000,)
distance 0.00010474523625257931
test1 (1000000,)
test2 (1000000,)
distance 0.0001795923039257112
test1 (1000000,)
test2 (1000000,)
distance 0.00011724995625918759


In [75]:
print('mean',torch.mean(distances))
print('std',torch.std(distances))

mean tensor(0.0002)
std tensor(5.9900e-05)
