In [1]:
import math
from tensorflow.keras import datasets

In [16]:
EPOCHS = 50
BATCH_SIZE = 8
NUM_CLASSES = 10
image_height = 32
image_width = 32
channels = 3
save_model_dir = "saved_model/model"
dataset_dir = "dataset/"
train_dir = dataset_dir + "train"
valid_dir = dataset_dir + "valid"
test_dir = dataset_dir + "test"

# choose a network
# model = "resnet18"
# model = "resnet34"
model = "resnet50"
# model = "resnet101"
# model = "resnet152"

In [30]:
import tensorflow as tf

class BasicBlock(tf.keras.layers.Layer):

    def __init__(self, filter_num, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=(3, 3),
                                            strides=stride,
                                            padding="same")
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=(3, 3),
                                            strides=1,
                                            padding="same")
        self.bn2 = tf.keras.layers.BatchNormalization()
        if stride != 1:
            self.downsample = tf.keras.Sequential()
            self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num,
                                                       kernel_size=(1, 1),
                                                       strides=stride))
            self.downsample.add(tf.keras.layers.BatchNormalization())
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=None, **kwargs):
        residual = self.downsample(inputs)

        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        x = self.bn2(x, training=training)

        output = tf.nn.relu(tf.keras.layers.add([residual, x]))

        return output


class PreActBlock(tf.keras.layers.Layer):

    def __init__(self, filter_num, stride=1):
        super(PreActBlock, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=(3, 3),
                                            strides=stride,
                                            padding="same")
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2D(filters=filter_num,
                                            kernel_size=(3, 3),
                                            strides=1,
                                            padding="same")
        self.bn2 = tf.keras.layers.BatchNormalization()
        if stride != 1:
            self.downsample = tf.keras.Sequential()
            self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num,
                                                       kernel_size=(1, 1),
                                                       strides=stride))
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=None, **kwargs):
        x = tf.nn.relu(self.bn1(inputs, training=training))
        residual = self.downsample(x)

        x = self.conv1(x)
        x = self.bn2(x, training=training)
        x = tf.nn.relu(x)
        x = self.conv2(x)

        output = tf.nn.relu(tf.keras.layers.add([residual, x]))

        return output


def make_basic_block_layer(filter_num, blocks, stride=1):
    res_block = tf.keras.Sequential()
    res_block.add(BasicBlock(filter_num, stride=stride))

    for _ in range(1, blocks):
        res_block.add(BasicBlock(filter_num, stride=1))

    return res_block


def make_preact_block_layer(filter_num, blocks, stride=1):
    res_block = tf.keras.Sequential()
    res_block.add(PreActBlock(filter_num, stride=stride))

    for _ in range(1, blocks):
        res_block.add(PreActBlock(filter_num, stride=1))

    return res_block

class ResNetTypeI(tf.keras.Model):
    def __init__(self, layer_params):
        super(ResNetTypeI, self).__init__()

        self.conv1 = tf.keras.layers.Conv2D(filters=64,
                                            kernel_size=(7, 7),
                                            strides=2,
                                            padding="same")
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size=(3, 3),
                                               strides=2,
                                               padding="same")

        self.layer1 = make_basic_block_layer(filter_num=64,
                                             blocks=layer_params[0])
        self.layer2 = make_basic_block_layer(filter_num=128,
                                             blocks=layer_params[1],
                                             stride=2)
        self.layer3 = make_basic_block_layer(filter_num=256,
                                             blocks=layer_params[2],
                                             stride=2)
        self.layer4 = make_basic_block_layer(filter_num=512,
                                             blocks=layer_params[3],
                                             stride=2)

        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(units=NUM_CLASSES, activation=tf.keras.activations.softmax)

    def call(self, inputs, training=None, mask=None):
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        x = self.pool1(x)
        x = self.layer1(x, training=training)
        x = self.layer2(x, training=training)
        x = self.layer3(x, training=training)
        x = self.layer4(x, training=training)
        x = self.avgpool(x)
        output = self.fc(x)

        return output


class PreActResNet(tf.keras.Model):
    def __init__(self, layer_params):
        super(PreActResNet, self).__init__()

        self.conv1 = tf.keras.layers.Conv2D(filters=64,
                                            kernel_size=(7, 7),
                                            strides=2,
                                            padding="same")
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size=(3, 3),
                                               strides=2,
                                               padding="same")

        self.layer1 = make_preact_block_layer(filter_num=64,
                                             blocks=layer_params[0])
        self.layer2 = make_preact_block_layer(filter_num=128,
                                             blocks=layer_params[1],
                                             stride=2)
        self.layer3 = make_preact_block_layer(filter_num=256,
                                             blocks=layer_params[2],
                                             stride=2)
        self.layer4 = make_preact_block_layer(filter_num=512,
                                             blocks=layer_params[3],
                                             stride=2)

        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(units=NUM_CLASSES, activation=tf.keras.activations.softmax)

    def call(self, inputs, training=None, mask=None):
        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        x = self.pool1(x)
        x = self.layer1(x, training=training)
        x = self.layer2(x, training=training)
        x = self.layer3(x, training=training)
        x = self.layer4(x, training=training)
        x = self.avgpool(x)
        output = self.fc(x)

        return output


In [31]:
def resnet_18():
    return ResNetTypeI(layer_params=[2, 2, 2, 2])

def preact_resnet_18():
    return PreActResNet(layer_params=[2, 2, 2, 2])

def get_model():
    model = preact_resnet_18()
    model.build(input_shape=(None, image_height, image_width, channels))
    model.summary()
    return model

model = get_model()


Model: "pre_act_res_net_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_61 (Conv2D)           multiple                  9472      
_________________________________________________________________
batch_normalization_58 (Batc multiple                  256       
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 multiple                  0         
_________________________________________________________________
sequential_22 (Sequential)   (None, 8, 8, 64)          148736    
_________________________________________________________________
sequential_23 (Sequential)   (None, 4, 4, 128)         526720    
_________________________________________________________________
sequential_25 (Sequential)   (None, 2, 2, 256)         2102016   
_________________________________________________________________
sequential_27 (Sequential)   (None, 1, 1, 512)   

In [32]:
@tf.function
def train_step(images, labels):

    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_object(y_true=labels, y_pred=predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()


train_images = tf.convert_to_tensor(train_images, dtype=tf.float32)
train_labels = tf.convert_to_tensor(train_labels, dtype=tf.float32)
train_dataset = tf.data.Dataset.from_tensor_slices((train_images/255, train_labels)).batch(batch_size=BATCH_SIZE)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adadelta()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

for epoch in range(3):
        train_loss.reset_states()
        train_accuracy.reset_states()
        # valid_loss.reset_states()
        # valid_accuracy.reset_states()
        step = 0
        for images, labels in train_dataset:            
            step += 1
            train_step(images, labels)
            print("Epoch: {}/{}, step: {}, loss: {:.5f}, accuracy: {:.5f}".format(epoch + 1,
                                                                                     10,
                                                                                     step,
                                                                                    #  math.ceil(train_count / config.BATCH_SIZE),
                                                                                     train_loss.result(),
                                                                                     train_accuracy.result()))



        print("Epoch: {}/{}, train loss: {:.5f}, train accuracy: {:.5f}, ".format(epoch + 1,
                                                                  10,
                                                                  train_loss.result(),
                                                                  train_accuracy.result()
                                                                 ))


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: 1/10, step: 1116, loss: 2.64442, accuracy: 0.11940
Epoch: 1/10, step: 1117, loss: 2.64455, accuracy: 0.11929
Epoch: 1/10, step: 1118, loss: 2.64493, accuracy: 0.11930
Epoch: 1/10, step: 1119, loss: 2.64490, accuracy: 0.11919
Epoch: 1/10, step: 1120, loss: 2.64508, accuracy: 0.11920
Epoch: 1/10, step: 1121, loss: 2.64489, accuracy: 0.11920
Epoch: 1/10, step: 1122, loss: 2.64465, accuracy: 0.11921
Epoch: 1/10, step: 1123, loss: 2.64488, accuracy: 0.11921
Epoch: 1/10, step: 1124, loss: 2.64461, accuracy: 0.11911
Epoch: 1/10, step: 1125, loss: 2.64463, accuracy: 0.11900
Epoch: 1/10, step: 1126, loss: 2.64475, accuracy: 0.11912
Epoch: 1/10, step: 1127, loss: 2.64497, accuracy: 0.11912
Epoch: 1/10, step: 1128, loss: 2.64545, accuracy: 0.11902
Epoch: 1/10, step: 1129, loss: 2.64585, accuracy: 0.11891
Epoch: 1/10, step: 1130, loss: 2.64567, accuracy: 0.11892
Epoch: 1/10, step: 1131, loss: 2.64590, accuracy: 0.11892
Epoch: 

KeyboardInterrupt: ignored