In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [3]:
from tqdm.auto import tqdm
import numpy as np
import jax.numpy as jnp

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from utils import jnp_to_tensor
from utils_HD import MatConvert, MMDu, TST_MMD_u

In [4]:
np.random.seed(819)
torch.manual_seed(819)
torch.cuda.manual_seed(819)
torch.backends.cudnn.deterministic = True
is_cuda = True

In [5]:
# Parameter Settings
n_epochs = 800
batch_size = 200
lr = 0.000002
img_size = 64
channels = 3
n = 1000

In [6]:
dtype = torch.float
device = torch.device("cuda:0")
cuda = True if torch.cuda.is_available() else False
N_per = 100 # permutation times
alpha = 0.05 # test threshold
N1 = n # number of samples in one set
K = 10 # number of trails
J = 1 # number of test locations
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)

In [7]:
# Naming variables
ep_OPT = np.zeros([K])
s_OPT = np.zeros([K])
s0_OPT = np.zeros([K])
T_org_OPT = torch.zeros([K,J,3,64,64]) # Record test locations obtained by MMD-D
COM_Results = np.zeros([1,K])

In [8]:
# Define the deep network for MMD-D
class Featurizer_COM(nn.Module):
    def __init__(self):
        super(Featurizer_COM, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0)] #0.25
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 300))

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        feature = self.adv_layer(out)

        return feature

In [9]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                             shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar10_1/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [10]:
# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)
    
    # Initialize deep networks for MMD-D (called featurizer)
    featurizer_com = Featurizer_COM()
    
    # Initialize parameters
    epsilonOPT = torch.log(MatConvert(np.random.rand(1) * 10 ** (-10), device, dtype))
    epsilonOPT.requires_grad = True
    sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype)
    sigmaOPT.requires_grad = True
    sigma0OPT = MatConvert(np.ones(1) * np.sqrt(0.005), device, dtype)
    sigma0OPT.requires_grad = True
    if cuda:
        featurizer_com.cuda()

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), N1, replace=False)
    Ind_te = np.delete(Ind_all, Ind_tr)
    train_data = []
    for i in Ind_tr:
       train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
    Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]

    # Initialize optimizers
    optimizer_COM = torch.optim.Adam(list(featurizer_com.parameters()) + [epsilonOPT] + [sigmaOPT] + [sigma0OPT], lr=lr, weight_decay=1e-4)
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # ----------------------------------------------------------------------------------------------------
    #  Training deep networks for MMD-D (called featurizer)
    # ----------------------------------------------------------------------------------------------------
    np.random.seed(seed=1102)
    torch.manual_seed(1102)
    torch.cuda.manual_seed(1102)
    for epoch in tqdm(range(n_epochs)):
        for i, (imgs, _) in enumerate(dataloader):
            if True:
                ind = np.random.choice(N1, imgs.shape[0], replace=False)
                Fake_imgs = New_CIFAR_tr[ind]
                # Adversarial ground truths
                valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                # Configure input
                real_imgs = Variable(imgs.type(Tensor))
                Fake_imgs = Variable(Fake_imgs.type(Tensor))
                X = torch.cat([real_imgs, Fake_imgs], 0)
                Y = torch.cat([valid, fake], 0).squeeze().long()

                # ------------------------------
                #  Train deep network for MMD-D
                # ------------------------------
                # Initialize optimizer
                optimizer_COM.zero_grad()
                
                # Compute output of deep network
                com_modelu_output = featurizer_com(X)
                
                # Compute epsilon, sigma and sigma_0
                ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT))
                sigma = sigmaOPT ** 2
                sigma0_u = sigma0OPT ** 2
                
                # Compute Compute J (STAT_u)
                COM_TEMP = MMDu(com_modelu_output, imgs.shape[0], X.view(X.shape[0],-1), sigma, sigma0_u, ep, complete=True)
                com_mmd = - COM_TEMP[0]
                com_mmd_std_tmp_tensor = jnp_to_tensor(COM_TEMP[1])
                com_mmd_std = torch.sqrt(com_mmd_std_tmp_tensor + 10**(-6))
                COM_STAT_u = torch.div(com_mmd, com_mmd_std)
                # COM_LOSS = - torch.log(com_mmd) + torch.log(com_mmd_std)
                
                if (epoch+1) % 100 == 0:
                    print("=" * 110)
                    print("epoch : ",epoch+1)
                    print("our mmd: ", -1 * com_mmd.item(), "our mmd_std: ", com_mmd_std.item(), "our statistic: ",
                    -1 * COM_STAT_u.item())
                    
                # Compute gradient
                COM_STAT_u.backward()
                # COM_LOSS.backward()
                
                # Update weights using gradient descent
                optimizer_COM.step()
            else:
                break

    # Run two-sample test on the training set
    # Fetch training data
    s1 = data_all[Ind_tr]
    s2 = data_trans[Ind_tr_v4]
    S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
    Sv = S.view(2 * N1, -1)
    # Run two-sample test (MMD-D) on the training set
    # dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)
    # com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

    # Record best epsilon, sigma and sigma_0
    ep_OPT[kk] = ep.item()
    s_OPT[kk] = sigma.item()
    s0_OPT[kk] = sigma0_u.item()

    # Compute test power of MMD-D and baselines
    DK_H_u = np.zeros(N)
    COM_H_u = np.zeros(N)

    np.random.seed(1102)
    dk_count_u = 0
    com_count_u = 0

    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans)-N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1 = data_all_te[Ind_N_te]
        s2 = data_trans[Ind_te_v4]
        S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
        Sv = S.view(2 * N_te, -1)
        # MMD-D
        com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer_com(S), N_per, N_te, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

        # Gather results
        com_count_u = com_count_u + com_h_u
        print("\r","Ours:", com_count_u, "MMD: ", com_mmd_value_u, end="")
        COM_H_u[k] = com_h_u

    # Print test power of MMD-D and baselines
    print("Our Reject rate_u: ", COM_H_u.sum() / N_f)
    COM_Results[0, kk] = COM_H_u.sum() / N_f
    print(f"Test Power of Ours ({K} times): ")
    print(f"{COM_Results}")
    print(f"Average Test Power of Ours ({K} times): ")
    print("Ours: ", (COM_Results.sum(1) / (kk + 1))[0])

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/800 [00:00<?, ?it/s]

  valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
CUDA backend failed to initialize: Found cuPTI version 18, but JAX was built against version 20, which is newer. The copy of cuPTI that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


epoch :  100
our mmd:  0.00015447381883859634 our mmd_std:  0.000999522076800441 our statistic:  0.15454768076066988
epoch :  100
our mmd:  0.0003137248568236828 our mmd_std:  0.0010851629849573767 our statistic:  0.289103905286638
epoch :  100
our mmd:  0.0020655517000705004 our mmd_std:  0.001385703510079356 our statistic:  1.490615911012747
epoch :  100
our mmd:  0.0005798335187137127 our mmd_std:  0.0010418116538007114 our statistic:  0.5565627113100323
epoch :  100
our mmd:  0.0005315174348652363 our mmd_std:  0.0010889843777288444 our statistic:  0.4880854544247497
epoch :  200
our mmd:  0.00265656691044569 our mmd_std:  0.0015753561531715058 our statistic:  1.686327821869037
epoch :  200
our mmd:  0.0014071613550186157 our mmd_std:  0.001288827024649059 our statistic:  1.0918155253625121
epoch :  200
our mmd:  0.002007821574807167 our mmd_std:  0.0014502106035379908 our statistic:  1.384503443781756
epoch :  200
our mmd:  0.004040979780256748 our mmd_std:  0.0017398956899393918 

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 20 MMD:  8.754432201385498e-08Our Reject rate_u:  0.2
Test Power of Ours (10 times): 
[[0.2 0.  0.  0.  0.  0.  0.  0.  0.  0. ]]
Average Test Power of Ours (10 times): 
Ours:  0.2


  0%|          | 0/800 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.0012494567781686783 our mmd_std:  0.0011998904326788696 our statistic:  1.0413090596773467
epoch :  100
our mmd:  -1.7358921468257904e-05 our mmd_std:  0.0009817097991567836 our statistic:  -0.017682334925420873
epoch :  100
our mmd:  0.0005636196583509445 our mmd_std:  0.0010572225397814548 our statistic:  0.5331135471888955
epoch :  100
our mmd:  0.0005863131955265999 our mmd_std:  0.0011014702646815313 our statistic:  0.5323005207917446
epoch :  100
our mmd:  0.0003943825140595436 our mmd_std:  0.0010485068053684681 our statistic:  0.37613729547606417
epoch :  200
our mmd:  0.001680620014667511 our mmd_std:  0.0012535843440169725 our statistic:  1.3406517261392639
epoch :  200
our mmd:  0.0010223351418972015 our mmd_std:  0.001184023544745637 our statistic:  0.8634415645145207
epoch :  200
our mmd:  0.0007442329078912735 our mmd_std:  0.0010930578584452085 our statistic:  0.680872382135277
epoch :  200
our mmd:  0.0011769048869609833 our mmd_std:  0.00109318

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 83 MMD:  0.00021959096193313599Our Reject rate_u:  0.83
Test Power of Ours (10 times): 
[[0.2  0.83 0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.515


  0%|          | 0/800 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.0004688610788434744 our mmd_std:  0.0010695605639199329 our statistic:  0.43836795657938377
epoch :  100
our mmd:  6.01988285779953e-05 our mmd_std:  0.0010007426154365676 our statistic:  0.060154157172305434
epoch :  100
our mmd:  0.00012912158854305744 our mmd_std:  0.0010254586281685132 our statistic:  0.12591594141020668
epoch :  100
our mmd:  0.00023737316951155663 our mmd_std:  0.0010038937958999868 our statistic:  0.23645247184614038
epoch :  100
our mmd:  0.0005381288938224316 our mmd_std:  0.0010455282511486837 our statistic:  0.5146956987830879
epoch :  200
our mmd:  0.0003694770857691765 our mmd_std:  0.0010403178313445288 our statistic:  0.35515788986492375
epoch :  200
our mmd:  7.999874651432037e-05 our mmd_std:  0.0009974617065606718 our statistic:  0.08020232354599606
epoch :  200
our mmd:  0.0011962968856096268 our mmd_std:  0.0012505156861073346 our statistic:  0.9566428465471851
epoch :  200
our mmd:  0.0011016931384801865 our mmd_std:  0.001

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 43 MMD:  0.000133611261844635Our Reject rate_u:  0.43
Test Power of Ours (10 times): 
[[0.2  0.83 0.43 0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.48666666666666664


  0%|          | 0/800 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.0009400565177202225 our mmd_std:  0.0012610262349029829 our statistic:  0.7454694372734805
epoch :  100
our mmd:  0.00032315216958522797 our mmd_std:  0.0010128418171332567 our statistic:  0.3190549245882013
epoch :  100
our mmd:  -6.887130439281464e-05 our mmd_std:  0.0010770654781824872 our statistic:  -0.06394347027910754
epoch :  100
our mmd:  0.0005291476845741272 our mmd_std:  0.0011431068832476363 our statistic:  0.4629030691082765
epoch :  100
our mmd:  -0.0001305900514125824 our mmd_std:  0.0009648000697638361 our statistic:  -0.13535452111290608
epoch :  200
our mmd:  0.0008327942341566086 our mmd_std:  0.0011437485910412016 our statistic:  0.7281270033290109
epoch :  200
our mmd:  0.0016780830919742584 our mmd_std:  0.0013760499142819382 our statistic:  1.2194928937951568
epoch :  200
our mmd:  0.0003060735762119293 our mmd_std:  0.001000687199969079 our statistic:  0.30586338690190795
epoch :  200
our mmd:  0.0006097685545682907 our mmd_std:  0.0011

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 43 MMD:  0.00011749938130378723Our Reject rate_u:  0.43
Test Power of Ours (10 times): 
[[0.2  0.83 0.43 0.43 0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.47250000000000003


  0%|          | 0/800 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.00022037606686353683 our mmd_std:  0.001058437342655368 our statistic:  0.20820889247035218
epoch :  100
our mmd:  -7.370021194219589e-05 our mmd_std:  0.0010019455669954937 our statistic:  -0.07355710167289693
epoch :  100
our mmd:  0.00038770586252212524 our mmd_std:  0.00100178027730754 our statistic:  0.38701686517941103
epoch :  100
our mmd:  8.9199747890234e-05 our mmd_std:  0.0010068097907107757 our statistic:  0.08859642477976083
epoch :  100
our mmd:  -7.545296102762222e-05 our mmd_std:  0.000988637314595 our statistic:  -0.07632016303019262
epoch :  200
our mmd:  0.0013506179675459862 our mmd_std:  0.0013032274342308242 our statistic:  1.0363639776683586
epoch :  200
our mmd:  2.3671425879001617e-05 our mmd_std:  0.0010303041096088106 our statistic:  0.02297518340287827
epoch :  200
our mmd:  0.0014965459704399109 our mmd_std:  0.0012510674082751938 our statistic:  1.1962152962670096
epoch :  200
our mmd:  0.0006121313199400902 our mmd_std:  0.0010397

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 81 MMD:  0.00014093145728111267Our Reject rate_u:  0.81
Test Power of Ours (10 times): 
[[0.2  0.83 0.43 0.43 0.81 0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.54


  0%|          | 0/800 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.001369742676615715 our mmd_std:  0.00128344507393002 our statistic:  1.0672390306672372
epoch :  100
our mmd:  0.00033631641417741776 our mmd_std:  0.0010551609614665506 our statistic:  0.31873470158522277
epoch :  100
our mmd:  0.00036022067070007324 our mmd_std:  0.0010792158371305167 our statistic:  0.3337800079526718
epoch :  100
our mmd:  0.0007506078109145164 our mmd_std:  0.0011727659139102473 our statistic:  0.640032083139109
epoch :  100
our mmd:  -4.6055763959884644e-05 our mmd_std:  0.0009913596429722011 our statistic:  -0.04645717049950166
epoch :  200
our mmd:  0.0007912274450063705 our mmd_std:  0.0010937383758410336 our statistic:  0.7234156380386247
epoch :  200
our mmd:  0.0016404585912823677 our mmd_std:  0.001405942916198519 our statistic:  1.1668031271980428
epoch :  200
our mmd:  0.00253323744982481 our mmd_std:  0.0013977195310076718 our statistic:  1.8124075636250845
epoch :  200
our mmd:  0.0033508939668536186 our mmd_std:  0.00178889348

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 5 MMD:  -4.207715392112732e-05Our Reject rate_u:  0.05
Test Power of Ours (10 times): 
[[0.2  0.83 0.43 0.43 0.81 0.05 0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.4583333333333333


  0%|          | 0/800 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.00039527658373117447 our mmd_std:  0.0010972541142082117 our statistic:  0.3602416054884511
epoch :  100
our mmd:  0.0006722984835505486 our mmd_std:  0.0011803838468205279 our statistic:  0.5695592034416992
epoch :  100
our mmd:  0.001134118065237999 our mmd_std:  0.00131506924312092 our statistic:  0.8624017869557286
epoch :  100
our mmd:  0.00031423382461071014 our mmd_std:  0.0010878968970879318 our statistic:  0.28884522554650827
epoch :  100
our mmd:  0.0006824322044849396 our mmd_std:  0.0011279199578098414 our statistic:  0.6050360220685024
epoch :  200
our mmd:  0.003307056613266468 our mmd_std:  0.0017340717951483536 our statistic:  1.907104782235122
epoch :  200
our mmd:  0.0032137539237737656 our mmd_std:  0.001629153580864664 our statistic:  1.9726525243053412
epoch :  200
our mmd:  0.0028189094737172127 our mmd_std:  0.0017517952329448093 our statistic:  1.6091546664267142
epoch :  200
our mmd:  0.001073502004146576 our mmd_std:  0.001233499926387

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 48 MMD:  0.0001430753618478775Our Reject rate_u:  0.48
Test Power of Ours (10 times): 
[[0.2  0.83 0.43 0.43 0.81 0.05 0.48 0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.46142857142857147


  0%|          | 0/800 [00:00<?, ?it/s]

epoch :  100
our mmd:  4.2706989916041493e-05 our mmd_std:  0.001001177138346209 our statistic:  0.042656776988123084
epoch :  100
our mmd:  2.47609568759799e-05 our mmd_std:  0.0010009549213571974 our statistic:  0.024737334666786445
epoch :  100
our mmd:  -7.503200322389603e-06 our mmd_std:  0.00100038241358373 our statistic:  -0.0075003320935145565
epoch :  100
our mmd:  2.2179388906806707e-05 our mmd_std:  0.0010021259748997499 our statistic:  0.022132336115751793
epoch :  100
our mmd:  6.412438233383e-05 our mmd_std:  0.001000503433837811 our statistic:  0.06409211619380113
epoch :  200
our mmd:  0.0003256509080529213 our mmd_std:  0.0010222195756606062 our statistic:  0.3185723652792214
epoch :  200
our mmd:  0.0006926874630153179 our mmd_std:  0.0010462683969517023 our statistic:  0.6620552288814795
epoch :  200
our mmd:  0.001137056853622198 our mmd_std:  0.0011332317045822397 our statistic:  1.003375434189223
epoch :  200
our mmd:  0.0009372583590447903 our mmd_std:  0.0010984

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 27 MMD:  6.313435733318329e-05Our Reject rate_u:  0.27
Test Power of Ours (10 times): 
[[0.2  0.83 0.43 0.43 0.81 0.05 0.48 0.27 0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.4375


  0%|          | 0/800 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.00034018559381365776 our mmd_std:  0.0010420920628325058 our statistic:  0.32644485640644916
epoch :  100
our mmd:  9.1486144810915e-05 our mmd_std:  0.0010138285500792008 our statistic:  0.09023828023364311
epoch :  100
our mmd:  0.0008962894789874554 our mmd_std:  0.0011197494346707906 our statistic:  0.8004375365020541
epoch :  100
our mmd:  0.0012006768956780434 our mmd_std:  0.0011329498925079277 our statistic:  1.059779345598589
epoch :  100
our mmd:  0.00015435297973453999 our mmd_std:  0.0010113319296403448 our statistic:  0.15262346140839417
epoch :  200
our mmd:  0.004976268857717514 our mmd_std:  0.001761171768189661 our statistic:  2.82554430385442
epoch :  200
our mmd:  0.0018527843058109283 our mmd_std:  0.001321642019217212 our statistic:  1.4018805991869898
epoch :  200
our mmd:  0.0019243238493800163 our mmd_std:  0.0013343025031784849 our statistic:  1.4421945884055698
epoch :  200
our mmd:  0.002742327284067869 our mmd_std:  0.001520719623354

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 52 MMD:  3.449898213148117e-05Our Reject rate_u:  0.52
Test Power of Ours (10 times): 
[[0.2  0.83 0.43 0.43 0.81 0.05 0.48 0.27 0.52 0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.4466666666666666


  0%|          | 0/800 [00:00<?, ?it/s]

epoch :  100
our mmd:  4.903459921479225e-05 our mmd_std:  0.0010066901827213528 our statistic:  0.04870872891820462
epoch :  100
our mmd:  -4.038633778691292e-05 our mmd_std:  0.0009948432584584797 our statistic:  -0.04059567921231329
epoch :  100
our mmd:  0.0004174686037003994 our mmd_std:  0.0010192390591108329 our statistic:  0.4095885062181507
epoch :  100
our mmd:  3.494415432214737e-05 our mmd_std:  0.0009996660209990516 our statistic:  0.03495582883493898
epoch :  100
our mmd:  7.367413491010666e-05 our mmd_std:  0.0010171770466355086 our statistic:  0.07243000139827849
epoch :  200
our mmd:  0.0001370948739349842 our mmd_std:  0.0010165871871856658 our statistic:  0.1348579597137355
epoch :  200
our mmd:  0.0001971307210624218 our mmd_std:  0.0010063129919914513 our statistic:  0.195894043534416
epoch :  200
our mmd:  0.0007988037541508675 our mmd_std:  0.001062158671592174 our statistic:  0.7520568965025368
epoch :  200
our mmd:  0.0008340789936482906 our mmd_std:  0.0011080

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 29 MMD:  0.00025125034153461456Our Reject rate_u:  0.29
Test Power of Ours (10 times): 
[[0.2  0.83 0.43 0.43 0.81 0.05 0.48 0.27 0.52 0.29]]
Average Test Power of Ours (10 times): 
Ours:  0.43099999999999994


#### Epoch 20 / lr: 0.0005
[[0.23 0.94 0.25 0.37 0.6  0.1  0.36 0.31 0.23 0.4 ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Baselines (10 times): 
MMD-D:  0.379