In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [3]:
from tqdm.auto import tqdm
import numpy as np
import jax.numpy as jnp

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

from utils_HD import MatConvert, MMDu, TST_MMD_u

In [4]:
np.random.seed(819)
torch.manual_seed(819)
torch.cuda.manual_seed(819)
torch.backends.cudnn.deterministic = True
is_cuda = True

In [5]:
# Parameter Settings
n_epochs = 1000
batch_size = 100
lr = 0.0002
img_size = 64
channels = 3
n = 1000

In [6]:
dtype = torch.float
device = torch.device("cuda:0")
cuda = True if torch.cuda.is_available() else False
N_per = 100 # permutation times
alpha = 0.05 # test threshold
N1 = n # number of samples in one set
K = 10 # number of trails
J = 1 # number of test locations
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)

In [7]:
# Naming variables
ep_OPT = np.zeros([K])
s_OPT = np.zeros([K])
s0_OPT = np.zeros([K])
T_org_OPT = torch.zeros([K,J,3,64,64]) # Record test locations obtained by MMD-D
DK_Results = np.zeros([1,K])

In [8]:
# Define the deep network for MMD-D
class Featurizer(nn.Module):
    def __init__(self):
        super(Featurizer, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0)] #0.25
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
            # *discriminator_block(128, 256), 
            
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 300))

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        feature = self.adv_layer(out)

        return feature

In [9]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                             shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar10_1/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [10]:
# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)
    
    # Initialize deep networks for MMD-D (called featurizer)
    featurizer = Featurizer()
    
    # Initialize parameters
    epsilonOPT = torch.log(MatConvert(np.random.rand(1) * 10 ** (-10), device, dtype))
    epsilonOPT.requires_grad = True
    sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype)
    sigmaOPT.requires_grad = True
    sigma0OPT = MatConvert(np.ones(1) * np.sqrt(0.005), device, dtype)
    sigma0OPT.requires_grad = True
    if cuda:
        featurizer.cuda()

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), N1, replace=False)
    Ind_te = np.delete(Ind_all, Ind_tr)
    train_data = []
    for i in Ind_tr:
       train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
    Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]

    # Initialize optimizers
    optimizer_F = torch.optim.Adam(list(featurizer.parameters()) + [epsilonOPT] + [sigmaOPT] + [sigma0OPT], lr=lr)

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # ----------------------------------------------------------------------------------------------------
    #  Training deep networks for MMD-D (called featurizer)
    # ----------------------------------------------------------------------------------------------------
    np.random.seed(seed=1102)
    torch.manual_seed(1102)
    torch.cuda.manual_seed(1102)
    for epoch in tqdm(range(n_epochs)):
        for i, (imgs, _) in enumerate(dataloader):
            if True:
                ind = np.random.choice(N1, imgs.shape[0], replace=False)
                Fake_imgs = New_CIFAR_tr[ind]
                # Adversarial ground truths
                valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                # Configure input
                real_imgs = Variable(imgs.type(Tensor))
                Fake_imgs = Variable(Fake_imgs.type(Tensor))
                X = torch.cat([real_imgs, Fake_imgs], 0)
                Y = torch.cat([valid, fake], 0).squeeze().long()

                # ------------------------------
                #  Train deep network for MMD-D
                # ------------------------------
                # Initialize optimizer
                optimizer_F.zero_grad()
                # Compute output of deep network
                modelu_output = featurizer(X)
                # Compute epsilon, sigma and sigma_0
                ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT))
                sigma = sigmaOPT ** 2
                sigma0_u = sigma0OPT ** 2
                # Compute Compute J (STAT_u)
                DK_TEMP = MMDu(modelu_output, imgs.shape[0], X.view(X.shape[0],-1), sigma, sigma0_u, ep, complete=False)
                dk_hsic_xx = torch.from_numpy(np.array(DK_TEMP[3])).to(torch.float32)
                dk_hsic_yy = torch.from_numpy(np.array(DK_TEMP[4])).to(torch.float32)
                dk_hsic_xy = torch.from_numpy(np.array(DK_TEMP[5])).to(torch.float32)
                
                dk_mmd = -1 * (DK_TEMP[0])
                dk_mmd_std = torch.sqrt(DK_TEMP[1] + 10 ** (-8))
                
                DK_STAT_u = torch.div(dk_mmd, dk_mmd_std) + (dk_hsic_xx + dk_hsic_yy - dk_hsic_xy)
                if (epoch+1) % 100 == 0:
                    print("-" * 50)
                    print("epoch : ",epoch+1)
                    print("dk mmd: ", -1 * dk_mmd.item(), "dk mmd_std: ", dk_mmd_std.item(), "dk statistic: ",
                    -1 * DK_STAT_u.item())
                    
                # Compute gradient
                DK_STAT_u.backward()                
                # Update weights using gradient descent
                optimizer_F.step()
            else:
                break

    # Run two-sample test on the training set
    # Fetch training data
    s1 = data_all[Ind_tr]
    s2 = data_trans[Ind_tr_v4]
    S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
    Sv = S.view(2 * N1, -1)
    # Run two-sample test (MMD-D) on the training set
    # dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)
    # com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

    # Record best epsilon, sigma and sigma_0
    ep_OPT[kk] = ep.item()
    s_OPT[kk] = sigma.item()
    s0_OPT[kk] = sigma0_u.item()

    # Compute test power of MMD-D and baselines
    DK_H_u = np.zeros(N)

    np.random.seed(1102)
    dk_count_u = 0

    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans)-N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1 = data_all_te[Ind_N_te]
        s2 = data_trans[Ind_te_v4]
        S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
        Sv = S.view(2 * N_te, -1)
        # MMD-D
        dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N_te, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)

        # Gather results
        dk_count_u = dk_count_u + dk_h_u
        print("\r","MMD-DK:", dk_count_u, "MMD: ", dk_mmd_value_u, end="")
        DK_H_u[k] = dk_h_u

    # Print test power of MMD-D and baselines
    print("DK Reject rate_u: ", DK_H_u.sum() / N_f)
    DK_Results[0, kk] = DK_H_u.sum() / N_f
    print(f"Test Power of DK ({K} times): ")
    print(f"{DK_Results}")
    print(f"Average Test Power of DK ({K} times): ")
    print("MMD-D: ", (DK_Results.sum(1) / (kk + 1))[0])

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
CUDA backend failed to initialize: Found cuPTI version 18, but JAX was built against version 20, which is newer. The copy of cuPTI that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


--------------------------------------------------
epoch :  100
dk mmd:  0.009275173768401146 dk mmd_std:  0.004235982429236174 dk statistic:  2.179488182067871
--------------------------------------------------
epoch :  100
dk mmd:  0.009926260448992252 dk mmd_std:  0.0036966900806874037 dk statistic:  2.6750881671905518
--------------------------------------------------
epoch :  100
dk mmd:  0.010025662370026112 dk mmd_std:  0.004053354728966951 dk statistic:  2.4632937908172607
--------------------------------------------------
epoch :  100
dk mmd:  0.010198758915066719 dk mmd_std:  0.0036343138199299574 dk statistic:  2.7961575984954834
--------------------------------------------------
epoch :  100
dk mmd:  0.009643965400755405 dk mmd_std:  0.0030855198856443167 dk statistic:  3.115452527999878
--------------------------------------------------
epoch :  100
dk mmd:  0.009220937266945839 dk mmd_std:  0.002899167826399207 dk statistic:  3.170462131500244
----------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

MMDs:  [ 1.42086064e-05 -8.81287269e-06 -7.79062975e-06  1.00263860e-05
 -1.47968531e-05 -1.43563375e-05 -8.98679718e-06  1.32992864e-06
  8.54174141e-06  1.54660083e-05 -7.47467857e-06 -1.31340930e-05
  5.69969416e-07  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  

  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.009882299229502678 dk mmd_std:  0.005693449638783932 dk statistic:  1.725608229637146
--------------------------------------------------
epoch :  100
dk mmd:  0.010174958035349846 dk mmd_std:  0.005934981629252434 dk statistic:  1.704331398010254
--------------------------------------------------
epoch :  100
dk mmd:  0.01096055842936039 dk mmd_std:  0.003907768055796623 dk statistic:  2.794710874557495
--------------------------------------------------
epoch :  100
dk mmd:  0.01110315416008234 dk mmd_std:  0.004413404036313295 dk statistic:  2.505676031112671
--------------------------------------------------
epoch :  100
dk mmd:  0.010988747701048851 dk mmd_std:  0.0050898059271276 dk statistic:  2.1488893032073975
--------------------------------------------------
epoch :  100
dk mmd:  0.009303654544055462 dk mmd_std:  0.00476284371688962 dk statistic:  1.943294644355774
--------------------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

MMDs:  [-5.93520235e-06 -3.33413482e-07 -1.07511878e-05 -1.91153958e-06
 -1.02757476e-05  7.25593418e-06  6.55534677e-07 -2.71247700e-07
 -6.20959327e-07 -4.61284071e-06 -1.90469436e-05  1.24956714e-05
  2.82644760e-06 -5.39352186e-06  1.31165143e-06  6.46035187e-06
 -1.03828497e-05 -1.16254669e-05 -7.61542469e-06 -9.91986599e-06
  6.36395998e-06 -3.90433706e-06 -1.64564699e-05 -1.95507891e-06
 -1.22569036e-05 -1.90786086e-05 -5.56418672e-06 -8.50297511e-07
 -8.99226870e-06  9.94140282e-06  6.91600144e-06  7.99843110e-06
  3.09594907e-06 -5.77257015e-06 -1.51365530e-05 -1.23304781e-05
  1.30007975e-05 -1.90269202e-05  1.43641373e-05 -4.07756306e-06
  9.66596417e-07 -4.81819734e-06 -2.12935265e-06 -8.82730819e-06
  1.40838092e-05 -2.75380444e-06  9.29483213e-06  5.27920201e-06
  2.95345671e-06  1.80040952e-05 -8.68272036e-06  6.59842044e-06
  5.97524922e-06  6.48235437e-06  9.65897925e-06 -5.37233427e-06
  6.36349432e-06  1.22492202e-06 -7.58236274e-06  3.91597860e-06
 -1.57435425e-05 -

  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.009453067556023598 dk mmd_std:  0.0026515170466154814 dk statistic:  3.5550477504730225
--------------------------------------------------
epoch :  100
dk mmd:  0.010934421792626381 dk mmd_std:  0.004330592695623636 dk statistic:  2.5148191452026367
--------------------------------------------------
epoch :  100
dk mmd:  0.011016441509127617 dk mmd_std:  0.0035131005570292473 dk statistic:  3.1256887912750244
--------------------------------------------------
epoch :  100
dk mmd:  0.009750076569616795 dk mmd_std:  0.003915238659828901 dk statistic:  2.4801347255706787
--------------------------------------------------
epoch :  100
dk mmd:  0.01021443773061037 dk mmd_std:  0.004689787048846483 dk statistic:  2.1678972244262695
--------------------------------------------------
epoch :  100
dk mmd:  0.0108552947640419 dk mmd_std:  0.0043825129978358746 dk statistic:  2.4667909145355225
----------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

MMDs:  [-8.88248906e-06  1.03958882e-06  4.38047573e-06  5.86151145e-06
 -1.20827463e-05 -7.34534115e-06  1.81701034e-06 -1.38785690e-05
  4.63821925e-06  3.09059396e-06 -3.78675759e-06  9.96049494e-07
 -2.19524372e-05 -5.22704795e-07  1.78425107e-05 -8.09389167e-06
  1.27092935e-05  1.63486693e-05 -1.26683153e-05  1.47360843e-05
 -1.69018749e-05 -6.17164187e-06  1.54986046e-05  1.88336708e-06
  1.31765846e-05 -5.78910112e-06  7.45221041e-06 -6.10202551e-06
 -1.21637713e-05 -1.30105764e-05  7.30296597e-06 -1.05618965e-05
  1.28531829e-05  6.94161281e-06  2.44728290e-06  1.64059456e-05
 -8.49645585e-06 -9.86712985e-06 -2.34856270e-06 -3.44612636e-06
 -1.95833854e-06 -8.99075530e-06 -1.20308250e-05 -7.45430589e-06
 -5.57629392e-06  1.52401626e-05  9.83243808e-06  1.49419066e-05
 -5.93042932e-06  1.13381539e-05  1.52920838e-05 -9.86852683e-06
 -6.31948933e-06  4.55230474e-06  8.54767859e-06 -5.52437268e-06
 -8.84104520e-06  3.64985317e-06 -2.65426934e-08 -1.07989181e-05
  3.26661393e-06 -

  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.01131890993565321 dk mmd_std:  0.004559747874736786 dk statistic:  2.4722063541412354
--------------------------------------------------
epoch :  100
dk mmd:  0.011300134472548962 dk mmd_std:  0.005146306939423084 dk statistic:  2.1856367588043213
--------------------------------------------------
epoch :  100
dk mmd:  0.011537613347172737 dk mmd_std:  0.00587574252858758 dk statistic:  1.9534802436828613
--------------------------------------------------
epoch :  100
dk mmd:  0.012638930231332779 dk mmd_std:  0.006317782681435347 dk statistic:  1.9904083013534546
--------------------------------------------------
epoch :  100
dk mmd:  0.012903961353003979 dk mmd_std:  0.007572372443974018 dk statistic:  1.6940007209777832
--------------------------------------------------
epoch :  100
dk mmd:  0.013117078691720963 dk mmd_std:  0.007389082573354244 dk statistic:  1.765078067779541
-------------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

MMDs:  [-1.04657374e-05 -3.68896872e-05 -1.38226897e-05  9.04221088e-06
 -3.34172510e-05  2.06287950e-05  2.46213749e-05 -2.48849392e-05
 -1.62208453e-05 -1.13807619e-05 -1.03972852e-05 -2.13566236e-05
  2.48309225e-05 -6.67106360e-06  3.52184288e-05 -1.68085098e-05
 -1.75796449e-05  1.97133049e-05 -4.14256938e-05  2.95834616e-06
 -1.81002542e-06 -8.01030546e-06 -1.70338899e-05 -2.24113464e-05
 -2.24830583e-05  7.57398084e-06 -8.60262662e-06 -2.47196294e-05
  2.54414044e-05 -2.13882886e-05  2.40774825e-05 -2.87997536e-05
  7.91624188e-06 -6.73532486e-06 -1.79447234e-05  1.32932328e-05
 -4.56813723e-07  4.20017168e-05  5.22937626e-07 -2.70921737e-06
 -5.73461875e-06 -3.07899900e-05 -2.68947333e-05  2.51284800e-05
 -1.17165036e-05  3.55364755e-05  4.12734225e-05 -1.69617124e-05
 -3.36626545e-05 -1.10827386e-05 -2.56514177e-05  1.43009238e-05
 -2.10613944e-05  1.28014944e-05 -1.24070793e-05 -2.00355425e-05
  3.07383016e-06 -6.16721809e-06 -2.36933120e-05  2.25934200e-05
 -9.64198261e-06 -

  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.010102467611432076 dk mmd_std:  0.004093593452125788 dk statistic:  2.4577536582946777
--------------------------------------------------
epoch :  100
dk mmd:  0.010749511420726776 dk mmd_std:  0.0032962497789412737 dk statistic:  3.251058340072632
--------------------------------------------------
epoch :  100
dk mmd:  0.009293824434280396 dk mmd_std:  0.0036061767023056746 dk statistic:  2.5671045780181885
--------------------------------------------------
epoch :  100
dk mmd:  0.009315036237239838 dk mmd_std:  0.003609145525842905 dk statistic:  2.5708208084106445
--------------------------------------------------
epoch :  100
dk mmd:  0.008922846987843513 dk mmd_std:  0.0032309002708643675 dk statistic:  2.7516350746154785
--------------------------------------------------
epoch :  100
dk mmd:  0.009284595027565956 dk mmd_std:  0.0043821679428219795 dk statistic:  2.108598232269287
--------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

MMDs:  [-1.06892549e-06 -2.22434755e-06 -6.29643910e-06  6.32007141e-06
  1.21108023e-05  7.25814607e-06 -1.54273584e-06 -6.81821257e-06
 -3.15927900e-06 -1.25680817e-05 -4.89526428e-06  1.79159688e-05
 -5.00876922e-06 -6.30284194e-06 -1.56858005e-06 -1.50904525e-05
 -1.32829882e-06 -4.32075467e-06 -8.53557140e-07 -8.02974682e-06
  1.06707448e-05  1.22527126e-06  2.70907767e-05  1.43552897e-05
 -4.72040847e-06  5.56989107e-06 -2.78768130e-06 -8.64616595e-07
 -1.26170926e-05  1.71265565e-05  1.94242457e-05 -1.88942067e-06
  1.43772922e-06  5.31144906e-06 -2.27708369e-06  1.46864913e-05
  1.64960511e-06 -7.88713805e-06 -1.15885632e-05  1.30997505e-05
 -4.06731851e-06  1.85426325e-06  1.57069881e-05 -6.99353404e-06
 -1.43628567e-05  3.68733890e-06 -1.59849878e-06 -7.03078695e-06
 -4.78699803e-07 -2.91352626e-06 -7.96106178e-06 -3.19653191e-06
  8.13382212e-06 -4.27267514e-06  6.82380050e-06  7.67281745e-06
 -1.00128818e-06 -3.45276203e-06 -7.76152592e-06 -4.85917553e-07
 -7.89412297e-07 -

  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.0110141197219491 dk mmd_std:  0.004874340258538723 dk statistic:  2.2495076656341553
--------------------------------------------------
epoch :  100
dk mmd:  0.010989113710820675 dk mmd_std:  0.004693980328738689 dk statistic:  2.3310365676879883
--------------------------------------------------
epoch :  100
dk mmd:  0.010231376625597477 dk mmd_std:  0.004025022964924574 dk statistic:  2.5318620204925537
--------------------------------------------------
epoch :  100
dk mmd:  0.00965351052582264 dk mmd_std:  0.00401795981451869 dk statistic:  2.392484426498413
--------------------------------------------------
epoch :  100
dk mmd:  0.008840330876410007 dk mmd_std:  0.003444228321313858 dk statistic:  2.5566160678863525
--------------------------------------------------
epoch :  100
dk mmd:  0.009438570588827133 dk mmd_std:  0.0053160833194851875 dk statistic:  1.7653346061706543
--------------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

MMDs:  [ 4.87675425e-06 -1.37281604e-05  3.41667328e-06 -4.18268610e-06
  1.31538836e-05 -9.45001375e-06 -1.02147460e-05  3.91469803e-06
 -1.18417665e-06  8.19528941e-06 -3.44961882e-06 -1.10466499e-05
 -8.10460187e-06  5.44777140e-06  1.39088370e-05 -8.49214848e-06
 -1.35952141e-05 -6.29108399e-07  2.81896209e-05 -1.05310464e-05
  2.52970494e-06  1.04979845e-05  3.56207602e-06  1.01653859e-06
  1.26659870e-05  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  

  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.009168184362351894 dk mmd_std:  0.0031489331740885973 dk statistic:  2.901414155960083
--------------------------------------------------
epoch :  100
dk mmd:  0.008601893670856953 dk mmd_std:  0.0032874085009098053 dk statistic:  2.6064929962158203
--------------------------------------------------
epoch :  100
dk mmd:  0.010494701564311981 dk mmd_std:  0.0036366195417940617 dk statistic:  2.8757541179656982
--------------------------------------------------
epoch :  100
dk mmd:  0.010571837425231934 dk mmd_std:  0.004137016367167234 dk statistic:  2.545320749282837
--------------------------------------------------
epoch :  100
dk mmd:  0.010406825691461563 dk mmd_std:  0.0039790039882063866 dk statistic:  2.605313301086426
--------------------------------------------------
epoch :  100
dk mmd:  0.01041500922292471 dk mmd_std:  0.0037551179993897676 dk statistic:  2.76344633102417
-----------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

MMDs:  [ 6.37373887e-07  1.04091596e-05  1.49038387e-05  1.84902456e-05
  1.06095104e-05 -8.19028355e-06 -2.59256922e-06 -1.87975820e-05
  8.20390414e-06  2.91201286e-06 -1.10149849e-05 -3.05403955e-06
 -3.02563421e-07 -1.54320151e-06 -1.09383836e-06 -9.44710337e-06
  2.31887680e-06 -4.27209307e-06 -5.33880666e-07  1.53384171e-05
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  

  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.0119848120957613 dk mmd_std:  0.00863969512283802 dk statistic:  1.3770465850830078
--------------------------------------------------
epoch :  100
dk mmd:  0.011191720142960548 dk mmd_std:  0.0080709895119071 dk statistic:  1.3765416145324707
--------------------------------------------------
epoch :  100
dk mmd:  0.010656350292265415 dk mmd_std:  0.007675367407500744 dk statistic:  1.378269910812378
--------------------------------------------------
epoch :  100
dk mmd:  0.01117322500795126 dk mmd_std:  0.008438385091722012 dk statistic:  1.3140122890472412
--------------------------------------------------
epoch :  100
dk mmd:  0.012472116388380527 dk mmd_std:  0.008071494288742542 dk statistic:  1.5350751876831055
--------------------------------------------------
epoch :  100
dk mmd:  0.012439735233783722 dk mmd_std:  0.00644381670281291 dk statistic:  1.9203537702560425
------------------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

MMDs:  [-3.01850960e-05 -1.61603093e-05 -8.97655264e-06 -1.60108320e-05
 -1.07986853e-05  3.17296945e-05 -4.92669642e-06 -2.28886493e-05
 -1.21290796e-05  3.17229424e-05 -1.22077763e-05  1.00177713e-05
  2.92435288e-06  2.02488154e-05  3.08775343e-05  2.05016695e-05
  5.57536259e-06 -1.26413070e-05  1.56138558e-05  1.25272200e-05
  5.68339601e-06 -1.59535557e-06 -7.35046342e-06 -1.49537809e-05
 -1.67582184e-05 -5.01517206e-07 -2.71294266e-06  4.12953086e-05
  1.72830187e-05 -5.90225682e-06  2.66819261e-05 -1.65500678e-05
  2.79271044e-05  3.59443948e-06 -9.69087705e-06  1.67912804e-05
 -1.36918388e-05 -1.96830370e-05 -1.00512989e-05 -1.24047510e-05
 -1.07879750e-05 -8.94255936e-06 -2.34832987e-06 -1.27688982e-05
 -1.65589154e-06  9.72766429e-07 -2.29850411e-06 -7.00913370e-06
 -1.55824237e-05  1.24652870e-05 -1.77049078e-05  6.00423664e-05
 -3.25199217e-05 -9.15210694e-06 -1.05937943e-06 -1.64890662e-05
  1.61919743e-05 -2.61701643e-07 -1.55246817e-05  9.14931297e-06
  1.18119642e-05 -

  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.009459104388952255 dk mmd_std:  0.003189158160239458 dk statistic:  2.9559226036071777
--------------------------------------------------
epoch :  100
dk mmd:  0.008613644167780876 dk mmd_std:  0.0030706943944096565 dk statistic:  2.7950215339660645
--------------------------------------------------
epoch :  100
dk mmd:  0.009402059949934483 dk mmd_std:  0.0030748238787055016 dk statistic:  3.047673463821411
--------------------------------------------------
epoch :  100
dk mmd:  0.008155299350619316 dk mmd_std:  0.0029403551016002893 dk statistic:  2.7634775638580322
--------------------------------------------------
epoch :  100
dk mmd:  0.008455194532871246 dk mmd_std:  0.0034046920482069254 dk statistic:  2.4733128547668457
--------------------------------------------------
epoch :  100
dk mmd:  0.008296551182866096 dk mmd_std:  0.00339986733160913 dk statistic:  2.430173397064209
---------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

MMDs:  [-1.25577208e-06 -2.57929787e-06 -5.02052717e-06 -8.49657226e-06
  3.63402069e-06  7.93195795e-06 -5.31703699e-06  1.51604181e-05
  7.05569983e-06  4.13111411e-06 -4.35160473e-07  3.13948840e-06
 -6.60971273e-06  5.05161006e-06 -3.87232285e-06 -1.43062789e-05
 -3.36056110e-06  8.41566361e-07  2.53140461e-05  3.77453398e-06
  6.06407411e-07 -5.22180926e-06 -2.55159102e-06 -1.22305937e-06
 -6.88142609e-06  1.97542831e-05 -3.51632480e-06  2.48289434e-05
 -1.36275776e-06  1.53778819e-05  5.31307887e-06 -2.73785554e-06
  9.91672277e-06  3.97115946e-06  6.43194653e-07  2.53308099e-06
  4.99596354e-06  3.23669519e-06  1.01525802e-06  2.20040092e-05
  2.74522463e-05 -2.06031837e-06  2.38453504e-06  8.63592140e-06
  1.32108107e-06 -7.82660209e-07  4.35102265e-06  1.80462375e-05
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  

  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.013018261641263962 dk mmd_std:  0.004431330598890781 dk statistic:  2.9276676177978516
--------------------------------------------------
epoch :  100
dk mmd:  0.012294463813304901 dk mmd_std:  0.003352038562297821 dk statistic:  3.6576483249664307
--------------------------------------------------
epoch :  100
dk mmd:  0.011401613242924213 dk mmd_std:  0.003723236732184887 dk statistic:  3.0521576404571533
--------------------------------------------------
epoch :  100
dk mmd:  0.010555597953498363 dk mmd_std:  0.003261633450165391 dk statistic:  3.2261579036712646
--------------------------------------------------
epoch :  100
dk mmd:  0.01042219065129757 dk mmd_std:  0.00412119971588254 dk statistic:  2.5187885761260986
--------------------------------------------------
epoch :  100
dk mmd:  0.010001448914408684 dk mmd_std:  0.0039781262166798115 dk statistic:  2.5039687156677246
-----------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

MMDs:  [-1.71570573e-05  2.94563361e-05 -1.97917689e-05 -6.75115734e-06
 -2.98209488e-06 -8.35210085e-06 -1.88434497e-05 -4.98769805e-06
 -1.71368010e-05  1.68115366e-05  3.62470746e-06  4.47407365e-06
  3.17906961e-05 -1.67035032e-05  1.84951350e-05 -3.38954851e-06
 -4.28175554e-06 -1.14252325e-05 -4.26545739e-06  2.24611722e-05
 -1.69118866e-05  1.42478384e-05  1.32694840e-05  5.79794869e-06
  4.70187515e-05 -4.55998816e-06  1.03423372e-06 -9.10903327e-06
 -1.23165082e-05  7.71437772e-06  6.52414747e-06 -2.78395601e-06
 -1.85449608e-05 -3.35765071e-06  1.54820736e-05 -7.57328235e-06
 -6.38980418e-06 -1.41579658e-05 -1.41165219e-06 -7.28294253e-06
  1.36904418e-07 -9.39820893e-06  9.19052400e-06 -1.69449486e-05
 -2.30551232e-05  1.51386485e-06  6.69690780e-06  1.96609180e-05
  3.51700000e-05  2.60719098e-05  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  

#### Epoch 20 / lr: 0.0005
[[0.23 0.94 0.25 0.37 0.6  0.1  0.36 0.31 0.23 0.4 ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Baselines (10 times): 
MMD-D:  0.379