In [1]:
%load_ext autoreload 
%autoreload 2

In [5]:
import jax
import jax.numpy as jnp
import jax.random as random
from jax.random import multivariate_normal
from util import MMDVar, compute_mmd_sq, compute_K_matrices

In [6]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

In [9]:
def sampler_normal(key, N1, N2, d):
    key, subkey = random.split(key)
    subkeys = random.split(subkey, num=2)
    X = jax.random.multivariate_normal(subkeys[0], jnp.zeros((d, )), jnp.eye(d), shape=(N1,)) 
    Y = jax.random.multivariate_normal(subkeys[1], 1.1 * jnp.zeros((d, )), jnp.eye(d), shape=(N2,)) 
    return X, Y

In [11]:
# Experiment 1

# Number of times to estimate MMD
num_samples = 10000

# Array to store MMD values
mmd_samples = jnp.zeros(num_samples)
sigma0 = 1

# Monte Carlo simulation to estimate MMD values
for i in range(num_samples):
    key, subkey1, subkey2 = random.split(key, 3)
    X, Y = sampler_normal(subkey1, N1= 200, N2=10, d=5)
    Kxx, Kyy, Kxy = compute_K_matrices(X, Y, sigma0)
    mmd_value = compute_mmd_sq(Kxx, Kyy, Kxy, 200, 10)
    mmd_samples = mmd_samples.at[i].set(mmd_value)

# Compute estimated variance of MMD
mmd_variance = jnp.var(mmd_samples, ddof=1)

print("Estimated Variance of MMD:", mmd_variance)

# Compare the results
print(f"Full Variance Estimate: {MMDVar(X, Y, sigma0)}")
print(f"Incomplete Variance Estimate: {MMDVar(X, Y, sigma0, complete=False)}")

Estimated Variance of MMD: 8.540801317976095e-05
Full Variance Estimate: 0.005245646583268721
Incomplete Variance Estimate: 0.005034097172184765


In [12]:
# Experiment 2

# Seed setting (for reproducible results)
key = random.PRNGKey(0)

# First dataset
mean1 = jnp.zeros(5)
cov1 = jnp.eye(5)

# Second dataset
mean2 = jnp.zeros(5)
cov2 = jnp.eye(5)
cov2 = cov2.at[3, 4].set(0.3)
cov2 = cov2.at[4, 3].set(0.5)

# Function to generate samples from multivariate normal distribution
def sample_mvn(key, mean, cov, num_samples):
    return multivariate_normal(mean=mean, cov=cov, shape=num_samples, key=key)

# Number of times to estimate MMD
num_samples = 10000

# Array to store MMD values
mmd_samples = jnp.zeros(num_samples)
sigma0 = 1

# Monte Carlo simulation to estimate MMD values
for i in range(num_samples):
    key, subkey1, subkey2 = random.split(key, 3)
    X = sample_mvn(subkey1, mean1, cov1, (200,))
    Y = sample_mvn(subkey2, mean2, cov2, (10,))

    Kxx, Kyy, Kxy = compute_K_matrices(X, Y, sigma0)
    mmd_value = compute_mmd_sq(Kxx, Kyy, Kxy, len(X), len(Y))

    mmd_samples = mmd_samples.at[i].set(mmd_value)

# Compute estimated variance of MMD
mmd_variance = jnp.var(mmd_samples, ddof=1)

print("Estimated Variance of MMD:", mmd_variance)

# Compare the results
print(f"Full Variance Estimate: {MMDVar(X, Y, sigma0)}")
print(f"Incomplete Variance Estimate: {MMDVar(X, Y, sigma0, complete=False)}")