In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [3]:
import math
import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt
from jax.random import multivariate_normal
# from sampler_mixture import sampler_mixture 
from tqdm.auto import tqdm
from pathlib import Path
# from utils_debias import MMDVar as MMD_var_u
from utils import MMDVar, compute_mmd_sq, compute_K_matrices, h1_mean_var_gram, MMDu_var, Vstat_MMD

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import jax.numpy as jnp

def compute_reduced_moment(KXY):
    m, n = KXY.shape
    
    # Create KXY' by removing the last column from KXY
    KXY_prime = KXY[:, :-1]

    # Compute the Kronecker product KXY' ⊗ KXY'
    KXY_KXY_kron = jnp.kron(KXY_prime, KXY_prime)

    # Create vectors of ones with shapes compatible with the Kronecker product dimensions
    one_m = jnp.ones((m * (m-1),))
    one_n = jnp.ones((n * (n-1),))

    # Calculate the reduced moment
    reduced_moment = one_m @ KXY_KXY_kron @ one_n
    
    return reduced_moment

# Example usage
KXY = jnp.array([[1, 2, 3],
                 [4, 5, 6],
                 [7, 8, 9]])  # Replace this with your actual KXY matrix

result = compute_reduced_moment(KXY)
print("The result of the reduced moment is:", result)


TypeError: dot_general requires contracting dimensions to have the same shape, got (6,) and (9,).

In [4]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [5]:
# Function to generate samples from multivariate normal distribution
def sample_mvn(key, mean, cov, num_samples):
    return multivariate_normal(mean=mean, cov=cov, shape=num_samples, key=key)

In [None]:
# decreasing signal as n grows
key = random.PRNGKey(42)


# Number of times to estimate MMD
iteraion = 1000
# Array to store MMD values
mmd_samples = jnp.zeros(iteraion)

full_v = jnp.zeros(iteraion)
com_2_v = jnp.zeros(iteraion)

# u for debias
full_u = jnp.zeros(iteraion)
com_2_u = jnp.zeros(iteraion)
Liu = jnp.zeros(iteraion)


# Sample sizes and variance data
full_v_variances = []
complete_v_variances = []
full_u_variances = []
complete_u_variances = []
liu_variances = []


sigma0 = 1.0

num_samples = [32,64,100,200,300,400,500,600,700,800,900,1000,2000,3000,4000,5000]

for sample_size in num_samples :
    
    print("number of samples : ",sample_size)
    shift = 10 / jnp.sqrt(sample_size)
    mean1 = jnp.zeros(5)
    mean2 = mean1 + shift
    
    cov = jnp.eye(5)
    
    # Monte Carlo simulation to estimate MMD values
    for i in range(iteraion):
        key, subkey1, subkey2 = random.split(key, 3)
        X = sample_mvn(subkey1, mean1, cov, (sample_size,))
        Y = sample_mvn(subkey2, mean2, cov, (sample_size,))

        Kxx, Kyy, Kxy = compute_K_matrices(X, Y, sigma0)
        m = X.shape[0]
        n = Y.shape[0]
        mmd_value = compute_mmd_sq(Kxx, Kyy, Kxy, m, n)
        
        mmd_samples = mmd_samples.at[i].set(mmd_value)
        full_v = full_v.at[i].set(MMDVar(X, Y, sigma0, copmlete=True, bias=False))
        com_2_v = com_2_v.at[i].set(MMDVar(X, Y, sigma0, complete=False))
        Liu = Liu.at[i].set(h1_mean_var_gram(Kxx, Kyy, Kxy, is_var_computed=True, use_1sample_U=True)[1])
    
    # Compute estimated variance of MMD
    mmd_variance = jnp.var(mmd_samples, ddof=1)
    full_v_variances.append(jnp.abs(jnp.abs(jnp.median(full_v))-mmd_variance)/mmd_variance)
    complete_v_variances.append(jnp.abs(jnp.abs(jnp.median(com_2_v))-mmd_variance)/mmd_variance)
    liu_variances.append(jnp.abs(jnp.median(Liu)-mmd_variance)/mmd_variance)
    
    print("The average of MMDu",jnp.average(mmd_samples))
    print("True Variance of MMD:", mmd_variance)

    # Compare the results
    print()
    print("How accurate estimators estimate the variance : estimator/true")
    print(f"Biased Full Variance Estimate (Ours 8): {jnp.abs(jnp.abs(jnp.median(full_v))-mmd_variance)/mmd_variance}")
    print(f"Biased Complete Variance Estimate (Ours 2): {jnp.abs(jnp.abs(jnp.median(com_2_v))-mmd_variance)/mmd_variance}")
    print(f"Incomplete Variance Estimate (Liu et al. 2): {jnp.abs(jnp.median(Liu)-mmd_variance)/mmd_variance}")
    print()

#### Gaussian Mean Difference (Equal Sample Size)

In [None]:
def simulate_mmd(mean1, mean2, cov1, cov2, ratio, num_samples, sigma0):
    mmd_samples = jnp.zeros(100)
    full_vars, complete_vars, incomplete_vars = [], [], []

    key = random.PRNGKey(42) # Set random seed
    for i in range(100):
        key, subkey1, subkey2 = random.split(key, 3)
        X = multivariate_normal(mean=mean1, cov=cov1, shape=(num_samples * ratio,), key=subkey1)
        Y = multivariate_normal(mean=mean2, cov=cov2, shape=(num_samples * ratio,), key=subkey2)

        Kxx, Kyy, Kxy = compute_K_matrices(X, Y, sigma0)
        mmd_value = compute_mmd_sq(Kxx, Kyy, Kxy, len(X), len(Y))
        mmd_samples = mmd_samples.at[i].set(mmd_value)

        full_vars.append(MMDVar(X, Y, sigma0))
        complete_vars.append(MMDVar(X, Y, sigma0, complete=False))
        incomplete_vars.append(h1_mean_var_gram(Kxx, Kyy, Kxy, is_var_computed=True, use_1sample_U=True)[1])

    return mmd_samples, full_vars, complete_vars, incomplete_vars

In [None]:
def save_results_to_json(results, filename="mmd_results.json"):
    with open(filename, 'w') as f:
        # Convert JAX arrays to native Python lists for serialization
        json_results = jax.tree_map(lambda x: x.tolist() if isinstance(x, jnp.ndarray) else x, results)
        json.dump(json_results, f, indent=4)
        

In [None]:
sigma0 = 1.0
num_samples = 100
ratios = [1, 3, 5, 7, 10]
mean_differences = jnp.linspace(0.0001, 0.1, num=5)

cov1 = jnp.eye(5)
cov2 = jnp.eye(5)

results = {}
print("Running...")
for mean_diff in tqdm(mean_differences):
    mean1 = jnp.zeros(5)
    mean2 = jnp.array([mean_diff] * 5)
    
    for ratio in tqdm(ratios):
        mmd_samples, full_vars, complete_vars, incomplete_vars = simulate_mmd(mean1, mean2, cov1, cov2, ratio, num_samples, sigma0)

        true_variance = jnp.var(mmd_samples, ddof=1)
        
        results_key = (float(mean_diff.item()), ratio)
        results[results_key] = {
            'true_variance': true_variance,
            'full_variance_estimate': jnp.mean(jnp.array(full_vars)),
            'complete_variance_estimate': jnp.mean(jnp.array(complete_vars)),
            'incomplete_variance_estimate': jnp.mean(jnp.array(incomplete_vars))
        }
        print(results[results_key])

for (mean_diff, ratio), vals in results.items():
    print(f"Mean Difference: {mean_diff}, Ratio: {ratio}")
    print(f"MMDu: {vals['true_variance']}")
    print(f"True Variance of MMD: {vals['true_variance']}")
    print(f"Full Variance Estimate (Ours 8): {vals['full_variance_estimate']}")
    print(f"Complete Variance Estimate (Ours 2): {vals['complete_variance_estimate']}")
    print(f"Incomplete Variance Estimate (Liu et al. 2): {vals['incomplete_variance_estimate']}")
    print("-------")

# save_results_to_json(results, "mmd_results.json")
# print("Results saved to mmd_results.json")

In [None]:
import matplotlib.pyplot as plt

# Calculate means and variances for each estimation method
full_var_means = [jnp.mean(full_var_estimates[ratio]) for ratio in ratios]
full_var_std = [jnp.std(full_var_estimates[ratio]) for ratio in ratios]

complete_var_means = [jnp.mean(complete_var_estimates[ratio]) for ratio in ratios]
complete_var_std = [jnp.std(complete_var_estimates[ratio]) for ratio in ratios]

complete_var_h_means = [jnp.mean(complete_var_h_estimates[ratio]) for ratio in ratios]
complete_var_h_std = [jnp.std(complete_var_h_estimates[ratio]) for ratio in ratios]

plt.figure(figsize=(12, 8))

# Plot means and variance bands for each method
plt.errorbar(ratios, full_var_means, yerr=full_var_std, fmt='-o', label="Full Variance Estimate")
plt.errorbar(ratios, complete_var_means, yerr=complete_var_std, fmt='-o', label="Complete Variance Estimate")
plt.errorbar(ratios, complete_var_h_means, yerr=complete_var_h_std, fmt='-o', label="Complete Variance Estimate using h")

# If you want to plot the incomplete variance estimate (only for ratio = 1)
if incomplete_var_estimates:
    plt.scatter([1], [incomplete_var_estimates[0]], color='red', marker='x', label="Incomplete Variance Estimate")

plt.title("Variance Estimates against Imbalance Ratio")
plt.xlabel("Imbalance Ratio")
plt.ylabel("Variance Estimate")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
plt.figure(figsize=(10, 6))
plt.plot(ratios[2:], true_mmd_variances[2:], '-o', label='True Variance of MMD')
plt.plot(ratios[2:], full_var_estimates[2:], '-o', label='Full Variance Estimate (Ours 8)')
plt.plot(ratios[2:], complete_var_estimates[2:], '-o', label='Complete Variance Estimate (Ours 2)')
plt.plot(ratios[2:], complete_var_h_estimates[2:], '-o', label='Complete Variance Estimate using h (Ours 2)')
# plt.plot(ratios, incomplete_var_estimates, '-o', label='Incomplete Variance Estimate (Liu et al. 2)')

plt.xlabel('Imbalance Ratio')
plt.ylabel('Variance Estimate')
plt.title('Variance of MMD by Imbalance Ratios')
plt.legend()
plt.grid(True)
plt.show()


#### Equal Sample Size

In [None]:
# Experiment 2
key = random.PRNGKey(42)

# First dataset
mean1 = jnp.zeros(5)
cov1 = jnp.eye(5)

# Second dataset
mean2 = jnp.zeros(5)
cov2 = jnp.eye(5)
cov2 = cov2.at[3, 4].set(0.8)
cov2 = cov2.at[4, 3].set(0.8)

# Function to generate samples from multivariate normal distribution
def sample_mvn(key, mean, cov, num_samples):
    return multivariate_normal(mean=mean, cov=cov, shape=num_samples, key=key)

# Number of times to estimate MMD
num_samples = 1000

# Imbalance ratios 
ratios = [1, 10, 30, 50]

eq_true_mmd_variances = []
eq_full_var_estimates = []
eq_complete_var_estimates = []
eq_complete_var_h_estimates = []
eq_incomplete_var_estimates = []

# Array to store MMD values
mmd_samples = jnp.zeros(num_samples)
sigma0 = 1

# Monte Carlo simulation to estimate MMD values
for ratio in ratios:
    for i in range(num_samples):
        key, subkey1, subkey2 = random.split(key, 3)
        X = sample_mvn(subkey1, mean1, cov1, (10 * ratio,))
        Y = sample_mvn(subkey2, mean2, cov2, (10 * ratio,))

        Kxx, Kyy, Kxy = compute_K_matrices(X, Y, sigma0)
        mmd_value = compute_mmd_sq(Kxx, Kyy, Kxy, len(X), len(Y))

        mmd_samples = mmd_samples.at[i].set(mmd_value)

    # Compute estimated variance of MMD
    mmd_variance = jnp.var(mmd_samples, ddof=1)

    # Store results in the lists
    eq_true_mmd_variances.append(mmd_variance)
    eq_full_var_estimates.append(MMDVar(X, Y, sigma0))
    eq_complete_var_estimates.append(MMDVar(X, Y, sigma0, complete=False))
    eq_complete_var_h_estimates.append(MMDu_var(Kxx, Kyy, Kxy))
    eq_incomplete_var_estimates.append(h1_mean_var_gram(Kxx, Kyy, Kxy, is_var_computed=True, use_1sample_U=True)[1])
    
    print(f"Sample Size: {10 * ratio, 10 * ratio}")
    print(f"MMDu: {mmd_value}")
    print("True Variance of MMD:", mmd_variance)

    # Compare the results
    print(f"Full Variance Estimate (Ours 8): {MMDVar(X, Y, sigma0)}")
    print(f"Complete Variance Estimate (Ours 2): {MMDVar(X, Y, sigma0, complete=False)}")
    print(f"Complete Variance Estimate using h (Ours 2): {MMDu_var(Kxx, Kyy, Kxy)}")
    print(f"Incomplete Variance Estimate (Liu et al. 2): {h1_mean_var_gram(Kxx, Kyy, Kxy, is_var_computed=True, use_1sample_U=True)[1]}")

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(ratios[2:], eq_true_mmd_variances[2:], '-o', label='True Variance of MMD')
plt.plot(ratios[2:], eq_full_var_estimates[2:], '-o', label='Full Variance Estimate (Ours 8)')
plt.plot(ratios[2:], eq_complete_var_estimates[2:], '-o', label='Complete Variance Estimate (Ours 2)')
plt.plot(ratios[2:], eq_complete_var_h_estimates[2:], '-o', label='Complete Variance Estimate using h (Ours 2)')
plt.plot(ratios[2:], eq_incomplete_var_estimates[2:], '-o', label='Incomplete Variance Estimate (Liu et al. 2)')

plt.xlabel('Sample Size')
plt.ylabel('Variance Estimate')
plt.title('Variance of MMD by Different Sample Size')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
import time
import jax.numpy as jnp
import matplotlib.pyplot as plt

sigma0 = 1.0
num_samples = [1000, 1500, 2000, 2500, 3000]

# Arrays to store results
true_variances = []
method1_variances = []
method2_variances = []
method3_variances = []
method4_variances = []
method1_times = []
method2_times = []
method3_times = []
method4_times = []

for N in num_samples:
    mmd_samples = jnp.zeros(100)
    
    key, subkey1, subkey2 = random.split(key, 3)
    X = sample_mvn(subkey1, mean1, cov1, (10,))
    Y = sample_mvn(subkey2, mean2, cov2, (10,))
    Kxx, Kyy, Kxy = compute_K_matrices(X, Y, sigma0)

    mmd_value = compute_mmd_sq(Kxx, Kyy, Kxy, N, N)
    mmd_samples = mmd_samples.at[i].set(mmd_value)

    true_variance = jnp.var(mmd_samples, ddof=1)
    true_variances.append(true_variance)

    start_time = time.time()
    method1_variance = MMDVar(X, Y, sigma0) # Full Variance Estimate
    method1_variances.append(method1_variance)
    method1_time = time.time() - start_time
    method1_times.append(method1_time)

    start_time = time.time()
    method2_variance = MMDVar(X, Y, sigma0, complete=False) # Complete Variance Estimate
    method2_variances.append(method2_variance)
    method2_time = time.time() - start_time
    method2_times.append(method2_time)

    start_time = time.time()
    method3_variance = h1_mean_var_gram(Kxx, Kyy, Kxy, is_var_computed=True, use_1sample_U=True)[1]
    method3_variances.append(method3_variance)
    method3_time = time.time() - start_time
    method3_times.append(method3_time)

    start_time = time.time()
    method4_variance = MMDu_var(Kxx, Kyy, Kxy)
    method4_variances.append(method4_variance)
    method4_time = time.time() - start_time
    method4_times.append(method4_time)

# Display results
for N, t1, t2, t3, t4 in zip(num_samples, method1_times, method2_times, method3_times, method4_times):
    print(f"Sample Size {N}:")
    print(f"Full Variance Estimate(Ours 8) Time: {t1:.4f} seconds")
    print(f"Complete Variance Estimate(Ours 2) Time: {t2:.4f} seconds")
    print(f"Deep MMD(Liu et al. 2) Time: {t3:.4f} seconds")
    print(f"Complete Variance Estimate using h(Ours 2) Time: {t4:.4f} seconds")
    print()

# Visualization
plt.plot(num_samples, true_variances, label='True Variance')
plt.plot(num_samples, method1_variances, label='Full Variance Estimate(Ours 8)')
plt.plot(num_samples, method2_variances, label='Complete Variance Estimate(Ours 2)')
plt.plot(num_samples, method3_variances, label='Deep MMD (Liu et al. 2)')
plt.plot(num_samples, method4_variances, label='Complete Variance Estimate using h(Ours 2)')
plt.xlabel('Sample Size')
plt.ylabel('Variance')
plt.legend()
plt.show()


In [None]:
# Experiment 2

# First dataset
mean1 = jnp.zeros()
cov1 = jnp.eye(100)

# Second dataset
mean2 = jnp.ones(100)
cov2 = jnp.eye(100)
cov2 = cov2.at[3, 4].set(0.3)
cov2 = cov2.at[4, 3].set(0.5)

# Function to generate samples from multivariate normal distribution
def sample_mvn(key, mean, cov, num_samples):
    return multivariate_normal(mean=mean, cov=cov, shape=num_samples, key=key)

# Number of times to estimate MMD
num_samples = 10000

# Array to store MMD values
mmd_samples = jnp.zeros(num_samples)
sigma0 = 1

# Monte Carlo simulation to estimate MMD values
for i in range(num_samples):
    key, subkey1, subkey2 = random.split(key, 3)
    X = sample_mvn(subkey1, mean1, cov1, (200,))
    Y = sample_mvn(subkey2, mean2, cov2, (200,))

    Kxx, Kyy, Kxy = compute_K_matrices(X, Y, sigma0)
    mmd_value = compute_mmd_sq(Kxx, Kyy, Kxy, len(X), len(Y))

    mmd_samples = mmd_samples.at[i].set(mmd_value)

# Compute estimated variance of MMD
mmd_variance = jnp.var(mmd_samples, ddof=1)

print("True Variance of MMD:", mmd_variance)

# Compare the results
print(f"Full Variance Estimate (Ours 8): {MMDVar(X, Y, sigma0)}")
print(f"Complete Variance Estimate (Ours 2): {MMDVar(X, Y, sigma0, complete=False)}")
print(f"Complete Variance Estimate using h (Ours 2): {MMDu_var(Kxx, Kyy, Kxy)}")
print(f"Incomplete Variance Estimate (Liu et al. 2): {h1_mean_var_gram(Kxx, Kyy, Kxy, is_var_computed=True, use_1sample_U=True)[1]}")