In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
from tqdm.auto import tqdm
import numpy as np
import jax.numpy as jnp

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

from utils_HD import MatConvert, MMDu, TST_MMD_u

In [4]:
np.random.seed(819)
torch.manual_seed(819)
torch.cuda.manual_seed(819)
torch.backends.cudnn.deterministic = True
is_cuda = True

In [5]:
# Parameter Settings
n_epochs = 300
batch_size = 100
lr = 0.00002
img_size = 64
channels = 3
n = 1000

In [6]:
dtype = torch.float
device = torch.device("cuda:0")
cuda = True if torch.cuda.is_available() else False
N_per = 100 # permutation times
alpha = 0.05 # test threshold
N1 = n # number of samples in one set
K = 10 # number of trails
J = 1 # number of test locations
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)

In [7]:
# Naming variables
ep_OPT = np.zeros([K])
s_OPT = np.zeros([K])
s0_OPT = np.zeros([K])
T_org_OPT = torch.zeros([K,J,3,64,64]) # Record test locations obtained by MMD-D
COM_Results = np.zeros([1,K])

In [8]:
# Define the deep network for MMD-D
class Featurizer_COM(nn.Module):
    def __init__(self):
        super(Featurizer_COM, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0)] #0.25
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 300))

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        feature = self.adv_layer(out)

        return feature

In [9]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                             shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar10_1/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [10]:
# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)
    
    # Initialize deep networks for MMD-D (called featurizer)
    featurizer_com = Featurizer_COM()
    
    # Initialize parameters
    epsilonOPT = torch.log(MatConvert(np.random.rand(1) * 10 ** (-10), device, dtype))
    epsilonOPT.requires_grad = True
    sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype)
    sigmaOPT.requires_grad = True
    sigma0OPT = MatConvert(np.ones(1) * np.sqrt(0.005), device, dtype)
    sigma0OPT.requires_grad = True
    if cuda:
        featurizer_com.cuda()

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), N1, replace=False)
    Ind_te = np.delete(Ind_all, Ind_tr)
    train_data = []
    for i in Ind_tr:
       train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
    Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]

    # Initialize optimizers
    optimizer_COM = torch.optim.Adam(list(featurizer_com.parameters()) + [epsilonOPT] + [sigmaOPT] + [sigma0OPT], lr=lr)

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # ----------------------------------------------------------------------------------------------------
    #  Training deep networks for MMD-D (called featurizer)
    # ----------------------------------------------------------------------------------------------------
    np.random.seed(seed=1102)
    torch.manual_seed(1102)
    torch.cuda.manual_seed(1102)
    for epoch in tqdm(range(n_epochs)):
        for i, (imgs, _) in enumerate(dataloader):
            if True:
                ind = np.random.choice(N1, imgs.shape[0], replace=False)
                Fake_imgs = New_CIFAR_tr[ind]
                # Adversarial ground truths
                valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                # Configure input
                real_imgs = Variable(imgs.type(Tensor))
                Fake_imgs = Variable(Fake_imgs.type(Tensor))
                X = torch.cat([real_imgs, Fake_imgs], 0)
                Y = torch.cat([valid, fake], 0).squeeze().long()

                # ------------------------------
                #  Train deep network for MMD-D
                # ------------------------------
                # Initialize optimizer
                optimizer_COM.zero_grad()
                # Compute output of deep network
                com_modelu_output = featurizer_com(X)
                # Compute epsilon, sigma and sigma_0
                ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT))
                sigma = sigmaOPT ** 2
                sigma0_u = sigma0OPT ** 2
                # Compute Compute J (STAT_u)
                COM_TEMP = MMDu(com_modelu_output, imgs.shape[0], X.view(X.shape[0],-1), sigma, sigma0_u, ep, complete=True)
                com_mmd = -1 * (COM_TEMP[0])
                com_mmd_std_tmp = np.abs(COM_TEMP[1])
                com_mmd_std_tmp_np = np.array(com_mmd_std_tmp) 
                com_mmd_std_tmp_tensor = torch.from_numpy(com_mmd_std_tmp_np).to(torch.float32)
                # com_mmd_std = torch.sqrt(com_mmd_std_tmp_tensor + 10 ** (-8))
                com_mmd_std = torch.sqrt(com_mmd_std_tmp_tensor)
                COM_STAT_u = torch.div(com_mmd, com_mmd_std)
                
                if COM_TEMP[1] < 0:
                    COM_STAT_u = - torch.abs(COM_STAT_u)                
                
                if (epoch+1) % 100 == 0:
                    print("=" * 110)
                    print("epoch : ",epoch+1)
                    print("our mmd: ", -1 * com_mmd.item(), "our mmd_std: ", com_mmd_std.item(), "our statistic: ",
                    -1 * COM_STAT_u.item())
                    
                # Compute gradient
                COM_STAT_u.backward()
                
                # Update weights using gradient descent
                optimizer_COM.step()
            else:
                break

    # Run two-sample test on the training set
    # Fetch training data
    s1 = data_all[Ind_tr]
    s2 = data_trans[Ind_tr_v4]
    S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
    Sv = S.view(2 * N1, -1)
    # Run two-sample test (MMD-D) on the training set
    # dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)
    # com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

    # Record best epsilon, sigma and sigma_0
    ep_OPT[kk] = ep.item()
    s_OPT[kk] = sigma.item()
    s0_OPT[kk] = sigma0_u.item()

    # Compute test power of MMD-D and baselines
    DK_H_u = np.zeros(N)
    COM_H_u = np.zeros(N)

    np.random.seed(1102)
    dk_count_u = 0
    com_count_u = 0

    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans)-N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1 = data_all_te[Ind_N_te]
        s2 = data_trans[Ind_te_v4]
        S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
        Sv = S.view(2 * N_te, -1)
        # MMD-D
        com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer_com(S), N_per, N_te, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

        # Gather results
        com_count_u = com_count_u + com_h_u
        print("\r","Ours:", com_count_u, "MMD: ", com_mmd_value_u, end="")
        COM_H_u[k] = com_h_u

    # Print test power of MMD-D and baselines
    print("Our Reject rate_u: ", COM_H_u.sum() / N_f)
    COM_Results[0, kk] = COM_H_u.sum() / N_f
    print(f"Test Power of Ours ({K} times): ")
    print(f"{COM_Results}")
    print(f"Average Test Power of Ours ({K} times): ")
    print("Ours: ", (COM_Results.sum(1) / (kk + 1))[0])

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/300 [00:00<?, ?it/s]

  valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
CUDA backend failed to initialize: Found cuPTI version 18, but JAX was built against version 20, which is newer. The copy of cuPTI that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


epoch :  100
our mmd:  0.002233915962278843 our mmd_std:  0.0012916913256049156 our statistic:  1.7294503450393677
epoch :  100
our mmd:  0.0061654821038246155 our mmd_std:  0.0025289563927799463 our statistic:  2.437954902648926
epoch :  100
our mmd:  0.008945474401116371 our mmd_std:  0.003397854743525386 our statistic:  2.632683038711548
epoch :  100
our mmd:  0.007460477761924267 our mmd_std:  0.0029259975999593735 our statistic:  2.5497212409973145
epoch :  100
our mmd:  0.015224603936076164 our mmd_std:  0.004931608680635691 our statistic:  3.0871477127075195
epoch :  100
our mmd:  0.010921857319772243 our mmd_std:  0.003300221636891365 our statistic:  3.309431314468384
epoch :  100
our mmd:  0.010709687136113644 our mmd_std:  0.0037687444128096104 our statistic:  2.841712236404419
epoch :  100
our mmd:  0.004723799414932728 our mmd_std:  0.001822067890316248 our statistic:  2.5925486087799072
epoch :  100
our mmd:  0.008850125595927238 our mmd_std:  0.002725842874497175 our stat

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 15 MMD:  -7.656402885913849e-06Our Reject rate_u:  0.15
Test Power of Ours (10 times): 
[[0.15 0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.15


  0%|          | 0/300 [00:00<?, ?it/s]

epoch :  100
our mmd:  5.133450031280518e-06 our mmd_std:  0.0010753674432635307 our statistic:  0.004773670807480812
epoch :  100
our mmd:  0.0012679770588874817 our mmd_std:  0.0017069708555936813 our statistic:  0.7428228855133057
epoch :  100
our mmd:  0.0008050538599491119 our mmd_std:  0.0008332578581757843 our statistic:  0.9661520719528198
epoch :  100
our mmd:  0.0022474341094493866 our mmd_std:  0.0010122522944584489 our statistic:  2.220231056213379
epoch :  100
our mmd:  0.0034379251301288605 our mmd_std:  0.002201561816036701 our statistic:  1.5615845918655396
epoch :  100
our mmd:  -0.0003032572567462921 our mmd_std:  0.0007983584655448794 our statistic:  0.37985101342201233
epoch :  100
our mmd:  0.0005818940699100494 our mmd_std:  0.0001591578620718792 our statistic:  3.656081438064575
epoch :  100
our mmd:  0.001464594155550003 our mmd_std:  0.0013382163597270846 our statistic:  1.0944374799728394
epoch :  100
our mmd:  0.0014075953513383865 our mmd_std:  0.00153239734

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 90 MMD:  0.0003214627504348755Our Reject rate_u:  0.9
Test Power of Ours (10 times): 
[[0.15 0.9  0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.525


  0%|          | 0/300 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.0010426081717014313 our mmd_std:  0.0015849281335249543 our statistic:  0.6578267812728882
epoch :  100
our mmd:  -0.0008924342691898346 our mmd_std:  0.0014330695848912 our statistic:  0.6227431297302246
epoch :  100
our mmd:  -0.0010189376771450043 our mmd_std:  0.0014203991740942001 our statistic:  0.7173601388931274
epoch :  100
our mmd:  -0.0020316168665885925 our mmd_std:  0.0016329704085364938 our statistic:  1.2441235780715942
epoch :  100
our mmd:  -0.0014815479516983032 our mmd_std:  0.0014338763430714607 our statistic:  1.033246636390686
epoch :  100
our mmd:  -0.0009102225303649902 our mmd_std:  0.0016061392379924655 our statistic:  0.5667145848274231
epoch :  100
our mmd:  0.00042654573917388916 our mmd_std:  0.0008012385223992169 our statistic:  0.5323579907417297
epoch :  100
our mmd:  0.00028518959879875183 our mmd_std:  0.0010714695090427995 our statistic:  0.2661667764186859
epoch :  100
our mmd:  0.0007398054003715515 our mmd_std:  0.00177291

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 63 MMD:  0.00016325898468494415Our Reject rate_u:  0.63
Test Power of Ours (10 times): 
[[0.15 0.9  0.63 0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.56


  0%|          | 0/300 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.006182124838232994 our mmd_std:  0.002318655839189887 our statistic:  2.6662538051605225
epoch :  100
our mmd:  0.0018658973276615143 our mmd_std:  0.0010087810223922133 our statistic:  1.8496555089950562
epoch :  100
our mmd:  0.004634648561477661 our mmd_std:  0.002154272049665451 our statistic:  2.1513757705688477
epoch :  100
our mmd:  -0.00037083961069583893 our mmd_std:  0.00042913536890409887 our statistic:  -0.8641553521156311
epoch :  100
our mmd:  0.00324944406747818 our mmd_std:  0.0019304367015138268 our statistic:  1.6832689046859741
epoch :  100
our mmd:  0.0031028706580400467 our mmd_std:  0.0019154839683324099 our statistic:  1.6198886632919312
epoch :  100
our mmd:  0.00477607361972332 our mmd_std:  0.0022627946455031633 our statistic:  2.110696792602539
epoch :  100
our mmd:  0.002135038375854492 our mmd_std:  0.001156988088041544 our statistic:  1.845341682434082
epoch :  100
our mmd:  0.006224973127245903 our mmd_std:  0.003057340858504176 o

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 38 MMD:  0.00032260268926620483Our Reject rate_u:  0.38
Test Power of Ours (10 times): 
[[0.15 0.9  0.63 0.38 0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.515


  0%|          | 0/300 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.0018117539584636688 our mmd_std:  0.0009340160177089274 our statistic:  1.9397462606430054
epoch :  100
our mmd:  0.0028307102620601654 our mmd_std:  0.0021101858001202345 our statistic:  1.3414506912231445
epoch :  100
our mmd:  0.00015755370259284973 our mmd_std:  0.0012205529492348433 our statistic:  0.12908387184143066
epoch :  100
our mmd:  0.00424872525036335 our mmd_std:  0.0029499188531190157 our statistic:  1.4402854442596436
epoch :  100
our mmd:  0.0017331354320049286 our mmd_std:  0.0016122752567753196 our statistic:  1.0749624967575073
epoch :  100
our mmd:  0.0055418238043785095 our mmd_std:  0.002492387080565095 our statistic:  2.2235004901885986
epoch :  100
our mmd:  3.5878270864486694e-05 our mmd_std:  0.0003524832136463374 our statistic:  0.10178717970848083
epoch :  100
our mmd:  0.001145884394645691 our mmd_std:  0.0010696843964979053 our statistic:  1.0712358951568604
epoch :  100
our mmd:  0.0009268559515476227 our mmd_std:  0.00095992564

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 57 MMD:  0.00021762773394584656Our Reject rate_u:  0.57
Test Power of Ours (10 times): 
[[0.15 0.9  0.63 0.38 0.57 0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.526


  0%|          | 0/300 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.0030839089304208755 our mmd_std:  0.0013473297003656626 our statistic:  2.2889044284820557
epoch :  100
our mmd:  0.0135405994951725 our mmd_std:  0.003913909196853638 our statistic:  3.4596099853515625
epoch :  100
our mmd:  0.010842767544090748 our mmd_std:  0.0037727367598563433 our statistic:  2.873979330062866
epoch :  100
our mmd:  0.0038216132670640945 our mmd_std:  0.0018062187591567636 our statistic:  2.1158087253570557
epoch :  100
our mmd:  0.003666697070002556 our mmd_std:  0.0017732653068378568 our statistic:  2.067765712738037
epoch :  100
our mmd:  0.003311866894364357 our mmd_std:  0.001550891320221126 our statistic:  2.135460376739502
epoch :  100
our mmd:  0.00861006136983633 our mmd_std:  0.0033876586239784956 our statistic:  2.5415964126586914
epoch :  100
our mmd:  0.006055695004761219 our mmd_std:  0.0028386611957103014 our statistic:  2.1332926750183105
epoch :  100
our mmd:  0.010219930671155453 our mmd_std:  0.0037592004518955946 our st

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 3 MMD:  -4.0675513446331024e-05Our Reject rate_u:  0.03
Test Power of Ours (10 times): 
[[0.15 0.9  0.63 0.38 0.57 0.03 0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.44333333333333336


  0%|          | 0/300 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.005221914499998093 our mmd_std:  0.002344887936487794 our statistic:  2.226935625076294
epoch :  100
our mmd:  0.007641952484846115 our mmd_std:  0.0029727709479629993 our statistic:  2.5706496238708496
epoch :  100
our mmd:  0.005985978990793228 our mmd_std:  0.002572661265730858 our statistic:  2.326765298843384
epoch :  100
our mmd:  0.0019083824008703232 our mmd_std:  0.0012439440470188856 our statistic:  1.5341384410858154
epoch :  100
our mmd:  0.0030144043266773224 our mmd_std:  0.0019935101736336946 our statistic:  1.5121088027954102
epoch :  100
our mmd:  0.008997542783617973 our mmd_std:  0.0032420740462839603 our statistic:  2.775242805480957
epoch :  100
our mmd:  0.008497695438563824 our mmd_std:  0.003004231257364154 our statistic:  2.828575611114502
epoch :  100
our mmd:  0.006791712716221809 our mmd_std:  0.002959428820759058 our statistic:  2.29494047164917
epoch :  100
our mmd:  0.0048507871106266975 our mmd_std:  0.0023296058643609285 our sta

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 19 MMD:  0.00018309429287910461Our Reject rate_u:  0.19
Test Power of Ours (10 times): 
[[0.15 0.9  0.63 0.38 0.57 0.03 0.19 0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.40714285714285714


  0%|          | 0/300 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.005612862296402454 our mmd_std:  0.0023292212281376123 our statistic:  2.409759044647217
epoch :  100
our mmd:  0.0029052970930933952 our mmd_std:  0.0013534484896808863 our statistic:  2.1465885639190674
epoch :  100
our mmd:  0.002242308109998703 our mmd_std:  0.0011549446498975158 our statistic:  1.9414852857589722
epoch :  100
our mmd:  0.006031271070241928 our mmd_std:  0.0021802629344165325 our statistic:  2.7663044929504395
epoch :  100
our mmd:  0.005683175288140774 our mmd_std:  0.0020382057409733534 our statistic:  2.788322687149048
epoch :  100
our mmd:  0.0064643314108252525 our mmd_std:  0.0025082803331315517 our statistic:  2.5771965980529785
epoch :  100
our mmd:  0.00525189284235239 our mmd_std:  0.0018173630814999342 our statistic:  2.8898425102233887
epoch :  100
our mmd:  0.004374371841549873 our mmd_std:  0.0019270931370556355 our statistic:  2.269932746887207
epoch :  100
our mmd:  0.00393357127904892 our mmd_std:  0.0015366104198619723 our

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 13 MMD:  2.997554838657379e-0535

#### Epoch 20 / lr: 0.0005
[[0.23 0.94 0.25 0.37 0.6  0.1  0.36 0.31 0.23 0.4 ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Baselines (10 times): 
MMD-D:  0.379