In [1]:
import tensorflow as tf
from tensorflow import keras

In [2]:
#自定义层结构
class MyDense(keras.layers.Layer):
    def __init__(self,in_dim,out_dim):
        super(MyDense,self).__init__()
        self.kernel = self.add_variable('w',[in_dim,out_dim])
        self.bias = self.add_variable('b',[out_dim])
    
    def call(self,input,training=None):
        out = input@self.kernel+self.bias
        return out

In [3]:
#自定义网络
class MyModel(keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()

        self.fc1 = MyDense(784, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)

    def call(self, inputs, training=None):

        x = self.fc1(inputs)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)
        x = tf.nn.softmax(x)

        return x

In [4]:
#载入数据
(x_train,y_train),(x_val,y_val)=keras.datasets.fashion_mnist.load_data()
print('训练集维度：',x_train.shape,y_train.shape)
print('数据范围:',x_train.min(),x_train.max())
print('类别：',set(y_train))
print('类别数：',len(set(y_train)))
print('验证集维度：',x_val.shape,y_val.shape)

训练集维度： (60000, 28, 28) (60000,)
数据范围: 0 255
类别： {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
类别数： 10
验证集维度： (10000, 28, 28) (10000,)


In [5]:
#预处理数据
def preprocess(x,y):
    x = tf.cast(x,dtype=tf.float32)/255
    x = tf.reshape(x,[-1,784])
    y = tf.cast(y,dtype=tf.int32)
    y = tf.one_hot(y,depth=10)
    return x,y

In [6]:
#准备每次喂入网络的数据batch
#每个batch200组数据，用preprocess处理这些数据，随机打乱（打乱时内存分配空间为10000）
db_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(200).map(preprocess).shuffle(10000)
db_val = tf.data.Dataset.from_tensor_slices((x_val,y_val)).batch(200).map(preprocess).shuffle(10000)

db_iter = iter(db_train)
sample = next(db_iter)
print('每次喂入的样本量:',sample[0].shape)
print('每次喂入的标签量:',sample[1].shape)

每次喂入的样本量: (200, 784)
每次喂入的标签量: (200, 10)


In [7]:
#实例化网络
model = MyModel()
model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01),loss=tf.losses.categorical_crossentropy,metrics=['accuracy'])
model.fit(db_train,epochs=10,validation_data=db_val,validation_steps=2)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x175ac800828>

In [8]:
model.compile(optimizer=tf.optimizers.SGD(learning_rate=0.01),loss=tf.losses.categorical_crossentropy,metrics=['accuracy'])
model.fit(db_train,epochs=10,validation_data=db_val,validation_steps=2)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x17668dc8e10>

In [9]:
model.evaluate(db_val)



[0.35045804679393766, 0.8814]