In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import cifar10
import numpy as np

In [3]:
class LeNetBlock(keras.Model):
    def __init__(self, nf, input_shape=None, padding='valid', pool:bool=True):
        super(LeNetBlock, self).__init__()
        self.is_pool = pool
        if input_shape and input_shape is not None:
            self.conv = layers.Conv2D(filters=nf, kernel_size=(5, 5), input_shape=input_shape, padding=padding)
        else:
            self.conv = layers.Conv2D(filters=nf, kernel_size=(5, 5), padding=padding)
        if pool:
            self.pool = layers.AveragePooling2D()
        
    def call(self, x):
        x = self.conv(x)
        if self.is_pool: x = self.pool(x)
        return x

In [8]:
class LeNet(keras.Model):
    def __init__(self, input_shape):
        super(LeNet, self).__init__(name='LeNet')
        
        self.body = keras.Sequential([LeNetBlock(nf=6, input_shape=input_shape),
                                     LeNetBlock(nf=16),
                                     LeNetBlock(nf=120, pool=False)])
        self.flatten = layers.Flatten()
        self.dense = layers.Dense(84)
        self.out = layers.Dense(10, activation='softmax')
    def call(self, x):
        x = self.body(x)
        x = self.flatten(x)
        x = self.dense(x)
        x = self.out(x)
        return x

In [9]:
model = LeNet(input_shape=(32, 32, 3))

In [10]:
model.build((1, 32, 32, 3))

In [11]:
model.summary()

Model: "LeNet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_1 (Sequential)    multiple                  50992     
_________________________________________________________________
flatten_1 (Flatten)          multiple                  0         
_________________________________________________________________
dense_2 (Dense)              multiple                  10164     
_________________________________________________________________
dense_3 (Dense)              multiple                  850       
Total params: 62,006
Trainable params: 62,006
Non-trainable params: 0
_________________________________________________________________


# Data (cifar10)

In [12]:
(x_train, y_train), (x_valid, y_valid) = cifar10.load_data()

In [13]:
x_train.shape, y_train.shape, x_valid.shape, y_valid.shape

((50000, 32, 32, 3), (50000, 1), (10000, 32, 32, 3), (10000, 1))

In [14]:
x_train = x_train/255.0
x_valid = x_valid/255.0

In [15]:
y_train = y_train.squeeze()
y_valid = y_valid.squeeze()

In [None]:
np.unique(y_train)

In [16]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

In [17]:
model.fit(x_train, y_train, batch_size=16, validation_data=(x_valid, y_valid))

Train on 50000 samples, validate on 10000 samples


<tensorflow.python.keras.callbacks.History at 0x1e782494128>