# Generative Causal Explanations of Black-Box classifiers

## Setup

In [None]:
import IPython
from IPython.display import Image
import os

In [None]:
# Functions to display images

def show_image(fname):
    IPython.display.display(Image(filename = fname))

def show_all_images(directory):
    for file in os.listdir(directory):
        if file.endswith(".png"):
            show_image(os.path.join(directory, file))

## Generating figures from pretrained models

In [None]:
from generate_figures import generate_figures

### MNIST 3/8

In [None]:
# Generate all figures and save them in the figures/ folder
generate_figures(implementation = 'MNIST_38', filetype = 'png')

# Display figures
show_all_images('figures/MNIST_38')

In [None]:
show_all_images('figures/MNIST_38')

### MNIST 1/4/9

In [None]:
# Generate all figures and save them in the figures/ folder
generate_figures(implementation = 'MNIST_149', filetype = 'png')

# Display figures
show_all_images('figures/MNIST_149')

### FMNIST 0/3/4

In [None]:
# Generate all figures and save them in the figures/ folder
generate_figures(implementation = 'FMNIST_034', filetype = 'png')

# Display figures
show_all_images('figures/FMNIST_034')

## Training models

NOTE: GCEs take a long time to train on CPU. Also, by default, GCEs and visualization scripts always try to load the latest model; this means that if training is started but canceled before a checkpoint is created, things may break. When the training is done, the new "pretrained" model is then used to create the figures.

### MNIST 3/8

#### Classifier

In [None]:
%run mnist_classifier_train.py --classes 3 8 --max_epochs 20 \
                               --datasets traditional

#### GCE

In [None]:
%run mnist_cvae_train.py --classes 3 8  --max_steps 8000 \
                         --batch_size 64 --lr 5e-4 --Nalpha 100 --Nbeta 25 --K 1 --L 7 --lamb 0.05 \
                         --dataset traditional

#### Generate figures

In [None]:
generate_figures(implementation = 'MNIST_38', filetype = 'png')
show_all_images('figures/MNIST_38')

### MNIST 1/4/9

#### Classifier

In [None]:
%run mnist_classifier_train.py --classes 1 4 9 --max_epochs 30 \
                               --datasets traditional

#### GCE

In [None]:
%run mnist_cvae_train.py --classes 1 4 9  --max_steps 8000 \
                         --batch_size 64 --lr 5e-4 --Nalpha 75 --Nbeta 25 --K 2 --L 2 --lamb 0.1 \
                         --dataset traditional

#### Generate figures

In [None]:
generate_figures(implementation = 'MNIST_149', filetype = 'png')
show_all_images('figures/MNIST_149')

### FMNIST 0/3/4

#### Classifier

In [None]:
%run mnist_classifier_train.py --classes 0 3 4 --max_epochs 50 \
                               --datasets fashion --log_dir fmnist_cnn

#### GCE

In [None]:
%run mnist_cvae_train.py --classes 0 3 4  --max_steps 8000 \
                         --batch_size 32 --lr 1e-4 --Nalpha 100 --Nbeta 25 --K 2 --L 4 --lamb 0.05 \
                         --dataset fashion --log_dir fmnist_gce --classifier_path fmnist_cnn_034

#### Generate figures

In [None]:
generate_figures(implementation = 'FMNIST_034', filetype = 'png')
show_all_images('figures/FMNIST_034')