# Peering inside Histogram Bins: Histograms with Examples via Spark and Bokeh

<img src="https://github.com/pwais/oarphpy/blob/master/oarphpy_test/fixtures/test_histogram_with_examples_2_demo_click.png?raw=true" alt="hist-with-examples" style="width: 500px;"/>

A histogram is one of the most effective tools for exploring a new dataset.  In one graph, a histogram displays key information about the data's mean, variance, outliers, and periodic features.  Histograms are so important than several libraries make histogramming extremely easy:
 * In [Pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.hist.html), the user can immediately plot a histogram from a Dataframe in a Jupyter notebook with a single function call.
 * [Tensorboard](https://github.com/tensorflow/tensorboard/blob/master/docs/r1/histograms.md) features a powerful temporal histogramming tool that can be critical for monitoring the weights of neural networks during training and debugging high-dimensional optimization problems.
 * [Bokeh](https://demo.bokeh.org/selection_histogram) provides a web-based histogram plotter with interactive tools and a nice Python API.

Histogram plots are often both surprising and boring: some bins have more items than expected, some bins have fewer, and a lot of bins are empty.  It's not unreasonable to immediately want to ask: can we peer inside a bin?  Which examples from my dataset are actually in there?  And, since this histogram only shows one dimension of the data, what might be some other dimensions or factors that are common among most of the things in that bin?

This tutorial will show you how OarphPy's `HistogramWithExamplesPlotter` helps you do exactly that!  What do you need?
 1. A DataFrame (Pandas or Spark) with at least one numeric or categorical column.
 2. A Python function for visualizing a row (or some portion of a row).  For example, a function to convert a row to a pretty string or HTML visualization.
 3. A Jupyter Notebook (like this one!) or a Python script to render the Bokeh HTML plot and display and/or save it to disk.

Why `HistogramWithExamplesPlotter` ?
 * We'll use Spark to compute the actual histogram.  Spark provides multi-cpu (and even multi-machine) processing to make histogramming scale linearly.
 * We'll also use Spark to render visualizations for the bucket items.  Spark's RDD API helps accomodate arbitrary user visualization functions and runs computation in parallel (even across many machines).
 * We use Bokeh's Histogram tool because it supports the simple interactivity we need (click on a bucket to view examples) and plots work in any modern browser with no extra dependencies.


## A Motivating Example: Exploring Out-of-Distribution Robustness in MNIST

<img src="https://miro.medium.com/max/3744/1*SGPGG7oeSvVlV5sOSQ2iZw.png" width="600" />

The [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset is well-studied in Computer Vision and consists of thousands of small pictures of hand-written digits.  (New to MNIST? Suppose you're the Post Office and you want to train a Computer Vision model that can read the zipcode digits that people write on their mail.  MNIST has a sample of such handwritten digits).  Today, it's easy to train a convolutional neural network on MNIST and achieve over 98% accuracy.  We're going to do exactly that in the next notebook cell!  

But MNIST is a relatively small dataset versus all the digits people have ever written on paper.  How robust is a trained MNIST model to new data?  What if we don't have labels for that new data?  In this tutorial, we're going to use `HistogramWithExamplesPlotter` to examine the scores that an MNIST-trained model gives to "corrupted" data never seen at training time.

First, let's train a basic MNIST model using Keras:

In [None]:
# Basic MNIST ConvNet c/o Keras
# https://keras.io/examples/vision/mnist_convnet/

import numpy as np
from tensorflow import keras
from tensorflow.keras import layers

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

batch_size = 128
epochs = 5

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)

score = model.evaluate(x_test, y_test, verbose=0)
print("Test loss:", score[0])
print("Test accuracy:", score[1])

Now, let's run the trained model on the test set and pack the predictions into a Pandas DataFrame.  For each input, the network outputs a score for each class (the numbers 0 through 9).  We'll take a look at the raw network scores for the class "7," which is easy to confused for a "1". 

In [None]:
predictions = model.predict(x_test)
rows = []
for i, (x_i, pred_i) in enumerate(zip(x_test, predictions)):
    row = {}
    for classname, score in enumerate(pred_i):
        row[f'score_{classname}'] = score
    row['x_i'] = x_i.tolist()
    rows.append(row)

import pandas as pd
prediction_df = pd.DataFrame(rows)
prediction_df

Now let's histogram the network's scores for the 7 class:

In [None]:
%matplotlib inline 
prediction_df['score_7'].hist()

In [None]:
!cd /opt && (git clone https://github.com/google-research/mnist-c || echo "have mnist-c")
import sys
sys.path.append('/opt/mnist-c')

In [None]:
!pip3 install wand
!ln -s /opt/mnist-c/pessimal_noise_matrix ./pessimal_noise_matrix || echo "symlink placed"

In [None]:
import corruptions

# x_corrupted = corruptions.speckle_noise(x_test[100] * 255)
# x_corrupted = corruptions.speckle_noise(x_test[100] * 255)
x_corrupted = corruptions.glass_blur(x_test[100] * 255, severity=4)
x_corrupted.shape

In [None]:
from oarphpy.plotting import img_to_img_tag

img_html = img_to_img_tag(x_corrupted)

def show_html(html):
    from IPython.core.display import display, HTML
    display(HTML(html))
show_html(img_html)

In [None]:
!pip3 install tqdm
from tqdm import tqdm
x_test_c = np.zeros_like(x_test)
for i in tqdm(range(len(x_test_c))):
#     xform = corruptions.glass_blur
    xform = corruptions.rotate
    x_test_c[i] = (1. / 255) *xform(x_test[i] * 255, severity=4)


In [None]:
predictions_c = model.predict(x_test_c)
rows = []
for i, (x_i, pred_i) in enumerate(zip(x_test_c, predictions_c)):
    row = {}
    for classname, score in enumerate(pred_i):
        row[f'score_{classname}'] = score
    row['x_i'] = x_i.tolist()
    rows.append(row)


import pandas as pd
prediction_c_df = pd.DataFrame(rows)
prediction_c_df

In [None]:
prediction_c_df['score_7'].hist()

In [None]:
from oarphpy.spark import NBSpark

spark = NBSpark.getOrCreate()


In [None]:
from bokeh.plotting import figure
from bokeh.io import output_notebook
from bokeh.io import show as bokeh_show
output_notebook()


In [None]:
from oarphpy import plotting as pl
class MyPlotter(pl.HistogramWithExamplesPlotter):
    NUM_BINS = 20
    def display_bucket(self, sub_pivot, bucket_id, irows):
        MAX_TO_VIZ = 150
        
        from oarphpy.plotting import img_to_img_tag
        htmls = []
        for row in irows:
            x_i = np.array(row['x_i'])
            img_html = img_to_img_tag(x_i)
            htmls.append(img_html)
            
            if len(htmls) > MAX_TO_VIZ:
                break
        
        # Make a nice table
        N_COLS = 25
        from oarphpy.util import ichunked
        trs = [
            "<tr>%s</tr>" % ''.join("<td>%s</td>" % ihtml for ihtml in row)
            for row in ichunked(htmls, n=N_COLS) 
        ]
        table_html = "<table>%s</table>" % ''.join(trs)
        
        return bucket_id, table_html

plotter = MyPlotter()
fig = plotter.run(spark.createDataFrame(prediction_c_df), 'score_7')
bokeh_show(fig)

In [None]:
fig = plotter.run(spark.createDataFrame(prediction_c_df), 'score_1')
bokeh_show(fig)

More examples of `HistogramWithExamplesPlotter` https://drive.google.com/drive/folders/1dOmkPvdFiGBMaYEddx1KK5vmCeYl2CyV?usp=sharing 

In [None]:
fig = plotter.run(spark.createDataFrame(prediction_df), 'score_7')
bokeh_show(fig)