# Convolutional Filter Visualization

In [ ]:

import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential

from neural_feature_importance.conv_callbacks import ConvVarianceImportanceKeras
from conv_viz_utils import build_model, rank_filters, plot_filters, accuracy_with_filters

def load_mnist():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0
    x_train = x_train[..., None]
    x_test = x_test[..., None]
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)
    return (x_train, y_train), (x_test, y_test), (28, 28, 1), (8, 8)

def load_digits_data():
    digits = load_digits()
    x = digits.images[..., None] / 16.0
    y = to_categorical(digits.target, 10)
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
    return (x_train, y_train), (x_test, y_test), (8, 8, 1), (3, 3)

DATASETS = {
    'mnist': load_mnist,
    'digits': load_digits_data,
}

def run_dataset(loader):
    (x_train, y_train), (x_test, y_test), input_shape, kernel_size = loader()
    model = build_model(input_shape, kernel_size)
    callback = ConvVarianceImportanceKeras()
    model.fit(x_train, y_train, epochs=5, batch_size=32, callbacks=[callback], verbose=1)

    scores = callback.feature_importances_
    weights = model.layers[0].get_weights()[0]
    heatmap = scores.reshape(weights.shape[:3])
    order = rank_filters(weights, heatmap, 0.0)

    conv_model = Sequential([model.layers[0]])
    example_out = conv_model.predict(x_test[:1], verbose=0)[0]
    plot_filters(weights, heatmap, example_out, order)

    for k in (2, 4, 6):
        acc = accuracy_with_filters(model, x_test, y_test, order[:k])
        print(f'Accuracy with top {k} filters:', acc)

for name, loader in DATASETS.items():
    print('Running', name)
    run_dataset(loader)
