Skip to content
/ VIBI Public

In-depth experiments for VIBI (Variational Information Bottleneck for Interpretability) for MNIST and CIFAR10 written in Python and PyTorch.

License

Notifications You must be signed in to change notification settings

willisk/VIBI

Repository files navigation

VIBI Experiments

This is a re-implementation of VIBI (arxiv, github) including experiments for MNIST and CIFAR10 written in Python and PyTorch.

To run the experiments, first clone this repository and install requirements.

git clone https://github.com/willisk/VIBI
cd VIBI
pip install -r requirements.txt

Run all experiments shown in results:

chmod +x run_experiments.sh
./run_experiments.sh

Otherwise run the script with passed arguments.

python train.py

optional arguments
  --dataset {MNIST,CIFAR10}
  --cuda                Enable cuda.
  --num_epochs NUM_EPOCHS
                        Number of training epochs for VIBI.
  --explainer_type {Unet,ResNet_2x,ResNet_4x,ResNet_8x}
  --xpl_channels {1,3}
  --k K                 Number of chunks.
  --beta BETA           beta in objective J = I(y,t) - beta * I(x,t).
  --num_samples NUM_SAMPLES
                        Number of samples used for estimating expectation over p(t|x).
  --resume_training     Recommence training vibi from last saved checkpoint.
  --save_best           Save only the best models (measured in valid accuracy).
  --save_images_every_epoch
                        Save explanation images every epoch.
  --jump_start          Use pretrained model with beta=0 as starting point.

VIBI Overview

The goal is to create interpretable explanations for black-box models. This is achieved by two neural network, the explainer and the approximator. The explainer network produces a probability distribution over the input chunks, given an input image. A relaxed k-hot vector is sampled from this distribution. This k-hot vector is used to create a masked input, which is then fed into the approximator network. The approximator network aims to match the probability distribution of the black-box model output. The whole idea builds heavily on L2X (Learning to explain). The only difference is that VIBI's additional term effectively increases the entropy of the distribution p(z), whereas L2X only optimizes for minimizing the cross-entropy H(p,q) between the black-box model's predictions and the approximator.

MNIST Example Results

Test Batch Explanation Distribution Top-k Explanation
MNIST_test_batch MNIST_best_distribution MNIST_best_top_k

Using explainer_model=Resnet4x, k=4, beta=0.01.

CIFAR10 Example Results

Test Batch Explanation Distribution Top-k Explanation
CIFAR10_test_batch_32 CIFAR10_test_distribution_32 CIFAR10_test_top_k_32

Using explainer_model=Unet, k=64, beta=0.001.

Green boxes indicate that the black-box model's prediction is correct, red boxes indicate incorrect predictions. The strength (calculated using 1 - JS(p,q)) of the outlining color gives feedback on how well the approximator's prediction (using top-k) fits the black-box model's output.

About

In-depth experiments for VIBI (Variational Information Bottleneck for Interpretability) for MNIST and CIFAR10 written in Python and PyTorch.

Topics

Resources

License

Stars

Watchers

Forks