In [None]:
import os
import cv2
import glob
import shap
import numpy as np

from time import time
from scipy.io import loadmat
from sklearn.utils import shuffle
from skimage.transform import resize
from matplotlib import pyplot as plt
from itertools import combinations, permutations

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.utils import multi_gpu_model
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.mixed_precision import experimental as mixed_precision
from tensorflow.keras.callbacks import LearningRateScheduler, ModelCheckpoint

os.environ["CUDA_VISIBLE_DEVICES"]= "0"
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [None]:
# distributed GPU
strategy = tf.distribute.MirroredStrategy()
print("Num devices: %d" % strategy.num_replicas_in_sync)

In [None]:
# mixed precision
policy = mixed_precision.Policy("mixed_float16")
mixed_precision.set_policy(policy)

In [None]:
# data parameters
num_classes = 10

In [None]:
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

print("x_train shape :", x_train.shape)
print("x_test shape :", x_test.shape)

In [None]:
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

print("y_train shape :", y_train.shape)
print("y_test shape :", y_test.shape)

In [None]:
# label smoothing
factor = 0.1
y_train = (1 - factor) * y_train + (factor / num_classes)

print("y_train shape :", y_train.shape)

In [None]:
# show a few examples
plt.figure(figsize=(5, 5))
for plot_idx in range(9):
    idx = np.random.randint(x_train.shape[0])
    plt.subplot(3, 3, plot_idx+1)
    plt.imshow(x_train[idx], interpolation="quadric")
plt.tight_layout()
plt.show()

In [None]:
# normalize
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

print("x_train shape :", x_train.shape)
print("x_test shape :", x_test.shape)

In [None]:
# ResNet50
with strategy.scope():
    resnet50 = ResNet50(include_top=False, weights="imagenet")

    inputs = layers.Input(shape=(x_train.shape[1], x_train.shape[2], 3))
    x = layers.experimental.preprocessing.Resizing(224, 224)(inputs)
    x = resnet50(x)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.5)(x, training=True)

    x = layers.Dense(num_classes)(x)
    outputs = layers.Activation("softmax", dtype="float32")(x)

    model = keras.Model(inputs=inputs, outputs=outputs)
    optimizer = keras.optimizers.SGD(learning_rate=1e-3, momentum=0.9)
    model.compile(loss="categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"])

In [None]:
def lr_decay(epoch):
    init_lr = 1e-3
    drop_rate = 0.85
    drop_epochs = 5
    
    # learning rate decay
    lr = init_lr * np.power(drop_rate, np.floor((1+epoch)/drop_epochs))
    
    if (1 + epoch) % drop_epochs == 0: 
        print('learning rate is decayed to %f' % lr)
        
    return lr

In [None]:
epochs = 70
batch_size = 256

lr_scheduler = LearningRateScheduler(lr_decay)
file_format = "./model/model-{epoch:02d}-{val_accuracy:.2f}.hdf5"
checkpoint = ModelCheckpoint(filepath=file_format, 
                             monitor="val_accuracy",
                             save_best_only=True,
                             mode="max")

history = model.fit(x_train, y_train, 
                    batch_size=batch_size, 
                    epochs=epochs, validation_split=0.1, 
                    callbacks=[lr_scheduler, checkpoint])

In [None]:
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend(["Train", "Val"])
plt.show()

plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.title("Accuracy")
plt.xlabel("Accuracy")
plt.ylabel("Loss")
plt.legend(["Train", "Val"])
plt.show()

In [None]:
model = load_model("./model/model-21-0.94.hdf5")
score = model.evaluate(x_test, y_test, verbose=0)

print("Test loss :", score[0])
print("Test accuracy :", score[1])

### Segmentation

In [None]:
from skimage.segmentation import felzenszwalb
from skimage.segmentation import mark_boundaries

In [None]:
img = x_test[29]

segments_fz = felzenszwalb(img, scale=100, sigma=0.5, min_size=50)
print("Felzenszwalb number of segments:", len(np.unique(segments_fz)))

plt.imshow(mark_boundaries(img, segments_fz))
plt.show()

### SHAP for Target

In [None]:
idx_to_label = ["Airplane", "Automobile", "Bird", "Cat", 
                "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"]

In [None]:
repeat = 3001

idx = 256
data = x_test[:idx]
label = y_test[:idx]
mask_value = np.mean(x_train[:1000])

count = np.zeros(data.shape[:-1])
importance = np.zeros(data.shape[:-1])

level = 1
level_set_prob = [0.5] * 2

In [None]:
def batch_segment(img, level):
    img_with_f = np.copy(img)
    img_without_f = np.copy(img)
    
    level_range = np.arange(level)
    scale_list = [50, 100, 150, 250, 500, 1200]
    
    for i in range(img.shape[0]):
        # segment
        scale = np.random.choice(scale_list, 1)[0]
        segments = felzenszwalb(img[i], scale=scale, sigma=0.5, min_size=20)
        
        # get the number of ones
        feat_dim = np.unique(segments).shape[0]
        level_set = np.concatenate([level_range, feat_dim-level_range-1])
        num_ones = np.random.choice(level_set, 1, p=level_set_prob)[0]
        
        # get mask cluster - without replacement
        mask_cluster = np.random.choice(np.arange(feat_dim), replace=False, size=feat_dim-num_ones)
        active_cluster = np.random.choice(mask_cluster, 1)
        
        # get mask cluster from segments
        for j in mask_cluster:
            row, col = np.where(segments == j)
            
            if j != active_cluster[0]:
                img_with_f[i, row, col] = mask_value    
            img_without_f[i, row, col] = mask_value
            
    return img_with_f, img_without_f

In [None]:
# determine m
start_time = time()

for i in range(0, repeat):
    data_with_f, data_without_f = batch_segment(data, level)

    # get softmax value
    pred_with_f = np.max(model.predict(data_with_f), axis=1)
    pred_without_f = np.max(model.predict(data_without_f), axis=1)

    # marginal Shapely value
    diff = pred_with_f - pred_without_f
    f_idx = np.where(data_with_f[:, :, :, 0] != data_without_f[:, :, :, 0])
    
    # get importance
    for batch_idx in range(data.shape[0]):
        idx_in_batch = np.where(f_idx[0] == batch_idx)[0]
        row = f_idx[1][idx_in_batch]
        col = f_idx[2][idx_in_batch]
        
        importance[batch_idx, row, col] += diff[batch_idx]
        count[batch_idx, row, col] += 1
    
    if i % 1000 == 0 and i > 0: # verbose
        print("For %dth iter, it takes %0.3f sec" % (i, time()-start_time))
        
        for data_num in range(0, data.shape[0]):
            img_name = "./images/fig_segment_iter" + str(i) + "_data" + str(data_num) + ".png"
            heatmap = np.divide(importance[data_num], count[data_num])
            pred = model.predict(np.expand_dims(data[data_num], axis=0))
            
            true_obj = idx_to_label[int(np.argmax(label[data_num]))]
            pred_obj = idx_to_label[int(np.argmax(pred, axis=1))]
            
            ax1 = plt.subplot(1, 2, 1)
            im1 = plt.imshow(data[data_num], interpolation="quadric")
            plt.title("True: %s, Pred: %s" % (true_obj, pred_obj))
            
            ax2 = plt.subplot(1, 2, 2)
            im2 = plt.imshow(heatmap, cmap="coolwarm")
            plt.colorbar(im2, fraction=0.046, pad=0.04)
            plt.tight_layout()
            plt.savefig(img_name, dpi=300, bbox_inches="tight")
            plt.show()
        
        print("\n")
        start_time = time()
    
importance = np.divide(importance, count)