In [None]:
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from two_layer_net import TwoLayerNet
from keras.datasets import mnist
from common.functions import softmax

Using TensorFlow backend.


In [None]:
# データの読み込み
(x_train, t_train), (x_test, t_test) = mnist.load_data()

# 1次元へ整形
x_train, x_test = x_train.reshape(-1, 784), x_test.reshape(-1, 784)

# 正規化
x_train, x_test = x_train.astype(np.float32) / 255.0, x_test.astype(np.float32) / 255.0

In [None]:
x_base = x_train[0].copy()
x_base

In [None]:
plt.imshow(x_base.reshape(28,28))
plt.title("x_base")
plt.show()

In [None]:
network = TwoLayerNet(input_size=784, hidden_size=50, output_size=10)

# パラメータのロード
network.load_params("params_nn.pkl")
print("loaded Network Parameters!")

In [None]:
pred = network.predict(x_base)
pred

In [None]:
pred_label = np.argmax(pred)
pred_label

In [None]:
pred_score = max(softmax(pred))
pred_score

In [None]:
x = x_base.copy().reshape(1, 784)
grads = network.gradient(x, np.array([5]))
grads

In [None]:
np.set_printoptions(threshold=np.inf)
print(grads['W1'])

In [None]:
print(grads['W1'].shape)

In [None]:
grads_w1_sum = np.sum(grads['W1'], axis=0)
grads_w1_sum

In [None]:
false_list = [[], [], [], [], [], [], [], [], [], []]
false_list

In [None]:
true_list = [[], [], [], [], [], [], [], [], [], []]
true_list

In [None]:
for i, t in enumerate(x_test):
    pre_l = np.argmax(network.predict(t))
    true_l = t_test[i]
    if pre_l == true_l:
        true_list[true_l].append(i)
    else:
        false_list[true_l].append(i)

In [None]:
len(true_list[5])

In [None]:
len(false_list[5])

In [None]:
def all_g(nums, label):
    li = np.zeros(50)
    for i in nums:
        g = network.gradient(x_test[i].reshape(1, 784), np.array([label]))
        li += np.sum(g['W1'], axis=0)
        
    return li
    

In [None]:
def compare_g(t_nums, f_nums, label):
    t = all_g(t_nums, label)
    f = all_g(f_nums, label)
    
    w = 0.4
    plt.figure(figsize=(30, 10))
    plt.bar(np.arange(50), t, color='b', label='True', width=w, align="center")
    plt.bar(np.arange(50)+w, f, color='g', label='False', width=w, align="center")
    plt.legend(loc=2)
    plt.show()
    
    print("True TOP10")
    
    for i in range(10):
        print(np.argsort(abs(t))[::-1][i], np.sort(abs(t))[::-1][i])
        
    print("===============")
    
    print("False TOP10")
    
    for i in range(10):
        print(np.argsort(abs(f))[::-1][i], np.sort(abs(f))[::-1][i])
        
        
    
    return t - f

In [None]:
def all_c(nums, label):
    li = np.zeros(50)
    for i in nums:
        g = network.gradient(x_test[i].reshape(1, 784), np.array([label]))
        li += (np.sum(g['W1'], axis=0) != 0)
        
    return li
    

In [None]:
def compare_c(t_nums, f_nums, label):
    t = all_c(t_nums, label)
    f = all_c(f_nums, label)
    
    w = 0.4
    plt.figure(figsize=(30, 10))
    plt.bar(np.arange(50), t, color='b', label='True', width=w, align="center")
    plt.bar(np.arange(50)+w, f, color='g', label='False', width=w, align="center")
    plt.legend(loc=2)
    plt.show()
    
    print("True TOP10")
    
    for i in range(10):
        print(np.argsort(t)[::-1][i], np.sort(t)[::-1][i])
        
                
    print("===============")
    
    print("False TOP10")
    
    for i in range(10):
        print(np.argsort(f)[::-1][i], np.sort(f)[::-1][i])
    
    diff_tf = t- f
    
    plt.cla()
    plt.bar(np.arange(50), (diff_tf))
    plt.show()
    
    print("True - False TOP10")
    
    for i in range(10):
        print(np.argsort(abs(diff_tf))[::-1][i], np.sort(abs(diff_tf)[::-1][i]))
    
    return t - f

In [None]:
for i in range(10):
    print(i)
    true_nums = true_list[i][:10]
#     print(len(true_nums))
    false_nums = false_list[i][:10]
#     print(len(false_nums))
    
    compare_g(true_nums, false_nums, i)
    compare_c(true_nums, false_nums, i)
    
    print("#############################")

In [None]:
def all_g_x(nums, label):
    li = np.zeros(784)
    for i in nums:
        g = network.gradient(x_test[i].reshape(1, 784), np.array([label]))
        li += np.sum(g['W1'], axis=1)
        
    return li
    

In [None]:
def all_c_x(nums, label):
    li = np.zeros(784)
    for i in nums:
        g = network.gradient(x_test[i].reshape(1, 784), np.array([label]))
        li += (np.sum(g['W1'], axis=1) != 0)
        
    return li
    

In [None]:
def compare_g_x(t_nums, f_nums, label):
    t = all_g_x(t_nums, label)
    f = all_g_x(f_nums, label)
    
    
    fig = plt.figure(figsize=(20, 20))
    ax1 = fig.add_subplot(1, 3, 1)
    ax2 = fig.add_subplot(1, 3, 2)
    ax3 = fig.add_subplot(1, 3, 3)
    ax1.set_title("True sum of grads")
    ax1.imshow(t.reshape(28,28))
    ax2.set_title("False sum of grads")
    ax2.imshow(f.reshape(28,28))
    ax3.set_title("True- False sum of grads")
    ax3.imshow((t-f).reshape(28,28))
    
    plt.show()
    
    t_img = np.zeros(784)
    f_img = np.zeros(784)
    
    t_idx = []
    f_idx = []
    
    for i in range(50):
        t_idx.append(np.argsort(abs(t))[::-1][i])
        f_idx.append(np.argsort(abs(f))[::-1][i])
    
    for t_i, f_i in zip(t_idx, f_idx):
        t_img[t_i] = t[t_i]
        f_img[f_i] = f[t_i]
        
    f_img = f_img.clip(min=-1, max=1)
    
    fig = plt.figure(figsize=(20, 20))
    ax4 = fig.add_subplot(1, 3, 1)
    ax5 = fig.add_subplot(1, 3, 2)
    ax6 = fig.add_subplot(1, 3, 3)
    ax4.set_title("True sum of grads")
    ax4.imshow(t_img.reshape(28,28))
    ax5.set_title("False sum of grads")
    ax5.imshow(f_img.reshape(28,28))
    ax6.set_title("True - False sum of grads")
    ax6.imshow((t_img - f_img).reshape(28,28))
    
    plt.show()
    
    return t_img, f_img
    

In [None]:
def compare_c_x(t_nums, f_nums, label):
    t = all_c_x(t_nums, label)
    f = all_c_x(f_nums, label)
    
    fig = plt.figure(figsize=(20, 20))
    ax1 = fig.add_subplot(1, 3, 1)
    ax2 = fig.add_subplot(1, 3, 2)
    ax3 = fig.add_subplot(1, 3, 3)
    ax1.set_title("True counts of grads")
    ax1.imshow(t.reshape(28,28))
    ax2.set_title("False counts of grads")
    ax2.imshow(f.reshape(28,28))
    ax3.set_title("True- False counts of grads")
    ax3.imshow((t-f).reshape(28,28))
    
    plt.show()
    
    t_img = np.zeros(784)
    f_img = np.zeros(784)
    
    t_idx = []
    f_idx = []
    
    for i in range(50):
        t_idx.append(np.argsort(t)[::-1][i])
        f_idx.append(np.argsort(f)[::-1][i])
    
    for t_i, f_i in zip(t_idx, f_idx):
        t_img[t_i] = t[t_i]
        f_img[f_i] = f[t_i]
        
    
    fig = plt.figure(figsize=(20, 20))
    ax4 = fig.add_subplot(1, 3, 1)
    ax5 = fig.add_subplot(1, 3, 2)
    ax6 = fig.add_subplot(1, 3, 3)
    ax4.set_title("True counts of grads")
    ax4.imshow(t_img.reshape(28,28))
    ax5.set_title("False counts of grads")
    ax5.imshow(f_img.reshape(28,28))
    ax6.set_title("True - False counts of grads")
    ax6.imshow((t_img - f_img).reshape(28,28))
    
    plt.show()
    
    return t_img, f_img
        
    
    

In [None]:
def check_miss(idxs, adv, label):
    nums = np.array(list(map(lambda x: x_test[x], idxs)))
    x_batch = nums - adv
    
    x_batch = x_batch.clip(min=0)
    
    fig = plt.figure(figsize=(15, 15))
    
    miss = 0
    
    for i, x in enumerate(x_batch):
        ax = fig.add_subplot(5, 5, i+1)
        ax.imshow(x.reshape(28,28), 'gray')
        pre = network.predict(x)
        pre_label = np.argmax(pre)
        pre_score = max(pre)
        ax.set_title(pre_label)
        
        if not pre_label == label:
            miss += 1

    plt.show()
    
    return miss/len(nums)

In [None]:
def check_miss_g(idxs, adv, label):
    nums = np.array(list(map(lambda x: x_test[x], idxs)))
    x_batch = nums + adv
    
    fig = plt.figure(figsize=(15, 15))
    
    miss = 0
    
    for i, x in enumerate(x_batch):
        ax = fig.add_subplot(5, 5, i+1)
        ax.imshow(x.reshape(28,28), 'gray')
        pre = network.predict(x)
        pre_label = np.argmax(pre)
        pre_score = max(pre)
        ax.set_title(pre_label)
        
        if not pre_label == label:
            miss += 1

    plt.show()
    
    return miss/len(nums)

In [None]:
for i in range(10):
    print(i)
    true_nums = true_list[i][:10]
    false_nums = false_list[i][:10]
    
    print("入力値への勾配の合計")
    
    t_img_g, f_img_g = compare_g_x(true_nums, false_nums, i)
    
    pre_t_g = network.predict(t_img_g)
    print(np.argmax(pre_t_g))
    print(max(softmax(pre_t_g)))
    
    pre_f_g = network.predict(f_img_g)
    print(np.argmax(pre_f_g))
    print(max(softmax(pre_f_g)))
    
    pre_tf_g = network.predict(t_img_g - f_img_g)
    print(np.argmax(pre_tf_g))
    print(max(softmax(pre_tf_g)))
    
    adv1 = x_test[true_nums[0]]+f_img_g
    
    print(adv1 == x_test[true_nums[0]])
    exit()
    
    adv1 = adv1.clip(min=0, max=1)
    
    fig1 = plt.figure(figsize=(20, 20))

    ax1 = fig1.add_subplot(1, 2, 1)
    ax2 = fig1.add_subplot(1, 2, 2)
    ax1.set_title("test x")
    ax1.imshow(x_test[true_nums[0]].reshape(28,28))
    ax2.set_title("test + false_g")
    ax2.imshow(adv1.reshape(28,28))
    plt.show()
    
    print("test x")
    print(np.argmax(softmax(network.predict(x_test[true_nums[0]]))))
    print(max(softmax(network.predict(x_test[true_nums[0]]))))
    
    print("test + false_g")
    print(np.argmax(softmax(network.predict(adv1))))
    print(max(softmax(network.predict(adv1))))
    
    print(check_miss_g(true_nums, adv1, i))
    
    print("--------------------------")
    
    print("関わった入力値への回数")
    
    t_img_c, f_img_c = compare_c_x(true_nums, false_nums, i)
    
    f_img_c = (f_img_c != 0)
#     print(f_img_c)
#     print(x_test[true_nums[0]]-f_img_c)
    
    pre_t_c = network.predict(t_img_c)
    print(np.argmax(pre_t_c))
    print(max(softmax(pre_t_c)))
    
    pre_f_c = network.predict(f_img_c)
    print(np.argmax(pre_f_c))
    print(max(softmax(pre_f_c)))
    
    pre_tf_c = network.predict(t_img_c - f_img_c)
    print(np.argmax(pre_tf_c))
    print(max(softmax(pre_tf_c)))
    
    adv = x_test[true_nums[0]]-f_img_c
    adv = adv.clip(min=0)
    
    fig2 = plt.figure(figsize=(20, 20))

    ax1 = fig2.add_subplot(1, 2, 1)
    ax2 = fig2.add_subplot(1, 2, 2)
    ax1.set_title("test x")
    ax1.imshow(x_test[true_nums[0]].reshape(28,28), 'gray')
    ax2.set_title("test - false_c")
    ax2.imshow(adv.reshape(28,28), 'gray')
    plt.show()
    
    print("test x")
    print(np.argmax(softmax(network.predict(x_test[true_nums[0]]))))
    print(max(softmax(network.predict(x_test[true_nums[0]]))))
    
    print("test - false_c")
    print(np.argmax(softmax(network.predict(adv))))
    print(max(softmax(network.predict(adv))))
    
    print("////////////")
    print("other nums")
    
    print(check_miss(true_nums, f_img_c, i))
    
    print("#############################")
    
    