In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import jax
import jax.numpy as jnp
import jax.random as random
from utils import MMDVar, compute_mmd_sq, compute_K_matrices, h1_mean_var_gram, MMDu_var
import matplotlib.pyplot as plt
import json
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

In [4]:
def sample_blobs_Q(N1, sigma_mx_2, rows=3, cols=3, key=None):
    """Generate Blob-D for testing type-II error (or test power)."""
    if key is None:
        key = random.PRNGKey(0)
    
    mu = jnp.zeros(2)
    sigma = jnp.eye(2) * 0.03
    
    key, subkey1, subkey2, subkey3, subkey4 = random.split(key, 5)
    
    X = random.multivariate_normal(subkey1, mean=mu, cov=sigma, shape=(N1,))
    Y = random.multivariate_normal(subkey2, mean=mu, cov=jnp.eye(2), shape=(N1,))
    
    X = X.at[:, 0].add(random.randint(subkey3, (N1,), 0, rows))
    X = X.at[:, 1].add(random.randint(subkey4, (N1,), 0, cols))
    
    key, subkey5, subkey6 = random.split(key, 3)
    Y_row = random.randint(subkey5, (N1,), 0, rows)
    Y_col = random.randint(subkey6, (N1,), 0, cols)
    
    locs = jnp.array([[0,0],[0,1],[0,2],[1,0],[1,1],[1,2],[2,0],[2,1],[2,2]])
    
    for i in range(9):
        corr_sigma = sigma_mx_2[i]
        L = jnp.linalg.cholesky(corr_sigma)
        ind = jnp.expand_dims((Y_row == locs[i][0]) & (Y_col == locs[i][1]), 1)
        ind2 = jnp.concatenate((ind, ind), 1)
        Y = jnp.where(ind2, jnp.matmul(Y,L) + locs[i], Y)
    
    return X, Y

In [5]:
# Generate variance and co-variance matrix of Q
sigma_mx_2_standard = jnp.array([[0.03, 0], [0, 0.03]])
sigma_mx_2 = jnp.zeros((9, 2, 2))
for i in range(9):
    sigma_mx_2 = sigma_mx_2.at[i].set(sigma_mx_2_standard)
    if i < 4:
        sigma_mx_2 = sigma_mx_2.at[i, 0, 1].set(-0.02 - 0.002 * i)
        sigma_mx_2 = sigma_mx_2.at[i, 1, 0].set(-0.02 - 0.002 * i)
    if i == 4:
        sigma_mx_2 = sigma_mx_2.at[i, 0, 1].set(0.00)
        sigma_mx_2 = sigma_mx_2.at[i, 1, 0].set(0.00)
    if i > 4:
        sigma_mx_2 = sigma_mx_2.at[i, 1, 0].set(0.02 + 0.002 * (i - 5))
        sigma_mx_2 = sigma_mx_2.at[i, 0, 1].set(0.02 + 0.002 * (i - 5))

In [None]:
def simulate_mmd(sigma_mx_2, ratio, num_samples, sigma0):
    mmd_samples = jnp.zeros(100)
    full_vars, complete_vars, incomplete_vars = [], [], []

    key = random.PRNGKey(42) # Set random seed
    for i in range(100):
        key, subkey = random.split(key)
        
        # Use the sample_blobs_Q function to generate your datasets
        X, Y = sample_blobs_Q(num_samples * ratio, sigma_mx_2, key=subkey)

        Kxx, Kyy, Kxy = compute_K_matrices(X, Y, sigma0)
        mmd_value = compute_mmd_sq(Kxx, Kyy, Kxy, len(X), len(Y))
        mmd_samples = mmd_samples.at[i].set(mmd_value)

        full_vars.append(MMDVar(X, Y, sigma0))
        complete_vars.append(MMDVar(X, Y, sigma0, complete=False))
        incomplete_vars.append(h1_mean_var_gram(Kxx, Kyy, Kxy, is_var_computed=True, use_1sample_U=True)[1])

    return mmd_samples, full_vars, complete_vars, incomplete_vars

In [7]:
# ratios = jnp.array([1, 2, 3, 4, 5])  
# num_samples = 10 
# results = {}
# print("Running...")
# for ratio in tqdm(ratios):
#     N1 = num_samples * ratio * 9
#     # Use the sample_blobs_Q function to generate your datasets
#     X, Y = sample_blobs_Q(N1, sigma_mx_2)
#     mmd_samples, full_vars, complete_vars, incomplete_vars = simulate_mmd(sigma_mx_2, ratio, N1, sigma0=1.0)
#     true_variance = jnp.var(mmd_samples, ddof=1)
#     results[str(ratio)] = {
#             'MMDu': jnp.mean(jnp.array(mmd_samples)),
#             'true_variance': true_variance,
#             'full_variance_estimate': jnp.mean(jnp.array(full_vars)),
#             'complete_variance_estimate': jnp.mean(jnp.array(complete_vars)),
#             'incomplete_variance_estimate': jnp.mean(jnp.array(incomplete_vars))
#         }

# for ratio, vals in results.items():
#     print(f"Ratio: {ratio}")
#     print(f"MMDu: {vals['MMDu']}")
#     print(f"True Variance of MMD: {vals['true_variance']}")
#     print(f"Full Variance Estimate (Ours 8): {vals['full_variance_estimate']}")
#     print(f"Complete Variance Estimate (Ours 2): {vals['complete_variance_estimate']}")
#     print(f"Incomplete Variance Estimate (Liu et al. 2): {vals['incomplete_variance_estimate']}")
#     print("-----"*10)

In [9]:
from jax.scipy.stats import norm, expon

def sample_mixture_of_gaussians(N, key):
    """Generate samples from a mixture of two Gaussian distributions."""
    key1, key2 = random.split(key, 2)
    gaussian1 = random.normal(key1, shape=(N//2,))  # Standard normal
    gaussian2 = random.normal(key2, shape=(N-N//2,)) + 5  # Normal with mean 5
    return jnp.concatenate([gaussian1, gaussian2])

def sample_exponential_shift(N, key, shift=5):
    """Generate samples from an exponential distribution with a shift."""
    samples = random.exponential(key, shape=(N,))
    return samples + shift


ratios = jnp.array([1, 2, 3, 4, 5])
num_samples = 1000
results = {}
print("Running...")
for ratio in tqdm(ratios):
    N1 = num_samples * ratio
    
    # Use the sample_mixture_of_gaussians and sample_exponential_shift functions to generate your datasets
    key = random.PRNGKey(42)
    key1, key2 = random.split(key, 2)
    X = sample_mixture_of_gaussians(N1, key1)
    Y = sample_exponential_shift(N1, key2)
    
    mmd_samples, full_vars, complete_vars, incomplete_vars = simulate_mmd(sigma_mx_2, ratio, N1, sigma0=1.0)
    true_variance = jnp.var(mmd_samples, ddof=1)
    results[str(ratio)] = {
            'MMDu': jnp.mean(jnp.array(mmd_samples)),
            'true_variance': true_variance,
            'full_variance_estimate': jnp.mean(jnp.array(full_vars)),
            'complete_variance_estimate': jnp.mean(jnp.array(complete_vars)),
            'incomplete_variance_estimate': jnp.mean(jnp.array(incomplete_vars))
        }
    print(results[str(ratio)])

for ratio, vals in results.items():
    print(f"Ratio: {ratio}")
    print(f"MMDu: {vals['MMDu']}")
    print(f"True Variance of MMD: {vals['true_variance']}")
    print(f"Full Variance Estimate (Ours 8): {vals['full_variance_estimate']}")
    print(f"Complete Variance Estimate (Ours 2): {vals['complete_variance_estimate']}")
    print(f"Incomplete Variance Estimate (Liu et al. 2): {vals['incomplete_variance_estimate']}")
    print("-----"*10)


Running...


  0%|          | 0/5 [00:00<?, ?it/s]

 20%|██        | 1/5 [03:14<12:58, 194.70s/it]

{'MMDu': Array(9.19976127e-05, dtype=float64), 'true_variance': Array(6.87780158e-07, dtype=float64), 'full_variance_estimate': Array(2.41551235e-06, dtype=float64), 'complete_variance_estimate': Array(1.27157332e-06, dtype=float64), 'incomplete_variance_estimate': Array(1.28013743e-06, dtype=float64)}


 20%|██        | 1/5 [13:40<54:41, 820.42s/it]


KeyboardInterrupt: 