## EMD calculation with neural nets

In [1]:
# add modules to Python's search path
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

# import our EMD calculation module(s)
from modules import krd_nn as emd

# import rest of the helper modules
import numpy as np
import tensorflow as tf

###  Create samplers for our test distributions

In [29]:
# make a convenient wrapper for producing samples in form of a tensor
def gaussian_sampler(mean, cov, size):
    samples = np.random.multivariate_normal(mean, cov, size)
    return tf.convert_to_tensor(samples, dtype=tf.float64)

# set up parameters for our two test distributions
dimension = 3
mean_1 = np.zeros(dimension)
mean_2 = mean_1 + 100.0 * np.ones(dimension)
cov_1 = np.identity(dimension)
cov_2 = cov_1

# finally create the samplers our test distributions
sampler_1 = lambda size: gaussian_sampler(mean_1, cov_1, size)
sampler_2 = lambda size: gaussian_sampler(mean_2, cov_2, size)

# test our samplers
print("samples from distribution #1:\n{}".format(sampler_1(3)))
print("samples from distribution #2:\n{}".format(sampler_2(3)))

samples from distribution #1:
[[ 2.45010024 -0.6250064  -0.73813751]
 [-0.20509948 -0.92776888  1.48045425]
 [ 1.01551443 -0.66929107  0.38439421]]
samples from distribution #2:
[[100.11749504  98.53563425  99.57306788]
 [ 99.94913615  99.61431729 101.04460215]
 [ 99.69466636 101.81548539 100.12731173]]


### Calculate approximate EMD with the help of Kantorovich-Rubinstein duality

In [30]:
emd_calc= emd.W1L1Calc(sampler_1, sampler_2, clip_value=0.02)
emd_calc.calculate(epochs=500)

epoch = 1, EMD = 0.6991717217090121
epoch = 2, EMD = 0.0015139606920390892
epoch = 3, EMD = 0.003665025915088325
epoch = 4, EMD = 0.005253414862830491
epoch = 5, EMD = 0.006526137289124918
epoch = 6, EMD = 0.007590325389371284
epoch = 7, EMD = 0.008523600158792654
epoch = 8, EMD = 0.009313111946083481
epoch = 9, EMD = 0.010051953432828269
epoch = 10, EMD = 0.010712261319661185
epoch = 11, EMD = 0.011327935410066079
epoch = 12, EMD = 0.011913644519336476
epoch = 13, EMD = 0.012477953595137288
epoch = 14, EMD = 0.013001704121471978
epoch = 15, EMD = 0.013481344592686302
epoch = 16, EMD = 0.014005864316001296
epoch = 17, EMD = 0.014464161779842897
epoch = 18, EMD = 0.014937374057551214
epoch = 19, EMD = 0.015300417092154758
epoch = 20, EMD = 0.015715968514561307
epoch = 21, EMD = 0.01618203838975978
epoch = 22, EMD = 0.01658456373847814
epoch = 23, EMD = 0.017069264532426742
epoch = 24, EMD = 0.01756654293834343
epoch = 25, EMD = 0.018110530601385916
epoch = 26, EMD = 0.01861811915789849


epoch = 232, EMD = 0.6631288078094393
epoch = 233, EMD = 0.6639126995378595
epoch = 234, EMD = 0.6644467274873844
epoch = 235, EMD = 0.6679082189829679
epoch = 236, EMD = 0.66853681320485
epoch = 237, EMD = 0.6701631840430559
epoch = 238, EMD = 0.6674073308995268
epoch = 239, EMD = 0.6694691810788989
epoch = 240, EMD = 0.6711918411375217
epoch = 241, EMD = 0.6714587337009019
epoch = 242, EMD = 0.6748427721073931
epoch = 243, EMD = 0.6750611325947568
epoch = 244, EMD = 0.6742235176505756
epoch = 245, EMD = 0.6746716744293371
epoch = 246, EMD = 0.6770660846043808
epoch = 247, EMD = 0.678181102641039
epoch = 248, EMD = 0.6767221398547371
epoch = 249, EMD = 0.6781414026767314
epoch = 250, EMD = 0.6790606526650553
epoch = 251, EMD = 0.6771630654725852
epoch = 252, EMD = 0.677872508798963
epoch = 253, EMD = 0.6817469544298709
epoch = 254, EMD = 0.6807240631195698
epoch = 255, EMD = 0.680957654574789
epoch = 256, EMD = 0.6802957299575103
epoch = 257, EMD = 0.6796954329748282
epoch = 258, EMD 

epoch = 464, EMD = 0.6860147196825024
epoch = 465, EMD = 0.6856540081596888
epoch = 466, EMD = 0.6842016235141767
epoch = 467, EMD = 0.6847722886761378
epoch = 468, EMD = 0.6823802120346852
epoch = 469, EMD = 0.6857895258273197
epoch = 470, EMD = 0.6864915245022267
epoch = 471, EMD = 0.6855058011546837
epoch = 472, EMD = 0.6853281145060678
epoch = 473, EMD = 0.682263527511183
epoch = 474, EMD = 0.6827232105043904
epoch = 475, EMD = 0.6860740679541589
epoch = 476, EMD = 0.6838139107351495
epoch = 477, EMD = 0.6855710263849422
epoch = 478, EMD = 0.6848641574833
epoch = 479, EMD = 0.6839874681028288
epoch = 480, EMD = 0.6860267058139947
epoch = 481, EMD = 0.6840945538651573
epoch = 482, EMD = 0.6831030378142554
epoch = 483, EMD = 0.687192782203732
epoch = 484, EMD = 0.6855092551692192
epoch = 485, EMD = 0.6840733157504278
epoch = 486, EMD = 0.6836284167233964
epoch = 487, EMD = 0.6862577513172813
epoch = 488, EMD = 0.685722849452918
epoch = 489, EMD = 0.6877224131841894
epoch = 490, EMD =

<tf.Tensor: shape=(), dtype=float64, numpy=0.6861628675221866>