In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import os
import sys

sys.path.append('/home/oldrain123/MMD/')
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

In [3]:
from tqdm.auto import tqdm
import numpy as np
import jax.numpy as jnp

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from utils import jnp_to_tensor
from utils_HD import MatConvert, MMDu, TST_MMD_u

In [4]:
np.random.seed(819)
torch.manual_seed(819)
torch.cuda.manual_seed(819)
torch.backends.cudnn.deterministic = True
is_cuda = True

In [5]:
# Parameter Settings
n_epochs = 1000
batch_size = 512
lr = 2e-5
img_size = 64
channels = 3
n = 1000

In [6]:
dtype = torch.float
device = torch.device("cuda:0")
cuda = True if torch.cuda.is_available() else False
N_per = 100 # permutation times
alpha = 0.05 # test threshold
N1 = n # number of samples in one set
K = 10 # number of trails
J = 1 # number of test locations
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)

In [7]:
# Naming variables
ep_OPT = np.zeros([K])
s_OPT = np.zeros([K])
s0_OPT = np.zeros([K])
T_org_OPT = torch.zeros([K,J,3,64,64]) # Record test locations obtained by MMD-D
COM_Results = np.zeros([1,K])

In [8]:
# Define the deep network for MMD-D
class Featurizer_COM(nn.Module):
    def __init__(self):
        super(Featurizer_COM, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0)] #0.25
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.model = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        # The height and width of downsampled image
        ds_size = img_size // 2 ** 4
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * ds_size ** 2, 300))

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        feature = self.adv_layer(out)

        return feature

In [9]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                             shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar10_1/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [10]:
# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)
    
    # Initialize deep networks for MMD-D (called featurizer)
    featurizer_com = Featurizer_COM()
    
    # Initialize parameters
    epsilonOPT = torch.log(MatConvert(np.random.rand(1) * 10 ** (-10), device, dtype))
    epsilonOPT.requires_grad = True
    sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype)
    sigmaOPT.requires_grad = True
    sigma0OPT = MatConvert(np.ones(1) * np.sqrt(0.005), device, dtype)
    sigma0OPT.requires_grad = True
    if cuda:
        featurizer_com.cuda()

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), N1, replace=False)
    Ind_te = np.delete(Ind_all, Ind_tr)
    train_data = []
    for i in Ind_tr:
       train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
    Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]

    # Initialize optimizers
    optimizer_COM = torch.optim.Adam(list(featurizer_com.parameters()) + [epsilonOPT] + [sigmaOPT] + [sigma0OPT], lr=lr, weight_decay=1e-4)
    scheduler_COM = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_COM, 'min', factor=0.5, patience=20, verbose=True)
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # ----------------------------------------------------------------------------------------------------
    #  Training deep networks for MMD-D (called featurizer)
    # ----------------------------------------------------------------------------------------------------
    np.random.seed(seed=1102)
    torch.manual_seed(1102)
    torch.cuda.manual_seed(1102)
    for epoch in tqdm(range(n_epochs)):
        for i, (imgs, _) in enumerate(dataloader):
            if True:
                ind = np.random.choice(N1, imgs.shape[0], replace=False)
                Fake_imgs = New_CIFAR_tr[ind]
                # Adversarial ground truths
                valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                # Configure input
                real_imgs = Variable(imgs.type(Tensor))
                Fake_imgs = Variable(Fake_imgs.type(Tensor))
                X = torch.cat([real_imgs, Fake_imgs], 0)
                Y = torch.cat([valid, fake], 0).squeeze().long()

                # ------------------------------
                #  Train deep network for MMD-D
                # ------------------------------
                # Initialize optimizer
                optimizer_COM.zero_grad()
                
                # Compute output of deep network
                com_modelu_output = featurizer_com(X)
                
                # Compute epsilon, sigma and sigma_0
                ep = torch.exp(epsilonOPT) / (1 + torch.exp(epsilonOPT))
                sigma = sigmaOPT ** 2
                sigma0_u = sigma0OPT ** 2
                
                # Compute Compute J (STAT_u)
                COM_TEMP = MMDu(com_modelu_output, imgs.shape[0], X.view(X.shape[0],-1), sigma, sigma0_u, ep, complete=True)
                com_mmd = - COM_TEMP[0]
                com_mmd_std_tmp_tensor = jnp_to_tensor(np.abs(COM_TEMP[1]))
                com_mmd_std = torch.sqrt(com_mmd_std_tmp_tensor)
                COM_STAT_u = torch.div(com_mmd, com_mmd_std)
                # COM_LOSS = - torch.log(com_mmd) + torch.log(com_mmd_std)
                
                if (epoch+1) % 100 == 0:
                    print("=" * 110)
                    print("epoch : ",epoch+1)
                    print("our mmd: ", -1 * com_mmd.item(), "our mmd_std: ", com_mmd_std.item(), "our statistic: ",
                    -1 * COM_STAT_u.item())
                    
                # Compute gradient
                COM_STAT_u.backward()
                # COM_LOSS.backward()
                
                # Update weights using gradient descent
                optimizer_COM.step()
                scheduler_COM.step(COM_STAT_u.item())
            else:
                break

    # Compute test power of MMD-D and baselines
    DK_H_u = np.zeros(N)
    COM_H_u = np.zeros(N)

    np.random.seed(1102)
    dk_count_u = 0
    com_count_u = 0

    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans)-N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1 = data_all_te[Ind_N_te]
        s2 = data_trans[Ind_te_v4]
        S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
        Sv = S.view(2 * N_te, -1)
        # MMD-D
        com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer_com(S), N_per, N_te, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

        # Gather results
        com_count_u = com_count_u + com_h_u
        print("\r","Ours:", com_count_u, "MMD: ", com_mmd_value_u, end="")
        COM_H_u[k] = com_h_u

    # Print test power of MMD-D and baselines
    print("Our Reject rate_u: ", COM_H_u.sum() / N_f)
    COM_Results[0, kk] = COM_H_u.sum() / N_f
    print(f"Test Power of Ours ({K} times): ")
    print(f"{COM_Results}")
    print(f"Average Test Power of Ours ({K} times): ")
    print("Ours: ", (COM_Results.sum(1) / (kk + 1))[0])

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
CUDA backend failed to initialize: Found cuPTI version 18, but JAX was built against version 20, which is newer. The copy of cuPTI that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Epoch 00115: reducing learning rate of group 0 to 1.0000e-05.
Epoch 00137: reducing learning rate of group 0 to 5.0000e-06.
Epoch 00158: reducing learning rate of group 0 to 2.5000e-06.
Epoch 00191: reducing learning rate of group 0 to 1.2500e-06.
epoch :  100
our mmd:  0.004952181130647659 our mmd_std:  0.00109793539726967 our statistic:  4.510448559143527
epoch :  100
our mmd:  0.00493897870182991 our mmd_std:  0.001030375246625797 our statistic:  4.793378643366815
Epoch 00230: reducing learning rate of group 0 to 6.2500e-07.
Epoch 00251: reducing learning rate of group 0 to 3.1250e-07.
Epoch 00272: reducing learning rate of group 0 to 1.5625e-07.
Epoch 00293: reducing learning rate of group 0 to 7.8125e-08.
Epoch 00314: reducing learning rate of group 0 to 3.9063e-08.
Epoch 00335: reducing learning rate of group 0 to 1.9531e-08.
epoch :  200
our mmd:  0.004211977124214172 our mmd_std:  0.0009826677304082057 our statistic:  4.28626787455867
epoch :  200
our mmd:  0.005936749279499054

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 19 MMD:  5.2247196435928345e-06Our Reject rate_u:  0.19
Test Power of Ours (10 times): 
[[0.19 0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.19


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00033: reducing learning rate of group 0 to 1.0000e-05.
Epoch 00054: reducing learning rate of group 0 to 5.0000e-06.
Epoch 00075: reducing learning rate of group 0 to 2.5000e-06.
Epoch 00096: reducing learning rate of group 0 to 1.2500e-06.
Epoch 00149: reducing learning rate of group 0 to 6.2500e-07.
Epoch 00170: reducing learning rate of group 0 to 3.1250e-07.
Epoch 00191: reducing learning rate of group 0 to 1.5625e-07.
epoch :  100
our mmd:  0.000616920180618763 our mmd_std:  0.00024225442812969894 our statistic:  2.546579583216016
epoch :  100
our mmd:  0.0008243517950177193 our mmd_std:  0.000301079738870677 our statistic:  2.7379849541180974
Epoch 00212: reducing learning rate of group 0 to 7.8125e-08.
Epoch 00233: reducing learning rate of group 0 to 3.9063e-08.
Epoch 00268: reducing learning rate of group 0 to 1.9531e-08.
epoch :  200
our mmd:  0.0010300762951374054 our mmd_std:  0.0003353424941366634 our statistic:  3.0717141822104255
epoch :  200
our mmd:  0.000395031

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 95 MMD:  0.00018450245261192322Our Reject rate_u:  0.95
Test Power of Ours (10 times): 
[[0.19 0.95 0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.57


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 00037: reducing learning rate of group 0 to 1.0000e-05.
Epoch 00062: reducing learning rate of group 0 to 5.0000e-06.
Epoch 00083: reducing learning rate of group 0 to 2.5000e-06.
Epoch 00104: reducing learning rate of group 0 to 1.2500e-06.
Epoch 00125: reducing learning rate of group 0 to 6.2500e-07.
Epoch 00146: reducing learning rate of group 0 to 3.1250e-07.
Epoch 00167: reducing learning rate of group 0 to 1.5625e-07.
Epoch 00188: reducing learning rate of group 0 to 7.8125e-08.
epoch :  100
our mmd:  0.00028167199343442917 our mmd_std:  0.00012178541209041069 our statistic:  2.312854951998047
epoch :  100
our mmd:  0.0006137210875749588 our mmd_std:  0.00022831761759908712 our statistic:  2.6880145913777818
Epoch 00209: reducing learning rate of group 0 to 3.9063e-08.
Epoch 00230: reducing learning rate of group 0 to 1.9531e-08.
epoch :  200
our mmd:  0.00045423442497849464 our mmd_std:  0.0001892619343719301 our statistic:  2.4000305528202572
epoch :  200
our mmd:  0.0004

  0%|          | 0/100 [00:00<?, ?it/s]

 Ours: 16 MMD:  6.969878450036049e-05Our Reject rate_u:  0.16
Test Power of Ours (10 times): 
[[0.19 0.95 0.16 0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Ours (10 times): 
Ours:  0.4333333333333333


  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

#### Epoch 20 / lr: 0.0005
[[0.23 0.94 0.25 0.37 0.6  0.1  0.36 0.31 0.23 0.4 ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of Baselines (10 times): 
MMD-D:  0.379