In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os
os.chdir('/home/oldrain123/MMD/')
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [3]:
from tqdm.auto import tqdm

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

from utils_HD import MatConvert, MMDu, TST_MMD_u

In [4]:
np.random.seed(819)
torch.manual_seed(819)
torch.cuda.manual_seed(819)
torch.backends.cudnn.deterministic = True
is_cuda = True

In [5]:
# Parameter Settings
n_epochs = 1000
batch_size = 100
lr = 0.0002
img_size = 64
channels = 3
n = 1000

In [6]:
dtype = torch.float
device = torch.device("cuda:0")
cuda = True if torch.cuda.is_available() else False
N_per = 100 # permutation times
alpha = 0.05 # test threshold
N1 = n # number of samples in one set
K = 10 # number of trails
J = 1 # number of test locations
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)

In [7]:
# Naming variables
ep_OPT = np.zeros([K])
s_OPT = np.zeros([K])
s0_OPT = np.zeros([K])
T_org_OPT = torch.zeros([K,J,3,64,64]) # Record test locations obtained by MMD-D
COM_Results = np.zeros([1,K])

In [9]:
class Featurizer_COM(nn.Module):
    def __init__(self):
        super(Featurizer_COM, self).__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 300)
        )
        
    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        feature = self.adv_layer(out)
        return feature 

In [10]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                             shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar10_1/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [11]:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()

# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)
    
    # Initialize deep networks for MMD-D (called featurizer)
    featurizer_com = Featurizer_COM()
    
    # Initialize parameters
    epsilonOPT = torch.log(MatConvert(np.random.rand(1) * 10 ** (-10), device, dtype))
    epsilonOPT.requires_grad = True
    sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype)
    sigmaOPT.requires_grad = True
    sigma0OPT = MatConvert(np.ones(1) * np.sqrt(0.005), device, dtype)
    sigma0OPT.requires_grad = True
    if cuda:
        featurizer_com.cuda()

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), N1, replace=False)
    Ind_te = np.delete(Ind_all, Ind_tr)
    train_data = []
    for i in Ind_tr:
       train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
    Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]

    # Initialize optimizers
    optimizer_COM = torch.optim.Adam(list(featurizer_com.parameters()) + [epsilonOPT] + [sigmaOPT] + [sigma0OPT], lr=lr, weight_decay=1e-4)
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # ----------------------------------------------------------------------------------------------------
    #  Training deep networks for MMD-D (called featurizer)
    # ----------------------------------------------------------------------------------------------------
    np.random.seed(seed=1102)
    torch.manual_seed(1102)
    torch.cuda.manual_seed(1102)
    for epoch in tqdm(range(n_epochs)):
        for i, (imgs, _) in enumerate(dataloader):
            if True:
                ind = np.random.choice(N1, imgs.shape[0], replace=False)
                Fake_imgs = New_CIFAR_tr[ind]
                # Adversarial ground truths
                valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                # Configure input
                real_imgs = Variable(imgs.type(Tensor))
                Fake_imgs = Variable(Fake_imgs.type(Tensor))
                X = torch.cat([real_imgs, Fake_imgs], 0)
                Y = torch.cat([valid, fake], 0).squeeze().long()

                # ------------------------------
                #  Train deep network for MMD-D
                # ------------------------------
                # Initialize optimizer
                optimizer_COM.zero_grad()
                
                with autocast():
                    # Compute output of deep network
                    com_modelu_output = featurizer_com(X)
                
                    # Compute epsilon, sigma and sigma_0
                    ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT))
                    sigma = sigmaOPT ** 2
                    sigma0_u = sigma0OPT ** 2
                
                    # Compute Compute J (STAT_u)
                    COM_TEMP = MMDu(com_modelu_output, imgs.shape[0], X.view(X.shape[0],-1), sigma, sigma0_u, ep, complete=True)
                    com_mmd = - COM_TEMP[0]
                    com_mmd_std = torch.sqrt(COM_TEMP[1]+10**(-5))
                    COM_STAT_u = torch.div(com_mmd, com_mmd_std)
                
                if (epoch+1) % 100 == 0:
                    print("=" * 110)
                    print("epoch : ",epoch+1)
                    print("our mmd: ", -1 * com_mmd.item(), "our mmd_std: ", com_mmd_std.item(), "our statistic: ",
                    -1 * COM_STAT_u.item())
                    
                scaler.scale(COM_STAT_u).backward()
                scaler.step(optimizer_COM)
                scaler.update()
            else:
                break

    # Run two-sample test on the training set
    # Fetch training data
    s1 = data_all[Ind_tr]
    s2 = data_trans[Ind_tr_v4]
    S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
    Sv = S.view(2 * N1, -1)
    # Run two-sample test (MMD-D) on the training set
    # dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)
    # com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

    # Record best epsilon, sigma and sigma_0
    ep_OPT[kk] = ep.item()
    s_OPT[kk] = sigma.item()
    s0_OPT[kk] = sigma0_u.item()

    # Compute test power of MMD-D and baselines
    DK_H_u = np.zeros(N)
    COM_H_u = np.zeros(N)

    np.random.seed(1102)
    dk_count_u = 0
    com_count_u = 0

    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans)-N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1 = data_all_te[Ind_N_te]
        s2 = data_trans[Ind_te_v4]
        S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
        Sv = S.view(2 * N_te, -1)
        # MMD-D
        com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer_com(S), N_per, N_te, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

        # Gather results
        com_count_u = com_count_u + com_h_u
        print("\r","Ours:", com_count_u, "MMD: ", com_mmd_value_u, end="")
        COM_H_u[k] = com_h_u

    # Print test power of MMD-D and baselines
    print("Our Reject rate_u: ", COM_H_u.sum() / N_f)
    COM_Results[0, kk] = COM_H_u.sum() / N_f
    print(f"Test Power of Ours ({K} times): ")
    print(f"{COM_Results}")
    print(f"Average Test Power of Ours ({K} times): ")
    print("Ours: ", (COM_Results.sum(1) / (kk + 1))[0])

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)


epoch :  100
our mmd:  0.12308130413293839 our mmd_std:  0.012299331879426688 our statistic:  10.00715366814507
epoch :  100
our mmd:  0.1291743963956833 our mmd_std:  0.012344086336470841 our statistic:  10.464476096058647
epoch :  100
our mmd:  0.15046954154968262 our mmd_std:  0.01484632566474202 our statistic:  10.135136797317267
epoch :  100
our mmd:  0.14840036630630493 our mmd_std:  0.014225419574685077 our statistic:  10.432055485406675
epoch :  100
our mmd:  0.1526522934436798 our mmd_std:  0.01458544609996901 our statistic:  10.466069559847343
epoch :  100
our mmd:  0.1659858673810959 our mmd_std:  0.013977815035015382 our statistic:  11.874950910803294
epoch :  100
our mmd:  0.14648213982582092 our mmd_std:  0.01313594421296856 our statistic:  11.151245578616672
epoch :  100
our mmd:  0.14234432578086853 our mmd_std:  0.01242041155475725 our statistic:  11.460516034699992
epoch :  100
our mmd:  0.13754673302173615 our mmd_std:  0.01203023744118448 our statistic:  11.43341797

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 2 MMD:  -0.0002692006528377533Our Reject rate_u:  0.02
Test Power of Ours (10 times): 
[[0.02 0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.02


  0%|          | 0/1000 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.19163602590560913 our mmd_std:  0.01698253421896966 our statistic:  11.284300884349156
epoch :  100
our mmd:  0.1886957287788391 our mmd_std:  0.015506551598410799 our statistic:  12.168774442293007
epoch :  100
our mmd:  0.18663442134857178 our mmd_std:  0.01600999346509708 our statistic:  11.657370239122711
epoch :  100
our mmd:  0.21468234062194824 our mmd_std:  0.01776464550013459 our statistic:  12.084808594706928
epoch :  100
our mmd:  0.20782554149627686 our mmd_std:  0.016778478557874288 our statistic:  12.386435443441473
epoch :  100
our mmd:  0.19705519080162048 our mmd_std:  0.017117099533247897 our statistic:  11.512183499246738
epoch :  100
our mmd:  0.18542730808258057 our mmd_std:  0.015854643656490294 our statistic:  11.695457312071067
epoch :  100
our mmd:  0.17839156091213226 our mmd_std:  0.01520974462592924 our statistic:  11.728767661753782
epoch :  100
our mmd:  0.19282221794128418 our mmd_std:  0.016090785811458355 our statistic:  11.9833

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 52 MMD:  0.000201348215341568Our Reject rate_u:  0.52
Test Power of Ours (10 times): 
[[0.02 0.52 0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.27


  0%|          | 0/1000 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.1686919927597046 our mmd_std:  0.013867323483626156 our statistic:  12.164711738273626
epoch :  100
our mmd:  0.1981753557920456 our mmd_std:  0.014962113867860767 our statistic:  13.245144205040063
epoch :  100
our mmd:  0.16689370572566986 our mmd_std:  0.012122747742120361 our statistic:  13.766986600389234
epoch :  100
our mmd:  0.15418237447738647 our mmd_std:  0.01152906118327605 our statistic:  13.373367703264686
epoch :  100
our mmd:  0.1777835488319397 our mmd_std:  0.01496182141794955 our statistic:  11.882480338834583
epoch :  100
our mmd:  0.17518886923789978 our mmd_std:  0.014861336588072973 our statistic:  11.788231038283485
epoch :  100
our mmd:  0.18950781226158142 our mmd_std:  0.014497445039837008 our statistic:  13.071807600638577
epoch :  100
our mmd:  0.17006580531597137 our mmd_std:  0.013963618650960255 our statistic:  12.179207236103961
epoch :  100
our mmd:  0.1978318989276886 our mmd_std:  0.013001123044818062 our statistic:  15.21652

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 27 MMD:  0.00020587071776390076Our Reject rate_u:  0.27
Test Power of Ours (10 times): 
[[0.02 0.52 0.27 0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.27


  0%|          | 0/1000 [00:00<?, ?it/s]

epoch :  100
our mmd:  0.18497399985790253 our mmd_std:  0.015145173868734076 our statistic:  12.213395597904995
epoch :  100
our mmd:  0.16356633603572845 our mmd_std:  0.01438983805850657 our statistic:  11.366794773554524
epoch :  100
our mmd:  0.1782696396112442 our mmd_std:  0.015401793734704216 our statistic:  11.57460245747589
epoch :  100
our mmd:  0.15228253602981567 our mmd_std:  0.013715582149365291 our statistic:  11.102885343941654
epoch :  100
our mmd:  0.180912584066391 our mmd_std:  0.015381969927934763 our statistic:  11.761340381887026
epoch :  100
our mmd:  0.1601111888885498 our mmd_std:  0.013056595131668382 our statistic:  12.262859288652132
epoch :  100
our mmd:  0.179611474275589 our mmd_std:  0.015600667781633092 our statistic:  11.513063209194696
epoch :  100
our mmd:  0.1922673135995865 our mmd_std:  0.01654257797553319 our statistic:  11.622572605306969
epoch :  100
our mmd:  0.17581042647361755 our mmd_std:  0.015787017483409205 our statistic:  11.136392713

#### Epoch 20 / lr: 0.0005
[[0.23 0.94 0.25 0.37 0.6  0.1  0.36 0.31 0.23 0.4 ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Baselines (10 times): 
MMD-D:  0.379