# Histograms with Examples
_Peering inside Histogram Bins with Spark and Bokeh_

<img src="https://github.com/pwais/oarphpy/blob/master/notebooks/hist-with-examples-diag.jpg?raw=true" alt="hist-with-examples" />

A histogram is one of the most effective tools for exploring a new dataset.  In one graph, a histogram displays key information about the data's mean, variance, outliers, and periodic features.  Histograms are so important than several libraries make histogramming extremely easy:
 * In [Pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.hist.html), the user can immediately plot a histogram from a Dataframe in a Jupyter notebook with a single function call.
 * [Tensorboard](https://github.com/tensorflow/tensorboard/blob/master/docs/r1/histograms.md) features a powerful temporal histogramming tool that can be critical for monitoring the weights of neural networks during training and debugging high-dimensional optimization problems.
 * [Bokeh](https://demo.bokeh.org/selection_histogram) provides a web-based histogram plotter with interactive tools and a nice Python API.

Histogram plots are often both surprising and boring: some bins have more items than expected, some bins have fewer, and a lot of bins are empty.  It's not unreasonable to immediately want to ask: can we peer inside a bin?  Which examples from my dataset are actually in there?  And, since this histogram only shows one dimension of the data, what might be some other dimensions or factors that are common among most of the things in that bin?

This tutorial will show you how OarphPy's `HistogramWithExamplesPlotter` helps you do exactly that!  What do you need?
 1. A DataFrame (Pandas or Spark) with at least one numeric or categorical column.
 2. A Python function for visualizing a row (or some portion of a row).  For example, a function to convert a row to a pretty string or HTML visualization.
 3. A Jupyter Notebook (like this one!) or a Python script to render the Bokeh HTML plot and display and/or save it to disk.

Why `HistogramWithExamplesPlotter` ?
 * We'll use Spark to compute the actual histogram.  Spark provides multi-cpu (and even multi-machine) processing to make histogramming scale linearly.
 * We'll also use Spark to render visualizations for the bucket items.  Spark's RDD API helps accomodate arbitrary user visualization functions and runs computation in parallel (even across many machines).
 * We use Bokeh's Histogram tool because it supports the simple interactivity we need (click on a bucket to view examples) and plots work in any modern browser with no extra dependencies.



## Notebook Setup

To run this notebook locally, try using the `oarphpy/full` dockerized environment:

```docker run -it --rm --net=host oarphpy/full:0.1.1 jupyter notebook --allow-root --ip="*"```

If you can't run the notebook locally, find an HTML-rendered copy [here](https://drive.google.com/file/d/1-uWxGQ7mrcY8aZMmBDc5AlV4kQAPLntR/view?usp=sharing).

In [None]:
import os
import sys
if 'google.colab' in sys.modules:
    !pip install oarphpy[spark]==0.1.1
    !pip install pyspark==3.3.2
    !apt-get update && apt-get install -y openjdk-11-jdk
    os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-11-openjdk-amd64'

## A Motivating Example: Exploring Out-of-Distribution Robustness in MNIST

<img src="https://miro.medium.com/max/3744/1*SGPGG7oeSvVlV5sOSQ2iZw.png" width="600" />
<center>An MNIST example input fed into a LeNet Network</center>


The [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset is well-studied in Computer Vision and consists of thousands of small pictures of hand-written digits.  (New to MNIST? Suppose you're the Post Office and you want to train a Computer Vision model that can read the zipcode digits that people write on their mail.  MNIST has a sample of such handwritten digits).  Today, it's easy to train a convolutional neural network on MNIST and achieve over 98% accuracy.  We're going to do exactly that in the next notebook cell!  

But MNIST is a relatively small dataset versus all the digits people have ever written on paper.  How robust is a trained MNIST model to new data?  What if we don't have labels for that new data?  In this tutorial, we're going to use `HistogramWithExamplesPlotter` to examine the scores that an MNIST-trained model gives to "corrupted" data never seen at training time.

First, let's train a basic MNIST model using Pytorch:

In [None]:
!pip install -v tqdm torch

In [None]:
# Basic MNIST ConvNet c/o Pytorch (with some small modifications noted)
# https://github.com/pytorch/examples/blob/40289773aa4916fad0d50967917b3ae8aa534fd6/mnist/main.py#L1

model_ckpt = '/opt/mnist_cnn.pt'

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from tqdm.notebook import tqdm

torch.manual_seed(1337)
if 'google.colab' in sys.modules:
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'

# We need this many epochs to get a nice bimodal score distribution for the 7 class
N_EPOCHS = 15

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        
        # Use softmax instead for easier interpretation of logits
        # output = F.log_softmax(x, dim=1)
        output = F.softmax(x, dim=1)

        return output
    
def train(model, train_loader, optimizer, epoch):
    model.train()
    iter_train = tqdm(enumerate(train_loader), desc='train_batches', total=len(train_loader))
    for batch_idx, (data, target) in iter_train:
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        
        # Push log into the loss instead of net output
        #loss = F.nll_loss(output, target)
        loss = F.nll_loss(torch.log(output), target)
        
        loss.backward()
        optimizer.step()
    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        iter_test = tqdm(test_loader, desc='test_batches', total=len(test_loader))
        for data, target in iter_test:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


train_kwargs = {'batch_size': 128}
test_kwargs = {'batch_size': 1024}
    
transform = transforms.Compose([
    transforms.ToTensor(),
#     transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('/opt/mnist-data', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('/opt/mnist-data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(DEVICE)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)

scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

if os.path.exists(model_ckpt):
    print(f"Resuming from existing checkpoint {model_ckpt}")
    print(f"To re-train, delete the checkpoint: $ rm {model_ckpt}")
    model.load_state_dict(torch.load(model_ckpt))
else:
    print(f"Training and saving to {model_ckpt}")

    for epoch in tqdm(range(1, N_EPOCHS + 1), desc='epoch'):
        train(model, train_loader, optimizer, epoch)
        test(model, test_loader)
        scheduler.step()


    torch.save(model.state_dict(), model_ckpt)

x_test = torch.cat([xx[0] for xx in test_loader])

def model_predict(x):
    model.cpu().eval()
    with torch.no_grad():
        prob = model(x)
        pred = prob.argmax(dim=1, keepdim=True)
    return prob, pred



Now, let's run the trained model on the test set and pack the predictions into a Pandas DataFrame.  For each input, the network outputs a score for each class (the numbers 0 through 9).  We'll take a look at the raw network scores for the class "7," which is easy to confused for a "1". 

In [None]:
prob, pred = model_predict(x_test)

rows = []
for i, (x_i, score_i, pred_i) in enumerate(zip(x_test, prob, pred)):
    row = {}
    for classname, score in enumerate(score_i):
        row[f'score_{classname}'] = score.item()
    row['x_i'] = x_i.squeeze().tolist()
    rows.append(row)

import pandas as pd
prediction_df = pd.DataFrame(rows)
prediction_df

Now let's histogram the network's scores for the 7 class.  Note that since the network has high accuracy, the scores are rather cleanly bi-modal.  For this stage, we'll use the `pandas` built-in `hist()` feature, which gives us a histogram (though without examples or other visualization).

In [None]:
%matplotlib inline 
prediction_df['score_7'].hist()

## Simulating Out-of-Distribution Samples using MNIST-C: Corrupted MNIST

MNIST-C is a benchmark dataset derived from MNIST that has synthetic corruptions.  For example, in MNIST-C, digits are rotated, blurred, speckled, etc.  We will take the normal MNIST model we trained above, run inference on MNIST-C examples, and examine how well the corruption-unaware network generalizes.  This scenario simulates a common situation in production where one has a trained model, lots of unlabeled data, and little tooling for measuring or inspecting error in the wild.  We'll see how `HistogramWithExamplesPlotter` can be a useful tool for quick exploration. 

First, let's get the code for using MNIST-C:

In [None]:
!cd /opt && ((git clone https://github.com/pwais/oarphpy-mirror-mnist-c && \
              cd oarphpy-mirror-mnist-c && git checkout bba57e4ccc282f106907c5239958e72298451ea7) || echo "have mnist-c")
import sys
sys.path.append('/opt/oarphpy-mirror-mnist-c')

# These are hard requirements of the package we need from above
!pip3 install scikit-image==0.19.3
!pip3 install wand scipy
!ln -s /opt/oarphpy-mirror-mnist-c/pessimal_noise_matrix ./pessimal_noise_matrix || echo "symlink placed"
!apt-get install -y libmagickwand-dev


Let's take a look at a single corrupted example:

In [None]:
import corruptions

# See more corruptions here: 
# https://github.com/google-research/mnist-c/blob/bba57e4ccc282f106907c5239958e72298451ea7/corruptions.py#L57 

x_to_corrupt = x_test[100].squeeze() * 255
# x_corrupted = corruptions.speckle_noise(x_to_corrupt)
# x_corrupted = corruptions.glass_blur(x_to_corrupt, severity=4)
x_corrupted = corruptions.rotate(x_to_corrupt, severity=4)

x_corrupted.shape

In [None]:
sys.path.append('/opt/oarphpy')
from oarphpy.plotting import img_to_img_tag

img_html = img_to_img_tag(x_to_corrupt)
img_html_c = img_to_img_tag(x_corrupted)

def show_html(html):
    from IPython.core.display import display, HTML
    display(HTML(html))
show_html('<b>Original:</b>' + img_html)
show_html('<b>Corrupted:</b>' + img_html_c)

Now, let's generated corrupted versions for the entire MNIST test set, and run inference of our earlier model on these corrupted examples:

In [None]:
import numpy as np
x_test_c = np.zeros_like(x_test)
for i in tqdm(range(len(x_test_c)), total=len(x_test_c)):
    # xform = corruptions.glass_blur
    xform = corruptions.rotate
    x_test_c[i][0] = (1. / 255) * xform(x_test[i].squeeze() * 255, severity=4)


Let's make sure that worked and also declare a utility function for visualizing digits:

In [None]:
NUM_TO_SHOW = 10

def unit_digit_to_img_tag(x):
    img_char = (255 * x).astype('uint8')
    return img_to_img_tag(img_char)

for r in x_test_c[:NUM_TO_SHOW, ...]:
    x_i = r.squeeze()
    show_html(unit_digit_to_img_tag(x_i))

Now let's run inference on the corrupted data!

In [None]:
prob_c, pred_c = model_predict(torch.from_numpy(x_test_c))

In [None]:
# Aggregate inference results into a dataframe
rows = []
for i, (x_i, score_i, pred_i) in enumerate(zip(x_test_c, prob_c, pred_c)):
    row = {}
    for classname, score in enumerate(score_i):
        row[f'score_{classname}'] = score.item()
    row['x_i'] = x_i.squeeze().tolist()
    rows.append(row)

import pandas as pd
prediction_c_df = pd.DataFrame(rows)
prediction_c_df

Ok, so how did the network score the `7` class in the corrupted data?

In [None]:
prediction_c_df['score_7'].hist()

Huh, that score distribution is still bi-modal, but is *much more uniform* than the plot we saw earlier.  Clearly the model is making mistakes due to the corruptions.  But what sorts of mistakes?  If we needed to select some of these examples to label, which would we choose?  Let's use `HistogramWithExamplesPlotter` to "peer inside" the histogram buckets of the plot above.

In [None]:
from oarphpy.spark import NBSpark
spark = NBSpark.getOrCreate()


In [None]:
from bokeh.plotting import figure
from bokeh.io import output_notebook
from bokeh.io import show as bokeh_show
output_notebook()


In [None]:
prediction_c_sdf = spark.createDataFrame(prediction_c_df)

# To see the text representation of a Spark Dataframe try:
# prediction_c_sdf.show()

In [None]:
from oarphpy import plotting as pl
class MyPlotter(pl.HistogramWithExamplesPlotter):
    NUM_BINS = 20
    def display_bucket(self, sub_pivot, bucket_id, irows):
        MAX_TO_VIZ = 50
        
        from oarphpy.plotting import img_to_img_tag
        htmls = []
        for row in irows:
            htmls.append(unit_digit_to_img_tag(np.array(row['x_i'])))
            
            if len(htmls) > MAX_TO_VIZ:
                break
        
        # Make a nice table
        N_COLS = 25
        from oarphpy.util import ichunked
        trs = [
            "<tr>%s</tr>" % ''.join("<td>%s</td>" % ihtml for ihtml in row)
            for row in ichunked(htmls, n=N_COLS) 
        ]
        table_html = "<table>%s</table>" % ''.join(trs)
        
        return bucket_id, table_html

plotter = MyPlotter()
fig = plotter.run(prediction_c_sdf, 'score_7')
bokeh_show(fig)

In [None]:
fig = plotter.run(spark.createDataFrame(prediction_c_df), 'score_1')
bokeh_show(fig)

The two figures above examine the model's inferences on the _corrupted_ dataset.  Let's use `HistogramWithExamplesPlotter` to visualize the inference results on the original MNIST dataset to compare:

In [None]:
fig = plotter.run(spark.createDataFrame(prediction_df), 'score_7')
bokeh_show(fig)

For more examples of `HistogramWithExamplesPlotter`, see these rendered HTML pages: https://drive.google.com/drive/folders/1dOmkPvdFiGBMaYEddx1KK5vmCeYl2CyV?usp=sharing 