In [1]:
import jax 
import jax.numpy as jnp
from jax.random import normal
import numpy as np
import pandas as pd
from IPython.display import display


![image.png](attachment:image.png)

In [2]:
def mmd(u1, u2, v1, v2, k):
    """
    u1, u2: N_u * d array where each row is a sample from pdf p(.)
    v1, v2: N_v * d array where each row is a sample from pdf q(.)
    k: kernel function that measures similarity of u / v
    """
    uu = jnp.mean(k(u1, u2))
    uv = jnp.mean(k(u1+u2, v1+v2))
    vv = jnp.mean(k(v1, v2))
    return uu + vv - 2*uv

In [3]:
def gaussian_kernel(u, v, s=1.):
    norm = jnp.linalg.norm(u-v, axis=1)
    return jnp.exp(-norm/(2*s**2))
    

In [4]:
def mmd_estimator_stats(batch_size=32, 
                        feature_dim=8, 
                        dm=0.1,
                        n_trials=10000,
                        seed=42,
                        kernel=gaussian_kernel):
    
    rng_key = jax.random.PRNGKey(seed)
    
    m1 = 0
    m2 = m1 + dm
    s1 = 1
    s2 = 1
    d = feature_dim
    N = batch_size
    shape = (N, d)
    
    mmd_estimates = []
    
    @jax.jit
    def aux(rng_key, i):
        *keys, rng_key = jax.random.split(rng_key, 5)
        u1 = s1*normal(keys[0], shape) + m1
        u2 = s1*normal(keys[1], shape) + m1
        v1 = s1*normal(keys[2], shape) + m2
        v2 = s1*normal(keys[3], shape) + m2
        return mmd(u1, u2, v1, v2, kernel), rng_key
    
    for i in range(n_trials):
        res, rng_key = aux(rng_key, i)
        mmd_estimates.append(res)
        
    mmd_mean = np.mean(mmd_estimates)
    mmd_std = np.std(mmd_estimates)
    return mmd_mean, mmd_std

In [9]:
def experiment(batch_size_lst=[16, 32, 64, 128],
               m_diff_lst=[1e-2, 1e-1, 1e0, 1e1],
               feature_dim_lst=[8, 16, 32, 64]):
    dfs = {}
    for batch_size in batch_size_lst:
        res_arr = np.empty((len(m_diff_lst), len(feature_dim_lst)), dtype=tuple)
        for i, m_diff in enumerate(m_diff_lst):
            for j, feature_dim in enumerate(feature_dim_lst):
                res_arr[i,j] = mmd_estimator_stats(batch_size=batch_size, 
                                                   dm=m_diff, 
                                                   feature_dim=feature_dim)
                
        df = pd.DataFrame(res_arr, 
                          pd.MultiIndex.from_product([['Difference in means'], m_diff_lst]),
                          pd.MultiIndex.from_product([['Number of feature dimensions'], feature_dim_lst]))
        dfs[batch_size] = df
        print(f"Batch size: {batch_size}")
        display(df.style.format(lambda ms: f"{ms[0]:.3} ({ms[1]:.3})" ))
        
    return dfs
                
        

In [10]:
dfs = experiment()

Batch size: 16


Unnamed: 0_level_0,Unnamed: 1_level_0,Number of feature dimensions,Number of feature dimensions,Number of feature dimensions,Number of feature dimensions
Unnamed: 0_level_1,Unnamed: 1_level_1,8,16,32,64
Difference in means,0.01,0.161 (0.0393),0.0898 (0.0153),0.0334 (0.00424),0.0072 (0.000765)
Difference in means,0.1,0.163 (0.0391),0.0906 (0.0153),0.0336 (0.00422),0.00723 (0.000764)
Difference in means,1.0,0.266 (0.0308),0.128 (0.0126),0.0417 (0.00387),0.00804 (0.000745)
Difference in means,10.0,0.323 (0.0274),0.139 (0.0123),0.0427 (0.00387),0.00808 (0.000745)


Batch size: 32


Unnamed: 0_level_0,Unnamed: 1_level_0,Number of feature dimensions,Number of feature dimensions,Number of feature dimensions,Number of feature dimensions
Unnamed: 0_level_1,Unnamed: 1_level_1,8,16,32,64
Difference in means,0.01,0.161 (0.0282),0.0899 (0.0108),0.0334 (0.00302),0.0072 (0.000542)
Difference in means,0.1,0.163 (0.0281),0.0907 (0.0108),0.0336 (0.00301),0.00723 (0.000541)
Difference in means,1.0,0.265 (0.0218),0.128 (0.00894),0.0417 (0.00276),0.00805 (0.000527)
Difference in means,10.0,0.323 (0.0196),0.139 (0.00875),0.0427 (0.00275),0.00808 (0.000527)


Batch size: 64


Unnamed: 0_level_0,Unnamed: 1_level_0,Number of feature dimensions,Number of feature dimensions,Number of feature dimensions,Number of feature dimensions
Unnamed: 0_level_1,Unnamed: 1_level_1,8,16,32,64
Difference in means,0.01,0.161 (0.02),0.0899 (0.00771),0.0334 (0.00214),0.0072 (0.000378)
Difference in means,0.1,0.163 (0.0199),0.0907 (0.00769),0.0336 (0.00213),0.00723 (0.000377)
Difference in means,1.0,0.266 (0.0155),0.128 (0.00636),0.0417 (0.00195),0.00804 (0.000368)
Difference in means,10.0,0.323 (0.0139),0.139 (0.0062),0.0427 (0.00194),0.00808 (0.000369)


Batch size: 128


Unnamed: 0_level_0,Unnamed: 1_level_0,Number of feature dimensions,Number of feature dimensions,Number of feature dimensions,Number of feature dimensions
Unnamed: 0_level_1,Unnamed: 1_level_1,8,16,32,64
Difference in means,0.01,0.161 (0.0142),0.0899 (0.0054),0.0334 (0.00149),0.00719 (0.000269)
Difference in means,0.1,0.163 (0.0141),0.0907 (0.00539),0.0336 (0.00149),0.00723 (0.000269)
Difference in means,1.0,0.265 (0.011),0.128 (0.00446),0.0417 (0.00136),0.00804 (0.000263)
Difference in means,10.0,0.323 (0.00982),0.139 (0.00436),0.0427 (0.00136),0.00808 (0.000263)
