In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [3]:
from tqdm.auto import tqdm
import numpy as np
import jax.numpy as jnp

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

from utils_HD import MatConvert, MMDu, TST_MMD_u

In [4]:
np.random.seed(819)
torch.manual_seed(819)
torch.cuda.manual_seed(819)
torch.backends.cudnn.deterministic = True
is_cuda = True

In [5]:
# Parameter Settings
n_epochs = 1000
batch_size = 100
lr = 0.0002
img_size = 64
channels = 3
n = 1000

In [6]:
dtype = torch.float
device = torch.device("cuda:0")
cuda = True if torch.cuda.is_available() else False
N_per = 100 # permutation times
alpha = 0.05 # test threshold
N1 = n # number of samples in one set
K = 10 # number of trails
J = 1 # number of test locations
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)

In [7]:
# Naming variables
ep_OPT = np.zeros([K])
s_OPT = np.zeros([K])
s0_OPT = np.zeros([K])
T_org_OPT = torch.zeros([K,J,3,64,64]) # Record test locations obtained by MMD-D
DK_Results = np.zeros([1,K])

In [8]:
# Define the deep network for MMD-D
class Featurizer(nn.Module):
    def __init__(self):
        super(Featurizer, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0)] #0.25
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
            # *discriminator_block(128, 256), 
            
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 300))

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        feature = self.adv_layer(out)

        return feature

In [9]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                             shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar10_1/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [11]:
# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    if kk < 5:
        pass
    else:
        torch.manual_seed(kk * 19 + N1)
        torch.cuda.manual_seed(kk * 19 + N1)
        np.random.seed(seed=1102 * (kk + 10) + N1)
        
        # Initialize deep networks for MMD-D (called featurizer)
        featurizer = Featurizer()
        
        # Initialize parameters
        epsilonOPT = torch.log(MatConvert(np.random.rand(1) * 10 ** (-10), device, dtype))
        epsilonOPT.requires_grad = True
        sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype)
        sigmaOPT.requires_grad = True
        sigma0OPT = MatConvert(np.ones(1) * np.sqrt(0.005), device, dtype)
        sigma0OPT.requires_grad = True
        if cuda:
            featurizer.cuda()

        # Collect CIFAR10 images
        Ind_tr = np.random.choice(len(data_all), N1, replace=False)
        Ind_te = np.delete(Ind_all, Ind_tr)
        train_data = []
        for i in Ind_tr:
            train_data.append([data_all[i], label_all[i]])

        dataloader = torch.utils.data.DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=True,
        )

        # Collect CIFAR10.1 images
        np.random.seed(seed=819 * (kk + 9) + N1)
        Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
        Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
        New_CIFAR_tr = data_trans[Ind_tr_v4]
        New_CIFAR_te = data_trans[Ind_te_v4]

        # Initialize optimizers
        optimizer_F = torch.optim.Adam(list(featurizer.parameters()) + [epsilonOPT] + [sigmaOPT] + [sigma0OPT], lr=lr)

        Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

        # ----------------------------------------------------------------------------------------------------
        #  Training deep networks for MMD-D (called featurizer)
        # ----------------------------------------------------------------------------------------------------
        np.random.seed(seed=1102)
        torch.manual_seed(1102)
        torch.cuda.manual_seed(1102)
        for epoch in tqdm(range(n_epochs)):
            for i, (imgs, _) in enumerate(dataloader):
                if True:
                    ind = np.random.choice(N1, imgs.shape[0], replace=False)
                    Fake_imgs = New_CIFAR_tr[ind]
                    # Adversarial ground truths
                    valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                    fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                    # Configure input
                    real_imgs = Variable(imgs.type(Tensor))
                    Fake_imgs = Variable(Fake_imgs.type(Tensor))
                    X = torch.cat([real_imgs, Fake_imgs], 0)
                    Y = torch.cat([valid, fake], 0).squeeze().long()

                    # ------------------------------
                    #  Train deep network for MMD-D
                    # ------------------------------
                    # Initialize optimizer
                    optimizer_F.zero_grad()
                    # Compute output of deep network
                    modelu_output = featurizer(X)
                    # Compute epsilon, sigma and sigma_0
                    ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT))
                    sigma = sigmaOPT ** 2
                    sigma0_u = sigma0OPT ** 2
                    # Compute Compute J (STAT_u)
                    DK_TEMP = MMDu(modelu_output, imgs.shape[0], X.view(X.shape[0],-1), sigma, sigma0_u, ep, complete=False)
                    dk_hsic_xx = torch.from_numpy(np.array(DK_TEMP[3])).to(torch.float32)
                    dk_hsic_yy = torch.from_numpy(np.array(DK_TEMP[4])).to(torch.float32)
                    dk_hsic_xy = torch.from_numpy(np.array(DK_TEMP[5])).to(torch.float32)
                    
                    dk_mmd = -1 * (DK_TEMP[0])
                    dk_mmd_std = torch.sqrt(DK_TEMP[1] + 10 ** (-8))
                    
                    DK_STAT_u = torch.div(dk_mmd, dk_mmd_std) + (dk_hsic_xx + dk_hsic_yy - dk_hsic_xy)
                    if (epoch+1) % 100 == 0:
                        print("-" * 50)
                        print("epoch : ",epoch+1)
                        print("dk mmd: ", -1 * dk_mmd.item(), "dk mmd_std: ", dk_mmd_std.item(), "dk statistic: ",
                        -1 * DK_STAT_u.item())
                        
                    # Compute gradient
                    DK_STAT_u.backward()                
                    # Update weights using gradient descent
                    optimizer_F.step()
                else:
                    break

        # Run two-sample test on the training set
        # Fetch training data
        s1 = data_all[Ind_tr]
        s2 = data_trans[Ind_tr_v4]
        S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
        Sv = S.view(2 * N1, -1)
        # Run two-sample test (MMD-D) on the training set
        # dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)
        # com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

        # Record best epsilon, sigma and sigma_0
        ep_OPT[kk] = ep.item()
        s_OPT[kk] = sigma.item()
        s0_OPT[kk] = sigma0_u.item()

        # Compute test power of MMD-D and baselines
        DK_H_u = np.zeros(N)

        np.random.seed(1102)
        dk_count_u = 0

        for k in tqdm(range(N)):
            # Fetch test data
            np.random.seed(seed=1102 * (k + 1) + N1)
            data_all_te = data_all[Ind_te]
            N_te = len(data_trans)-N1
            Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
            s1 = data_all_te[Ind_N_te]
            s2 = data_trans[Ind_te_v4]
            S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
            Sv = S.view(2 * N_te, -1)
            # MMD-D
            dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N_te, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)

            # Gather results
            dk_count_u = dk_count_u + dk_h_u
            print("\r","MMD-DK:", dk_count_u, "MMD: ", dk_mmd_value_u, end="")
            DK_H_u[k] = dk_h_u

        # Print test power of MMD-D and baselines
        print("DK Reject rate_u: ", DK_H_u.sum() / N_f)
        DK_Results[0, kk] = DK_H_u.sum() / N_f
        print(f"Test Power of DK ({K} times): ")
        print(f"{DK_Results}")
        print(f"Average Test Power of DK ({K} times): ")
        print("MMD-D: ", (DK_Results.sum(1) / (kk + 1))[0])

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
CUDA backend failed to initialize: Found cuPTI version 18, but JAX was built against version 20, which is newer. The copy of cuPTI that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


--------------------------------------------------
epoch :  100
dk mmd:  0.0110141197219491 dk mmd_std:  0.004874340258538723 dk statistic:  2.2495076656341553
--------------------------------------------------
epoch :  100
dk mmd:  0.010989113710820675 dk mmd_std:  0.004693980328738689 dk statistic:  2.3310365676879883
--------------------------------------------------
epoch :  100
dk mmd:  0.010231376625597477 dk mmd_std:  0.004025022964924574 dk statistic:  2.5318620204925537
--------------------------------------------------
epoch :  100
dk mmd:  0.00965351052582264 dk mmd_std:  0.00401795981451869 dk statistic:  2.392484426498413
--------------------------------------------------
epoch :  100
dk mmd:  0.008840330876410007 dk mmd_std:  0.003444228321313858 dk statistic:  2.5566160678863525
--------------------------------------------------
epoch :  100
dk mmd:  0.009438570588827133 dk mmd_std:  0.0053160833194851875 dk statistic:  1.7653346061706543
--------------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

 MMD-DK: 42 MMD:  2.5668414309620857e-06DK Reject rate_u:  0.42
Test Power of DK (10 times): 
[[0.   0.   0.   0.   0.   0.42 0.   0.   0.   0.  ]]
Average Test Power of DK (10 times): 
MMD-D:  0.06999999999999999


  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.009168184362351894 dk mmd_std:  0.0031489331740885973 dk statistic:  2.901414155960083
--------------------------------------------------
epoch :  100
dk mmd:  0.008601893670856953 dk mmd_std:  0.0032874085009098053 dk statistic:  2.6064929962158203
--------------------------------------------------
epoch :  100
dk mmd:  0.010494701564311981 dk mmd_std:  0.0036366195417940617 dk statistic:  2.8757541179656982
--------------------------------------------------
epoch :  100
dk mmd:  0.010571837425231934 dk mmd_std:  0.004137016367167234 dk statistic:  2.545320749282837
--------------------------------------------------
epoch :  100
dk mmd:  0.010406825691461563 dk mmd_std:  0.0039790039882063866 dk statistic:  2.605313301086426
--------------------------------------------------
epoch :  100
dk mmd:  0.01041500922292471 dk mmd_std:  0.0037551179993897676 dk statistic:  2.76344633102417
-----------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

 MMD-DK: 75 MMD:  3.918085712939501e-05DK Reject rate_u:  0.75
Test Power of DK (10 times): 
[[0.   0.   0.   0.   0.   0.42 0.75 0.   0.   0.  ]]
Average Test Power of DK (10 times): 
MMD-D:  0.16714285714285712


  0%|          | 0/1000 [00:00<?, ?it/s]

--------------------------------------------------
epoch :  100
dk mmd:  0.0119848120957613 dk mmd_std:  0.00863969512283802 dk statistic:  1.3770465850830078
--------------------------------------------------
epoch :  100
dk mmd:  0.011191720142960548 dk mmd_std:  0.0080709895119071 dk statistic:  1.3765416145324707
--------------------------------------------------
epoch :  100
dk mmd:  0.010656350292265415 dk mmd_std:  0.007675367407500744 dk statistic:  1.378269910812378
--------------------------------------------------
epoch :  100
dk mmd:  0.01117322500795126 dk mmd_std:  0.008438385091722012 dk statistic:  1.3140122890472412
--------------------------------------------------
epoch :  100
dk mmd:  0.012472116388380527 dk mmd_std:  0.008071494288742542 dk statistic:  1.5350751876831055
--------------------------------------------------
epoch :  100
dk mmd:  0.012439735233783722 dk mmd_std:  0.00644381670281291 dk statistic:  1.9203537702560425
------------------------------------

  0%|          | 0/100 [00:00<?, ?it/s]

 MMD-DK: 67 MMD:  3.548571839928627e-05DK Reject rate_u:  0.67
Test Power of DK (10 times): 
[[0.   0.   0.   0.   0.   0.42 0.75 0.67 0.   0.  ]]
Average Test Power of DK (10 times): 
MMD-D:  0.22999999999999998


  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

#### Epoch 20 / lr: 0.0005
[[0.23 0.94 0.25 0.37 0.6  0.1  0.36 0.31 0.23 0.4 ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Baselines (10 times): 
MMD-D:  0.379