# View a CNN for CIFAR-10 classification

In [1]:
import sys
from tensorflow import keras
import plotly.graph_objects as go
from ipywidgets import widgets

In [2]:
#%load_ext autoreload
#%autoreload 2
sys.path.append('../')

In [6]:
from dnnviewerlib.Grapher import Grapher
import dnnviewerlib.layers
import dnnviewerlib.bridge.tensorflow as tf_bridge

# Load model

In [4]:
model0 = keras.models.load_model('models/CIFAR-10_CNN5.h5')
model0.summary()

cifar10_classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv_0 (Conv2D)              (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32)        0         
_________________________________________________________________
conv_1 (Conv2D)              (None, 13, 13, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
conv_2 (Conv2D)              (None, 4, 4, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1024)              0

# Grapher test

In [7]:
fig_widget = go.FigureWidget()
fig_widget.update_layout(margin=dict(l=10, r=10, b=10, t=10))

grapher = Grapher()

# Create all other layers from the Keras Sequential model
tf_bridge.keras_extract_sequential_network(grapher, model0, ['red', 'green', 'blue'], cifar10_classes)

topn = widgets.IntSlider(
    value=3.0,
    min=1.0,
    max=4.0,
    step=1.0,
    description='Top N:',
    continuous_update=False
)

grapher.plot_layers(fig_widget)
grapher.plot_topn_connections(fig_widget, topn.value, grapher.layers[2], 10)

def set_topn(change):
    with fig_widget.batch_update():
        grapher.plot_topn_connections(topn.value)
      
topn.observe(set_topn, names='value')

fig_widget.update_layout(barmode='overlay')
top_bar = widgets.HBox(children=[topn])
main_widget = widgets.VBox([top_bar, fig_widget])

main_widget

Ignored max_pooling2d
Ignored max_pooling2d_1
Ignored dropout


VBox(children=(HBox(children=(IntSlider(value=3, continuous_update=False, description='Top N:', max=4, min=1),â€¦

In [None]:
l = model0.layers[7]
l

In [None]:
len(l.get_weights())