In [None]:
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
import os
import joblib

In [None]:
from IPython.display import display, clear_output, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
try:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
except:
    # Invalid device or cannot modify virtual devices once initialized.
    print("No GPU?")
    clear_output()

In [None]:
print("Setting up pre-trained keras ResNet50 model")
model = ResNet50(weights='imagenet')
print("Model ready")
clear_output()

In [None]:
import h5py

In [None]:
import urllib.request
if not os.path.exists('val_preds.h5'):
    print("Downloading MICP calibration data (200MB) - be patient!")
    urllib.request.urlretrieve("https://cml.rhul.ac.uk/people/ptocca/ILSVRC2012-CP/val_preds.h5",
                               'val_preds.h5')
    clear_output()

In [None]:
with h5py.File('val_preds.h5','r') as f:
    preds = f['preds'][:]

In [None]:
def pValues(calibrationAlphas,testAlphas,randomized=False):
    testAlphas = np.array(testAlphas)
    sortedCalAlphas = np.sort(calibrationAlphas)
    
    leftPositions = np.searchsorted(sortedCalAlphas,testAlphas)
    
    if randomized:
        rightPositions = np.searchsorted(sortedCalAlphas,testAlphas,side='right')
        ties  = rightPositions-leftPositions+1   # ties in cal set plus the test alpha itself
        randomizedTies = ties * np.random.uniform(size=len(ties))
        return  (len(calibrationAlphas) - rightPositions + randomizedTies)/(len(calibrationAlphas)+1)
    else:
        return  (len(calibrationAlphas) - leftPositions + 1)/(len(calibrationAlphas)+1)


In [None]:
 # shape: (num_labels,num_test_objects)

In [None]:
def micp_pValues(preds,test_preds,y_cal):
    ncm_cal = -preds
    ncm_test = -test_preds
    
    micp_pValues = []

    for i in range(test_preds.shape[1]):
        p_i = pValues(ncm_cal[:,i][y_cal==i],ncm_test[:,i])
        micp_pValues.append(p_i)

    micp_pValues = np.array(micp_pValues)
    
    return micp_pValues

In [None]:
# ilsrvc_dir = "/mnt/d/Research/ILSVRC2012/"
ilsrvc_dir = "."

In [None]:
gt_file = os.path.join(ilsrvc_dir,"ILSVRC2012_validation_ground_truth.txt")
lbls_file = os.path.join(ilsrvc_dir,"labels.txt")

In [None]:
mapping_file = os.path.join(ilsrvc_dir,"ILSVRC2012_mapping.txt")

In [None]:
n_to_ki = {}
ki_to_synset = {}
with open(os.path.join(ilsrvc_dir,'synset_words.txt')) as f:
    for i,l in enumerate(f):
        n_to_ki[l.split()[0].strip()]=i
        ki_to_synset[i]=l[10:].split(",")[0].strip()

In [None]:
ii_to_n = ["Error"]
with open(mapping_file) as f:
    for l in f:
        lf = l.split()
        ii_to_n.append(lf[1].strip())

In [None]:
ii_to_ki = [0]+[n_to_ki[ii_to_n[i]] for i in range(1,1001)]

In [None]:
ground_truth_ki = np.zeros(50000,dtype=np.int16)

In [None]:
with open(gt_file) as f:
    for i,l in enumerate(f):
        ground_truth_ki[i] = ii_to_ki[int(l)]

In [None]:
from ipywidgets import IntSlider,Image,interactive,VBox,Textarea,Layout,FloatSlider,HBox,Label,Output
import io

In [None]:
import PIL.Image
import joblib

In [None]:
mem = joblib.Memory('/dev/shm/joblib',verbose=0)

@mem.cache
def getImage(url):
    img_data = PIL.Image.open(urllib.request.urlopen(url))
    if img_data.mode != 'RGB':
        img_data = img_data.convert('RGB')
    img_data = img_data.resize((224,224))
    return img_data    

In [None]:
def get_prob_sets(preds, eps):
    preds_as = np.argsort(-preds,axis=1)
    preds_cumul = np.cumsum(np.take_along_axis(preds,preds_as,axis=1),axis=1)

    set_masks = preds_cumul<1-eps

    sets = [(pr_as[m],pr[pr_as[m]]) for pr_as, m,pr in zip(preds_as,set_masks,preds)]
    return sets

In [None]:
def show_pic(i,eps):
    #img_file = os.path.join(ilsrvc_dir,"img","ILSVRC2012_val_%08d.JPEG"%i)
    # with open(img_file,"rb") as f:
    #    img.value = f.read()
    url="""https://cml.rhul.ac.uk/people/ptocca/ILSVRC2012-CP/img/ILSVRC2012_val_%08d.JPEG"""%i



    # img_data = keras_image.load_img(img_file, target_size=(224, 224))
    img_data = getImage(url)
    print("Image loaded")
    output = io.BytesIO()
    img_data.save(output,format="PNG")
    img.value = output.getvalue()


    # compute ResNet50 preds
    x = keras_image.img_to_array(img_data)
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)
    test_preds = model.predict(x)
    resNet50_set = zip(*(get_prob_sets(test_preds.reshape(1,-1), eps=eps)[0]))

    # compute CP
    p_vals = micp_pValues(preds,test_preds,ground_truth_ki)
    ps = np.argwhere(p_vals>eps)[:,0].T
    ps_p_vals = p_vals[ps].flatten()
    sorting_by_p_val = np.argsort(ps_p_vals)[::-1] 
    ps_synset = [ki_to_synset[k]+":%0.3f"%p for k,p in zip(ps[sorting_by_p_val],ps_p_vals[sorting_by_p_val])]
    print(ps_synset)
    
    # Do all widget updates
    ## update ground truth widget
    lbl = ki_to_synset[ground_truth_ki[i-1]]
    desc.children[1].value = lbl

    ## update resNet50 widget
    resnet50.children[0].value = "ResNet50 at cumul prob %0.2f"%(1-eps)
    resnet50.children[1].value = "\n".join(["%s: %0.3f"%(ki_to_synset[k],pr) for k,pr in resNet50_set])

    ## update CP widget
    CP.children[0].value = "Conformal Predictor at significance level %0.2f"%eps
    CP.children[1].value = "\n".join(ps_synset)
    return

In [None]:
desc = VBox([Label("True label"),
             Textarea("N/A",layout=Layout(height="100%"))])

img = Image(layout=Layout(height="400px",width="400px"))

resnet50 = VBox([Label("ResNet50 Probability (top 5)"),
                 Textarea(layout=Layout(height="100%"))])
CP = VBox([Label("Conformal Predictor at significance level 'eps'"),
           Textarea(layout=Layout(height="100%"))])
labels = HBox([desc,resnet50,CP],layout=Layout(height="200px",align_content="stretch"))

pic_idx = IntSlider(value=500, min=1,max=2000,continuous_update=False,layout=Layout(width="90%",align_items='center'))
eps_slider = FloatSlider(value=0.1,min=0.01,max=1.0,continuous_update=False,step=0.01,layout=Layout(width="90%",align_items='center'))
                    

gui = VBox([img,pic_idx,eps_slider,labels],layout=Layout(align_items='center'))

show_pic(1020,0.1)

#clear_output();


In [None]:
interactive(show_pic,i = pic_idx, eps = eps_slider)

gui