In [7]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, optimizers, layers, metrics, Sequential

In [8]:
# 实现自定义的layer，需要继承自layers.Layer
class MyDense(layers.Layer):
    
    def __init__(self, input_dim, output_dim):
        super(MyDense, self).__init__()
        # 初始化w和b，使用Layer已经实现的add_weight方法初始化w和b参数
        self.kernel = self.add_weight('w', [input_dim, output_dim])
        self.bias = self.add_weight('b', [output_dim])
    # 调用mydense(x) ==> mydense.__call__(x) ==> mydense.call(x)
    def call(self, inputs, training=None):
        x = inputs @ self.kernel + self.bias  # 定义每层的线性层的计算
        return x

In [9]:
# 实现自定义的Model，继承自keras.Model
class MyModel(keras.Model):
    
    def __init__(self):
        super(MyModel, self).__init__()
        # 定义model的各个层
        self.fc1 = MyDense(32*32*3, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)
    # 调用mymodel(x) ==> mymodel.__call__(x) ==> mymodel.call(x)
    # 定义网络前向传播的方式
    def call(self, inputs, training=None):
        
        x = tf.reshape(inputs, [-1, 32*32*3])
        # [b, 32*32*3] => [b, 256]
        x = tf.nn.relu(self.fc1(x))
        # [b, 256] => [b, 128]
        x = tf.nn.relu(self.fc2(x))
        # [b, 128] => [b, 64]
        x = tf.nn.relu(self.fc3(x))
        # [b, 64] => [b, 32]
        x = tf.nn.relu(self.fc4(x))
        # [b, 32] => [b, 10]
        x = self.fc5(x)
        
        return x

In [10]:
# 数据预处理
def process(x, y):
    
    x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1  # 对于cifar10的数据来说，把数据归一化到-1到1之间可以得到更好的训练效果
    y = tf.cast(y, dtype=tf.int32)
    y = tf.squeeze(y)
    y = tf.one_hot(y, depth=10)
    
    return x, y

In [11]:
(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))

train_db = train_db.map(process).shuffle(10000).batch(128)
test_db = test_db.map(process).batch(128)

In [20]:
# 创建自定义的Model
network = MyModel()
# 组织网络的组建，由于继承自keras.Model，所以可以使用compile方法
network.compile(optimizer=optimizers.Adam(lr=1e-3),
                         loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                         metrics=['accuracy'])

In [21]:
# 训练模型
# 对于训练结果，可以增加epoch提高性能，但是天花板就在那里
# 此外还可以增加每层的参数量提高模型的复杂度，但是容易出现过拟合的现象
# 还可以调整x的归一化形式，在process函数中调整
network.fit(train_db, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f7b17140048>

In [22]:
# 评估模型
network.evaluate(test_db)



[1.3888207191153417, 0.5127]

In [15]:
# 保存所有的训练参数到checkpoint文件中
network.save_weights('ckpt/weights.ckpt')
# 删除model
del network

In [16]:
# 使用这种方法保存的model，需要重建一个一模一样的network，然后把参数load进来
network = MyModel()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
                         loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                         metrics=['accuracy'])

network.load_weights('ckpt/weights.ckpt')
# eval的结果和刚才的model一样
network.evaluate(test_db)



[1.3963348352456395, 0.507]

In [27]:
# 或者使用这个方法，直接保存整个网络到一个通用文件中，可用于在其他地方部署，使用tensor serving可以加载这个model
tf.saved_model.save(network, 'saved_mode/')

INFO:tensorflow:Assets written to: saved_mode/assets


In [43]:
# 加载model
imported = tf.saved_model.load('saved_mode/')
# 返回一个可被调用的对象
predictor = imported.signatures['serving_default']

In [44]:
# 模拟一组数据进行预测
predictor(tf.ones([32,32,3]))

{'output_1': <tf.Tensor: id=35111, shape=(1, 10), dtype=float32, numpy=
 array([[ 1.7184494 , -7.043526  , -0.38373035,  0.98874104, -2.973388  ,
          0.6442143 , -5.308611  , -2.715285  , -1.3639593 , -5.4435244 ]],
       dtype=float32)>}

In [50]:
# 如果是keras.Model的子类的化，是不能直接保存为h5类型的
network.save('saved_model.h5')

NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using `save_weights`.