In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
## Run experiments 
## https://github.com/antoninschrab/mmdagg-paper/blob/master/experiments.py

import numpy as np
import itertools 
import pandas as pd 
import torch
from util import completeMMDVar, compute_mmd_sq, compute_K_matrices

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
sigma0 = 1

In [11]:
import jax.numpy as jnp
from jax import random

def mvn_sample(key, mean, cov, size):
    """Generate samples from a multivariate normal distribution."""
    n = mean.shape[0]
    chol = jnp.linalg.cholesky(cov)
    normal_samples = random.normal(key, (size, n))
    return mean + jnp.dot(normal_samples, chol.T)

key = random.PRNGKey(0)

# First dataset: 5D multivariate normal with no correlation
mean1 = jnp.zeros(5)
cov1 = jnp.eye(5)

# Second dataset: Multivariate normal where the last two dimensions have a correlation of 0.8
mean2 = jnp.zeros(5)
cov2 = jnp.eye(5)
cov2_values = jnp.array(cov2)
cov2_values = cov2_values.at[3, 4].set(0.8)
cov2_values = cov2_values.at[4, 3].set(0.8)
cov2 = cov2_values

# Sample data
X = mvn_sample(key, mean1, cov1, 1000)
Y = mvn_sample(key, mean2, cov2, 1000)

print("Data 1 Shape:", X.shape)
print("Data 2 Shape:", Y.shape)

num_samples = 10000
mmd_samples_list = []

for i in range(num_samples):
    key, subkey1, subkey2 = random.split(key, 3)
    
    # Sample new data
    X = mvn_sample(subkey1, mean1, cov1, 200)
    Y = mvn_sample(subkey2, mean2, cov2, 10)
    
    # Note: You'll have to translate `compute_K_matrices` and `compute_mmd_sq` to use JAX as well
    Kxx, Kyy, Kxy = compute_K_matrices(X, Y, sigma0)
    mmd_value = compute_mmd_sq(Kxx, Kyy, Kxy, len(X), len(Y))
    
    mmd_samples_list.append(mmd_value ** 2)

mmd_samples = jnp.array(mmd_samples_list)

# Calculate the variance
mmd_variance = jnp.var(mmd_samples, ddof=1)  # ddof=1 for unbiased variance
print("Estimated Variance of MMD:", mmd_variance)

Data 1 Shape: (1000, 5)
Data 2 Shape: (1000, 5)
Estimated Variance of MMD: 1.2265769314639631e-07


In [12]:
completeMMDVar(X, Y, sigma0)

Array(0.004493, dtype=float64)