# 案例四（嵌套自定义）

In [None]:
from tensorflow.keras import layers
from tensorflow.keras import activations
from tensorflow.keras.layers import *
from tensorflow.keras.utils import plot_model, CustomObjectScope
from tensorflow.keras import utils
from tensorflow.keras.datasets import cifar10, mnist
from tensorflow.keras.models import load_model, save_model, Sequential, Model
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow import keras
import numpy as np

import os

## 定义网络一

In [None]:
class Linear(layers.Layer):
    def __init__(self, units=32, input_dim=32, **kwargs):
        super(Linear, self).__init__(**kwargs)
        self.units = units
        
    def build(self, input_shape):
        self.w = self.add_weight(shape=(input_shape[-1], self.units), initializer="random_normal", trainable=True,)
        self.b = self.add_weight(shape=(self.units,), initializer="random_normal", trainable=True)
        
    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b
    
    def get_config(self):
        config = super(Linear, self).get_config()
        config.update({"units":self.units})
        return config

## 定义网络二

In [None]:
class MLPBlock(layers.Layer):
    def __init__(self, **kwargs):
        super(MLPBlock, self).__init__(**kwargs)
        # 在__init__函数中不要使用自定义的类
        
    def call(self, inputs):
        x = Linear(32)(inputs)
        x = tf.nn.relu(x)
        x = Linear(32)(x)
        x = tf.nn.relu(x)
        x = Linear(1)(x)
        return x
    
    def get_config(self):
        config = super(MLPBlock, self).get_config()
        return config

## 构建模型

In [None]:
model = Sequential([
    Input(shape=(3, 64)),
    MLPBlock()
])

## 保存模型

In [None]:
model.save('temp.h5')

## 加载模型

In [None]:
new_model = load_model('temp.h5', custom_objects={"MLPBlock":MLPBlock, "Linear":Linear})