本文将详细描述模型保存过程的使用。

## 保存参数
有两种格式：

- [HDF5](https://zhuanlan.zhihu.com/p/104145585)格式：保存成一个文件
- TensorFlow格式

### HDF5格式

 首先训练模型。


![image-20201112175426582](images/image-20201112175426582.png)

然后保存模型参数，然后加载回来。

In [1]:
import os
import tensorflow as tf
from tensorflow import keras

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(device=gpus[0], enable=True)
tf.config.experimental.set_virtual_device_configuration(
  gpus[0],
  [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)]
)

def get_model():
    model = keras.models.Sequential()
    model.add(keras.layers.Flatten(input_shape=(28, 28)))
    model.add(keras.layers.Dense(128, activation='relu'))
    model.add(keras.layers.Dropout(0.2))
    model.add(keras.layers.Dense(10))    
    model.compile(
      optimizer='adam',
      loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.metrics.SparseCategoricalAccuracy()],
      # metrics=['accuracy'], 当采用这种方式时，load_model时，accuracy计算错误。
    )
    return model

def get_mnist():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0
    x_train = x_train[:1024]
    y_train = y_train[:1024]
    x_test = x_test[:512]
    y_test = y_test[:512]    
    return x_train, y_train, x_test, y_test

def evaulate(model, x, y):
    loss, accuracy = model.evaluate(x, y, batch_size=256, verbose=0)
    print('loss is {:.3f}, accuracy is {:.3f}'.format(loss, accuracy))

x_train, y_train, x_test, y_test = get_mnist()

model = get_model()
print('-'*25 + 'Before Training' + '-'*25)
evaulate(model, x_test, y_test)

model.fit(
    x_train,
    y_train,
    batch_size=256,
    epochs=10,
    verbose=False,
    validation_split=0.5,
)

print('-'*25 + 'After Training' + '-'*25)
evaulate(model, x_test, y_test)

-------------------------Before Training-------------------------
loss is 2.492, accuracy is 0.012
-------------------------After Training-------------------------
loss is 0.797, accuracy is 0.723


In [2]:
# 保存参数
checkpoint_dir = "./checkpoints/h5_weights"
checkpoint_path = os.path.join(checkpoint_dir, "weights.h5")
if not os.path.exists(checkpoint_dir):  os.makedirs(checkpoint_dir)
model.save_weights(checkpoint_path)
! ls -l {checkpoint_dir}

# 加载参数
print('-'*50)
model = get_model() 
model.load_weights(checkpoint_path)
evaulate(model, x_test, y_test)

total 412
-rw-r--r--. 1 root root 421408 Nov 16 10:14 weights.h5
--------------------------------------------------
loss is 0.797, accuracy is 0.723


![image-20201112175611860](images/image-20201112175611860.png)

### SavedModel格式

同样也是保存模型参数，然后加载回来。

In [51]:
# 保存参数
checkpoint_dir = "./checkpoints/tf_weights"
checkpoint_path = os.path.join(checkpoint_dir, "weights")
if not os.path.exists(checkpoint_dir):  os.makedirs(checkpoint_dir)
model.save_weights(checkpoint_path)
! ls -l {checkpoint_dir}

# 加载参数
print('-'*50)
model = get_model()  # 创建模型
model.load_weights(checkpoint_path)
evaulate(model, x_test, y_test)

total 408
-rw-r--r--. 1 root root     71 Nov 16 09:35 checkpoint
-rw-r--r--. 1 root root 407625 Nov 16 09:35 weights.data-00000-of-00001
-rw-r--r--. 1 root root    401 Nov 16 09:35 weights.index
--------------------------------------------------
loss is 0.805, accuracy is 0.725


![image-20201112175705517](images/image-20201112175705517.png)

上面的代码生成了三个文件：

- checkpoint：模型保存的路径。其内容如下。

  ~~~~shell
  model_checkpoint_path: "weights"
  all_model_checkpoint_paths: "weights"
  ~~~~

- weights.index：参数的索引文件。如果在分布式情况下，参数可能会从不同的分区（shards）得到。

- weights.data-00000-of-00001：参数文件。如果有多个分区，将会有多个文件。

## 保存整个模型

总体思路保存参数基本相同。也有两种格式：

- [HDF5](https://zhuanlan.zhihu.com/p/104145585)格式：保存成一个文件
- TensorFlow格式

### HDF5格式

In [41]:
! rm -rf  ./checkpoints/h5_model
! ls -lh ./checkpoints/h5_model

ls: cannot access './checkpoints/h5_model': No such file or directory


In [52]:
# 保存参数
checkpoint_dir = "./checkpoints/h5_model"
checkpoint_path = os.path.join(checkpoint_dir, "weights.h5")
if not os.path.exists(checkpoint_dir):  os.makedirs(checkpoint_dir)
model.save(checkpoint_path)
! ls -l {checkpoint_dir}
evaulate(model, x_test, y_test)

# 加载参数
print('-'*50) 
model = tf.keras.models.load_model(checkpoint_path)
evaulate(model, x_test, y_test)

total 416
-rw-r--r--. 1 root root 422808 Nov 16 09:36 weights.h5
loss is 0.805, accuracy is 0.725
--------------------------------------------------
loss is 0.805, accuracy is 0.725


In [43]:
! ls -l {checkpoint_path}

-rw-r--r--. 1 root root 422808 Nov 16 08:54 ./checkpoints/h5_model/weights.h5


### SavedModel格式

In [54]:
def tree(path, intent=""):  
    if len(intent)>0:
        print(intent + os.path.basename(path))
    else:
        print(intent + path)
    if os.path.isdir(path):
        for child in os.listdir(path):  
            child_path = os.path.join(path, child) 
            tree(child_path, intent + "    ")
            

# 保存参数
checkpoint_dir = "./checkpoints/tf_model"
checkpoint_path = os.path.join(checkpoint_dir, "weights")
if not os.path.exists(checkpoint_dir):  os.makedirs(checkpoint_dir)
model.save(checkpoint_path)
evaulate(model, x_test, y_test)
print('-'*50)
tree(checkpoint_path)

# 加载参数
print('-'*50)
model = get_model()  # 创建模型
model = tf.keras.models.load_model(checkpoint_path)
evaulate(model, x_test, y_test)

INFO:tensorflow:Assets written to: ./checkpoints/tf_model/weights/assets
loss is 0.805, accuracy is 0.725
--------------------------------------------------
./checkpoints/tf_model/weights
    variables
        variables.data-00000-of-00001
        variables.index
    assets
    saved_model.pb
--------------------------------------------------
loss is 0.805, accuracy is 0.725


In [46]:
! ls -l {checkpoint_path}/assets

total 0


In [47]:
! ls -l {checkpoint_path}/variables

total 408
-rw-r--r--. 1 root root 409700 Nov 16 08:54 variables.data-00000-of-00001
-rw-r--r--. 1 root root    623 Nov 16 08:54 variables.index


In [48]:
import os 

def tree(path, intent=""):  
    if len(intent)>0:
        print(intent + os.path.basename(path))
    else:
        print(intent + path)
    if os.path.isdir(path):
        for child in os.listdir(path):  
            child_path = os.path.join(path, child) 
            tree(child_path, intent + "    ")
            

tree(checkpoint_path)

./checkpoints/tf_model/weights
    variables
        variables.data-00000-of-00001
        variables.index
    assets
    saved_model.pb


## Save Checkpoints

## 保存自定义对象



## 参考

- [Save and load models](https://www.tensorflow.org/tutorials/keras/save_and_load#what_are_these_files)
- [Using the SavedModel format](https://www.tensorflow.org/guide/saved_model)
- [Making new Layers and Models via subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models)

In [4]:
def add(a, b=3):
    print(a + b)

add(9)

12


In [12]:
a = {'a1':3, 'a2':4}
b = (3, 4)
c = [3, 4]

print(type(a), a)
print(type(b), b)
print(type(c), c)

<class 'dict'> {'a1': 3, 'a2': 4}
<class 'tuple'> (3, 4)
<class 'list'> [3, 4]


In [64]:
def add(a, b=3, *args):
    print('='*50)
    print('a', type(a), a)
    print('b', type(a), b)
    print('args', type(args), args)  
        

add(1, 11, 12, 13)

lst = [11, 12, 13]
add(1, *lst)        # 注意11被分给了b，这种情况下缺省参数失去了意义
add(1, *lst, 14, 15)  

#add(1, b=4, *lst)  # 会报错： add() got multiple values for argument 'b'

a <class 'int'> 1
b <class 'int'> 11
args <class 'tuple'> (12, 13)
a <class 'int'> 1
b <class 'int'> 11
args <class 'tuple'> (12, 13)
a <class 'int'> 1
b <class 'int'> 11
args <class 'tuple'> (12, 13, 14, 15)


In [65]:
def add(a, b=3, **kwargs):
    print('='*50)
    print('a', type(a), a)
    print('b', type(a), b)
    print('kwargs', type(kwargs), kwargs)  


add(1, c=21, d=22, e=23)
add(1, 4, **dct)
add(1, b=4, **dct)
add(1, **dct)
add(1, **dct, b=4)
add(1, **dct, f=4)
add(1, f=4, **dct)
add(1, f=4, **dct, b=4, g=4)

a <class 'int'> 1
b <class 'int'> 3
kwargs <class 'dict'> {'c': 21, 'd': 22, 'e': 23}
a <class 'int'> 1
b <class 'int'> 4
kwargs <class 'dict'> {'c': 21, 'd': 22, 'e': 23}
a <class 'int'> 1
b <class 'int'> 4
kwargs <class 'dict'> {'c': 21, 'd': 22, 'e': 23}
a <class 'int'> 1
b <class 'int'> 3
kwargs <class 'dict'> {'c': 21, 'd': 22, 'e': 23}
a <class 'int'> 1
b <class 'int'> 4
kwargs <class 'dict'> {'c': 21, 'd': 22, 'e': 23}
a <class 'int'> 1
b <class 'int'> 3
kwargs <class 'dict'> {'c': 21, 'd': 22, 'e': 23, 'f': 4}
a <class 'int'> 1
b <class 'int'> 3
kwargs <class 'dict'> {'f': 4, 'c': 21, 'd': 22, 'e': 23}
a <class 'int'> 1
b <class 'int'> 4
kwargs <class 'dict'> {'f': 4, 'c': 21, 'd': 22, 'e': 23, 'g': 4}
