In [1]:
import tensorflow as tf
from tensorflow import keras
from keras import layers, optimizers, Sequential, metrics
from matplotlib import pyplot as plt
import  io

(x, y), (x_test, y_test) = keras.datasets.fashion_mnist.load_data() # 返回numpy
print(f"x shape:{x.shape}, y shape:{y.shape}")

def preprocess(x, y):

    x = tf.cast(x, dtype=tf.float32) / 255.
    y = tf.cast(y, dtype=tf.int32)
    return x,y

batchsz = 128
db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(preprocess).shuffle(10000).batch(batchsz)
db_iter = iter(db)
sample = next(db_iter)
print('batch:', sample[0].shape, sample[1].shape)

db_test = tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_test = db_test.map(preprocess).batch(batchsz)

x shape:(60000, 28, 28), y shape:(60000,)
batch: (128, 28, 28) (128,)


In [2]:
model = Sequential([
    layers.Dense(256, activation=tf.nn.relu), # [b, 784] => [b, 256]
    layers.Dense(128, activation=tf.nn.relu), # [b, 256] => [b, 128]
    layers.Dense(64, activation=tf.nn.relu), # [b, 128] => [b, 64]
    layers.Dense(32, activation=tf.nn.relu), # [b, 64] => [b, 32]
    layers.Dense(10) # [b, 32] => [b, 10], 330 = 32*10 + 10
])
model.build(input_shape=[None, 28*28])
model.summary()
# w = w - lr*grad
optimizer = optimizers.Adam(learning_rate=1e-3)

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 256)               200960    
                                                                 
 dense_1 (Dense)             (None, 128)               32896     
                                                                 
 dense_2 (Dense)             (None, 64)                8256      
                                                                 
 dense_3 (Dense)             (None, 32)                2080      
                                                                 
 dense_4 (Dense)             (None, 10)                330       
                                                                 
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________


In [3]:
for epoch in range(30):


    for step, (x,y) in enumerate(db):

        # x: [b, 28, 28] => [b, 784]
        # y: [b]
        x = tf.reshape(x, [-1, 28*28])

        with tf.GradientTape() as tape:
            # [b, 784] => [b, 10]
            logits = model(x)
            y_onehot = tf.one_hot(y, depth=10)
            # [b]
            loss_mse = tf.reduce_mean(tf.losses.MSE(y_onehot, logits))
            loss_ce = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
            loss_ce = tf.reduce_mean(loss_ce)

        grads = tape.gradient(loss_ce, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))


        if step % 100 == 0:
            print(epoch, step, 'loss:', float(loss_ce), float(loss_mse))


    # test
    total_correct = 0
    total_num = 0
    for x,y in db_test:

        # x: [b, 28, 28] => [b, 784]
        # y: [b]
        x = tf.reshape(x, [-1, 28*28])
        # [b, 10]
        logits = model(x)
        # logits => prob, [b, 10]
        prob = tf.nn.softmax(logits, axis=1)
        # [b, 10] => [b], int64
        pred = tf.argmax(prob, axis=1)
        pred = tf.cast(pred, dtype=tf.int32)
        # pred:[b]
        # y: [b]
        # correct: [b], True: equal, False: not equal
        correct = tf.equal(pred, y)
        correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))

        total_correct += int(correct)
        total_num += x.shape[0]

    acc = total_correct / total_num
    print(epoch, 'test acc:', acc)

0 0 loss: 2.328538179397583 0.18116885423660278
0 100 loss: 0.5295809507369995 35.6922721862793
0 200 loss: 0.5176625847816467 28.342437744140625
0 300 loss: 0.5652111172676086 40.81658172607422
0 400 loss: 0.42393407225608826 34.39350509643555
0 test acc: 0.8452
1 0 loss: 0.5498256087303162 30.525524139404297
1 100 loss: 0.25784212350845337 34.683197021484375
1 200 loss: 0.2835013270378113 32.9970817565918
1 300 loss: 0.5449256896972656 42.74700927734375
1 400 loss: 0.3391828238964081 39.91959762573242
1 test acc: 0.8587
2 0 loss: 0.3791736960411072 36.629451751708984
2 100 loss: 0.2740439474582672 37.08982849121094
2 200 loss: 0.4671878516674042 35.50667190551758
2 300 loss: 0.27072280645370483 40.191925048828125
2 400 loss: 0.3873215317726135 43.14110565185547
2 test acc: 0.8661
3 0 loss: 0.2713596522808075 41.80927276611328
3 100 loss: 0.30535274744033813 41.8309326171875
3 200 loss: 0.25934654474258423 47.52680206298828
3 300 loss: 0.2546522617340088 42.33328628540039
3 400 loss: 