In [2]:
%load_ext autoreload 
%autoreload 2

In [3]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [5]:
import math
from tqdm.auto import tqdm

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from pathlib import Path
from utils import compute_K, Ustat_MMD
from mmdvar import IncomMMDVar, ComMMDVar, h1_mean_var_gram

In [6]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [7]:
# Function to generate samples from multivariate normal distribution
def sample_mvn(key, mean, cov, num_samples):
    return jax.random.multivariate_normal(mean=mean, cov=cov, shape=num_samples, key=key)

In [8]:
def MAPE(true, est):
    return jnp.abs(true - jnp.average(est)) / true

In [10]:
##########  Simulation 1  ############
#### decreasing signal as n grows ####
######################################
key = jax.random.PRNGKey(42)

# Parameter Settings 
n_iters = 1000 # Number of times to estimate MMD
sigma0 = 1.0 # Default bandwidth for Gaussian kernel 
d = 5 # dimension of data 

# Array to store MMD values
mmd_samples = jnp.zeros(n_iters)
full = jnp.zeros(n_iters) # Complete MMD Variance Estimate(Ours)
Liu = jnp.zeros(n_iters)  # V-stats MMD variance Estimate(Liu et al. 2020)
suth = jnp.zeros(n_iters)# Incomplete MMD Variance Estimate(Sutherland et al. 2019)
Vm = jnp.zeros(n_iters)
# com_2 = jnp.zeros(n_iters)

full_variance_estimates = []
# complete_variance_estimates = []
sutherland_estimates = []
vm_estimates = []
liu_estimates = []

num_samples = [64, 128, 256, 512, 1024, 2048]

for sample_size in tqdm(num_samples):
    
    print("number of samples : ",sample_size)
    delta = d ** (1/4) / jnp.sqrt(sample_size)
    # First dataset
    mean1 = jnp.zeros(d)
    cov = jnp.eye(d)

    # Second dataset
    mean2 = jnp.zeros(d) + delta
    
    # Monte Carlo simulation to estimate MMD values
    for i in tqdm(range(n_iters)):
        key, subkey1, subkey2 = jax.random.split(key, 3)
        
        # Mean Difference 
        X = sample_mvn(subkey1, mean1, cov, (sample_size,))
        Y = sample_mvn(subkey2, mean2, cov, (sample_size,))

        tKxx, tKyy, Kxy = compute_K(X, Y, sigma0, bias=False)
        Kxx, Kyy, Kxy = compute_K(X, Y, sigma0, bias=True)
        m = Kxx.shape[0]
        n = Kyy.shape[0]
        mmd_value = Ustat_MMD(tKxx, tKyy, Kxy, m, n)

        mmd_samples = mmd_samples.at[i].set(mmd_value)
        full = full.at[i].set(ComMMDVar(tKxx, tKyy, Kxy))
        # com_2 = com_2.at[i].set(compute_unbiased_var (Kxx, Kyy, Kxy, m ,n, complete = False))
        suth = suth.at[i].set(IncomMMDVar(tKxx, tKyy, Kxy))
        Liu = Liu.at[i].set(h1_mean_var_gram(Kxx, Kyy, Kxy, is_var_computed=True, use_1sample_U=True)[1])

    # Compute estimated variance of MMD
    mmd_variance = jnp.var(mmd_samples, ddof=1)
    
    full_variance_estimates.append(MAPE(mmd_variance, full))
    # complete_variance_estimates.append(jnp.abs(mmd_variance - jnp.average(com_2))/mmd_variance)
    sutherland_estimates.append(MAPE(mmd_variance, suth))
    liu_estimates.append(MAPE(mmd_variance, Liu))
    vm_estimates.append(MAPE(mmd_variance, Vm))

    print("The average of MMDu",jnp.average(mmd_samples))
    print("True Variance of MMD:", mmd_variance)

    # Compare the results
    print()
    print("How accurate estimators estimate the variance : estimator/true")
    print(f"Full Variance Estimate (Ours 8): {MAPE(mmd_variance, full)}")
    # print(f"Complete Variance Estimate (Ours 2): {MAPE}")
    print(f"Sutherland (2019) : {MAPE(mmd_variance, suth)}")
    print(f"Incomplete Variance Estimate (Liu et al. 2): {MAPE(mmd_variance, Liu)}")
    print(f"Vm: {MAPE(mmd_variance, Vm)}")
    print()

  0%|          | 0/6 [00:00<?, ?it/s]

number of samples :  64


  0%|          | 0/1000 [00:00<?, ?it/s]

The average of MMDu 0.001209887862551965
True Variance of MMD: 8.413538034923826e-06

How accurate estimators estimate the variance : estimator/true
Full Variance Estimate (Ours 8): 0.08058881685214747
Sutherland (2019) : 0.010705895095924895
Incomplete Variance Estimate (Liu et al. 2): 110.14100479662471
Vm: 1.0

number of samples :  128


  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

# Data
sample_sizes = [10, 100, 500, 1000, 2000]
full_variance_estimates = [0.17655509368585445, 0.11044403689984993, 0.15714277058262002, 0.14202993663589988, 0.05336349313853741]
complete_variance_estimates = [0.945722031121335, 0.8725068453005361, 0.8833989844950265, 0.8757223919743891, 0.8863523431043184]
sutherland_estimates = [0.03326352570744373, 0.02137627443252068, 0.07469605649442351, 0.05894031950457824, 1238.4849737345553]
incomplete_variance_estimates = [0.34708349699653546, 0.7648303734535044, 0.725764638276863, 0.7523154270156995, 0.964991847340291]

# Create plots for each estimator
plt.figure(figsize=(10, 6))
plt.plot(sample_sizes, full_variance_estimates, label='Full Variance Estimate (Ours 8)')
plt.plot(sample_sizes, complete_variance_estimates, label='Complete Variance Estimate (Ours 2)')
plt.plot(sample_sizes, sutherland_estimates, label='Sutherland (2022)')
plt.plot(sample_sizes, incomplete_variance_estimates, label='Incomplete Variance Estimate (Liu et al. 2)')

# Set plot labels and title
plt.xlabel('Sample Size')
plt.ylabel('Estimator/True Variance Ratio')
plt.yscale('log')  # You can choose to use a logarithmic scale if needed
plt.title('Estimator Variance Performance')
plt.legend()

# Show the plot
plt.grid(True)
plt.tight_layout()
plt.show()
