In [1]:
import tensorflow as tf
import numpy as np

## 1.张量方式(save_weights/load_weights)

网络状态体现在网络的结构以及网络层的参数张量数据上，仅保存张量数据是最轻量的方式：

In [2]:
from tensorflow import keras
from tensorflow.keras import layers,datasets
network=keras.Sequential([
    layers.Dense(256,activation='relu'),
    layers.Dense(128,activation='relu'),
    layers.Dense(64,activation='relu'),
    layers.Dense(32,activation='relu'),
    layers.Dense(10)
])
network.build(input_shape=(4,28*28))

from tensorflow.keras import optimizers,losses

# 采用Adam优化器，学习率为0.01；采用交叉熵损失函数，包含Softmax
network.compile(optimizer=optimizers.Adam(lr=0.01),
                loss=losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'] # 设置测量指标为准确率
                )

def preprocess(x,y):
    # 调用此函数会自动传入x，y
    # 标准化到0~1
    x=tf.cast(x,dtype=tf.float32)/255.
    x=tf.reshape(x,[-1,28*28]) # 打平
    y=tf.cast(y,dtype=tf.int32) # 转换成整型张量
    y=tf.one_hot(y,depth=10) # 进行one-hot编码
    return x,y

def load_data():
    # 加载MNIST
    (x,y),(x_val,y_val)=datasets.mnist.load_data()
    batchsz=512
    # 构建数据集对象
    train_dataset=tf.data.Dataset.from_tensor_slices((x,y))
    train_dataset=train_dataset.shuffle(1000)
    #批量训练
    train_dataset=train_dataset.batch(batchsz)
    train_dataset=train_dataset.map(preprocess)
    train_dataset=train_dataset.repeat(20)

    # 加载验证/测试集
    val_dataset=tf.data.Dataset.from_tensor_slices((x_val,y_val))
    val_dataset=val_dataset.shuffle(1000).batch(batchsz).map(preprocess)
    return train_dataset,val_dataset

train_db,val_db=load_data()
history=network.fit(train_db,epochs=1,validation_data=val_db,validation_freq=2)

# 保存模型参数到文件上
network.save_weights('weights.ckpt')
print('saved weights.')
# del network

# 重新创建相同的网络结构
network1=keras.Sequential([
    layers.Dense(256,activation='relu'),
    layers.Dense(128,activation='relu'),
    layers.Dense(64,activation='relu'),
    layers.Dense(32,activation='relu'),
    layers.Dense(10)
])
network1.compile(
    optimizer=optimizers.Adam(lr=0.01),
    loss=losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
network1.build(input_shape=(4,28*28))

# 从参数文件中读取数据并写入当前网络
print('loading weights……')
network1.load_weights('weights.ckpt')
print('loaded weights!')

saved weights.
loading weights……
loaded weights!


## 2.网络方式(save/load_model)

保存网络的结构以及参数张量数据，即整个模型，如下：

In [3]:
network1.save('model.h5',include_optimizer=True)
print('saved total model.')

# 从文件恢复整个网络
print('loading model……')
network2=keras.models.load_model('model.h5') # AttributeError: 'str' object has no attribute 'decode'
print('loaded total!')
network2.summary()

x=tf.random.uniform([1,28*28])
assert np.allclose(network1.predict(x),network2.predict(x))

saved total model.
loading model……
loaded total!
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_5 (Dense)              (4, 256)                  200960    
_________________________________________________________________
dense_6 (Dense)              (4, 128)                  32896     
_________________________________________________________________
dense_7 (Dense)              (4, 64)                   8256      
_________________________________________________________________
dense_8 (Dense)              (4, 32)                   2080      
_________________________________________________________________
dense_9 (Dense)              (4, 10)                   330       
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________


## 3.SavedModel方式

该方式具有平台无关性

In [4]:
# 保存模型结构与参数到文件
tf.saved_model.save(network,'saved_model.sm')
print('model has been saved!')

# 从文件恢复网络结构与参数
network3=tf.saved_model.load('saved_model.sm')

# 准确率计算器
acc_meter=tf.metrics.CategoricalAccuracy()
for x,y in val_db:
    pred=network3(x) # 前向计算
    acc_meter.update_state(y_true=y,y_pred=pred)

# 打印准确率
print(f'Test Acc: {acc_meter.result()}')


INFO:tensorflow:Assets written to: saved_model.sm/assets
model has been saved!
Test Acc: 0.9747999906539917


In [None]:
import os
pid=os.getpid()
!kill -9 $pid
