In [1]:
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
import os
import base64
from json_tricks import dumps, loads


In [None]:
from os import getenv

# quick-and-dirty way of detecting if we are running on Binder
def running_on_binder():
    return getenv('BINDER_SERVICE_HOST',None) is not None

In [None]:
from jupyter_dash import JupyterDash

if running_on_binder():
    # needed when running on Binder
    JupyterDash.infer_jupyter_proxy_config()

from IPython.display import display, clear_output, HTML

display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
%autosave 0
clear_output()

In [None]:
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
try:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
    # Invalid device or cannot modify virtual devices once initialized.
    print("No GPU?")
    clear_output()

In [None]:
print("Setting up pre-trained keras ResNet50 model")
model = ResNet50(weights='imagenet')
print("Model ready")
clear_output()

In [None]:
import h5py

In [None]:
import urllib.request

if not os.path.exists('val_preds.h5'):
    print("Downloading MICP calibration data (190MB) - be patient!")
    urllib.request.urlretrieve("https://cml.rhul.ac.uk/people/ptocca/ILSVRC2012-CP/val_preds.h5",
                               'val_preds.h5')
    clear_output()

In [None]:
with h5py.File('val_preds.h5', 'r') as f:
    preds_cal = f['preds'][:]


In [None]:
def pValues(calibrationAlphas, testAlphas, randomized=False):
    testAlphas = np.array(testAlphas)
    sortedCalAlphas = np.sort(calibrationAlphas)

    leftPositions = np.searchsorted(sortedCalAlphas, testAlphas)

    if randomized:
        rightPositions = np.searchsorted(sortedCalAlphas, testAlphas, side='right')
        ties = rightPositions - leftPositions + 1  # ties in cal set plus the test alpha itself
        randomizedTies = ties * np.random.uniform(size=len(ties))
        return (len(calibrationAlphas) - rightPositions + randomizedTies) / (len(calibrationAlphas) + 1)
    else:
        return (len(calibrationAlphas) - leftPositions + 1) / (len(calibrationAlphas) + 1)

In [None]:
def rev_score(scores, label):
    return -scores[:, label]


def ratio_max_to_hypothetical(scores, label):
    mask = np.ones(scores.shape[1], dtype=np.bool)
    mask[label] = False

    return np.amax(scores, axis=1, where=mask, initial=0) / scores[:, label]


In [None]:
def micp_pValues(scores_cal, scores_test, y_cal, ncm):
    """Compute p-values for a Mondrian Inductive Conformal Predictor
    scores_cal,scores_test: arrays of shape (objects,labels) of scores for
                            calibration set and test set
    y_cal: array of shape (objects,) with the labels of the calibration set
    ncm: function of scores and label, computing the NCM"""

    micp_pValues = []

    for i in range(scores_test.shape[1]):
        ncm_cal = ncm(scores_cal[y_cal == i], i)
        ncm_test = ncm(scores_test, i)
        p_i = pValues(ncm_cal, ncm_test)

        micp_pValues.append(p_i)

    micp_pValues = np.array(micp_pValues)

    return micp_pValues


In [None]:
# ilsrvc_dir = "/mnt/d/Research/ILSVRC2012/"
ilsrvc_dir = "."

In [None]:
gt_cal_file = os.path.join(ilsrvc_dir, "cal_gt.txt")
gt_test_file = os.path.join(ilsrvc_dir, "test_gt.txt")
lbls_file = os.path.join(ilsrvc_dir, "labels.txt")

In [None]:
n_to_ki = {}
ki_to_synset = {}
with open(os.path.join(ilsrvc_dir, 'synset_words.txt')) as f:
    for i, l in enumerate(f):
        ki_to_synset[i] = l[10:].strip()

In [None]:
ground_truth_ki_cal = np.loadtxt(gt_cal_file, dtype=np.int)
ground_truth_ki_test = np.loadtxt(gt_test_file, dtype=np.int)

In [None]:
import io

In [None]:
import PIL.Image
import joblib

In [None]:
mem = joblib.Memory('/dev/shm/joblib', verbose=0)


@mem.cache
def getImage(url):
    img_data = PIL.Image.open(urllib.request.urlopen(url))
    if img_data.mode != 'RGB':
        img_data = img_data.convert('RGB')
    img_data = img_data.resize((224, 224), resample=PIL.Image.NEAREST)
    return img_data


In [None]:
def get_prob_sets(preds, eps):
    preds_as = np.argsort(-preds, axis=1)
    preds_cumul = np.cumsum(np.take_along_axis(preds, preds_as, axis=1), axis=1)

    set_masks = preds_cumul < 1 - eps
    set_masks[:, 1:] = set_masks[:, :-1]
    set_masks[:, 0] = True

    sets = [(pr_as[m], pr[pr_as[m]]) for pr_as, m, pr in zip(preds_as, set_masks, preds)]
    return sets


In [None]:
import dash
import dash_core_components as dcc
import dash_html_components as html
import dash_table
from dash.dependencies import Input, Output, State
from dash.exceptions import PreventUpdate
from dash_extensions.callback import DashCallbackBlueprint

# Look at dash_reusable_components.py
# There are several instances in demo apps in the Dash Gallery.
# There are good ideas and useful snippets. For instance, you can find how to draw an image in dash-image-processing

# Also, there is dash_bootstrap_components

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def draw_ncm_histo(ecdf_ncm=None, ncm_test=None, sel_p_val=None, label_synset=None):
    fig = make_subplots(specs=[[{"secondary_y": True}]])
    if ecdf_ncm is not None:
        fig.add_trace(go.Scatter(x=list(ecdf_ncm.x), y=list(ecdf_ncm(ecdf_ncm.x)), name='NCM'),
                      secondary_y=False)
        fig.add_trace(go.Scatter(x=[ncm_test[0]], y=[sel_p_val],
                                 mode='markers', name=label_synset),
                      secondary_y=True)
    fig.update_layout(title=go.layout.Title(text="Histogram of NCM for '%s'" % label_synset,
                                            x=0.5, y=0.85,
                                            xanchor='center', yanchor='top'),
                      legend=dict(x=0.6, y=0.1))

    fig.update_yaxes(title_text="ECDF of NCM of calibration examples", range=[0, 1], secondary_y=False)
    fig.update_yaxes(title_text="p-value for '%s'" % label_synset, range=[1, 0], secondary_y=True)

    fig.update_xaxes(title_text="NCM")

    return fig

In [None]:
# see https://github.com/plotly/dash/issues/242
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
external_scripts = [r'https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-MML-AM_CHTML']
app = JupyterDash(__name__,
                       external_stylesheets=external_stylesheets,
                       external_scripts=external_scripts)

dcb = DashCallbackBlueprint()  # Needed to have two callbacks for the same Output

heading = html.H1("Demo of Conformal Prediction using ResNet50 on ImageNet data", style={'textAlign': 'center', 'margin-top': "2rem"})
desc = "ImageNet label"
desc_div = html.Div(id='desc_div', children=["ImageNet label:  ", desc],
                    style={'display': 'flex', 'padding': '5px',
                           'flex-direction': 'column', 'justify-content': 'center',
                           'align-items': 'center',
                           'width': 500,
                           'font-size': 18})

style_grid = {'display': 'grid', 'grid-template-columns': 'auto 70% auto', 'align-items': 'center',
              'justify-content': 'center', 'width': "80%", 'margin': '5px'}

# Cannot make this work
style_flex = {'display': 'flex', 'flex-direction': 'row', 'align-items': 'center',
              'justify-content': 'center', 'width': "80%"}


def create_slider_readout(id, dcb, label, initVal, slider_kwargs, input_kwargs):
    slider = html.Div([label,
                       dcc.Slider(id=id + "_slider", className='centered-slider', **slider_kwargs),
                       dcc.Input(id=id + "_input", **input_kwargs),
                       dcc.Store(id=id + "_sync", data=initVal)],
                      style=style_grid)

    @dcb.callback(Output(id + "_sync", "data"), [Input(id + "_input", "value")])
    def sync_input_value(value):
        return value

    @dcb.callback(Output(id + "_sync", "data"), [Input(id + "_slider", "value")])
    def sync_slider_value(value):
        return value

    @dcb.callback([Output(id + "_input", "value"), Output(id + "_slider", "value")], 
                  [Input(id + "_sync", "data"), Input(id + "_sync", "modified_timestamp")],
                  [State(id + "_input", "value"), State(id + "_slider", "value")])
    def update_components(current_value, _, input_prev, slider_prev):
        # Update only inputs that are out of sync (this step "breaks" the circular dependency).
        input_value = current_value if current_value != input_prev else dash.no_update
        slider_value = current_value if current_value != slider_prev else dash.no_update
        return [input_value, slider_value]

    return slider


dcb = DashCallbackBlueprint()  # Needed to have two callbacks for the same Output

In [None]:
ILSRVC_image = html.Img(id="ILSRVC_image",
                        src='data:image/png;base64, ', style={'margin': 5})

pic_idx = create_slider_readout(id="pic_idx", dcb=dcb, label="Image index: ", initVal=1000,
                                slider_kwargs=dict(min=1, max=2000, step=1),
                                input_kwargs=dict(type="text", style={'width': '4em'}))

imagenet_div = html.Div([ILSRVC_image, desc_div, pic_idx],
                        style={'display': 'flex', 'flex-direction': 'column', 'align-items': 'center',
                               'justify-content': 'center', 'width': "80%",
                               'border': '1px solid black', 'border-radius': '5px',
                               'margin': 10, 'background-color': 'white', 'padding': 10})
# Now Let's build the output pane

eps = create_slider_readout(id="eps", dcb=dcb, label="Significance level:", initVal=0.2,
                            slider_kwargs=dict(min=0.0, max=1.0, step=0.01),
                            input_kwargs=dict(type="text", style={'width': '4em'}))

pr_data = [{"label": l, "Prob": l / 10.0} for l in range(10)]
pval_data = [{"label": l, "p-value": l / 10.0} for l in range(10)]

resnet50_div = html.Div([
    html.Div(id="ResNet50 heading", children="No image", style={'font-size': 16}),
    dash_table.DataTable(id='ResNet50', data=pr_data,
                         columns=[{"name": "label", "id": "label"}, {"name": "Prob", "id": "Prob"}],
                         style_cell={'textAlign': 'left', 'textOverflow': 'ellipsis',
                                     'maxWidth': '5em', 'overflow': 'hidden', 'font-size': 14},
                         style_cell_conditional=[
                             {'if': {'column_id': 'Prob'},
                              'width': '70px'},
                         ],
                         cell_selectable=False,
                         fixed_rows={'headers': True},
                         style_as_list_view=True,
                         style_table={'height': '500px', 'width': "400px", 'margin': 5}),
    dcc.Store('test_preds')])

CP_div = html.Div([
    html.Div(id="CP heading", children="No image", style={'font-size': 16}),
    dash_table.DataTable(id='CP', data=pval_data,
                         columns=[{"name": "label", "id": "label"}, {"name": "p-value", "id": "pValue"}],
                         fixed_rows={'headers': True},
                         style_cell={'textAlign': 'left', 'textOverflow': 'ellipsis',
                                     'maxWidth': '5em', 'overflow': 'hidden', 'font-size': 14},
                         style_cell_conditional=[
                             {'if': {'column_id': 'pValue'},
                              'width': '70px'},
                         ],
                         style_as_list_view=True,
                         style_table={'height': '500px', 'width': "400px", 'margin': 5}),
    dcc.Store('ps'),
    dcc.Store('p_vals'),
    dcc.Store('sorting_by_p_val')])

NCM = dcc.RadioItems(id='NCM', options=[{'label': 'NegProb', 'value': 'NegProb'}, {'label': 'Ratio', 'value': 'Ratio'}],
                     value='NegProb', style={'font-size': 16})

NCM_hist_output = dcc.Graph(id="NCM_hist_output", config={'displayModeBar': False},
                            figure=draw_ncm_histo(),
                            style={'width': 600})

NCM_div = html.Div(["NCM", NCM, NCM_hist_output],
                   style={'display': 'flex', 'flex-direction': 'column', 'align-items': 'center', 'font-size': 16})

output_div = html.Div([resnet50_div, NCM_div, CP_div],
                      style={'display': 'flex', 'flex-direction': 'row', 'justify-content': 'space-between', 'margin': 5})

preds_div = html.Div([eps, output_div],
                     style={'display': 'flex', 'flex-direction': 'column', 'align-items': 'center',
                            'width': '80%',
                            'justify-content': 'center', 
                            'border': '1px solid black', 'border-radius': '5px',
                            'margin': 10, 'background-color': 'white', 'padding': 10})

main_tab = html.Div(children=[
    heading, imagenet_div, preds_div
],
    style={'display': 'flex', 'flex-direction': 'column',
           'align-items': 'center', 'justify-content': 'space-between',
           'background-color': 'lightgrey'}
)

with open("ILSRVC_CP_Notes.html") as f:
    notes = f.read()
    
import dash_dangerously_set_inner_html
notes_tab = html.Div([
    dash_dangerously_set_inner_html.DangerouslySetInnerHTML(notes)])

app.layout = html.Div([dcc.Tabs([dcc.Tab(label='Demo', children=main_tab),
                                 dcc.Tab(label='Notes', children=notes_tab)],
                                style={'height':40, 'width':'10em', 'padding':5})])


In [None]:
# NOTE
# in Dash, callbacks should access variables outside the local scope
# These global variables are fixed, so accessing them should be OK
# ground_truth_ki_cal
# ground_truth_ki_test
# preds_cal
# ki_to_synset

@dcb.callback([Output("ILSRVC_image", "src"), Output('desc_div', 'children'), Output('test_preds', 'data')],
              [Input("pic_idx_sync", "data")])
def update_pic(i):
    if i is None:
        i = 1000

    if 0:  # for development environment
        img_file = os.path.join(".", "img", "ILSVRC2012_valsub_%08d.JPEG" % i)
        img_data = keras_image.load_img(img_file, target_size=(224, 224))
    else:
        url = """https://cml.rhul.ac.uk/people/ptocca/ILSVRC2012-CP/img/ILSVRC2012_valsub_%08d.JPEG""" % i
        img_data = getImage(url)

    output = io.BytesIO()
    img_data.save(output, format="PNG")
    img_encoded = 'data:image/png;base64, ' + base64.b64encode(output.getvalue()).decode("utf-8")

    # compute ResNet50 preds
    x = keras_image.img_to_array(img_data)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    test_preds = model.predict(x)

    ## update ground truth widget
    lbl = ki_to_synset[ground_truth_ki_test[i - 1]]

    return img_encoded, lbl, dumps(test_preds)


@dcb.callback([Output('ResNet50 heading', 'children'), Output("ResNet50", "data")],
              [Input("eps_sync", "data"), Input('test_preds', 'data')])
def update_ResNet(eps, test_preds_json):
    if test_preds_json is None:
        raise PreventUpdate
        
    test_preds = loads(test_preds_json)
    resNet50_set = zip(*(get_prob_sets(test_preds. reshape(1, -1), eps=eps)[0]))

    ## update resNet50 widget
    resnet50_heading = "ResNet50 (prob) at aggr prob %0.2f" % (1 - eps)
    resnet50_data = [dict(label=ki_to_synset[k], Prob="%0.3f" % pr) for k, pr in resNet50_set]
    return resnet50_heading, resnet50_data


def get_ncm_function(ncm_label):
    if ncm_label == 'NegProb':
        ncm_f = rev_score
    elif ncm_label == 'Ratio':
        ncm_f = ratio_max_to_hypothetical
    return ncm_f


@dcb.callback([Output("CP heading", "children"),
               Output('CP', 'data'),
               Output('p_vals','data'),
               Output('ps','data'),
               Output('sorting_by_p_val','data')],
              [Input("eps_sync", "data"),
               Input("NCM", "value"),
               Input('test_preds', 'data'),
               ])
def update_CP(eps, ncm_label, test_preds_json):

    if test_preds_json is None:
        raise PreventUpdate
    test_preds = loads(test_preds_json)

    ncm_f = get_ncm_function(ncm_label)

    p_vals = micp_pValues(preds_cal, test_preds, ground_truth_ki_cal, ncm=ncm_f)

    ps = np.argwhere(p_vals > eps)[:, 0].T
    ps_p_vals = p_vals[ps].flatten()
    sorting_by_p_val = np.argsort(ps_p_vals)[::-1]
    ps_synset = [dict(label=ki_to_synset[k], pValue="%0.3f" % p) for k, p in
                 zip(ps[sorting_by_p_val], ps_p_vals[sorting_by_p_val])]

    ## update CP widget
    CP_heading = "CP (p-val) pred set at significance level %0.2f" % eps
    CP_table = ps_synset

    return CP_heading, CP_table, dumps(p_vals), dumps(ps), dumps(sorting_by_p_val)


from statsmodels.distributions.empirical_distribution import ECDF


@dcb.callback([Output("NCM_hist_output", "figure")],
              [Input("CP", "data"),
               Input("NCM", "value"),
               Input("CP","active_cell"),
               Input('ps', 'data'),
               Input('sorting_by_p_val', 'data'),
               Input('p_vals', 'data')],
              [State('test_preds', 'data')])
def update_NCM_histo(CP, ncm_label, selected, ps_json, sorting_by_p_val_json, p_vals_json, test_preds_json):

    if CP is None:
        raise PreventUpdate

    idx = 0
    try:
        idx = selected['row']
    except:
        pass

    sorting_by_p_val = loads(sorting_by_p_val_json)
    test_preds = loads(test_preds_json)
    p_vals = loads(p_vals_json)
    ps = loads(ps_json)

    try:
        sel_p_val_label = ps[sorting_by_p_val[idx]]
    except IndexError:
        sel_p_val_label = np.argmax(p_vals)

    ncm_f = get_ncm_function(ncm_label)
    ncm_cal = ncm_f(preds_cal, sel_p_val_label)
    ncm_test = ncm_f(test_preds, sel_p_val_label)

    label_synset = ki_to_synset[sel_p_val_label]
    if len(label_synset) > 15:
        label_synset = label_synset[:15] + "..."

    ncm_cal_mondrian = ncm_cal[ground_truth_ki_cal == sel_p_val_label]

    ecdf_ncm = ECDF(np.r_[ncm_cal_mondrian, ncm_test], side='left')  # TODO: check number of dimensions?

    fig = draw_ncm_histo(ecdf_ncm, ncm_test, p_vals[sel_p_val_label, 0], label_synset)

    return fig


dcb.register(app)

In [None]:
app.run_server(mode='inline', width="100%", height="1250", debug=False)

In [None]:
# app._terminate_server_for_port("127.0.0.1",8050)
