In [2]:
from mmd_fuse import *
import numpy as np
import matplotlib.pyplot as plt
import pickle

In [3]:
def generate_cov_matrix(n_clusters, d):
    mu_mx = np.zeros([n_clusters, d])
    for i in range(n_clusters):
        mu_mx[i] = mu_mx[i] + 0.5*i
    sigma_mx_1 = np.eye(d)
    sigma_mx_2 = [np.eye(d), np.eye(d)]
    sigma_mx_2[0][0, 1] = 0.5
    sigma_mx_2[0][1, 0] = 0.5
    sigma_mx_2[1][0, 1] = -0.5
    sigma_mx_2[1][1, 0] = -0.5

    return mu_mx, sigma_mx_1, sigma_mx_2 

def sample_hdgm_semi_t2(n_train, n_test, d=10, n_clusters=2, kk=0):
    mu_mx, sigma_mx_1, sigma_mx_2 = generate_cov_matrix(n_clusters, d)
    
    n = n_train + n_test

    s1 = np.zeros([n*n_clusters, d])
    s2 = np.zeros([n*n_clusters, d])

    np.random.seed(seed=1102*kk)
    # tr_idx = np.random.choice(n, n_train, replace=False)
    # tr_idx = np.tile(tr_idx, n_clusters)
    # for i in range(n_clusters):
    #     tr_idx[i*n_train:(i+1)*n_train] = tr_idx[i*n_train:(i+1)*n_train] + i*n

    tr_idx = np.random.choice(n*n_clusters, n_train*n_clusters, replace=False)
    
    te_idx = np.delete(np.arange(n*n_clusters), tr_idx)

    for i in range(n_clusters):
        np.random.seed(seed=1102*kk + i + n)
        s1[i*n:(i+1)*n, :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n)
        np.random.seed(seed=819*kk + i + n + 1)
        s2[i*n:(i+1)*n, :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_2[i], n)

    return s1[tr_idx], s1[te_idx], s2[tr_idx], s2[te_idx]
    # return s1, None, s2, None

def sample_hdgm_semi_t1(n_train, n_test, d=10, n_clusters=2, kk=0):
    mu_mx, sigma_mx_1, _ = generate_cov_matrix(n_clusters, d)

    n = n_train + n_test

    s1 = np.zeros([n*n_clusters, d])
    s2 = np.zeros([n*n_clusters, d])

    np.random.seed(seed=1102*kk)
    # tr_idx = np.random.choice(n, n_train, replace=False)
    # tr_idx = np.tile(tr_idx, n_clusters)
    # for i in range(n_clusters):
    #     tr_idx[i*n_train:(i+1)*n_train] = tr_idx[i*n_train:(i+1)*n_train] + i*n

    tr_idx = np.random.choice(n*n_clusters, n_train*n_clusters, replace=False)

    te_idx = np.delete(np.arange(n*n_clusters), tr_idx)

    for i in range(n_clusters):
        np.random.seed(seed=1102*kk + i + n)
        s1[i*n:(i+1)*n, :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n)
        np.random.seed(seed=819*kk + i + n + 1)
        s2[i*n:(i+1)*n, :] = np.random.multivariate_normal(mu_mx[i], sigma_mx_1, n)

    return s1[tr_idx], s1[te_idx], s2[tr_idx], s2[te_idx]

In [4]:
n_train = 500
n_test = 500
s1_tr, s1_te, s2_tr, s2_te = sample_hdgm_semi_t2(
    n_train, n_test, kk=0)
S = np.concatenate((s1_tr, s2_tr), axis=0)

In [5]:
n_list = [125, 250, 500, 750, 1000, 1250]
summary = []
for n in n_list:
    n_train = n
    n_test = n

    key = random.PRNGKey(42)
    outputs = []
    for i in range(100):
        key, subkey = random.split(key)
        s1_tr, s1_te, s2_tr, s2_te = sample_hdgm_semi_t2(
            n_train, n_test, kk=subkey)
        S1 = np.concatenate((s1_tr, s1_te), axis=0)
        S2 = np.concatenate((s2_tr, s2_te), axis=0)
        key, subkey = random.split(key)
        outputs.append(mmdfuse(S1, S2, subkey))
        
    print(np.mean(outputs), np.std(outputs))
    
    summary.append(outputs)
    


0.06 0.23748684174075832
0.12 0.32496153618543844
0.38 0.48538644398046393
0.75 0.4330127018922193
0.91 0.2861817604250837
0.99 0.09949874371066199


In [17]:
np.mean(outputs)

0.15

In [7]:
with open("result/mmdfuse_HGDM_baselin_result.pkl", "wb") as f:
    pickle.dump(summary, f)

In [5]:
n_list = [125, 250, 500, 750, 1000, 1250]
summary = []
for n in n_list:
    n_train = 500
    n_test = 500

    key = random.PRNGKey(42)
    outputs = []
    for i in range(100):
        key, subkey = random.split(key)
        s1_tr, s1_te, s2_tr, s2_te = sample_hdgm_semi_t1(
            n_train, n_test, kk=subkey)
        S1 = np.concatenate((s1_tr, s1_te), axis=0)
        S2 = np.concatenate((s2_tr, s2_te), axis=0)
        key, subkey = random.split(key)
        outputs.append(mmdfuse(S1, S2, subkey))
        
    print(np.mean(outputs), np.std(outputs))
    
    summary.append(outputs)

0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
0.0 0.0
