In [99]:
import tensorflow as tf
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics
from tensorflow import keras

In [40]:
(x,y),(x_val,y_val) = datasets.mnist.load_data() 

In [42]:
x.shape

TensorShape([60000, 784])

In [101]:
def preprocess(x,y):
    x = tf.cast(x,dtype= tf.float32) / 255.
    x = tf.reshape(x,[28*28])
    y = tf.cast(y,dtype= tf.int32)
    y = tf.one_hot(y,depth=10)
    return x,y

In [102]:
batchsz = 128
db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(preprocess).shuffle(100000).batch(batchsz)
db_val = tf.data.Dataset.from_tensor_slices((x_val,y_val))
db_val = db_val.map(preprocess).batch(batchsz)

In [117]:
sample = next(iter(db))
sample[0].shape

TensorShape([128, 784])

In [105]:
#创建网络结构
network = Sequential([layers.Dense(256, activation='relu'),
                     layers.Dense(128, activation='relu'),
                     layers.Dense(64, activation='relu'),
                     layers.Dense(32, activation='relu'),
                     layers.Dense(10)])
network.build(input_shape=(None, 28*28))
network.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_15 (Dense)             (None, 256)               200960    
_________________________________________________________________
dense_16 (Dense)             (None, 128)               32896     
_________________________________________________________________
dense_17 (Dense)             (None, 64)                8256      
_________________________________________________________________
dense_18 (Dense)             (None, 32)                2080      
_________________________________________________________________
dense_19 (Dense)             (None, 10)                330       
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________


In [96]:
network.trainable_variables[0].shape

TensorShape([784, 256])

In [106]:
network.compile(optimizer=optimizers.Adam(learning_rate= 0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits= True),
                metrics=['accuracy']
)

In [107]:
network.fit(db,epochs= 5,validation_data=db_val,validation_freq=2)
network.evaluate(db_val)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.13662420213222504, 0.9653000235557556]

In [75]:
pred = network.predict(x)

In [76]:
pred = tf.argmax(pred, axis=1)

In [64]:
pred

<tf.Tensor: shape=(60000,), dtype=int64, numpy=array([5, 0, 4, ..., 5, 6, 8])>

In [None]:
##自定义层或网络

In [125]:
# 自定义层
class MyDense(layers.Layer):
    def __init__(self,in_dim,out_dim):
        super(MyDense,self).__init__()
        self.kernel = self.add_weight('w',[in_dim,out_dim])
        self.bias = self.add_weight('b',[out_dim])
    def call(self,inputs,training = None):
        out = inputs@self.kernel+self.bias
        return out

In [120]:
#自定义网络
class MyModel(keras.Model):
    def __init__(self):
        super(MyModel,self).__init__()
        self.fc1 = MyDense(28*28,256)
        self.fc2 = MyDense(256,128)
        self.fc3 = MyDense(128,64)
        self.fc4 = MyDense(64,32)
        self.fc5 = MyDense(32,10)
    def call(self,inputs,training = None):
        x = self.fc1(inputs)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)
        
        return x

In [121]:
network1 = MyModel()
network1.compile(optimizer=optimizers.Adam(learning_rate= 0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits= True),
                metrics=['accuracy']
)

In [122]:
network1.fit(db,epochs= 5,validation_data=db_val,validation_freq=2)
network1.evaluate(db_val)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


[0.11940108984708786, 0.9740999937057495]

In [130]:
#权值保存
network1.save_weights('./network.ckpt')

In [131]:
network2 = MyModel()
network2.compile(optimizer=optimizers.Adam(learning_rate= 0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits= True),
                metrics=['accuracy']
)

In [132]:
#权值加载
network2.load_weights('./network.ckpt')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x158134f70>

In [133]:
network2.evaluate(db_val)



[0.11940108984708786, 0.9740999937057495]

In [None]:
#保存整个模型
network.save('model.h5')
#加载模型
net = tf.keras.models.load_model('model.h5')
net.evaluate(db_val)

In [None]:
#保存整个结构 可用于c++或其他语言
tf.saved_model.save(m,'...')
imported = tf.saved_model.load(path)
f = imported.signatures["serving_default"]
print(f(x = tf.ones([1,28,28,3])))