In [1]:
import os
import cv2
import torch
import numpy as np
import gradio as gr
import networkx as nx
import matplotlib.pyplot as plt

from PIL import Image
from nbdt.model import SoftNBDT
from utils import load_vgg16, plot_decision_tree

torch.manual_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'


classes = (
    'airplane', 'car', 'bird', 'cat', 'deer', 
    'dog', 'frog', 'horse', 'ship', 'truck'
)

# load checkpoint
state_dict = torch.load('./SoftNBDT_model.pt', map_location=torch.device('cpu'))
model_weights = state_dict['state_dict']

model = load_vgg16(num_classes=10).to(device)
model_nbdt = SoftNBDT(model=model, dataset='CIFAR10', hierarchy='induced-vgg16')
model_nbdt.load_state_dict(model_weights)
model_nbdt.eval()

  from .autonotebook import tqdm as notebook_tqdm
stty: 'standard input': Inappropriate ioctl for device
  warn(


not enough values to unpack (expected 2, got 0)


Using cache found in /home/vichshir/.cache/torch/hub/pytorch_vision_v0.10.0


SoftNBDT(
  (rules): SoftEmbeddedDecisionRules()
  (model): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   

In [2]:
def plot_decision_tree(decisions, nbdt):
    graph = nbdt.rules.tree.G
    
    labeldict = {}
    for wnid, node in nbdt.rules.tree.wnid_to_node.items():
        labeldict[wnid] = node.name

    color_map = []
    for node_name in list(labeldict.values()):
        if node_name in list(map(lambda x: x['name'], decisions)):
            color_map.append('royalblue')
        else:
            color_map.append('whitesmoke')

    pos = nx.nx_agraph.graphviz_layout(graph, prog="dot")
    nx.draw(graph, pos, 
            labels=labeldict, 
            with_labels=True, 
            node_color=color_map, 
            node_size=1400, 
            font_color='black',
            font_weight='bold',
            font_family='sans-serif',
            font_size=8,
            edge_color='lightgray')
    
    path_pos = [pos[n['node'].wnid] for n in decisions[:-1]]
    for idx, d in enumerate(decisions[1:]):
        x, y = path_pos[idx]
        prob = d['prob']
        plt.text(x+5, y-30, s=f'Prob. {prob:.0%}', 
                 horizontalalignment='center', 
                 fontsize='x-small', 
                 color='darkcyan', 
                 fontweight='bold')
    
    plt.savefig('./temp_img.png')


def show_image(img):
    # preprocessing
    img = cv2.resize(img, dsize=(32, 32), interpolation=cv2.INTER_LINEAR)
    img_torch = (torch.tensor(img).movedim(-1, 0) / 255).unsqueeze(0).to(device)
    
    # get predicted label
    pred_label = classes[torch.argmax(model_nbdt(img_torch), dim=1)]
    
    # generate decision plot
    plot_decision_tree(model_nbdt.forward_with_decisions(img_torch)[1][0], model_nbdt)
    fig = Image.open('./temp_img.png')
    fig = np.asarray(fig)
    os.remove('./temp_img.png')
    
    return pred_label, fig


app = gr.Interface(
    fn=show_image,
    inputs=gr.Image(show_label=False),
    outputs=[
        gr.Label(label='Predicted Class'),
        gr.Image(label='Why?')
    ],
)

app.launch()

Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


