In [1]:
# prepare model 
# prepare benign images & adversarial images -> value range should be the same
# extract simple NIC model 

In [2]:
from MNIST_models import *
import PI

meta_params = {
    'num_of_train_dataset': 1000,
    'num_of_test_dataset': 100,
    'is_flatten': False
}


PI = PI.PIInterface(meta_params)
model = load_model('store/MNIST_CNN.pt')
PI.set_model(model)
print('train acc:', PI.eval_model('train'))
print('test acc:', PI.eval_model('test'))

1000 100
(100, 1, 28, 28)




train acc: 0.991
test acc: 0.98


In [3]:
import pickle
prefix = 'store/'    
# LOAD 
adv_types = ['None', 'FGSM', 'JSMA', 'CWL2', 'LINFPGD', 'LINFBI', 'ENL1', 'ST']
set_of_train_dataset, set_of_test_dataset = [], []

for adv_type in adv_types:
    # extract from the store file 
    if adv_type == 'None': fn_name=prefix+'normal.txt'
    else: fn_name=prefix+adv_type+'.txt'
    fp = open(fn_name, 'rb')
    set_of_signatures = pickle.load(fp)
    
    # separate and store in for later training and evaluation 
    if adv_type == 'None': split_percentage = 0.8
    else: split_percentage = 0.8
    split_line = int(len(set_of_signatures)*split_percentage)
    train_set_of_signatures, test_set_of_signatures = set_of_signatures[:split_line], set_of_signatures[split_line:]
    set_of_train_dataset.append(train_set_of_signatures)
    set_of_test_dataset.append(test_set_of_signatures)
    fp.close()
#     set_of_signatures = np.array(set_of_signatures)
#     for i in range(4):
#         print(np.max(set_of_signatures[0][i][0]), np.min(set_of_signatures[0][i][0]))
    
    print(adv_type, len(set_of_signatures), len(train_set_of_signatures), len(test_set_of_signatures))

None 1000 800 200
FGSM 813 650 163
JSMA 966 772 194
CWL2 877 701 176
LINFPGD 978 782 196
LINFBI 970 776 194
ENL1 1000 800 200
ST 991 792 199


In [None]:
import torch
class NIC(nn.Module):
    def __init__(self):
        super(NIC, self).__init__()
        n = 128
        m = 32

        self.fc_compress_1 = nn.Linear(16*24*24, n)
        self.fc_compress_2 = nn.Linear(16*10*10, n)
        self.fc_compress_3 = nn.Linear(32*3*3, n)
        self.fc_compress_4 = nn.Linear(64, n)
        self.softmax = nn.Softmax()
        self.relu = nn.ReLU()
        self.output1 = nn.Linear(n, m)
        self.output2 = nn.Linear(m, 2)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x1, x2, x3, x4):
        x1 = x1.view(-1, 16*24*24)
        x2 = x2.view(-1, 16*10*10)
        x3 = x3.view(-1, 32*3*3)
        x4 = x4.view(-1, 64)
        
        x1 = self.relu(self.fc_compress_1(x1))
        x1 = self.dropout(x1)
        x2 = self.relu(self.fc_compress_2(x2))
        x2 = self.dropout(x2)
        x3 = self.relu(self.fc_compress_3(x3))
        x3 = self.dropout(x3)
        x4 = self.relu(self.fc_compress_4(x4))
        x4 = self.dropout(x4)

        s = torch.add(x1, x2)
        s = torch.add(s, x3)
        s = torch.add(s, x4)
        s = self.relu(s)
        s = self.dropout(s)
        
        s = self.relu(self.output1(s))
        s = self.dropout(s)
        
        s = self.output2(s)
        
        return self.softmax(s)
    
def test_guard_model(NIC, set_of_test_dataset, adv_types, verbose=False):
    NIC.eval()
    total_train_correct_count, total_train_count = 0, 0 
    for test_dataset, adv_type in zip(set_of_test_dataset, adv_types):
        current_count = 0
        for singatures in test_dataset:
            f1, f2, f3, f4 = preprocess(singatures)
            outputs = NIC.forward(f1, f2, f3, f4)
            if adv_type == 'None': label = torch.from_numpy(np.array([[1, 0]])).float()
            else: label = torch.from_numpy(np.array([[0, 1]])).float()

            prediction = (outputs.max(1, keepdim=True)[1]).item()     
            if adv_type == 'None': 
                if (prediction == 0): 
                    current_count += 1
            else: 
                if (prediction == 1): 
                    current_count += 1
            
        # record the current train set acc
        if verbose:
            if adv_type == 'None': 
                print('benign correct:', current_count, '/', len(test_dataset))
            else:
                print('adv (', adv_type, ') correct:', current_count, '/', len(test_dataset))

        total_train_correct_count += current_count
        total_train_count += len(test_dataset)

    acc = total_train_correct_count/total_train_count
    if verbose:
        print('acc:', acc)
        
        
    NIC.train()
    return acc
    
NIC = NIC()
NIC.train()
optimizer = torch.optim.Adam(NIC.parameters())
loss_func = nn.BCELoss()

epoches = 30

train_accs, test_accs, losses = [], [], []
set_train_sub_accs, set_test_sub_accs = [], []

for epoch in range(epoches):
    total_loss = None 
    # labeling ...
    train_dataset, train_labels = [], []
    for dataset, adv_type in zip(set_of_train_dataset, adv_types):
        for singatures in dataset:
            if adv_type == 'None': 
                for _ in range(2):
                    train_dataset.append(singatures)
                    label = torch.from_numpy(np.array([[1, 0]])).float()
                    train_labels.append(label)

            else: 
                train_dataset.append(singatures)
                label = torch.from_numpy(np.array([[0, 1]])).float()
                train_labels.append(label)

    # shuffling 
    shuffle_indexs = np.arange(len(train_dataset))
    np.random.shuffle(shuffle_indexs)

    # training 
    for index in shuffle_indexs:
        singatures, label = train_dataset[index], train_labels[index]
        f1, f2, f3, f4 = preprocess(singatures)
        outputs = NIC.forward(f1, f2, f3, f4)

        # for recording the training process 
        loss = loss_func(outputs, label)
        if total_loss is None: total_loss = loss 
        else: total_loss += loss

        # Optimization (back-propogation)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('epoch:', (epoch+1), 'loss:', total_loss.item())    
    train_acc = test_guard_model(NIC, set_of_train_dataset, adv_types, verbose=True)
    test_acc = test_guard_model(NIC, set_of_test_dataset, adv_types, verbose=True)
    print('train (acc):', train_acc)
    print('test (acc):', test_acc)
    print()
    






epoch: 1 loss: 1809.939697265625
benign correct: 722 / 800
adv ( FGSM ) correct: 650 / 650
adv ( JSMA ) correct: 769 / 772
adv ( CWL2 ) correct: 661 / 701
adv ( LINFPGD ) correct: 773 / 782
adv ( LINFBI ) correct: 745 / 776
adv ( ENL1 ) correct: 766 / 800
adv ( ST ) correct: 784 / 792
acc: 0.9665733574839454
benign correct: 180 / 200
adv ( FGSM ) correct: 163 / 163
adv ( JSMA ) correct: 193 / 194
adv ( CWL2 ) correct: 152 / 176
adv ( LINFPGD ) correct: 176 / 196
adv ( LINFBI ) correct: 176 / 194
adv ( ENL1 ) correct: 185 / 200
adv ( ST ) correct: 198 / 199
acc: 0.9349540078843627
train (acc): 0.9665733574839454
test (acc): 0.9349540078843627

epoch: 2 loss: 913.9996337890625
benign correct: 727 / 800
adv ( FGSM ) correct: 650 / 650
adv ( JSMA ) correct: 771 / 772
adv ( CWL2 ) correct: 675 / 701
adv ( LINFPGD ) correct: 780 / 782
adv ( LINFBI ) correct: 765 / 776
adv ( ENL1 ) correct: 786 / 800
adv ( ST ) correct: 787 / 792
acc: 0.9782644492013832
benign correct: 154 / 200
adv ( FGSM ) 

adv ( CWL2 ) correct: 172 / 176
adv ( LINFPGD ) correct: 187 / 196
adv ( LINFBI ) correct: 189 / 194
adv ( ENL1 ) correct: 198 / 200
adv ( ST ) correct: 198 / 199
acc: 0.9638633377135348
train (acc): 0.9934134694549646
test (acc): 0.9638633377135348

epoch: 14 loss: 266.3631896972656
benign correct: 765 / 800
adv ( FGSM ) correct: 650 / 650
adv ( JSMA ) correct: 772 / 772
adv ( CWL2 ) correct: 692 / 701
adv ( LINFPGD ) correct: 782 / 782
adv ( LINFBI ) correct: 776 / 776
adv ( ENL1 ) correct: 793 / 800
adv ( ST ) correct: 789 / 792
acc: 0.9911081837642022
benign correct: 164 / 200
adv ( FGSM ) correct: 163 / 163
adv ( JSMA ) correct: 194 / 194
adv ( CWL2 ) correct: 173 / 176
adv ( LINFPGD ) correct: 191 / 196
adv ( LINFBI ) correct: 191 / 194
adv ( ENL1 ) correct: 199 / 200
adv ( ST ) correct: 198 / 199
acc: 0.9678055190538765
train (acc): 0.9911081837642022
test (acc): 0.9678055190538765

epoch: 15 loss: 252.17510986328125
benign correct: 770 / 800
adv ( FGSM ) correct: 650 / 650
adv 