In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os
os.chdir('/home/oldrain123/MMD/')
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [3]:
from tqdm.auto import tqdm

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

from utils_HD import MatConvert, MMDu, TST_MMD_u

In [4]:
np.random.seed(819)
torch.manual_seed(819)
torch.cuda.manual_seed(819)
torch.backends.cudnn.deterministic = True
is_cuda = True

In [5]:
# Parameter Settings
n_epochs = 1000
batch_size = 200
lr = 0.0002
img_size = 64
channels = 3
n = 1000

In [6]:
dtype = torch.float
device = torch.device("cuda:0")
cuda = True if torch.cuda.is_available() else False
N_per = 100 # permutation times
alpha = 0.05 # test threshold
N1 = n # number of samples in one set
K = 10 # number of trails
J = 1 # number of test locations
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)

In [7]:
# Naming variables
ep_OPT = np.zeros([K])
s_OPT = np.zeros([K])
s0_OPT = np.zeros([K])
T_org_OPT = torch.zeros([K,J,3,64,64]) # Record test locations obtained by MMD-D
COM_Results = np.zeros([1,K])

In [8]:
class Featurizer_COM(nn.Module):
    def __init__(self):
        super(Featurizer_COM, self).__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 300)
        )
        
    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        feature = self.adv_layer(out)
        return feature 

In [9]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                             shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar10_1/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [10]:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()

# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)
    
    # Initialize deep networks for MMD-D (called featurizer)
    featurizer_com = Featurizer_COM()
    
    # Initialize parameters
    epsilonOPT = torch.log(MatConvert(np.random.rand(1) * 10 ** (-10), device, dtype))
    epsilonOPT.requires_grad = True
    sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype)
    sigmaOPT.requires_grad = True
    sigma0OPT = MatConvert(np.ones(1) * np.sqrt(0.005), device, dtype)
    sigma0OPT.requires_grad = True
    if cuda:
        featurizer_com.cuda()

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), N1, replace=False)
    Ind_te = np.delete(Ind_all, Ind_tr)
    train_data = []
    for i in Ind_tr:
       train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
    Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]

    # Initialize optimizers
    optimizer_COM = torch.optim.Adam(list(featurizer_com.parameters()) + [epsilonOPT] + [sigmaOPT] + [sigma0OPT], lr=lr, weight_decay=1e-4)
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # ----------------------------------------------------------------------------------------------------
    #  Training deep networks for MMD-D (called featurizer)
    # ----------------------------------------------------------------------------------------------------
    np.random.seed(seed=1102)
    torch.manual_seed(1102)
    torch.cuda.manual_seed(1102)
    for epoch in tqdm(range(n_epochs)):
        for i, (imgs, _) in enumerate(dataloader):
            if True:
                ind = np.random.choice(N1, imgs.shape[0], replace=False)
                Fake_imgs = New_CIFAR_tr[ind]
                # Adversarial ground truths
                valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                # Configure input
                real_imgs = Variable(imgs.type(Tensor))
                Fake_imgs = Variable(Fake_imgs.type(Tensor))
                X = torch.cat([real_imgs, Fake_imgs], 0)
                Y = torch.cat([valid, fake], 0).squeeze().long()

                # ------------------------------
                #  Train deep network for MMD-D
                # ------------------------------
                # Initialize optimizer
                optimizer_COM.zero_grad()
                
                with autocast():
                    # Compute output of deep network
                    com_modelu_output = featurizer_com(X)
                
                    # Compute epsilon, sigma and sigma_0
                    ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT))
                    sigma = sigmaOPT ** 2
                    sigma0_u = sigma0OPT ** 2
                
                    # Compute Compute J (STAT_u)
                    COM_TEMP = MMDu(com_modelu_output, imgs.shape[0], X.view(X.shape[0],-1), sigma, sigma0_u, ep, complete=True)
                    com_mmd = - COM_TEMP[0]
                    com_mmd_std = torch.sqrt(COM_TEMP[1] + 10**(-5))
                    COM_STAT_u = torch.div(com_mmd, com_mmd_std)
                    gamma = COM_STAT_u.item()
                    COM_STAT = com_mmd + gamma*com_mmd_std
                
                if (epoch+1) % 100 == 0:
                    print("=" * 110)
                    print("epoch : ",epoch+1)
                    print("our mmd: ", -1 * com_mmd.item(), "our mmd_std: ", com_mmd_std.item(), "our statistic: ",
                    -1 * COM_STAT_u.item())
                    
                scaler.scale(COM_STAT).backward()
                scaler.step(optimizer_COM)
                scaler.update()
            else:
                break

    # Run two-sample test on the training set
    # Fetch training data
    s1 = data_all[Ind_tr]
    s2 = data_trans[Ind_tr_v4]
    S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
    Sv = S.view(2 * N1, -1)
    # Run two-sample test (MMD-D) on the training set
    # dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)
    # com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

    # Record best epsilon, sigma and sigma_0
    ep_OPT[kk] = ep.item()
    s_OPT[kk] = sigma.item()
    s0_OPT[kk] = sigma0_u.item()

    # Compute test power of MMD-D and baselines
    DK_H_u = np.zeros(N)
    COM_H_u = np.zeros(N)

    np.random.seed(1102)
    dk_count_u = 0
    com_count_u = 0

    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans)-N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1 = data_all_te[Ind_N_te]
        s2 = data_trans[Ind_te_v4]
        S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
        Sv = S.view(2 * N_te, -1)
        # MMD-D
        com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer_com(S), N_per, N_te, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

        # Gather results
        com_count_u = com_count_u + com_h_u
        print("\r","Ours:", com_count_u, "MMD: ", com_mmd_value_u, end="")
        COM_H_u[k] = com_h_u

    # Print test power of MMD-D and baselines
    print("Our Reject rate_u: ", COM_H_u.sum() / N_f)
    COM_Results[0, kk] = COM_H_u.sum() / N_f
    print(f"Test Power of Ours ({K} times): ")
    print(f"{COM_Results}")
    print(f"Average Test Power of Ours ({K} times): ")
    print("Ours: ", (COM_Results.sum(1) / (kk + 1))[0])

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)


epoch :  100
our mmd:  0.13617700338363647 our mmd_std:  0.009265010114698873 our statistic:  14.69798755724968
epoch :  100
our mmd:  0.16693958640098572 our mmd_std:  0.011058841387732974 our statistic:  15.095576520896987
epoch :  100
our mmd:  0.16912499070167542 our mmd_std:  0.010521567576439614 our statistic:  16.074124836719957
epoch :  100
our mmd:  0.1681123673915863 our mmd_std:  0.010494576325268453 our statistic:  16.018976105477606
epoch :  100
our mmd:  0.14410898089408875 our mmd_std:  0.008872870744901723 our statistic:  16.24152825362553
epoch :  200
our mmd:  0.17499412596225739 our mmd_std:  0.010269772557035865 our statistic:  17.03972751006722
epoch :  200
our mmd:  0.17281796038150787 our mmd_std:  0.00933593705401118 our statistic:  18.5110460130252
epoch :  200
our mmd:  0.1765669584274292 our mmd_std:  0.009694165733221169 our statistic:  18.213734248667492
epoch :  200
our mmd:  0.18588867783546448 our mmd_std:  0.010850584977613618 our statistic:  17.1316733

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 2 MMD:  -0.00017654895782470703Our Reject rate_u:  0.02
Test Power of Ours (10 times): 
[[0.02 0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.02


  0%|          | 0/1000 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.19779661297798157 our mmd_std:  0.011610705497372578 our statistic:  17.035710105879577
epoch :  100
our mmd:  0.20054417848587036 our mmd_std:  0.012027817521954573 our statistic:  16.67336390162336
epoch :  100
our mmd:  0.2043575942516327 our mmd_std:  0.012011682990534367 our statistic:  17.013235731635085
epoch :  100
our mmd:  0.19968706369400024 our mmd_std:  0.011377041008869958 our statistic:  17.551757397931226
epoch :  100
our mmd:  0.1919415295124054 our mmd_std:  0.011109198976236835 our statistic:  17.277711014356527
epoch :  200
our mmd:  0.19604375958442688 our mmd_std:  0.011744146369704275 our statistic:  16.69289137013399
epoch :  200
our mmd:  0.18996661901474 our mmd_std:  0.011574470510900846 our statistic:  16.412553717755753
epoch :  200
our mmd:  0.18186858296394348 our mmd_std:  0.010729749405573115 our statistic:  16.949937607068392
epoch :  200
our mmd:  0.19742894172668457 our mmd_std:  0.01199072113765237 our statistic:  16.4651432

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 59 MMD:  0.0005861781537532806Our Reject rate_u:  0.59
Test Power of Ours (10 times): 
[[0.02 0.59 0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.305


  0%|          | 0/1000 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.1592157632112503 our mmd_std:  0.009792332072443654 our statistic:  16.25922834656468
epoch :  100
our mmd:  0.15057986974716187 our mmd_std:  0.008930425121439628 our statistic:  16.861444746416243
epoch :  100
our mmd:  0.15617430210113525 our mmd_std:  0.009954893036913515 our statistic:  15.688194892906315
epoch :  100
our mmd:  0.17442843317985535 our mmd_std:  0.010937389469899944 our statistic:  15.94790362543897
epoch :  100
our mmd:  0.15857571363449097 our mmd_std:  0.009661741602122072 our statistic:  16.412746289929956
epoch :  200
our mmd:  0.17108532786369324 our mmd_std:  0.01022138679766028 our statistic:  16.737976093699476
epoch :  200
our mmd:  0.1646537482738495 our mmd_std:  0.009767747596458088 our statistic:  16.856879915033335
epoch :  200
our mmd:  0.171683669090271 our mmd_std:  0.010585312598964385 our statistic:  16.21904572823553
epoch :  200
our mmd:  0.16438639163970947 our mmd_std:  0.010257095671271692 our statistic:  16.0266021

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 7 MMD:  0.00019833818078041077Our Reject rate_u:  0.07
Test Power of Ours (10 times): 
[[0.02 0.59 0.07 0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.22666666666666666


  0%|          | 0/1000 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.1621769517660141 our mmd_std:  0.009580221153503408 our statistic:  16.928309813255964
epoch :  100
our mmd:  0.1930796504020691 our mmd_std:  0.011264746516401539 our statistic:  17.140168233785282
epoch :  100
our mmd:  0.18214048445224762 our mmd_std:  0.01040398026291784 our statistic:  17.506807957089062
epoch :  100
our mmd:  0.18119579553604126 our mmd_std:  0.010701223623495985 our statistic:  16.932250171672084
epoch :  100
our mmd:  0.1978575885295868 our mmd_std:  0.011472773474193643 our statistic:  17.245837632429424
epoch :  200
our mmd:  0.18503326177597046 our mmd_std:  0.010421399797062754 our statistic:  17.755125547349373
epoch :  200
our mmd:  0.19683188199996948 our mmd_std:  0.011414063132254858 our statistic:  17.24468138289377
epoch :  200
our mmd:  0.19062481820583344 our mmd_std:  0.011437981627785776 our statistic:  16.665948976763268
epoch :  200
our mmd:  0.17913317680358887 our mmd_std:  0.010435782876462244 our statistic:  17.1652

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 14 MMD:  0.00011524558067321777Our Reject rate_u:  0.14
Test Power of Ours (10 times): 
[[0.02 0.59 0.07 0.14 0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.20500000000000002


  0%|          | 0/1000 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.2127954661846161 our mmd_std:  0.011776326821974182 our statistic:  18.069765674942698
epoch :  100
our mmd:  0.18744245171546936 our mmd_std:  0.011323786464872105 our statistic:  16.552983606405935
epoch :  100
our mmd:  0.19281302392482758 our mmd_std:  0.011445923254378099 our statistic:  16.84556323152665
epoch :  100
our mmd:  0.188654825091362 our mmd_std:  0.010858731312056145 our statistic:  17.37356047127751
epoch :  100
our mmd:  0.1916559487581253 our mmd_std:  0.011079301375804576 our statistic:  17.298559020758407
epoch :  200
our mmd:  0.20306508243083954 our mmd_std:  0.011557685499428936 our statistic:  17.569701342095954
epoch :  200
our mmd:  0.19012907147407532 our mmd_std:  0.010659223645413884 our statistic:  17.837046843076426
epoch :  200
our mmd:  0.2131708860397339 our mmd_std:  0.012218361061673134 our statistic:  17.44676597489116
epoch :  200
our mmd:  0.19950218498706818 our mmd_std:  0.011304368254845792 our statistic:  17.6482383

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 17 MMD:  0.00010598450899124146Our Reject rate_u:  0.17
Test Power of Ours (10 times): 
[[0.02 0.59 0.07 0.14 0.17 0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.198


  0%|          | 0/1000 [00:00<?, ?it/s]

#### Epoch 20 / lr: 0.0005
[[0.23 0.94 0.25 0.37 0.6  0.1  0.36 0.31 0.23 0.4 ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Baselines (10 times): 
MMD-D:  0.379