# 案例二（自定义类中使用已有功能）

In [1]:
from tensorflow.keras import layers
from tensorflow.keras import activations
from tensorflow.keras.layers import *
from tensorflow.keras.utils import plot_model, CustomObjectScope
from tensorflow.keras import utils
from tensorflow.keras.datasets import cifar10, mnist
from tensorflow.keras.models import load_model, save_model, Sequential, Model
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow import keras
import numpy as np

import os

## 设置参数

In [2]:
units = 32
timesteps = 10
input_dim = 5
batch_size = 16

## 定义层

In [3]:
class CustomRNN(layers.Layer):
    def __init__(self, units, **kwargs):
        super(CustomRNN, self).__init__(**kwargs)
        self.units = units
        self.dense1 = Dense(units=units, activation='tanh')
        self.dense2 = Dense(units=units, activation='tanh')
        self.classifier = Dense(1)
        
    def call(self, inputs):
        outputs = []
        state = tf.zeros(shape=(inputs.shape[0], self.units))
        for t in range(inputs.shape[1]):
            x = inputs[:, t, :]
            h = self.dense1(x)
            y = h + self.dense2(state)
            state = y
            outputs.append(y)
        features = tf.stack(outputs, axis=-1)
        return self.classifier(features)
    
    def get_config(self):
        config = super(CustomRNN, self).get_config()
        config.update({"units":self.units})
        return config

## 搭建模型

In [4]:
inputs = Input(batch_shape=(batch_size, timesteps, input_dim))
x = layers.Conv1D(32, 3)(inputs)
outputs = CustomRNN(units)(x)

model = Model(inputs=inputs, outputs=outputs)

## 保存模型

In [5]:
model.save('temp.h5')

## 从文件中载入模型

In [7]:
new_model = load_model('temp.h5', custom_objects={"CustomRNN":CustomRNN})

