In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

In [39]:
import torch
from simclr import SimCLR
from model.resnet_simclr import ResNetSimCLR
import torch.backends.cudnn as cudnn
from tqdm.auto import tqdm
import numpy as np
import jax.numpy as jnp
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

from utils_HD import MatConvert, MMDu, TST_MMD_u

In [40]:
class Config:
    def __init__(self):
        self.cifar10 = "/data4/oldrain123/C2ST/data/cifar_data/cifar10"
        self.cifar10_1 = "/data4/oldrain123/C2ST/data/cifar_data/cifar10.1_v4_data.npy"
        self.arch = 'resnet18'
        self.workers = 12
        self.epochs = 200
        self.batch_size = 256
        self.lr = 0.0003
        self.weight_decay = 1e-4
        self.seed = None
        self.disable_cuda = False
        self.fp16_precision = False
        self.out_dim = 128
        self.log_every_n_steps = 100
        self.temperature = 0.07
        self.n_views = 2
        self.gpu_index = 0

# Now you can create an instance of Config and access the attributes
args = Config()

if not args.disable_cuda and torch.cuda.is_available():
        args.device = torch.device('cuda')
        cudnn.deterministic = True 
        cudnn.benchmark = True 
else:
    args.device = torch.device('cpu')
    args.gpu_index = -1

In [36]:
model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim)
optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10000, eta_min=0,
                                                        last_epoch=-1)

simclr_model = SimCLR(args = args, model = model, optimizer = optimizer, scheduler = scheduler)



In [37]:
checkpoint_path = '/home/oldrain123/C2ST/runs/Nov09_23-18-07_brl2/checkpoint_0200.pth.tar'
checkpoint = torch.load(checkpoint_path, map_location=args.device)

In [38]:
simclr_model.model.load_state_dict(checkpoint['state_dict'])
simclr_model.optimizer.load_state_dict(checkpoint['optimizer'])

In [41]:
np.random.seed(819)
torch.manual_seed(819)
torch.cuda.manual_seed(819)
torch.backends.cudnn.deterministic = True
is_cuda = True

In [42]:
# Parameter Settings
n_epochs = 1000
batch_size = 100
lr = 0.0002
img_size = 64
channels = 3
n = 1000

In [43]:
dtype = torch.float
device = torch.device("cuda:0")
cuda = True if torch.cuda.is_available() else False
N_per = 100 # permutation times
alpha = 0.05 # test threshold
N1 = n # number of samples in one set
K = 10 # number of trails
J = 1 # number of test locations
N = 100 # number of test sets
N_f = 100.0 # number of test sets (float)

In [44]:
# Naming variables
ep_OPT = np.zeros([K])
s_OPT = np.zeros([K])
s0_OPT = np.zeros([K])
T_org_OPT = torch.zeros([K,J,3,64,64]) # Record test locations obtained by MMD-D
DK_Results = np.zeros([1,K])

In [45]:
# Configure data loader
dataset_test = datasets.CIFAR10(root='/data4/oldrain123/C2ST/data/cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000,
                                             shuffle=True, num_workers=1)
# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('/data4/oldrain123/C2ST/data/cifar10_1/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [None]:
# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)
    
    # Initialize parameters

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), len(data_all), replace=False)
    Ind_te = Ind_tr
    train_data = []
    for i in Ind_tr:
       train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), len(data_T), replace=False)
    Ind_te_v4 = Ind_tr_v4
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]

    # Initialize optimizers
    optimizer_F = torch.optim.Adam(list(simclr_model.model.parameters()), lr=lr)
    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # ----------------------------------------------------------------------------------------------------
    #  Training deep networks for MMD-D (called simclr_model)
    # ----------------------------------------------------------------------------------------------------
    np.random.seed(seed=1102)
    torch.manual_seed(1102)
    torch.cuda.manual_seed(1102)
    for epoch in tqdm(range(n_epochs)):
        for i, (imgs, _) in enumerate(dataloader):
            if True:
                ind = np.random.choice(N1, imgs.shape[0], replace=False)
                Fake_imgs = New_CIFAR_tr[ind]
                # Adversarial ground truths
                valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
                fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

                # Configure input
                real_imgs = Variable(imgs.type(Tensor))
                Fake_imgs = Variable(Fake_imgs.type(Tensor))
                X = torch.cat([real_imgs, Fake_imgs], 0)
                Y = torch.cat([valid, fake], 0).squeeze().long()

                # ------------------------------
                #  Train deep network for MMD-D
                # ------------------------------
                # Initialize optimizer
                optimizer_F.zero_grad()
                # Compute output of deep network
                modelu_output = simclr_model.model(X)
                print(modelu_output)

In [46]:
# Repeat experiments K times (K = 10) and report average test power (rejection rate)
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)
    
    # Initialize deep networks for MMD-D (called simclr_model)
    # simclr_model = simclr_model()
    
    # Initialize parameters
    epsilonOPT = torch.log(MatConvert(np.random.rand(1) * 10 ** (-10), device, dtype))
    epsilonOPT.requires_grad = True
    sigmaOPT = MatConvert(np.ones(1) * np.sqrt(2 * 32 * 32), device, dtype)
    sigmaOPT.requires_grad = True
    sigma0OPT = MatConvert(np.ones(1) * np.sqrt(0.005), device, dtype)
    sigma0OPT.requires_grad = True

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), len(data_all), replace=False)
    Ind_te = Ind_tr
    train_data = []
    for i in Ind_tr:
       train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), len(data_T), replace=False)
    Ind_te_v4 = Ind_tr_v4
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]

    # Initialize optimizers
    optimizer_F = torch.optim.Adam(list(simclr_model.model.parameters()) + [epsilonOPT] + [sigmaOPT] + [sigma0OPT], lr=lr)

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    # ----------------------------------------------------------------------------------------------------
    #  Training deep networks for MMD-D (called simclr_model)
    # ----------------------------------------------------------------------------------------------------
    np.random.seed(seed=1102)
    torch.manual_seed(1102)
    torch.cuda.manual_seed(1102)

    # Run two-sample test on the training set
    # Fetch training data
    s1 = data_all[Ind_tr]
    s2 = data_trans[Ind_tr_v4]
    S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
    Sv = S.view(2 * N1, -1)
    # Run two-sample test (MMD-D) on the training set
    # dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=False)
    # com_h_u, com_threshold_u, com_mmd_value_u = TST_MMD_u(featurizer(S), N_per, N1, Sv, sigma, sigma0_u, ep, alpha, device, dtype, complete=True)

    # Record best epsilon, sigma and sigma_0
    ep_OPT[kk] = epsilonOPT.item()
    s_OPT[kk] = sigmaOPT
    s0_OPT[kk] = sigma0OPT

    # Compute test power of MMD-D and baselines
    DK_H_u = np.zeros(N)

    np.random.seed(1102)
    dk_count_u = 0

    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans)-N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1 = data_all_te[Ind_N_te]
        s2 = data_trans[Ind_te_v4]
        S = torch.cat([s1.cpu(), s2.cpu()], 0).cuda()
        Sv = S.view(2 * N_te, -1)
        # MMD-D
        dk_h_u, dk_threshold_u, dk_mmd_value_u = TST_MMD_u(simclr_model.model(S), N_per, N_te, Sv, sigmaOPT, sigma0OPT, epsilonOPT, alpha, device, dtype, complete=False)

        # Gather results
        dk_count_u = dk_count_u + dk_h_u
        print("\r","MMD-DK:", dk_count_u, "MMD: ", dk_mmd_value_u, end="")
        DK_H_u[k] = dk_h_u

    # Print test power of MMD-D and baselines
    print("DK Reject rate_u: ", DK_H_u.sum() / N_f)
    DK_Results[0, kk] = DK_H_u.sum() / N_f
    print(f"Test Power of DK ({K} times): ")
    print(f"{DK_Results}")
    print(f"Average Test Power of DK ({K} times): ")
    print("MMD-D: ", (DK_Results.sum(1) / (kk + 1))[0])

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]



 MMD-DK: 5 MMD:  -7.303032134586829e-08DK Reject rate_u:  0.05
Test Power of DK (10 times): 
[[0.05 0.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
Average Test Power of DK (10 times): 
MMD-D:  0.05


  0%|          | 0/100 [00:00<?, ?it/s]

 MMD-DK: 0 MMD:  3.837164399556059e-0797