In [1]:
import tensorflow as tf

from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import Conv2D, AveragePooling2D, ZeroPadding2D
from tensorflow.keras.layers import Flatten, Dense

In [2]:
class FeatureExtractor(Layer):
  """LeNet 1, 4, 5에 공통으로 입력되는 layer들을 sub-classing으로 하나의 object로 만듦"""
  def __init__(self, filter1, filter2):
    super(FeatureExtractor, self).__init__()
    
    self.conv1 = Conv2D(filters=filter1, kernel_size=5, padding='valid', strides=1, activation='tanh')
    self.conv1_pool = AveragePooling2D(pool_size=2, strides=2)
    self.conv2 = Conv2D(filters=filter2, kernel_size=5, padding='valid', strides=1, activation='tanh')
    self.conv2_pool = AveragePooling2D(pool_size=2, strides=2)
  
  def call(self, x):
    x = self.conv1(x)
    x = self.conv1_pool(x)
    x = self.conv2(x)
    x = self.conv2_pool(x)
    return x

In [3]:
class LeNet1(Model):
  def __init__(self):
    super(LeNet1, self).__init__()
    
    # feature extractor
    self.feature_extractor = FeatureExtractor(4, 12)
    
    # classifier
    self.classifier = Sequential()
    self.classifier.add(Flatten())
    self.classifier.add(Dense(units=10, activation='softmax'))
    
  def call(self, x):
    x = self.feature_extractor(x)
    x = self.classifier(x)
    return x
  
model = LeNet1()
model.build(input_shape=(None, 28, 28, 1))
model.summary()

Model: "le_net1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
feature_extractor (FeatureEx multiple                  1316      
_________________________________________________________________
sequential (Sequential)      (None, 10)                1930      
Total params: 3,246
Trainable params: 3,246
Non-trainable params: 0
_________________________________________________________________


In [4]:
class LeNet4(Model):
  def __init__(self):
    super(LeNet4, self).__init__()
    
    # feature extractor
    self.zero_padding = ZeroPadding2D(padding=2)
    self.feature_extractor = FeatureExtractor(4, 16)
    
    # classifier
    self.classifier = Sequential()
    self.classifier.add(Flatten())
    self.classifier.add(Dense(units=120, activation='tanh'))
    self.classifier.add(Dense(units=10, activation='softmax'))
    
  def call(self, x):
    x = self.zero_padding(x)
    x = self.feature_extractor(x)
    x = self.classifier(x)
    return x
  
model = LeNet4()
model.build(input_shape=(None, 28, 28, 1))
model.summary()

Model: "le_net4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
zero_padding2d (ZeroPadding2 multiple                  0         
_________________________________________________________________
feature_extractor_1 (Feature multiple                  1720      
_________________________________________________________________
sequential_1 (Sequential)    (None, 10)                49330     
Total params: 51,050
Trainable params: 51,050
Non-trainable params: 0
_________________________________________________________________


In [5]:
class LeNet5(Model):
  def __init__(self):
    super(LeNet5, self).__init__()
    
    # feature extractor
    self.zero_padding = ZeroPadding2D(padding=2)
    self.feature_extractor = FeatureExtractor(6, 16)
    
    # classifier
    self.classifier = Sequential()
    self.classifier.add(Flatten())
    self.classifier.add(Dense(units=140, activation='tanh'))
    self.classifier.add(Dense(units=84, activation='tanh'))
    self.classifier.add(Dense(units=10, activation='softmax'))
    
  def call(self, x):
    x = self.zero_padding(x)
    x = self.feature_extractor(x)
    x = self.classifier(x)
    return x

model = LeNet5()
model.build(input_shape=(None, 28, 28, 1))
model.summary()

Model: "le_net5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
zero_padding2d_1 (ZeroPaddin multiple                  0         
_________________________________________________________________
feature_extractor_2 (Feature multiple                  2572      
_________________________________________________________________
sequential_2 (Sequential)    (None, 10)                68834     
Total params: 71,406
Trainable params: 71,406
Non-trainable params: 0
_________________________________________________________________
