In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os
import sys

sys.path.append('/home/oldrain123/MMD/')
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [3]:
from tqdm.auto import tqdm
import numpy as np
import jax.numpy as jnp

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

from utils_HD import MatConvert, MMDu, TST_MMD_u

In [4]:
np.random.seed(819)
torch.manual_seed(819)
torch.cuda.manual_seed(819)
torch.backends.cudnn.deterministic = True
is_cuda = True

In [5]:
# Parameter Settings
n_epochs = 1000
batch_size = 512
lr = 0.00002
img_size = 64
channels = 3
n = 1000

In [6]:
dtype = torch.float
device = torch.device("cuda:0")
cuda = True if torch.cuda.is_available() else False
N_per = 100 # permutation times
alpha = 0.05 # test threshold
N1 = n # number of samples in one set
K = 10 # number of trails
J = 1 # number of test locations
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)

In [7]:
# Naming variables
ep_OPT = np.zeros([K])
s_OPT = np.zeros([K])
s0_OPT = np.zeros([K])
T_org_OPT = torch.zeros([K,J,3,64,64]) # Record test locations obtained by MMD-D
DK_Results = np.zeros([1,K])

In [8]:
# Define the deep network for MMD-D
class Featurizer(nn.Module):
    def __init__(self):
        super(Featurizer, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0)] #0.25
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
            # *discriminator_block(128, 256), 
            
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 300))

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        feature = self.adv_layer(out)

        return feature

In [9]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                             shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar10_1/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [10]:
# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    if kk > 5:
        pass
    else:
        torch.manual_seed(kk * 19 + N1)
        torch.cuda.manual_seed(kk * 19 + N1)
        np.random.seed(seed=1102 * (kk + 10) + N1)
        
        # Initialize deep networks for MMD-D (called featurizer)
        featurizer = Featurizer()
        
        # Initialize parameters
        epsilonOPT = torch.log(MatConvert(np.random.rand(1) * 10 ** (-10), device, dtype))
        epsilonOPT.requires_grad = True
        sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype)
        sigmaOPT.requires_grad = True
        sigma0OPT = MatConvert(np.ones(1) * np.sqrt(0.005), device, dtype)
        sigma0OPT.requires_grad = True
        if cuda:
            featurizer.cuda()

        # Collect CIFAR10 images
        Ind_tr = np.random.choice(len(data_all), N1, replace=False)
        Ind_te = np.delete(Ind_all, Ind_tr)
        train_data = []
        for i in Ind_tr:
            train_data.append([data_all[i], label_all[i]])

        dataloader = torch.utils.data.DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=True,
        )

        # Collect CIFAR10.1 images
        np.random.seed(seed=819 * (kk + 9) + N1)
        Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
        Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
        New_CIFAR_tr = data_trans[Ind_tr_v4]
        New_CIFAR_te = data_trans[Ind_te_v4]

        # Initialize optimizers
        optimizer_F = torch.optim.Adam(list(featurizer.parameters()) + [epsilonOPT] + [sigmaOPT] + [sigma0OPT], lr=lr)

        Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

        # ----------------------------------------------------------------------------------------------------
        #  Training deep networks for MMD-D (called featurizer)
        # ----------------------------------------------------------------------------------------------------
        np.random.seed(seed=1102)
        torch.manual_seed(1102)
        torch.cuda.manual_seed(1102)
        for epoch in tqdm(range(n_epochs)):
            for i, (imgs, _) in enumerate(dataloader):
                if True:
                    ind = np.random.choice(N1, imgs.shape[0], replace=False)
                    Fake_imgs = New_CIFAR_tr[ind]
                    # Adversarial ground truths
                    valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                    fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                    # Configure input
                    real_imgs = Variable(imgs.type(Tensor))
                    Fake_imgs = Variable(Fake_imgs.type(Tensor))
                    X = torch.cat([real_imgs, Fake_imgs], 0)
                    Y = torch.cat([valid, fake], 0).squeeze().long()

                    # ------------------------------
                    #  Train deep network for MMD-D
                    # ------------------------------
                    # Initialize optimizer
                    optimizer_F.zero_grad()
                    # Compute output of deep network
                    modelu_output = featurizer(X)
                    # Compute epsilon, sigma and sigma_0
                    ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT))
                    sigma = sigmaOPT ** 2
                    sigma0_u = sigma0OPT ** 2
                    # Compute Compute J (STAT_u)
                    DK_TEMP = MMDu(modelu_output, imgs.shape[0], X.view(X.shape[0],-1), sigma, sigma0_u, ep, complete=False)
                    dk_hsic_xx = torch.from_numpy(np.array(DK_TEMP[3])).to(torch.float32)
                    dk_hsic_yy = torch.from_numpy(np.array(DK_TEMP[4])).to(torch.float32)
                    dk_hsic_xy = torch.from_numpy(np.array(DK_TEMP[5])).to(torch.float32)
                    
                    dk_mmd = -1 * (DK_TEMP[0])
                    dk_mmd_std = torch.sqrt(DK_TEMP[1] + 10 ** (-8))
                    
                    DK_STAT_u = torch.div(dk_mmd, dk_mmd_std) + (dk_hsic_xx + dk_hsic_yy - dk_hsic_xy)
                    if (epoch+1) % 100 == 0:
                        print("-" * 50)
                        print("epoch : ",epoch+1)
                        print("dk mmd: ", -1 * dk_mmd.item(), "dk mmd_std: ", dk_mmd_std.item(), "dk statistic: ",
                        -1 * DK_STAT_u.item())
                        
                    # Compute gradient
                    DK_STAT_u.backward()                
                    # Update weights using gradient descent
                    optimizer_F.step()
                else:
                    break

        # Run two-sample test on the training set
        # Fetch training data
        s1 = data_all[Ind_tr]
        s2 = data_trans[Ind_tr_v4]
        S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
        Sv = S.view(2 * N1, -1)
        # Run two-sample test (MMD-D) on the training set
        # dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)
        # com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

        # Record best epsilon, sigma and sigma_0
        ep_OPT[kk] = ep.item()
        s_OPT[kk] = sigma.item()
        s0_OPT[kk] = sigma0_u.item()

        # Compute test power of MMD-D and baselines
        DK_H_u = np.zeros(N)

        np.random.seed(1102)
        dk_count_u = 0

        for k in tqdm(range(N)):
            # Fetch test data
            np.random.seed(seed=1102 * (k + 1) + N1)
            data_all_te = data_all[Ind_te]
            N_te = len(data_trans)-N1
            Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
            s1 = data_all_te[Ind_N_te]
            s2 = data_trans[Ind_te_v4]
            S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
            Sv = S.view(2 * N_te, -1)
            # MMD-D
            dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N_te, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)

            # Gather results
            dk_count_u = dk_count_u + dk_h_u
            print("\r","MMD-DK:", dk_count_u, "MMD: ", dk_mmd_value_u, end="")
            DK_H_u[k] = dk_h_u

        # Print test power of MMD-D and baselines
        print("DK Reject rate_u: ", DK_H_u.sum() / N_f)
        DK_Results[0, kk] = DK_H_u.sum() / N_f
        print(f"Test Power of DK ({K} times): ")
        print(f"{DK_Results}")
        print(f"Average Test Power of DK ({K} times): ")
        print("MMD-D: ", (DK_Results.sum(1) / (kk + 1))[0])

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
CUDA backend failed to initialize: Found cuPTI version 18, but JAX was built against version 20, which is newer. The copy of cuPTI that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


mmd2:  5.3442083299160004e-05 	 var:  1.2270756997168064e-06
mmd2:  5.192245589569211e-05 	 var:  1.990687451325357e-06
mmd2:  5.914946086704731e-06 	 var:  1.0268922778777778e-06
mmd2:  9.648088598623872e-05 	 var:  2.3553584469482303e-06
mmd2:  4.830287070944905e-05 	 var:  1.4786201063543558e-06
mmd2:  4.8554560635238886e-05 	 var:  1.941465598065406e-06
mmd2:  6.739539094269276e-05 	 var:  1.5097903087735176e-06
mmd2:  7.796008139848709e-05 	 var:  2.360102371312678e-06
mmd2:  3.0437426175922155e-05 	 var:  1.094886101782322e-06
mmd2:  0.00012391043128445745 	 var:  3.544526407495141e-06
mmd2:  0.00012872030492872 	 var:  2.4683322408236563e-06
mmd2:  9.419111302122474e-05 	 var:  2.031432813964784e-06
mmd2:  9.196670725941658e-05 	 var:  1.6588965081609786e-06
mmd2:  8.279841858893633e-05 	 var:  2.6276902644895017e-06
mmd2:  0.0001607865560799837 	 var:  3.401684807613492e-06
mmd2:  9.228842100128531e-05 	 var:  1.934247848112136e-06
mmd2:  0.00010991084855049849 	 var:  3.400084

  0%|          | 0/100 [00:00<?, ?it/s]

mmd2:  2.2781023290008307e-06 	 var:  2.8584269934488788e-11
 MMD-DK: 0 MMD:  2.2781023290008307e-06mmd2:  5.7640718296170235e-06 	 var:  2.296701999766996e-11
 MMD-DK: 0 MMD:  5.7640718296170235e-06mmd2:  8.58324347063899e-06 	 var:  5.331953093797321e-11
 MMD-DK: 1 MMD:  8.58324347063899e-06mmd2:  -2.419808879494667e-06 	 var:  -4.4900383437615266e-12
 MMD-DK: 1 MMD:  -2.419808879494667e-06mmd2:  5.087640602141619e-06 	 var:  3.163806467644684e-11
 MMD-DK: 1 MMD:  5.087640602141619e-06mmd2:  6.796413799747825e-06 	 var:  1.427531706509786e-11
 MMD-DK: 2 MMD:  6.796413799747825e-06mmd2:  2.332933945581317e-06 	 var:  1.6935959445295243e-11
 MMD-DK: 2 MMD:  2.332933945581317e-06mmd2:  5.937152309343219e-06 	 var:  3.828609738774679e-11
 MMD-DK: 2 MMD:  5.937152309343219e-06mmd2:  8.413044270128012e-06 	 var:  5.100920373388276e-11
 MMD-DK: 3 MMD:  8.413044270128012e-06mmd2:  6.1409082263708115e-06 	 var:  2.476375480944316e-11
 MMD-DK: 3 MMD:  6.1409082263708115e-06mmd2:  8.78472928889

  0%|          | 0/1000 [00:00<?, ?it/s]

mmd2:  -1.1644558981060982e-05 	 var:  2.914417564170435e-06
mmd2:  9.932788088917732e-06 	 var:  1.2759555829688907e-06
mmd2:  -1.5848316252231598e-05 	 var:  2.9822949727531523e-06
mmd2:  2.279318869113922e-05 	 var:  1.882915967144072e-06
mmd2:  3.3921562135219574e-05 	 var:  2.831358870025724e-06
mmd2:  9.453389793634415e-06 	 var:  3.629218554124236e-06
mmd2:  0.00012146402150392532 	 var:  8.153961971402168e-06
mmd2:  0.0003063729964196682 	 var:  1.1833144526463002e-05
mmd2:  4.116399213671684e-05 	 var:  1.8618011381477118e-06
mmd2:  0.00010666670277714729 	 var:  6.508998922072351e-06
mmd2:  0.00010850420221686363 	 var:  4.703164449892938e-06
mmd2:  0.00010032625868916512 	 var:  4.0700906538404524e-06
mmd2:  6.8712979555130005e-06 	 var:  2.461183612467721e-06
mmd2:  8.580973371863365e-05 	 var:  3.937355359084904e-06
mmd2:  1.4271121472120285e-05 	 var:  1.9228718883823603e-06
mmd2:  6.251735612750053e-05 	 var:  4.984212864656001e-06
mmd2:  0.00016192346811294556 	 var:  5

  0%|          | 0/100 [00:00<?, ?it/s]

mmd2:  8.484814316034317e-06 	 var:  3.579030157293413e-11
 MMD-DK: 1 MMD:  8.484814316034317e-06mmd2:  1.7442565876990557e-05 	 var:  7.302161273393213e-11
 MMD-DK: 2 MMD:  1.7442565876990557e-05mmd2:  1.879414776340127e-05 	 var:  9.129925327116089e-11
 MMD-DK: 3 MMD:  1.879414776340127e-05mmd2:  2.7129019144922495e-05 	 var:  1.4287469775006066e-10
 MMD-DK: 4 MMD:  2.7129019144922495e-05mmd2:  2.969906199723482e-05 	 var:  1.458391885239507e-10
 MMD-DK: 5 MMD:  2.969906199723482e-05mmd2:  1.0577670764178038e-05 	 var:  4.9854097181332615e-11
 MMD-DK: 6 MMD:  1.0577670764178038e-05mmd2:  1.2754346244037151e-05 	 var:  4.8581338185031494e-11
 MMD-DK: 7 MMD:  1.2754346244037151e-05mmd2:  2.002331893891096e-05 	 var:  8.097272590303441e-11
 MMD-DK: 8 MMD:  2.002331893891096e-05mmd2:  1.714227255433798e-05 	 var:  5.796209064869591e-11
 MMD-DK: 9 MMD:  1.714227255433798e-05mmd2:  2.4064502213150263e-05 	 var:  9.42715289290946e-11
 MMD-DK: 10 MMD:  2.4064502213150263e-05mmd2:  1.40303163

  0%|          | 0/1000 [00:00<?, ?it/s]

mmd2:  -9.445357136428356e-05 	 var:  9.38682205742225e-07
mmd2:  -1.7430167645215988e-05 	 var:  6.463465979322791e-06
mmd2:  -3.453134559094906e-05 	 var:  2.4192704586312175e-06
mmd2:  -4.1351886466145515e-05 	 var:  2.656983269844204e-06
mmd2:  -4.136376082897186e-05 	 var:  2.7056230464950204e-06
mmd2:  -1.4712568372488022e-06 	 var:  3.0719093047082424e-06
mmd2:  -2.6677269488573074e-05 	 var:  2.5374683900736272e-06
mmd2:  3.722985275089741e-05 	 var:  2.587679773569107e-06
mmd2:  -4.799314774572849e-05 	 var:  2.6915477064903826e-06
mmd2:  4.670256748795509e-05 	 var:  4.7295980039052665e-06
mmd2:  9.527383372187614e-05 	 var:  8.876028005033731e-06
mmd2:  3.303214907646179e-05 	 var:  4.490932042244822e-06
mmd2:  -3.241468220949173e-06 	 var:  1.6118247003760189e-06
mmd2:  -5.9845391660928726e-05 	 var:  2.5013505364768207e-06
mmd2:  5.100807175040245e-05 	 var:  8.123541192617267e-06
mmd2:  3.2373471185564995e-05 	 var:  2.489825419615954e-06
mmd2:  1.2691598385572433e-05 	 v

  0%|          | 0/100 [00:00<?, ?it/s]

mmd2:  1.6588717699050903e-05 	 var:  3.5888334451348445e-11
 MMD-DK: 1 MMD:  1.6588717699050903e-05mmd2:  1.6968842828646302e-05 	 var:  4.391240469575462e-11
 MMD-DK: 2 MMD:  1.6968842828646302e-05mmd2:  1.2095901183784008e-05 	 var:  4.200230678528157e-11
 MMD-DK: 3 MMD:  1.2095901183784008e-05mmd2:  1.1006079148501158e-05 	 var:  2.267737642613892e-11
 MMD-DK: 4 MMD:  1.1006079148501158e-05mmd2:  1.9751809304580092e-05 	 var:  5.6236578761691825e-11
 MMD-DK: 5 MMD:  1.9751809304580092e-05mmd2:  1.2304371921345592e-05 	 var:  2.6778152752020006e-11
 MMD-DK: 6 MMD:  1.2304371921345592e-05mmd2:  7.49467290006578e-06 	 var:  2.5834614399998776e-11
 MMD-DK: 6 MMD:  7.49467290006578e-06mmd2:  1.9278813851997256e-05 	 var:  5.7764178523516365e-11
 MMD-DK: 7 MMD:  1.9278813851997256e-05mmd2:  1.2511183740571141e-05 	 var:  2.7538321130822798e-11
 MMD-DK: 8 MMD:  1.2511183740571141e-05mmd2:  1.3593671610578895e-05 	 var:  3.4829019848408187e-11
 MMD-DK: 9 MMD:  1.3593671610578895e-05mmd2:  

  0%|          | 0/1000 [00:00<?, ?it/s]

mmd2:  -4.00543212890625e-05 	 var:  1.680977948126383e-05
mmd2:  0.00015807710587978363 	 var:  3.3258394978474826e-05
mmd2:  0.0003879610449075699 	 var:  6.032284727552906e-05
mmd2:  0.00013324618339538574 	 var:  2.0149745978415012e-05
mmd2:  0.00016357749700546265 	 var:  2.892308111768216e-05
mmd2:  4.066154360771179e-05 	 var:  1.7281388863921165e-05
mmd2:  0.00020493380725383759 	 var:  3.132070560241118e-05
mmd2:  0.0002534892410039902 	 var:  4.8257912567351013e-05
mmd2:  0.00012254342436790466 	 var:  3.0168623197823763e-05
mmd2:  0.00022895447909832 	 var:  3.7408695789054036e-05
mmd2:  8.367002010345459e-05 	 var:  2.3528824385721236e-05
mmd2:  0.0002500899136066437 	 var:  4.242744034854695e-05
mmd2:  0.00019911490380764008 	 var:  3.3187709050253034e-05
mmd2:  0.0005015507340431213 	 var:  5.424470145953819e-05
mmd2:  0.00012610666453838348 	 var:  1.831366535043344e-05
mmd2:  0.000187702476978302 	 var:  3.3999393053818494e-05
mmd2:  0.00035967491567134857 	 var:  4.728

  0%|          | 0/100 [00:00<?, ?it/s]

mmd2:  7.4160052463412285e-06 	 var:  5.072137136254216e-11
 MMD-DK: 0 MMD:  7.4160052463412285e-06mmd2:  1.0254501830786467e-05 	 var:  6.260952476288605e-11
 MMD-DK: 1 MMD:  1.0254501830786467e-05mmd2:  3.0435039661824703e-06 	 var:  2.5232730408963958e-11
 MMD-DK: 1 MMD:  3.0435039661824703e-06mmd2:  1.5338649973273277e-05 	 var:  7.205646025740566e-11
 MMD-DK: 2 MMD:  1.5338649973273277e-05mmd2:  1.180986873805523e-05 	 var:  5.65145148703843e-11
 MMD-DK: 3 MMD:  1.180986873805523e-05mmd2:  2.3914442863315344e-05 	 var:  1.5510703267983267e-10
 MMD-DK: 4 MMD:  2.3914442863315344e-05mmd2:  7.692782673984766e-06 	 var:  4.0738134953567666e-11
 MMD-DK: 4 MMD:  7.692782673984766e-06mmd2:  1.113285543397069e-05 	 var:  2.9856762018409886e-11
 MMD-DK: 4 MMD:  1.113285543397069e-05mmd2:  1.3665761798620224e-05 	 var:  3.3687950214391574e-11
 MMD-DK: 5 MMD:  1.3665761798620224e-05mmd2:  1.6593257896602154e-05 	 var:  1.092432719367595e-10
 MMD-DK: 6 MMD:  1.6593257896602154e-05mmd2:  1.250

  0%|          | 0/1000 [00:00<?, ?it/s]

mmd2:  0.00025306129828095436 	 var:  1.8044549506157637e-05
mmd2:  4.773959517478943e-06 	 var:  3.7134959711693227e-06
mmd2:  -1.6445759683847427e-05 	 var:  4.338595317676663e-06
mmd2:  -2.3268163204193115e-05 	 var:  3.0006340239197016e-06
mmd2:  -5.4784584790468216e-05 	 var:  3.199424099875614e-06
mmd2:  0.0001026540994644165 	 var:  7.770431693643332e-06
mmd2:  3.858562558889389e-05 	 var:  7.483926310669631e-06
mmd2:  4.8148445785045624e-05 	 var:  6.100119207985699e-06
mmd2:  -3.516674041748047e-05 	 var:  4.513945896178484e-06
mmd2:  -6.672926247119904e-07 	 var:  2.5884000933729112e-06
mmd2:  -2.8927810490131378e-05 	 var:  3.101642505498603e-06
mmd2:  0.00021622143685817719 	 var:  1.353980042040348e-05
mmd2:  3.8714148104190826e-05 	 var:  3.4682961995713413e-06
mmd2:  0.00031133322045207024 	 var:  2.1417370589915663e-05
mmd2:  0.000138190109282732 	 var:  6.347152520902455e-06
mmd2:  5.131121724843979e-05 	 var:  6.991831469349563e-06
mmd2:  0.00027754809707403183 	 var:

  0%|          | 0/100 [00:00<?, ?it/s]

mmd2:  7.856055162847042e-06 	 var:  4.5046087887803845e-11
 MMD-DK: 0 MMD:  7.856055162847042e-06mmd2:  1.0298361303284764e-05 	 var:  2.0953168618753758e-11
 MMD-DK: 1 MMD:  1.0298361303284764e-05mmd2:  2.2640248062089086e-05 	 var:  7.732780225036104e-11
 MMD-DK: 2 MMD:  2.2640248062089086e-05mmd2:  2.0423787645995617e-05 	 var:  8.178197905908156e-11
 MMD-DK: 3 MMD:  2.0423787645995617e-05mmd2:  1.9130297005176544e-05 	 var:  6.842004850789011e-11
 MMD-DK: 4 MMD:  1.9130297005176544e-05mmd2:  1.1524301953613758e-05 	 var:  4.122343529328745e-11
 MMD-DK: 5 MMD:  1.1524301953613758e-05mmd2:  1.3566634152084589e-05 	 var:  7.29417626849308e-11
 MMD-DK: 6 MMD:  1.3566634152084589e-05mmd2:  1.6292440705001354e-05 	 var:  6.19769795841755e-11
 MMD-DK: 7 MMD:  1.6292440705001354e-05mmd2:  3.279210068285465e-05 	 var:  1.7838705674870888e-10
 MMD-DK: 8 MMD:  3.279210068285465e-05mmd2:  5.979236448183656e-06 	 var:  2.4005238303802366e-11
 MMD-DK: 8 MMD:  5.979236448183656e-06mmd2:  8.55035

  0%|          | 0/1000 [00:00<?, ?it/s]

mmd2:  -9.501539170742035e-05 	 var:  5.310666892910376e-06
mmd2:  0.00015288032591342926 	 var:  1.1653180990833789e-05
mmd2:  0.00012441398575901985 	 var:  9.847877663560212e-06
mmd2:  0.00014861347153782845 	 var:  1.1681193427648395e-05
mmd2:  0.0001563667319715023 	 var:  1.4722034393344074e-05
mmd2:  2.7795322239398956e-06 	 var:  7.888949767220765e-06
mmd2:  0.0004232097417116165 	 var:  3.6618301237467676e-05
mmd2:  9.419210255146027e-05 	 var:  6.844988092780113e-06
mmd2:  1.0007526725530624e-05 	 var:  1.0947289410978556e-05
mmd2:  8.51089134812355e-05 	 var:  9.777933883015066e-06
mmd2:  0.00014300737529993057 	 var:  9.925730410031974e-06
mmd2:  0.0002872804179787636 	 var:  2.389666042290628e-05
mmd2:  0.00018343515694141388 	 var:  1.2997639714740217e-05
mmd2:  0.00017143599689006805 	 var:  1.5308891306631267e-05
mmd2:  0.0002113720402121544 	 var:  1.3485303497873247e-05
mmd2:  0.00030031008645892143 	 var:  2.3562002752441913e-05
mmd2:  0.0001445012167096138 	 var:  1

  0%|          | 0/100 [00:00<?, ?it/s]

mmd2:  6.787420716136694e-06 	 var:  4.450280039255586e-11
 MMD-DK: 0 MMD:  6.787420716136694e-06mmd2:  1.1402240488678217e-05 	 var:  4.531981535242569e-11
 MMD-DK: 1 MMD:  1.1402240488678217e-05mmd2:  1.1419644579291344e-05 	 var:  4.1678836813546574e-11
 MMD-DK: 2 MMD:  1.1419644579291344e-05mmd2:  7.813883712515235e-06 	 var:  3.0670154627387945e-11
 MMD-DK: 3 MMD:  7.813883712515235e-06mmd2:  7.639988325536251e-06 	 var:  4.2489349164781124e-11
 MMD-DK: 4 MMD:  7.639988325536251e-06mmd2:  1.5033525414764881e-05 	 var:  3.6064539519262984e-11
 MMD-DK: 5 MMD:  1.5033525414764881e-05mmd2:  1.1706550139933825e-05 	 var:  5.3747515530962235e-11
 MMD-DK: 6 MMD:  1.1706550139933825e-05mmd2:  1.1056603398174047e-05 	 var:  4.985149785605185e-11
 MMD-DK: 7 MMD:  1.1056603398174047e-05mmd2:  6.156507879495621e-06 	 var:  1.7781118054010672e-11
 MMD-DK: 7 MMD:  6.156507879495621e-06mmd2:  1.426198286935687e-05 	 var:  5.2735043030937626e-11
 MMD-DK: 8 MMD:  1.426198286935687e-05mmd2:  1.8999

#### Epoch 20 / lr: 0.0005
[[0.23 0.94 0.25 0.37 0.6  0.1  0.36 0.31 0.23 0.4 ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Baselines (10 times): 
MMD-D:  0.379

In [11]:
np.mean([0.25, 0.93, 0.94, 0.68, 0.8, 0.6, 0.93, 0.77, 0.78, 0.83])

0.7510000000000001