In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os 
os.chdir('/home/oldrain123/MMD/')
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [3]:
import numpy as np
import torch
from tqdm.auto import tqdm
from sklearn.utils import check_random_state
from sampler.sampler_perturbations import sampler_perturbations
from utils_HD import MatConvert, MMDu, TST_MMD_u

In [4]:
class ModelLatentF(torch.nn.Module):
    """Latent space for both domains."""

    def __init__(self, x_in, H, x_out):
        """Init latent features."""
        super(ModelLatentF, self).__init__()
        self.restored = False

        self.latent = torch.nn.Sequential(
            torch.nn.Linear(x_in, H, bias=True),
            torch.nn.Softplus(),
            torch.nn.Linear(H, H, bias=True),
            torch.nn.Softplus(),
            torch.nn.Linear(H, H, bias=True),
            torch.nn.Softplus(),
            torch.nn.Linear(H, x_out, bias=True),
        )
    def forward(self, input):
        """Forward the LeNet."""
        fealant = self.latent(input)
        return fealant

In [5]:
# Setup for experiments
dtype = torch.float
device = torch.device("cuda:0")
N_per = 100 # permutation times
alpha = 0.05 # test threshold
scale = 0.2
n_list = [500, 1000, 1500, 2000, 2500, 3000]
x_in = 1
H = 50 
x_out = 50
d = 1
number_perturbations = 2
learning_rate = 0.005
N_epoch = 1000 # number of training epochs
K = 10 # number of trails
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)
is_cuda = True

In [6]:
for n in n_list:
    N1 = n 
    np.random.seed(1102)
    torch.manual_seed(1102)
    torch.cuda.manual_seed(1102)
    Results = np.zeros([1, K])
    Opt_ep = np.zeros([1, K])
    J_star_u = np.zeros([K, N_epoch])
    ep_OPT = np.zeros([K])
    s_OPT = np.zeros([K])
    s0_OPT = np.zeros([K])
    # Repeat experiments K times (K = 10) and report average test power (rejection rate)
    for kk in tqdm(range(K),desc="Experiment"):
        # Initialize parameters
        if is_cuda:
            model_u = ModelLatentF(x_in, H, x_out).cuda()
        else:
            model_u = ModelLatentF(x_in, H, x_out)
        epsilonOPT = MatConvert(np.random.rand(1) * (10 ** (-10)), device, dtype)
        epsilonOPT.requires_grad = True
        sigmaOPT = MatConvert(np.sqrt(np.random.rand(1) * 0.3), device, dtype)
        sigmaOPT.requires_grad = True
        sigma0OPT = MatConvert(np.sqrt(np.random.rand(1) * 0.002), device, dtype)
        sigma0OPT.requires_grad = True
        # Setup optimizer for training deep kernel
        optimizer_u = torch.optim.Adam(list(model_u.parameters())+[epsilonOPT]+[sigmaOPT]+[sigma0OPT], lr=learning_rate) #
        # Generate Blob-D
        np.random.seed(seed=112 * kk + 1 + n)
        s1,s2 = sampler_perturbations(m=n, n=n, d=d, scale=scale, number_perturbations=number_perturbations, seed=1102)
        # REPLACE above line with
        # s1,s2 = sample_blobs(N1)
        # for validating type-I error (s1 ans s2 are from the same distribution)
        if kk==0:
            s1_o = s1
            s2_o = s2
        S = np.concatenate((s1, s2), axis=0)
        S = MatConvert(S, device, dtype)
        # Train deep kernel to maximize test power
        np.random.seed(seed=1102)
        torch.manual_seed(1102)
        torch.cuda.manual_seed(1102)
        for t in range(N_epoch):
            # Compute epsilon, sigma and sigma_0
            ep = torch.exp(epsilonOPT)/(1+torch.exp(epsilonOPT))
            sigma = sigmaOPT ** 2
            sigma0_u = sigma0OPT ** 2
            # Compute output of the deep network
            modelu_output = model_u(S)
            # Compute J (STAT_u)
            TEMP = MMDu(modelu_output, N1, S, sigma, sigma0_u, ep, complete=True)
            mmd_value_temp = -1 * TEMP[0]
            mmd_std_temp = torch.sqrt(TEMP[1] + 10**(-6))
            # STAT_u = torch.div(mmd_value_temp, mmd_std_temp)
            STAT_u = mmd_value_temp - 1e-3 * mmd_std_temp
            J_star_u[kk, t] = STAT_u.item()
            # Initialize optimizer and Compute gradient
            optimizer_u.zero_grad()
            STAT_u.backward(retain_graph=True)
            # Update weights using gradient descent
            optimizer_u.step()
            # Print MMD, std of MMD and J
            if t % 100 == 0:
                print("mmd_value: ", -1 * mmd_value_temp.item(), "mmd_std: ", mmd_std_temp.item(), "Statistic J: ",
                      -1 * STAT_u.item())
        h_u, threshold_u, mmd_value_u = TST_MMD_u(model_u(S), N_per, N1, S, sigma, sigma0_u, ep, alpha, device,
                                                  dtype, complete=True)
        ep_OPT[kk] = ep.item()
        s_OPT[kk] = sigma.item()
        s0_OPT[kk] = sigma0_u.item()
        print(ep, epsilonOPT)
        # Compute test power of deep kernel based MMD
        H_u = np.zeros(N)
        T_u = np.zeros(N)
        M_u = np.zeros(N)
        np.random.seed(1102)
        count_u = 0
        for k in tqdm(range(N), desc="testing"):
            # Generate Blob-D
            np.random.seed(seed=11 * k + 10 + n)
            s1,s2 = sampler_perturbations(m=n, n=n, d=d, scale=scale, number_perturbations=number_perturbations, seed=11*k+10+n)
            # REPLACE above line with
            # s1,s2 = sample_blobs(N1)
            # for validating type-I error (s1 ans s2 are from the same distribution)
            S = np.concatenate((s1, s2), axis=0)
            S = MatConvert(S, device, dtype)
            # Run two sample test (deep kernel) on generated data
            h_u, threshold_u, mmd_value_u = TST_MMD_u(model_u(S), N_per, N1, S, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)
            # Gather results
            count_u = count_u + h_u
            print("MMD-DK:", count_u, "Threshold:", threshold_u, "MMD_value:", mmd_value_u)
            H_u[k] = h_u
            T_u[k] = threshold_u
            M_u[k] = mmd_value_u
        # Print test power of MMD-D
        print("n =",str(n),"--- Test Power of MMD-D: ", H_u.sum()/N_f)
        Results[0, kk] = H_u.sum() / N_f
        print("n =",str(n),"--- Test Power of MMD-D (K times): ",Results[0])
        print("n =",str(n),"--- Average Test Power of MMD-D: ",Results[0].sum()/(kk+1))
    np.save('./Results_Blob_'+str(n)+'_H1_MMD-D',Results)

Experiment:   0%|          | 0/10 [00:00<?, ?it/s]

mmd_value:  0.0024214982986450195 mmd_std:  0.0021556988550694364 Statistic J:  0.002423653997500089
mmd_value:  0.01057770848274231 mmd_std:  0.004789912317759402 Statistic J:  0.010582498395060069
mmd_value:  0.01063382625579834 mmd_std:  0.004808308887211932 Statistic J:  0.010638634564685551
mmd_value:  0.010718569159507751 mmd_std:  0.004825286696613669 Statistic J:  0.010723394446204366
mmd_value:  0.010803118348121643 mmd_std:  0.004847129067597997 Statistic J:  0.01080796547718924
mmd_value:  0.010892987251281738 mmd_std:  0.004872685326004242 Statistic J:  0.010897859936607742
mmd_value:  0.010994315147399902 mmd_std:  0.004910347039425082 Statistic J:  0.010999225494439328
mmd_value:  0.011095628142356873 mmd_std:  0.004947953617188046 Statistic J:  0.011100576095974061
mmd_value:  0.011188700795173645 mmd_std:  0.004988233088939093 Statistic J:  0.011193689028262585
mmd_value:  0.01127845048904419 mmd_std:  0.005044742497731711 Statistic J:  0.01128349523154192
tensor([0.080

testing:   0%|          | 0/100 [00:00<?, ?it/s]

MMD-DK: 0 Threshold: NaN MMD_value: 0.0018687695264816284
MMD-DK: 0 Threshold: NaN MMD_value: 0.0029642432928085327
MMD-DK: 0 Threshold: NaN MMD_value: 0.00016982853412628174
MMD-DK: 0 Threshold: NaN MMD_value: 0.0024282485246658325
MMD-DK: 1 Threshold: 0.0042745620012283325 MMD_value: 0.005540922284126282
MMD-DK: 1 Threshold: NaN MMD_value: 0.0019452869892120361
MMD-DK: 1 Threshold: NaN MMD_value: 0.0018497109413146973
MMD-DK: 1 Threshold: NaN MMD_value: 0.0018939226865768433
MMD-DK: 1 Threshold: NaN MMD_value: -0.001065775752067566
MMD-DK: 1 Threshold: NaN MMD_value: -0.0006405115127563477
MMD-DK: 1 Threshold: NaN MMD_value: 0.0009906142950057983
MMD-DK: 1 Threshold: NaN MMD_value: 0.0004178732633590698
MMD-DK: 2 Threshold: 0.00369320809841156 MMD_value: 0.010013937950134277
MMD-DK: 2 Threshold: NaN MMD_value: -0.0006597787141799927
MMD-DK: 2 Threshold: NaN MMD_value: -0.000104561448097229
MMD-DK: 2 Threshold: NaN MMD_value: 0.0006604939699172974
MMD-DK: 2 Threshold: NaN MMD_value: 0

testing:   0%|          | 0/100 [00:00<?, ?it/s]

MMD-DK: 0 Threshold: NaN MMD_value: 2.168118953704834e-05
MMD-DK: 1 Threshold: 0.003725200891494751 MMD_value: 0.0033928751945495605
MMD-DK: 1 Threshold: NaN MMD_value: -0.00016160309314727783
MMD-DK: 1 Threshold: NaN MMD_value: 0.000902518630027771
MMD-DK: 2 Threshold: 0.0033742189407348633 MMD_value: 0.0066366493701934814
MMD-DK: 2 Threshold: NaN MMD_value: 0.0017336905002593994
MMD-DK: 2 Threshold: NaN MMD_value: 0.002062380313873291
MMD-DK: 3 Threshold: 0.003934696316719055 MMD_value: 0.003510013222694397
MMD-DK: 3 Threshold: NaN MMD_value: -0.001594245433807373
MMD-DK: 3 Threshold: NaN MMD_value: -0.0011935234069824219
MMD-DK: 3 Threshold: NaN MMD_value: 0.0012465566396713257
MMD-DK: 3 Threshold: NaN MMD_value: 0.0012897104024887085
MMD-DK: 4 Threshold: 0.002737075090408325 MMD_value: 0.009496763348579407
MMD-DK: 4 Threshold: NaN MMD_value: -0.0021296441555023193
MMD-DK: 4 Threshold: NaN MMD_value: -0.0012970566749572754
MMD-DK: 4 Threshold: NaN MMD_value: -0.00015866756439208984


testing:   0%|          | 0/100 [00:00<?, ?it/s]

MMD-DK: 0 Threshold: NaN MMD_value: 0.00021754205226898193
MMD-DK: 0 Threshold: NaN MMD_value: 0.003310278058052063
MMD-DK: 0 Threshold: NaN MMD_value: -9.472668170928955e-05
MMD-DK: 0 Threshold: NaN MMD_value: 0.0010477304458618164
MMD-DK: 1 Threshold: 0.002450913190841675 MMD_value: 0.0060079991817474365
MMD-DK: 1 Threshold: NaN MMD_value: 0.001810312271118164
MMD-DK: 1 Threshold: NaN MMD_value: 0.0018454194068908691
MMD-DK: 2 Threshold: 0.003255710005760193 MMD_value: 0.0032310187816619873
MMD-DK: 2 Threshold: NaN MMD_value: -0.0014673620462417603
MMD-DK: 2 Threshold: NaN MMD_value: -0.001038089394569397
MMD-DK: 2 Threshold: NaN MMD_value: 0.0011056959629058838
MMD-DK: 2 Threshold: NaN MMD_value: 0.0012513846158981323
MMD-DK: 3 Threshold: 0.003316521644592285 MMD_value: 0.009685426950454712
MMD-DK: 3 Threshold: NaN MMD_value: -0.001837015151977539
MMD-DK: 3 Threshold: NaN MMD_value: -0.0010136663913726807
MMD-DK: 3 Threshold: NaN MMD_value: -5.938112735748291e-05
MMD-DK: 4 Threshold

testing:   0%|          | 0/100 [00:00<?, ?it/s]

MMD-DK: 0 Threshold: NaN MMD_value: 0.002998724579811096
MMD-DK: 0 Threshold: NaN MMD_value: 0.0027691423892974854
MMD-DK: 0 Threshold: NaN MMD_value: 0.0002348572015762329
MMD-DK: 0 Threshold: NaN MMD_value: 0.0026709288358688354
MMD-DK: 1 Threshold: 0.0037310421466827393 MMD_value: 0.006576508283615112
MMD-DK: 1 Threshold: NaN MMD_value: 0.0016814619302749634
MMD-DK: 2 Threshold: 0.0031419843435287476 MMD_value: 0.003056555986404419
MMD-DK: 2 Threshold: NaN MMD_value: 0.0010900795459747314
MMD-DK: 2 Threshold: NaN MMD_value: -0.0008446872234344482
MMD-DK: 2 Threshold: NaN MMD_value: -0.0003588646650314331
MMD-DK: 2 Threshold: NaN MMD_value: 0.0015256255865097046
MMD-DK: 2 Threshold: NaN MMD_value: 0.0003440678119659424
MMD-DK: 3 Threshold: 0.0027239620685577393 MMD_value: 0.008399739861488342
MMD-DK: 3 Threshold: NaN MMD_value: -0.0006647855043411255
MMD-DK: 3 Threshold: NaN MMD_value: 0.00011980533599853516
MMD-DK: 3 Threshold: NaN MMD_value: 0.0008642077445983887
MMD-DK: 3 Threshol

testing:   0%|          | 0/100 [00:00<?, ?it/s]

MMD-DK: 0 Threshold: NaN MMD_value: 0.0009158551692962646
MMD-DK: 0 Threshold: NaN MMD_value: 0.002775430679321289
MMD-DK: 0 Threshold: NaN MMD_value: -0.0002250969409942627
MMD-DK: 1 Threshold: 0.0033142268657684326 MMD_value: 0.002302318811416626
MMD-DK: 2 Threshold: 0.003872305154800415 MMD_value: 0.006242036819458008
MMD-DK: 2 Threshold: NaN MMD_value: 0.0015052258968353271
MMD-DK: 2 Threshold: NaN MMD_value: 0.002088695764541626
MMD-DK: 2 Threshold: NaN MMD_value: 0.002548038959503174
MMD-DK: 2 Threshold: NaN MMD_value: -0.00124436616897583
MMD-DK: 2 Threshold: NaN MMD_value: -0.0011443495750427246
MMD-DK: 2 Threshold: NaN MMD_value: 0.0015640854835510254
MMD-DK: 2 Threshold: NaN MMD_value: -0.0003122091293334961
MMD-DK: 3 Threshold: 0.0041030943393707275 MMD_value: 0.007679849863052368
MMD-DK: 3 Threshold: NaN MMD_value: -0.0013006925582885742
MMD-DK: 3 Threshold: NaN MMD_value: -0.0012826919555664062
MMD-DK: 3 Threshold: NaN MMD_value: 0.0005139708518981934
MMD-DK: 3 Threshold: 

testing:   0%|          | 0/100 [00:00<?, ?it/s]

MMD-DK: 0 Threshold: NaN MMD_value: 0.0021208226680755615
MMD-DK: 0 Threshold: NaN MMD_value: 0.0026843249797821045
MMD-DK: 0 Threshold: NaN MMD_value: 0.0001020580530166626
MMD-DK: 0 Threshold: NaN MMD_value: 0.002063930034637451
MMD-DK: 1 Threshold: 0.00371745228767395 MMD_value: 0.005535393953323364
MMD-DK: 1 Threshold: NaN MMD_value: 0.0016814172267913818
MMD-DK: 1 Threshold: NaN MMD_value: 0.00177745521068573
MMD-DK: 1 Threshold: NaN MMD_value: 0.0019070208072662354
MMD-DK: 1 Threshold: NaN MMD_value: -0.0009503811597824097
MMD-DK: 1 Threshold: NaN MMD_value: -0.0006459355354309082
MMD-DK: 1 Threshold: NaN MMD_value: 0.0011588633060455322
MMD-DK: 1 Threshold: NaN MMD_value: 0.000530436635017395
MMD-DK: 2 Threshold: 0.0028507113456726074 MMD_value: 0.009248405694961548
MMD-DK: 2 Threshold: NaN MMD_value: -0.0007513463497161865
MMD-DK: 2 Threshold: NaN MMD_value: -1.7374753952026367e-05
MMD-DK: 2 Threshold: NaN MMD_value: 0.0004769265651702881
MMD-DK: 2 Threshold: NaN MMD_value: 0.0