In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os
os.chdir('/home/oldrain123/MMD/')
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

In [3]:
from tqdm.auto import tqdm

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

from utils_HD import MatConvert, MMDu, TST_MMD_u

In [4]:
np.random.seed(819)
torch.manual_seed(819)
torch.cuda.manual_seed(819)
torch.backends.cudnn.deterministic = True
is_cuda = True

In [5]:
# Parameter Settings
n_epochs = 27
batch_size = 1000
lr = 0.0002
img_size = 64
channels = 3
n = 1000

In [6]:
dtype = torch.float
device = torch.device("cuda:0")
cuda = True if torch.cuda.is_available() else False
N_per = 100 # permutation times
alpha = 0.05 # test threshold
N1 = n # number of samples in one set
K = 10 # number of trails
J = 1 # number of test locations
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)

In [7]:
# Naming variables
ep_OPT = np.zeros([K])
s_OPT = np.zeros([K])
s0_OPT = np.zeros([K])
T_org_OPT = torch.zeros([K,J,3,64,64]) # Record test locations obtained by MMD-D
COM_Results = np.zeros([1,K])

In [8]:
class Featurizer_COM(nn.Module):
    def __init__(self):
        super(Featurizer_COM, self).__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 300)
        )
        
    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        feature = self.adv_layer(out)
        return feature 

In [9]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                             shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar10_1/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [10]:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
# gamma = torch.nn.Parameter(MatConvert(np.random.rand(1) * 10 ** (2), device, dtype))

# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)
    
    # Initialize deep networks for MMD-D (called featurizer)
    featurizer_com = Featurizer_COM()
    
    # Initialize parameters
    epsilonOPT = torch.log(MatConvert(np.random.rand(1) * 10 ** (-10), device, dtype))
    epsilonOPT.requires_grad = True
    sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype)
    sigmaOPT.requires_grad = True
    sigma0OPT = MatConvert(np.ones(1) * np.sqrt(0.005), device, dtype)
    sigma0OPT.requires_grad = True
    if cuda:
        featurizer_com.cuda()

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), N1, replace=False)
    Ind_te = np.delete(Ind_all, Ind_tr)
    train_data = []
    for i in Ind_tr:
       train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
    Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]

    # Initialize optimizers
    optimizer_COM = torch.optim.Adam(list(featurizer_com.parameters()) + [epsilonOPT] + [sigmaOPT] + [sigma0OPT], lr=lr, weight_decay=1e-4)
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # ----------------------------------------------------------------------------------------------------
    #  Training deep networks for MMD-D (called featurizer)
    # ----------------------------------------------------------------------------------------------------
    np.random.seed(seed=1102)
    torch.manual_seed(1102)
    torch.cuda.manual_seed(1102)
    for epoch in tqdm(range(n_epochs)):
        for i, (imgs, _) in enumerate(dataloader):
            if True:
                ind = np.random.choice(N1, imgs.shape[0], replace=False)
                Fake_imgs = New_CIFAR_tr[ind]
                # Adversarial ground truths
                valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                # Configure input
                real_imgs = Variable(imgs.type(Tensor))
                Fake_imgs = Variable(Fake_imgs.type(Tensor))
                X = torch.cat([real_imgs, Fake_imgs], 0)
                Y = torch.cat([valid, fake], 0).squeeze().long()

                # ------------------------------
                #  Train deep network for MMD-D
                # ------------------------------
                # Initialize optimizer
                optimizer_COM.zero_grad()
                
                with autocast():
                    # Compute output of deep network
                    com_modelu_output = featurizer_com(X)
                
                    # Compute epsilon, sigma and sigma_0
                    ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT))
                    sigma = sigmaOPT ** 2
                    sigma0_u = sigma0OPT ** 2
                
                    # Compute Compute J (STAT_u)
                    COM_TEMP = MMDu(com_modelu_output, imgs.shape[0], X.view(X.shape[0],-1), sigma, sigma0_u, ep, complete=True)
                    com_mmd = - COM_TEMP[0]
                    # threshold = torch.tensor(1e-5)
                    # penalty = (-1 * com_mmd - threshold).clamp(min=0) ** 2
                    com_mmd_std = torch.sqrt(COM_TEMP[1] + 10**(-5))

                    COM_STAT_u = torch.div(com_mmd, com_mmd_std)
                    gamma = 10 ** 2
                    COM_STAT = com_mmd + gamma*com_mmd_std
                
                if (epoch+1) % 10 == 0:
                    print("=" * 110)
                    print("epoch : ",epoch+1)
                    print("our mmd: ", -1 * com_mmd.item(), "our mmd_std: ", com_mmd_std.item(), "our statistic: ",
                    -1 * COM_STAT_u.item())
                    
                scaler.scale(COM_STAT).backward()
                scaler.step(optimizer_COM)
                scaler.update()
            else:
                break

    # Run two-sample test on the training set
    # Fetch training data
    s1 = data_all[Ind_tr]
    s2 = data_trans[Ind_tr_v4]
    S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
    Sv = S.view(2 * N1, -1)
    # Run two-sample test (MMD-D) on the training set
    # dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)
    # com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

    # Record best epsilon, sigma and sigma_0
    ep_OPT[kk] = ep.item()
    s_OPT[kk] = sigma.item()
    s0_OPT[kk] = sigma0_u.item()

    # Compute test power of MMD-D and baselines
    DK_H_u = np.zeros(N)
    COM_H_u = np.zeros(N)

    np.random.seed(1102)
    dk_count_u = 0
    com_count_u = 0

    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans)-N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1 = data_all_te[Ind_N_te]
        s2 = data_trans[Ind_te_v4]
        S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
        Sv = S.view(2 * N_te, -1)
        # MMD-D
        com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer_com(S), N_per, N_te, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

        # Gather results
        com_count_u = com_count_u + com_h_u
        print("\r","Ours:", com_count_u, "MMD: ", com_mmd_value_u, end="")
        COM_H_u[k] = com_h_u

    # Print test power of MMD-D and baselines
    print("Our Reject rate_u: ", COM_H_u.sum() / N_f)
    COM_Results[0, kk] = COM_H_u.sum() / N_f
    print(f"Test Power of Ours ({K} times): ")
    print(f"{COM_Results}")
    print(f"Average Test Power of Ours ({K} times): ")
    print("Ours: ", (COM_Results.sum(1) / (kk + 1))[0])

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

  valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)


epoch :  10
our mmd:  0.0028620008379220963 our mmd_std:  0.003219379576146217 our statistic:  0.888991425282034
epoch :  20
our mmd:  0.010806561447679996 our mmd_std:  0.003369243959734039 our statistic:  3.207414356701271


  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 23 MMD:  7.189810276031494e-07Our Reject rate_u:  0.23
Test Power of Ours (10 times): 
[[0.23 0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.23


  0%|          | 0/27 [00:00<?, ?it/s]

epoch :  10
our mmd:  0.002383144572377205 our mmd_std:  0.003199177390327549 our statistic:  0.7449241731897854
epoch :  20
our mmd:  0.009383775293827057 our mmd_std:  0.003337348050974185 our statistic:  2.811746078173625


  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 85 MMD:  0.0002751350402832031Our Reject rate_u:  0.85
Test Power of Ours (10 times): 
[[0.23 0.85 0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.54


  0%|          | 0/27 [00:00<?, ?it/s]

epoch :  10
our mmd:  0.001443241722881794 our mmd_std:  0.0031824394700845767 our statistic:  0.45350170410104873
epoch :  20
our mmd:  0.007835584692656994 our mmd_std:  0.003302781165105821 our statistic:  2.3724201819486708


  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 36 MMD:  0.0001935381442308426Our Reject rate_u:  0.36
Test Power of Ours (10 times): 
[[0.23 0.85 0.36 0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.48


  0%|          | 0/27 [00:00<?, ?it/s]

epoch :  10
our mmd:  0.0026806723326444626 our mmd_std:  0.0032072691081319986 our statistic:  0.8358114776984709
epoch :  20
our mmd:  0.010435696691274643 our mmd_std:  0.0033483600766098277 our statistic:  3.1166590368144202


  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 35 MMD:  0.00014770962297916412Our Reject rate_u:  0.35
Test Power of Ours (10 times): 
[[0.23 0.85 0.36 0.35 0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.4475


  0%|          | 0/27 [00:00<?, ?it/s]

epoch :  10
our mmd:  0.002257036045193672 our mmd_std:  0.003199018996274458 our statistic:  0.7055400570681798
epoch :  20
our mmd:  0.009297732263803482 our mmd_std:  0.0033126118820121435 our statistic:  2.8067677696536735


  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 71 MMD:  0.0001420937478542328Our Reject rate_u:  0.71
Test Power of Ours (10 times): 
[[0.23 0.85 0.36 0.35 0.71 0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.5


  0%|          | 0/27 [00:00<?, ?it/s]

epoch :  10
our mmd:  0.003202039748430252 our mmd_std:  0.0032313244309268158 our statistic:  0.990937250925261
epoch :  20
our mmd:  0.01054326817393303 our mmd_std:  0.003352949533850999 our statistic:  3.1444756527020123


  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 5 MMD:  -4.476122558116913e-05Our Reject rate_u:  0.05
Test Power of Ours (10 times): 
[[0.23 0.85 0.36 0.35 0.71 0.05 0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.425


  0%|          | 0/27 [00:00<?, ?it/s]

epoch :  10
our mmd:  0.0037458911538124084 our mmd_std:  0.0032449920198228004 our statistic:  1.1543606674314597
epoch :  20
our mmd:  0.011906946077942848 our mmd_std:  0.003391073104422399 our statistic:  3.511261984418633


  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 41 MMD:  0.00014532171189785004Our Reject rate_u:  0.41
Test Power of Ours (10 times): 
[[0.23 0.85 0.36 0.35 0.71 0.05 0.41 0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.4228571428571429


  0%|          | 0/27 [00:00<?, ?it/s]

epoch :  10
our mmd:  0.00029863929376006126 our mmd_std:  0.003163575348572462 our statistic:  0.0943992985325353
epoch :  20
our mmd:  0.004717525094747543 our mmd_std:  0.0032374109590573327 our statistic:  1.4571906855227268


  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 25 MMD:  4.734843969345093e-05Our Reject rate_u:  0.25
Test Power of Ours (10 times): 
[[0.23 0.85 0.36 0.35 0.71 0.05 0.41 0.25 0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.40125


  0%|          | 0/27 [00:00<?, ?it/s]

epoch :  10
our mmd:  0.0022062789648771286 our mmd_std:  0.0032003904225017927 our statistic:  0.6893780675522856
epoch :  20
our mmd:  0.009832112118601799 our mmd_std:  0.0033368942261677893 our statistic:  2.946485999316003


  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 48 MMD:  6.54757022857666e-05Our Reject rate_u:  0.48
Test Power of Ours (10 times): 
[[0.23 0.85 0.36 0.35 0.71 0.05 0.41 0.25 0.48 0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.41


  0%|          | 0/27 [00:00<?, ?it/s]

epoch :  10
our mmd:  0.0012626005336642265 our mmd_std:  0.003175366544293894 our statistic:  0.39762355496662544
epoch :  20
our mmd:  0.006457749754190445 our mmd_std:  0.003278309662155848 our statistic:  1.9698412961830356


  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 33 MMD:  0.0003005731850862503Our Reject rate_u:  0.33
Test Power of Ours (10 times): 
[[0.23 0.85 0.36 0.35 0.71 0.05 0.41 0.25 0.48 0.33]]
Average Test Power of Ours (10 times): 
Ours:  0.40199999999999997


#### Epoch 20 / lr: 0.0005
[[0.23 0.94 0.25 0.37 0.6  0.1  0.36 0.31 0.23 0.4 ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Baselines (10 times): 
MMD-D:  0.379