In [3]:
import tensorflow as tf
tf.__version__

'2.0.0'

In [4]:
import numpy as np
import pandas as pd

#### 数据获取，预处理的类

In [5]:
class MNISTLoader():
    def __init__(self):
        mnist = tf.keras.datasets.mnist
        (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()
        # MNIST中的图像默认为uint8（0-255的数字）。以下代码将其归一化到0-1之间的浮点数，并在最后增加一维作为颜色通道
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)
        return self.train_data[index, :], self.train_label[index]

#### 模型类   tf.keras.layers    tf.keras.Model

In [6]:
class MLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10)
        
    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.dense1(x)
        x = self.dense2(x)
        outputs = tf.nn.softmax(x)
        return outputs

#### 模型训练 tf.keras.losses     tf.keras.optimizer

In [7]:
num_epochs = 5
batch_size = 50
learning_rate = 0.001

In [8]:
model = MLP()
data_loader = MNISTLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [9]:
num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

batch 0: loss 2.438133
batch 1: loss 2.273934
batch 2: loss 2.355415
batch 3: loss 2.272263
batch 4: loss 2.181853
batch 5: loss 2.091420
batch 6: loss 2.146718
batch 7: loss 1.954695
batch 8: loss 1.872318
batch 9: loss 1.837145
batch 10: loss 1.738694
batch 11: loss 1.768432
batch 12: loss 1.792785
batch 13: loss 1.599320
batch 14: loss 1.419369
batch 15: loss 1.598553
batch 16: loss 1.366261
batch 17: loss 1.573335
batch 18: loss 1.518774
batch 19: loss 1.447316
batch 20: loss 1.315905
batch 21: loss 1.121393
batch 22: loss 1.279834
batch 23: loss 1.298772
batch 24: loss 1.272601
batch 25: loss 1.194776
batch 26: loss 1.077628
batch 27: loss 1.015672
batch 28: loss 1.130749
batch 29: loss 0.904288
batch 30: loss 0.868203
batch 31: loss 1.070194
batch 32: loss 0.845088
batch 33: loss 0.907807
batch 34: loss 1.161140
batch 35: loss 0.839448
batch 36: loss 0.846559
batch 37: loss 0.854337
batch 38: loss 0.764163
batch 39: loss 0.887984
batch 40: loss 0.906739
batch 41: loss 0.798458
ba

batch 333: loss 0.172308
batch 334: loss 0.198127
batch 335: loss 0.509073
batch 336: loss 0.132321
batch 337: loss 0.396576
batch 338: loss 0.285402
batch 339: loss 0.125688
batch 340: loss 0.298166
batch 341: loss 0.354519
batch 342: loss 0.202544
batch 343: loss 0.190226
batch 344: loss 0.243647
batch 345: loss 0.224673
batch 346: loss 0.324189
batch 347: loss 0.235534
batch 348: loss 0.163767
batch 349: loss 0.413491
batch 350: loss 0.379918
batch 351: loss 0.286801
batch 352: loss 0.201032
batch 353: loss 0.193176
batch 354: loss 0.303939
batch 355: loss 0.208271
batch 356: loss 0.233065
batch 357: loss 0.148381
batch 358: loss 0.267533
batch 359: loss 0.343363
batch 360: loss 0.313611
batch 361: loss 0.205873
batch 362: loss 0.380588
batch 363: loss 0.362755
batch 364: loss 0.167483
batch 365: loss 0.158388
batch 366: loss 0.334325
batch 367: loss 0.410258
batch 368: loss 0.185753
batch 369: loss 0.388926
batch 370: loss 0.461327
batch 371: loss 0.261707
batch 372: loss 0.314858


batch 670: loss 0.232246
batch 671: loss 0.342955
batch 672: loss 0.307051
batch 673: loss 0.291257
batch 674: loss 0.400734
batch 675: loss 0.308997
batch 676: loss 0.388449
batch 677: loss 0.280194
batch 678: loss 0.115115
batch 679: loss 0.156057
batch 680: loss 0.169511
batch 681: loss 0.185613
batch 682: loss 0.184230
batch 683: loss 0.211868
batch 684: loss 0.247305
batch 685: loss 0.369734
batch 686: loss 0.355961
batch 687: loss 0.262209
batch 688: loss 0.074488
batch 689: loss 0.284408
batch 690: loss 0.184993
batch 691: loss 0.116622
batch 692: loss 0.291178
batch 693: loss 0.258720
batch 694: loss 0.138539
batch 695: loss 0.113394
batch 696: loss 0.228311
batch 697: loss 0.217063
batch 698: loss 0.160929
batch 699: loss 0.207107
batch 700: loss 0.181734
batch 701: loss 0.087337
batch 702: loss 0.355201
batch 703: loss 0.269486
batch 704: loss 0.168906
batch 705: loss 0.271224
batch 706: loss 0.210080
batch 707: loss 0.151846
batch 708: loss 0.311158
batch 709: loss 0.401927


batch 1005: loss 0.221929
batch 1006: loss 0.203036
batch 1007: loss 0.166824
batch 1008: loss 0.191108
batch 1009: loss 0.098934
batch 1010: loss 0.214455
batch 1011: loss 0.264127
batch 1012: loss 0.071164
batch 1013: loss 0.196858
batch 1014: loss 0.254317
batch 1015: loss 0.216082
batch 1016: loss 0.212308
batch 1017: loss 0.353487
batch 1018: loss 0.441830
batch 1019: loss 0.135525
batch 1020: loss 0.150710
batch 1021: loss 0.154704
batch 1022: loss 0.106424
batch 1023: loss 0.211322
batch 1024: loss 0.163746
batch 1025: loss 0.095586
batch 1026: loss 0.211833
batch 1027: loss 0.108249
batch 1028: loss 0.160094
batch 1029: loss 0.255276
batch 1030: loss 0.408355
batch 1031: loss 0.160982
batch 1032: loss 0.107724
batch 1033: loss 0.059702
batch 1034: loss 0.319899
batch 1035: loss 0.177109
batch 1036: loss 0.152864
batch 1037: loss 0.234080
batch 1038: loss 0.263151
batch 1039: loss 0.176015
batch 1040: loss 0.113232
batch 1041: loss 0.180857
batch 1042: loss 0.117127
batch 1043: 

batch 1343: loss 0.181572
batch 1344: loss 0.124535
batch 1345: loss 0.110635
batch 1346: loss 0.062804
batch 1347: loss 0.115794
batch 1348: loss 0.083502
batch 1349: loss 0.073093
batch 1350: loss 0.180666
batch 1351: loss 0.054115
batch 1352: loss 0.212064
batch 1353: loss 0.441950
batch 1354: loss 0.156293
batch 1355: loss 0.338918
batch 1356: loss 0.193677
batch 1357: loss 0.071389
batch 1358: loss 0.043808
batch 1359: loss 0.115723
batch 1360: loss 0.160234
batch 1361: loss 0.086074
batch 1362: loss 0.139353
batch 1363: loss 0.199011
batch 1364: loss 0.269550
batch 1365: loss 0.113000
batch 1366: loss 0.169264
batch 1367: loss 0.040457
batch 1368: loss 0.074780
batch 1369: loss 0.324395
batch 1370: loss 0.125488
batch 1371: loss 0.163521
batch 1372: loss 0.090384
batch 1373: loss 0.146167
batch 1374: loss 0.106601
batch 1375: loss 0.122729
batch 1376: loss 0.121462
batch 1377: loss 0.131891
batch 1378: loss 0.206210
batch 1379: loss 0.087353
batch 1380: loss 0.046529
batch 1381: 

batch 1677: loss 0.160194
batch 1678: loss 0.035753
batch 1679: loss 0.141109
batch 1680: loss 0.058940
batch 1681: loss 0.094044
batch 1682: loss 0.273682
batch 1683: loss 0.044207
batch 1684: loss 0.193323
batch 1685: loss 0.136382
batch 1686: loss 0.134447
batch 1687: loss 0.088205
batch 1688: loss 0.158309
batch 1689: loss 0.159827
batch 1690: loss 0.241754
batch 1691: loss 0.064703
batch 1692: loss 0.057464
batch 1693: loss 0.345379
batch 1694: loss 0.055618
batch 1695: loss 0.103882
batch 1696: loss 0.112979
batch 1697: loss 0.270604
batch 1698: loss 0.380239
batch 1699: loss 0.098131
batch 1700: loss 0.083716
batch 1701: loss 0.278071
batch 1702: loss 0.150444
batch 1703: loss 0.079158
batch 1704: loss 0.041771
batch 1705: loss 0.052621
batch 1706: loss 0.196216
batch 1707: loss 0.078263
batch 1708: loss 0.136213
batch 1709: loss 0.075868
batch 1710: loss 0.225519
batch 1711: loss 0.118236
batch 1712: loss 0.080280
batch 1713: loss 0.202467
batch 1714: loss 0.054207
batch 1715: 

batch 2007: loss 0.072340
batch 2008: loss 0.133862
batch 2009: loss 0.093813
batch 2010: loss 0.189375
batch 2011: loss 0.047397
batch 2012: loss 0.124968
batch 2013: loss 0.073413
batch 2014: loss 0.303389
batch 2015: loss 0.109084
batch 2016: loss 0.223872
batch 2017: loss 0.047719
batch 2018: loss 0.089301
batch 2019: loss 0.072166
batch 2020: loss 0.098508
batch 2021: loss 0.081114
batch 2022: loss 0.044827
batch 2023: loss 0.085451
batch 2024: loss 0.237212
batch 2025: loss 0.280838
batch 2026: loss 0.151866
batch 2027: loss 0.127723
batch 2028: loss 0.271027
batch 2029: loss 0.141278
batch 2030: loss 0.067672
batch 2031: loss 0.193344
batch 2032: loss 0.206371
batch 2033: loss 0.080495
batch 2034: loss 0.106537
batch 2035: loss 0.081373
batch 2036: loss 0.063803
batch 2037: loss 0.229890
batch 2038: loss 0.283576
batch 2039: loss 0.055973
batch 2040: loss 0.120511
batch 2041: loss 0.239391
batch 2042: loss 0.119733
batch 2043: loss 0.073612
batch 2044: loss 0.061366
batch 2045: 

batch 2339: loss 0.089268
batch 2340: loss 0.064107
batch 2341: loss 0.117457
batch 2342: loss 0.114749
batch 2343: loss 0.084663
batch 2344: loss 0.049583
batch 2345: loss 0.082707
batch 2346: loss 0.103206
batch 2347: loss 0.053971
batch 2348: loss 0.436312
batch 2349: loss 0.083268
batch 2350: loss 0.054404
batch 2351: loss 0.016921
batch 2352: loss 0.158012
batch 2353: loss 0.142246
batch 2354: loss 0.046114
batch 2355: loss 0.404361
batch 2356: loss 0.054854
batch 2357: loss 0.156947
batch 2358: loss 0.051046
batch 2359: loss 0.113070
batch 2360: loss 0.048772
batch 2361: loss 0.170007
batch 2362: loss 0.135803
batch 2363: loss 0.065859
batch 2364: loss 0.067526
batch 2365: loss 0.084251
batch 2366: loss 0.052549
batch 2367: loss 0.027904
batch 2368: loss 0.109133
batch 2369: loss 0.322382
batch 2370: loss 0.092979
batch 2371: loss 0.103315
batch 2372: loss 0.082282
batch 2373: loss 0.135728
batch 2374: loss 0.152199
batch 2375: loss 0.041260
batch 2376: loss 0.069801
batch 2377: 

batch 2681: loss 0.114992
batch 2682: loss 0.040958
batch 2683: loss 0.064496
batch 2684: loss 0.115168
batch 2685: loss 0.054857
batch 2686: loss 0.059141
batch 2687: loss 0.011904
batch 2688: loss 0.079529
batch 2689: loss 0.024099
batch 2690: loss 0.215777
batch 2691: loss 0.075046
batch 2692: loss 0.066942
batch 2693: loss 0.159111
batch 2694: loss 0.176741
batch 2695: loss 0.196946
batch 2696: loss 0.041107
batch 2697: loss 0.051088
batch 2698: loss 0.051604
batch 2699: loss 0.201741
batch 2700: loss 0.159544
batch 2701: loss 0.108727
batch 2702: loss 0.028287
batch 2703: loss 0.032168
batch 2704: loss 0.055052
batch 2705: loss 0.171080
batch 2706: loss 0.043639
batch 2707: loss 0.177002
batch 2708: loss 0.037802
batch 2709: loss 0.165160
batch 2710: loss 0.099388
batch 2711: loss 0.043725
batch 2712: loss 0.076239
batch 2713: loss 0.094707
batch 2714: loss 0.066269
batch 2715: loss 0.154890
batch 2716: loss 0.126170
batch 2717: loss 0.052647
batch 2718: loss 0.082448
batch 2719: 

batch 3020: loss 0.077873
batch 3021: loss 0.155828
batch 3022: loss 0.014657
batch 3023: loss 0.107585
batch 3024: loss 0.056097
batch 3025: loss 0.105639
batch 3026: loss 0.139832
batch 3027: loss 0.170285
batch 3028: loss 0.109127
batch 3029: loss 0.158520
batch 3030: loss 0.057416
batch 3031: loss 0.053786
batch 3032: loss 0.112596
batch 3033: loss 0.068411
batch 3034: loss 0.305801
batch 3035: loss 0.044172
batch 3036: loss 0.016322
batch 3037: loss 0.070881
batch 3038: loss 0.157743
batch 3039: loss 0.210448
batch 3040: loss 0.086598
batch 3041: loss 0.106504
batch 3042: loss 0.224849
batch 3043: loss 0.074590
batch 3044: loss 0.141114
batch 3045: loss 0.039578
batch 3046: loss 0.136249
batch 3047: loss 0.040294
batch 3048: loss 0.057529
batch 3049: loss 0.111494
batch 3050: loss 0.020595
batch 3051: loss 0.105711
batch 3052: loss 0.115008
batch 3053: loss 0.046849
batch 3054: loss 0.082091
batch 3055: loss 0.133074
batch 3056: loss 0.203819
batch 3057: loss 0.034299
batch 3058: 

batch 3359: loss 0.105540
batch 3360: loss 0.017435
batch 3361: loss 0.140355
batch 3362: loss 0.131217
batch 3363: loss 0.023856
batch 3364: loss 0.042953
batch 3365: loss 0.255985
batch 3366: loss 0.022413
batch 3367: loss 0.081400
batch 3368: loss 0.172822
batch 3369: loss 0.200968
batch 3370: loss 0.047161
batch 3371: loss 0.143140
batch 3372: loss 0.052726
batch 3373: loss 0.061457
batch 3374: loss 0.052740
batch 3375: loss 0.091821
batch 3376: loss 0.030775
batch 3377: loss 0.058270
batch 3378: loss 0.098163
batch 3379: loss 0.143585
batch 3380: loss 0.116482
batch 3381: loss 0.114432
batch 3382: loss 0.053491
batch 3383: loss 0.185666
batch 3384: loss 0.091311
batch 3385: loss 0.127679
batch 3386: loss 0.022625
batch 3387: loss 0.426715
batch 3388: loss 0.076241
batch 3389: loss 0.018241
batch 3390: loss 0.124654
batch 3391: loss 0.115805
batch 3392: loss 0.133751
batch 3393: loss 0.089501
batch 3394: loss 0.059743
batch 3395: loss 0.070447
batch 3396: loss 0.070344
batch 3397: 

batch 3698: loss 0.074212
batch 3699: loss 0.021574
batch 3700: loss 0.076546
batch 3701: loss 0.085733
batch 3702: loss 0.016000
batch 3703: loss 0.160426
batch 3704: loss 0.108519
batch 3705: loss 0.275406
batch 3706: loss 0.084754
batch 3707: loss 0.196894
batch 3708: loss 0.089914
batch 3709: loss 0.143554
batch 3710: loss 0.065484
batch 3711: loss 0.036847
batch 3712: loss 0.131308
batch 3713: loss 0.005070
batch 3714: loss 0.043243
batch 3715: loss 0.080142
batch 3716: loss 0.184699
batch 3717: loss 0.085642
batch 3718: loss 0.191736
batch 3719: loss 0.019275
batch 3720: loss 0.031902
batch 3721: loss 0.105319
batch 3722: loss 0.101683
batch 3723: loss 0.140294
batch 3724: loss 0.025922
batch 3725: loss 0.120470
batch 3726: loss 0.165678
batch 3727: loss 0.015648
batch 3728: loss 0.128222
batch 3729: loss 0.063549
batch 3730: loss 0.162487
batch 3731: loss 0.020553
batch 3732: loss 0.135446
batch 3733: loss 0.095988
batch 3734: loss 0.008094
batch 3735: loss 0.077960
batch 3736: 

batch 4040: loss 0.072249
batch 4041: loss 0.091931
batch 4042: loss 0.028185
batch 4043: loss 0.116460
batch 4044: loss 0.027178
batch 4045: loss 0.149237
batch 4046: loss 0.024256
batch 4047: loss 0.114092
batch 4048: loss 0.077570
batch 4049: loss 0.050464
batch 4050: loss 0.071449
batch 4051: loss 0.046450
batch 4052: loss 0.034529
batch 4053: loss 0.090841
batch 4054: loss 0.072922
batch 4055: loss 0.112073
batch 4056: loss 0.059751
batch 4057: loss 0.126895
batch 4058: loss 0.060096
batch 4059: loss 0.031818
batch 4060: loss 0.052398
batch 4061: loss 0.036490
batch 4062: loss 0.012176
batch 4063: loss 0.098178
batch 4064: loss 0.035308
batch 4065: loss 0.179858
batch 4066: loss 0.056316
batch 4067: loss 0.039370
batch 4068: loss 0.174071
batch 4069: loss 0.030910
batch 4070: loss 0.019627
batch 4071: loss 0.023283
batch 4072: loss 0.096282
batch 4073: loss 0.194965
batch 4074: loss 0.177297
batch 4075: loss 0.025663
batch 4076: loss 0.051847
batch 4077: loss 0.057345
batch 4078: 

batch 4381: loss 0.029183
batch 4382: loss 0.028607
batch 4383: loss 0.026518
batch 4384: loss 0.062160
batch 4385: loss 0.008079
batch 4386: loss 0.079679
batch 4387: loss 0.157912
batch 4388: loss 0.049998
batch 4389: loss 0.054488
batch 4390: loss 0.092487
batch 4391: loss 0.019647
batch 4392: loss 0.037965
batch 4393: loss 0.059613
batch 4394: loss 0.041037
batch 4395: loss 0.056484
batch 4396: loss 0.120272
batch 4397: loss 0.075093
batch 4398: loss 0.130533
batch 4399: loss 0.082249
batch 4400: loss 0.037358
batch 4401: loss 0.124255
batch 4402: loss 0.220542
batch 4403: loss 0.042931
batch 4404: loss 0.047431
batch 4405: loss 0.099772
batch 4406: loss 0.053369
batch 4407: loss 0.030118
batch 4408: loss 0.057212
batch 4409: loss 0.044214
batch 4410: loss 0.027357
batch 4411: loss 0.046567
batch 4412: loss 0.045217
batch 4413: loss 0.031394
batch 4414: loss 0.136500
batch 4415: loss 0.036946
batch 4416: loss 0.030478
batch 4417: loss 0.050636
batch 4418: loss 0.024017
batch 4419: 

batch 4722: loss 0.093903
batch 4723: loss 0.011096
batch 4724: loss 0.041609
batch 4725: loss 0.055827
batch 4726: loss 0.074875
batch 4727: loss 0.075745
batch 4728: loss 0.097463
batch 4729: loss 0.007652
batch 4730: loss 0.063148
batch 4731: loss 0.024279
batch 4732: loss 0.067345
batch 4733: loss 0.022485
batch 4734: loss 0.024651
batch 4735: loss 0.066455
batch 4736: loss 0.098375
batch 4737: loss 0.076060
batch 4738: loss 0.014577
batch 4739: loss 0.035120
batch 4740: loss 0.089022
batch 4741: loss 0.022179
batch 4742: loss 0.011320
batch 4743: loss 0.045622
batch 4744: loss 0.052889
batch 4745: loss 0.026092
batch 4746: loss 0.095596
batch 4747: loss 0.055879
batch 4748: loss 0.067991
batch 4749: loss 0.060729
batch 4750: loss 0.035595
batch 4751: loss 0.103296
batch 4752: loss 0.138553
batch 4753: loss 0.072969
batch 4754: loss 0.092223
batch 4755: loss 0.112830
batch 4756: loss 0.066137
batch 4757: loss 0.119887
batch 4758: loss 0.067026
batch 4759: loss 0.010911
batch 4760: 

batch 5054: loss 0.024231
batch 5055: loss 0.060392
batch 5056: loss 0.063589
batch 5057: loss 0.007975
batch 5058: loss 0.046136
batch 5059: loss 0.008734
batch 5060: loss 0.100550
batch 5061: loss 0.069121
batch 5062: loss 0.058437
batch 5063: loss 0.048394
batch 5064: loss 0.055730
batch 5065: loss 0.046472
batch 5066: loss 0.017050
batch 5067: loss 0.143175
batch 5068: loss 0.166556
batch 5069: loss 0.024362
batch 5070: loss 0.059877
batch 5071: loss 0.058309
batch 5072: loss 0.027666
batch 5073: loss 0.095333
batch 5074: loss 0.043421
batch 5075: loss 0.037905
batch 5076: loss 0.031751
batch 5077: loss 0.025664
batch 5078: loss 0.049571
batch 5079: loss 0.058766
batch 5080: loss 0.029005
batch 5081: loss 0.044393
batch 5082: loss 0.026949
batch 5083: loss 0.043574
batch 5084: loss 0.029485
batch 5085: loss 0.043897
batch 5086: loss 0.276486
batch 5087: loss 0.032624
batch 5088: loss 0.042376
batch 5089: loss 0.036515
batch 5090: loss 0.121533
batch 5091: loss 0.084589
batch 5092: 

batch 5398: loss 0.054161
batch 5399: loss 0.007547
batch 5400: loss 0.006234
batch 5401: loss 0.160426
batch 5402: loss 0.032967
batch 5403: loss 0.116467
batch 5404: loss 0.042949
batch 5405: loss 0.041698
batch 5406: loss 0.026382
batch 5407: loss 0.044358
batch 5408: loss 0.073422
batch 5409: loss 0.092524
batch 5410: loss 0.055302
batch 5411: loss 0.024991
batch 5412: loss 0.223779
batch 5413: loss 0.019264
batch 5414: loss 0.387550
batch 5415: loss 0.135067
batch 5416: loss 0.013767
batch 5417: loss 0.024681
batch 5418: loss 0.162643
batch 5419: loss 0.025067
batch 5420: loss 0.025076
batch 5421: loss 0.048527
batch 5422: loss 0.019485
batch 5423: loss 0.134592
batch 5424: loss 0.052995
batch 5425: loss 0.028899
batch 5426: loss 0.006011
batch 5427: loss 0.046829
batch 5428: loss 0.040654
batch 5429: loss 0.008383
batch 5430: loss 0.060415
batch 5431: loss 0.010276
batch 5432: loss 0.042624
batch 5433: loss 0.114477
batch 5434: loss 0.049449
batch 5435: loss 0.028644
batch 5436: 

batch 5740: loss 0.039588
batch 5741: loss 0.013281
batch 5742: loss 0.055334
batch 5743: loss 0.097276
batch 5744: loss 0.071569
batch 5745: loss 0.017326
batch 5746: loss 0.003502
batch 5747: loss 0.027420
batch 5748: loss 0.026941
batch 5749: loss 0.036586
batch 5750: loss 0.018612
batch 5751: loss 0.072531
batch 5752: loss 0.009556
batch 5753: loss 0.026269
batch 5754: loss 0.021206
batch 5755: loss 0.063031
batch 5756: loss 0.051354
batch 5757: loss 0.066381
batch 5758: loss 0.140855
batch 5759: loss 0.023639
batch 5760: loss 0.016897
batch 5761: loss 0.009292
batch 5762: loss 0.005251
batch 5763: loss 0.018382
batch 5764: loss 0.093803
batch 5765: loss 0.028754
batch 5766: loss 0.019885
batch 5767: loss 0.019958
batch 5768: loss 0.036867
batch 5769: loss 0.089943
batch 5770: loss 0.098732
batch 5771: loss 0.036460
batch 5772: loss 0.104822
batch 5773: loss 0.047763
batch 5774: loss 0.008732
batch 5775: loss 0.113346
batch 5776: loss 0.126947
batch 5777: loss 0.019885
batch 5778: 

#### 模型评估  tf.keras.metrics

In [10]:
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)
for batch_index in range(num_batches):
    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    y_pred = model.predict(data_loader.test_data[start_index: end_index])
    sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index],y_pred=y_pred)
print("test accuracy is: %f" % sparse_categorical_accuracy.result())

test accuracy is: 0.972900


### 卷积神经网络 CNN

In [11]:
class CNN(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.connv1 = tf.keras.layers.Conv2D(
            filters=32,             # 卷积层神经元（卷积核）数目
            kernel_size=[5, 5],     # 感受野大小
            padding='same',         # padding策略（vaild 或 same）
            activation=tf.nn.relu   # 激活函数
        )
        self.pool1 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
        self.conv2 = tf.keras.layers.Conv2D(
            filters=64,
            kernel_size=[5, 5],
            padding='same',
            activation=tf.nn.relu
        )
        self.pool2 = tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)
        self.flatten = tf.keras.layers.Reshape(target_shape=(7 * 7 * 64,))
        self.dense1 = tf.keras.layers.Dense(units=1024, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(units=10)

    def call(self, inputs):
        x = self.conv1(inputs)                  # [batch_size, 28, 28, 32]
        x = self.pool1(x)                       # [batch_size, 14, 14, 32]
        x = self.conv2(x)                       # [batch_size, 14, 14, 64]
        x = self.pool2(x)                       # [batch_size, 7, 7, 64]
        x = self.flatten(x)                     # [batch_size, 7 * 7 * 64]
        x = self.dense1(x)                      # [batch_size, 1024]
        x = self.dense2(x)                      # [batch_size, 10]
        output = tf.nn.softmax(x)
        return output

In [12]:
modelCNN = CNN()

In [13]:
num_epochs = 5
batch_size = 50
learning_rate = 0.001

In [14]:
model = MLP()
data_loader = MNISTLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [15]:
num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

batch 0: loss 2.396172
batch 1: loss 2.301820
batch 2: loss 2.206160
batch 3: loss 2.165649
batch 4: loss 2.141267
batch 5: loss 2.086282
batch 6: loss 1.957843
batch 7: loss 1.949489
batch 8: loss 1.818043
batch 9: loss 1.704629
batch 10: loss 1.711488
batch 11: loss 1.727079
batch 12: loss 1.677685
batch 13: loss 1.547471
batch 14: loss 1.432473
batch 15: loss 1.546335
batch 16: loss 1.529451
batch 17: loss 1.452398
batch 18: loss 1.457676
batch 19: loss 1.323920
batch 20: loss 1.156933
batch 21: loss 1.253210
batch 22: loss 1.163003
batch 23: loss 1.329647
batch 24: loss 1.223069
batch 25: loss 1.077659
batch 26: loss 1.068201
batch 27: loss 1.015373
batch 28: loss 1.052594
batch 29: loss 1.028634
batch 30: loss 0.971218
batch 31: loss 0.986218
batch 32: loss 0.970325
batch 33: loss 0.892468
batch 34: loss 0.880930
batch 35: loss 0.805289
batch 36: loss 0.876451
batch 37: loss 0.856156
batch 38: loss 0.920523
batch 39: loss 0.928826
batch 40: loss 0.807605
batch 41: loss 0.598919
ba

batch 336: loss 0.219123
batch 337: loss 0.188114
batch 338: loss 0.116074
batch 339: loss 0.175368
batch 340: loss 0.474938
batch 341: loss 0.178944
batch 342: loss 0.429519
batch 343: loss 0.332504
batch 344: loss 0.191827
batch 345: loss 0.514219
batch 346: loss 0.498515
batch 347: loss 0.475369
batch 348: loss 0.223633
batch 349: loss 0.351792
batch 350: loss 0.299564
batch 351: loss 0.277855
batch 352: loss 0.243060
batch 353: loss 0.407063
batch 354: loss 0.294628
batch 355: loss 0.267463
batch 356: loss 0.457696
batch 357: loss 0.329360
batch 358: loss 0.512212
batch 359: loss 0.316043
batch 360: loss 0.174431
batch 361: loss 0.225253
batch 362: loss 0.138698
batch 363: loss 0.256407
batch 364: loss 0.237275
batch 365: loss 0.223827
batch 366: loss 0.138545
batch 367: loss 0.153996
batch 368: loss 0.213599
batch 369: loss 0.297388
batch 370: loss 0.261974
batch 371: loss 0.227705
batch 372: loss 0.317586
batch 373: loss 0.288888
batch 374: loss 0.176673
batch 375: loss 0.370153


batch 672: loss 0.083702
batch 673: loss 0.062046
batch 674: loss 0.157105
batch 675: loss 0.176974
batch 676: loss 0.257782
batch 677: loss 0.186977
batch 678: loss 0.144210
batch 679: loss 0.170667
batch 680: loss 0.154602
batch 681: loss 0.109601
batch 682: loss 0.498901
batch 683: loss 0.140099
batch 684: loss 0.114049
batch 685: loss 0.182271
batch 686: loss 0.339528
batch 687: loss 0.193452
batch 688: loss 0.138961
batch 689: loss 0.215980
batch 690: loss 0.251689
batch 691: loss 0.118931
batch 692: loss 0.221623
batch 693: loss 0.275397
batch 694: loss 0.202496
batch 695: loss 0.373201
batch 696: loss 0.124747
batch 697: loss 0.338553
batch 698: loss 0.335709
batch 699: loss 0.251377
batch 700: loss 0.140034
batch 701: loss 0.247119
batch 702: loss 0.111636
batch 703: loss 0.137635
batch 704: loss 0.248955
batch 705: loss 0.158482
batch 706: loss 0.166410
batch 707: loss 0.180451
batch 708: loss 0.062551
batch 709: loss 0.373578
batch 710: loss 0.080802
batch 711: loss 0.068056


batch 1009: loss 0.156234
batch 1010: loss 0.109541
batch 1011: loss 0.112047
batch 1012: loss 0.070254
batch 1013: loss 0.190675
batch 1014: loss 0.199943
batch 1015: loss 0.145734
batch 1016: loss 0.192558
batch 1017: loss 0.107726
batch 1018: loss 0.341545
batch 1019: loss 0.219242
batch 1020: loss 0.195877
batch 1021: loss 0.266778
batch 1022: loss 0.105959
batch 1023: loss 0.114902
batch 1024: loss 0.124128
batch 1025: loss 0.195042
batch 1026: loss 0.237744
batch 1027: loss 0.134760
batch 1028: loss 0.091405
batch 1029: loss 0.094443
batch 1030: loss 0.575960
batch 1031: loss 0.217058
batch 1032: loss 0.155264
batch 1033: loss 0.070170
batch 1034: loss 0.179063
batch 1035: loss 0.133224
batch 1036: loss 0.252399
batch 1037: loss 0.315145
batch 1038: loss 0.308006
batch 1039: loss 0.054328
batch 1040: loss 0.096342
batch 1041: loss 0.206589
batch 1042: loss 0.051051
batch 1043: loss 0.103626
batch 1044: loss 0.115388
batch 1045: loss 0.036220
batch 1046: loss 0.149905
batch 1047: 

batch 1341: loss 0.154748
batch 1342: loss 0.087963
batch 1343: loss 0.133031
batch 1344: loss 0.142947
batch 1345: loss 0.127596
batch 1346: loss 0.133709
batch 1347: loss 0.097252
batch 1348: loss 0.229546
batch 1349: loss 0.233560
batch 1350: loss 0.072082
batch 1351: loss 0.096195
batch 1352: loss 0.156937
batch 1353: loss 0.074492
batch 1354: loss 0.140459
batch 1355: loss 0.132774
batch 1356: loss 0.082736
batch 1357: loss 0.134929
batch 1358: loss 0.075598
batch 1359: loss 0.103545
batch 1360: loss 0.085108
batch 1361: loss 0.203485
batch 1362: loss 0.072804
batch 1363: loss 0.127868
batch 1364: loss 0.319042
batch 1365: loss 0.071953
batch 1366: loss 0.040803
batch 1367: loss 0.233312
batch 1368: loss 0.200146
batch 1369: loss 0.133670
batch 1370: loss 0.256150
batch 1371: loss 0.131899
batch 1372: loss 0.061795
batch 1373: loss 0.075071
batch 1374: loss 0.193443
batch 1375: loss 0.180886
batch 1376: loss 0.140202
batch 1377: loss 0.265849
batch 1378: loss 0.311208
batch 1379: 

batch 1679: loss 0.234198
batch 1680: loss 0.046610
batch 1681: loss 0.060960
batch 1682: loss 0.151495
batch 1683: loss 0.108967
batch 1684: loss 0.137659
batch 1685: loss 0.049222
batch 1686: loss 0.222519
batch 1687: loss 0.077001
batch 1688: loss 0.134183
batch 1689: loss 0.043086
batch 1690: loss 0.192661
batch 1691: loss 0.231231
batch 1692: loss 0.087949
batch 1693: loss 0.211083
batch 1694: loss 0.222689
batch 1695: loss 0.172486
batch 1696: loss 0.075634
batch 1697: loss 0.252229
batch 1698: loss 0.077829
batch 1699: loss 0.131433
batch 1700: loss 0.141713
batch 1701: loss 0.231697
batch 1702: loss 0.283069
batch 1703: loss 0.072943
batch 1704: loss 0.049865
batch 1705: loss 0.098785
batch 1706: loss 0.135179
batch 1707: loss 0.227527
batch 1708: loss 0.341239
batch 1709: loss 0.253548
batch 1710: loss 0.118332
batch 1711: loss 0.048732
batch 1712: loss 0.078513
batch 1713: loss 0.163374
batch 1714: loss 0.206944
batch 1715: loss 0.130391
batch 1716: loss 0.170963
batch 1717: 

batch 2018: loss 0.073631
batch 2019: loss 0.068210
batch 2020: loss 0.353305
batch 2021: loss 0.141389
batch 2022: loss 0.114661
batch 2023: loss 0.216378
batch 2024: loss 0.067543
batch 2025: loss 0.298857
batch 2026: loss 0.059981
batch 2027: loss 0.065037
batch 2028: loss 0.033319
batch 2029: loss 0.039643
batch 2030: loss 0.295403
batch 2031: loss 0.111665
batch 2032: loss 0.118637
batch 2033: loss 0.025287
batch 2034: loss 0.146813
batch 2035: loss 0.075870
batch 2036: loss 0.316391
batch 2037: loss 0.052660
batch 2038: loss 0.160141
batch 2039: loss 0.059286
batch 2040: loss 0.097738
batch 2041: loss 0.039095
batch 2042: loss 0.077879
batch 2043: loss 0.150620
batch 2044: loss 0.068997
batch 2045: loss 0.091090
batch 2046: loss 0.085704
batch 2047: loss 0.063647
batch 2048: loss 0.120590
batch 2049: loss 0.070650
batch 2050: loss 0.091913
batch 2051: loss 0.109092
batch 2052: loss 0.199356
batch 2053: loss 0.133844
batch 2054: loss 0.121111
batch 2055: loss 0.097331
batch 2056: 

batch 2355: loss 0.130350
batch 2356: loss 0.246310
batch 2357: loss 0.034481
batch 2358: loss 0.020769
batch 2359: loss 0.121668
batch 2360: loss 0.063215
batch 2361: loss 0.049488
batch 2362: loss 0.202264
batch 2363: loss 0.031086
batch 2364: loss 0.060104
batch 2365: loss 0.047621
batch 2366: loss 0.154475
batch 2367: loss 0.317391
batch 2368: loss 0.017888
batch 2369: loss 0.135021
batch 2370: loss 0.067062
batch 2371: loss 0.094890
batch 2372: loss 0.037344
batch 2373: loss 0.153989
batch 2374: loss 0.190612
batch 2375: loss 0.274292
batch 2376: loss 0.094432
batch 2377: loss 0.141967
batch 2378: loss 0.027714
batch 2379: loss 0.156208
batch 2380: loss 0.114741
batch 2381: loss 0.080494
batch 2382: loss 0.098100
batch 2383: loss 0.133565
batch 2384: loss 0.023946
batch 2385: loss 0.073382
batch 2386: loss 0.094195
batch 2387: loss 0.027692
batch 2388: loss 0.044630
batch 2389: loss 0.062117
batch 2390: loss 0.065014
batch 2391: loss 0.087378
batch 2392: loss 0.056802
batch 2393: 

batch 2695: loss 0.038093
batch 2696: loss 0.284340
batch 2697: loss 0.109086
batch 2698: loss 0.089192
batch 2699: loss 0.108856
batch 2700: loss 0.043901
batch 2701: loss 0.110045
batch 2702: loss 0.068297
batch 2703: loss 0.069473
batch 2704: loss 0.053298
batch 2705: loss 0.057842
batch 2706: loss 0.066053
batch 2707: loss 0.043189
batch 2708: loss 0.075837
batch 2709: loss 0.043430
batch 2710: loss 0.064378
batch 2711: loss 0.028908
batch 2712: loss 0.202399
batch 2713: loss 0.096580
batch 2714: loss 0.029390
batch 2715: loss 0.043699
batch 2716: loss 0.091654
batch 2717: loss 0.150475
batch 2718: loss 0.107252
batch 2719: loss 0.075156
batch 2720: loss 0.276152
batch 2721: loss 0.133526
batch 2722: loss 0.031911
batch 2723: loss 0.154928
batch 2724: loss 0.169524
batch 2725: loss 0.108939
batch 2726: loss 0.199170
batch 2727: loss 0.114946
batch 2728: loss 0.137480
batch 2729: loss 0.088153
batch 2730: loss 0.118776
batch 2731: loss 0.050511
batch 2732: loss 0.044979
batch 2733: 

batch 3030: loss 0.046887
batch 3031: loss 0.082227
batch 3032: loss 0.024574
batch 3033: loss 0.137155
batch 3034: loss 0.038000
batch 3035: loss 0.017864
batch 3036: loss 0.069539
batch 3037: loss 0.052872
batch 3038: loss 0.103921
batch 3039: loss 0.105504
batch 3040: loss 0.100022
batch 3041: loss 0.043063
batch 3042: loss 0.189763
batch 3043: loss 0.048018
batch 3044: loss 0.039081
batch 3045: loss 0.033843
batch 3046: loss 0.086513
batch 3047: loss 0.150366
batch 3048: loss 0.018937
batch 3049: loss 0.073106
batch 3050: loss 0.200702
batch 3051: loss 0.206837
batch 3052: loss 0.075422
batch 3053: loss 0.073749
batch 3054: loss 0.106359
batch 3055: loss 0.022642
batch 3056: loss 0.055619
batch 3057: loss 0.140377
batch 3058: loss 0.067899
batch 3059: loss 0.244004
batch 3060: loss 0.153798
batch 3061: loss 0.095745
batch 3062: loss 0.057028
batch 3063: loss 0.056478
batch 3064: loss 0.043031
batch 3065: loss 0.045408
batch 3066: loss 0.078359
batch 3067: loss 0.164106
batch 3068: 

batch 3366: loss 0.041670
batch 3367: loss 0.064618
batch 3368: loss 0.041957
batch 3369: loss 0.049870
batch 3370: loss 0.060263
batch 3371: loss 0.060886
batch 3372: loss 0.174063
batch 3373: loss 0.052726
batch 3374: loss 0.044236
batch 3375: loss 0.025836
batch 3376: loss 0.161467
batch 3377: loss 0.031406
batch 3378: loss 0.139448
batch 3379: loss 0.069558
batch 3380: loss 0.012331
batch 3381: loss 0.227151
batch 3382: loss 0.053731
batch 3383: loss 0.098084
batch 3384: loss 0.178142
batch 3385: loss 0.083397
batch 3386: loss 0.049928
batch 3387: loss 0.082797
batch 3388: loss 0.024421
batch 3389: loss 0.063308
batch 3390: loss 0.286704
batch 3391: loss 0.062117
batch 3392: loss 0.053257
batch 3393: loss 0.085264
batch 3394: loss 0.110868
batch 3395: loss 0.049224
batch 3396: loss 0.122820
batch 3397: loss 0.076596
batch 3398: loss 0.024318
batch 3399: loss 0.120803
batch 3400: loss 0.052204
batch 3401: loss 0.138065
batch 3402: loss 0.022345
batch 3403: loss 0.029151
batch 3404: 

batch 3702: loss 0.138025
batch 3703: loss 0.092022
batch 3704: loss 0.168532
batch 3705: loss 0.070502
batch 3706: loss 0.052274
batch 3707: loss 0.053715
batch 3708: loss 0.093549
batch 3709: loss 0.108607
batch 3710: loss 0.010725
batch 3711: loss 0.133316
batch 3712: loss 0.027444
batch 3713: loss 0.044935
batch 3714: loss 0.102989
batch 3715: loss 0.044246
batch 3716: loss 0.035496
batch 3717: loss 0.120179
batch 3718: loss 0.179447
batch 3719: loss 0.039822
batch 3720: loss 0.045079
batch 3721: loss 0.082695
batch 3722: loss 0.022994
batch 3723: loss 0.096556
batch 3724: loss 0.146626
batch 3725: loss 0.023665
batch 3726: loss 0.060866
batch 3727: loss 0.122735
batch 3728: loss 0.053382
batch 3729: loss 0.050571
batch 3730: loss 0.025430
batch 3731: loss 0.015605
batch 3732: loss 0.080272
batch 3733: loss 0.026290
batch 3734: loss 0.098870
batch 3735: loss 0.028636
batch 3736: loss 0.098575
batch 3737: loss 0.053111
batch 3738: loss 0.153110
batch 3739: loss 0.032034
batch 3740: 

batch 4042: loss 0.027052
batch 4043: loss 0.049749
batch 4044: loss 0.073173
batch 4045: loss 0.128604
batch 4046: loss 0.073971
batch 4047: loss 0.035247
batch 4048: loss 0.025738
batch 4049: loss 0.024741
batch 4050: loss 0.019606
batch 4051: loss 0.119411
batch 4052: loss 0.037069
batch 4053: loss 0.067469
batch 4054: loss 0.053603
batch 4055: loss 0.027569
batch 4056: loss 0.079934
batch 4057: loss 0.082109
batch 4058: loss 0.049387
batch 4059: loss 0.276429
batch 4060: loss 0.075579
batch 4061: loss 0.052072
batch 4062: loss 0.089996
batch 4063: loss 0.048676
batch 4064: loss 0.043466
batch 4065: loss 0.021156
batch 4066: loss 0.004475
batch 4067: loss 0.084687
batch 4068: loss 0.058513
batch 4069: loss 0.060451
batch 4070: loss 0.071697
batch 4071: loss 0.046369
batch 4072: loss 0.037537
batch 4073: loss 0.049000
batch 4074: loss 0.019922
batch 4075: loss 0.008100
batch 4076: loss 0.047846
batch 4077: loss 0.040992
batch 4078: loss 0.081566
batch 4079: loss 0.043458
batch 4080: 

batch 4381: loss 0.148139
batch 4382: loss 0.053431
batch 4383: loss 0.050869
batch 4384: loss 0.049054
batch 4385: loss 0.082399
batch 4386: loss 0.038735
batch 4387: loss 0.127622
batch 4388: loss 0.020856
batch 4389: loss 0.013505
batch 4390: loss 0.094363
batch 4391: loss 0.025532
batch 4392: loss 0.057302
batch 4393: loss 0.046809
batch 4394: loss 0.051599
batch 4395: loss 0.014929
batch 4396: loss 0.105968
batch 4397: loss 0.021007
batch 4398: loss 0.019068
batch 4399: loss 0.077884
batch 4400: loss 0.055025
batch 4401: loss 0.029480
batch 4402: loss 0.019754
batch 4403: loss 0.018744
batch 4404: loss 0.011391
batch 4405: loss 0.037763
batch 4406: loss 0.134452
batch 4407: loss 0.012631
batch 4408: loss 0.089242
batch 4409: loss 0.173244
batch 4410: loss 0.082721
batch 4411: loss 0.076199
batch 4412: loss 0.012935
batch 4413: loss 0.173199
batch 4414: loss 0.067037
batch 4415: loss 0.012305
batch 4416: loss 0.323236
batch 4417: loss 0.065713
batch 4418: loss 0.093216
batch 4419: 

batch 4719: loss 0.071734
batch 4720: loss 0.068976
batch 4721: loss 0.022084
batch 4722: loss 0.058616
batch 4723: loss 0.030503
batch 4724: loss 0.013907
batch 4725: loss 0.057147
batch 4726: loss 0.065345
batch 4727: loss 0.035268
batch 4728: loss 0.058688
batch 4729: loss 0.084060
batch 4730: loss 0.062663
batch 4731: loss 0.041848
batch 4732: loss 0.034755
batch 4733: loss 0.138174
batch 4734: loss 0.072428
batch 4735: loss 0.125636
batch 4736: loss 0.021028
batch 4737: loss 0.098666
batch 4738: loss 0.050486
batch 4739: loss 0.041377
batch 4740: loss 0.019627
batch 4741: loss 0.089092
batch 4742: loss 0.017630
batch 4743: loss 0.209464
batch 4744: loss 0.002932
batch 4745: loss 0.030447
batch 4746: loss 0.036443
batch 4747: loss 0.028712
batch 4748: loss 0.134750
batch 4749: loss 0.021205
batch 4750: loss 0.012298
batch 4751: loss 0.066959
batch 4752: loss 0.006932
batch 4753: loss 0.031335
batch 4754: loss 0.053612
batch 4755: loss 0.134308
batch 4756: loss 0.077991
batch 4757: 

batch 5055: loss 0.044513
batch 5056: loss 0.005869
batch 5057: loss 0.070658
batch 5058: loss 0.068659
batch 5059: loss 0.024078
batch 5060: loss 0.036382
batch 5061: loss 0.097462
batch 5062: loss 0.022070
batch 5063: loss 0.015613
batch 5064: loss 0.048467
batch 5065: loss 0.030624
batch 5066: loss 0.055006
batch 5067: loss 0.063727
batch 5068: loss 0.024763
batch 5069: loss 0.031105
batch 5070: loss 0.011099
batch 5071: loss 0.083443
batch 5072: loss 0.011884
batch 5073: loss 0.011722
batch 5074: loss 0.041276
batch 5075: loss 0.026198
batch 5076: loss 0.015449
batch 5077: loss 0.035771
batch 5078: loss 0.010049
batch 5079: loss 0.078552
batch 5080: loss 0.036374
batch 5081: loss 0.009637
batch 5082: loss 0.022656
batch 5083: loss 0.011265
batch 5084: loss 0.044628
batch 5085: loss 0.050601
batch 5086: loss 0.025882
batch 5087: loss 0.008636
batch 5088: loss 0.020947
batch 5089: loss 0.054255
batch 5090: loss 0.090393
batch 5091: loss 0.025827
batch 5092: loss 0.011637
batch 5093: 

batch 5391: loss 0.032586
batch 5392: loss 0.061131
batch 5393: loss 0.070526
batch 5394: loss 0.027915
batch 5395: loss 0.015393
batch 5396: loss 0.022969
batch 5397: loss 0.008124
batch 5398: loss 0.017162
batch 5399: loss 0.024092
batch 5400: loss 0.122628
batch 5401: loss 0.013741
batch 5402: loss 0.145304
batch 5403: loss 0.006706
batch 5404: loss 0.061381
batch 5405: loss 0.005992
batch 5406: loss 0.101179
batch 5407: loss 0.030472
batch 5408: loss 0.046592
batch 5409: loss 0.134839
batch 5410: loss 0.059967
batch 5411: loss 0.035152
batch 5412: loss 0.012176
batch 5413: loss 0.021970
batch 5414: loss 0.014461
batch 5415: loss 0.046102
batch 5416: loss 0.106599
batch 5417: loss 0.068721
batch 5418: loss 0.023385
batch 5419: loss 0.048124
batch 5420: loss 0.039097
batch 5421: loss 0.028251
batch 5422: loss 0.052256
batch 5423: loss 0.041946
batch 5424: loss 0.183886
batch 5425: loss 0.078213
batch 5426: loss 0.108326
batch 5427: loss 0.035426
batch 5428: loss 0.018738
batch 5429: 

batch 5729: loss 0.079636
batch 5730: loss 0.087026
batch 5731: loss 0.032402
batch 5732: loss 0.025816
batch 5733: loss 0.011451
batch 5734: loss 0.032417
batch 5735: loss 0.084475
batch 5736: loss 0.013100
batch 5737: loss 0.006147
batch 5738: loss 0.011537
batch 5739: loss 0.038000
batch 5740: loss 0.005562
batch 5741: loss 0.018352
batch 5742: loss 0.121740
batch 5743: loss 0.015971
batch 5744: loss 0.031160
batch 5745: loss 0.060675
batch 5746: loss 0.011252
batch 5747: loss 0.037928
batch 5748: loss 0.005231
batch 5749: loss 0.014073
batch 5750: loss 0.012757
batch 5751: loss 0.289059
batch 5752: loss 0.046656
batch 5753: loss 0.060204
batch 5754: loss 0.014990
batch 5755: loss 0.029703
batch 5756: loss 0.015125
batch 5757: loss 0.022769
batch 5758: loss 0.018195
batch 5759: loss 0.024059
batch 5760: loss 0.036038
batch 5761: loss 0.019679
batch 5762: loss 0.058995
batch 5763: loss 0.028026
batch 5764: loss 0.084823
batch 5765: loss 0.019581
batch 5766: loss 0.055665
batch 5767: 

In [16]:
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)
for batch_index in range(num_batches):
    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    y_pred = model.predict(data_loader.test_data[start_index: end_index])
    sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index],y_pred=y_pred)
print("test accuracy is: %f" % sparse_categorical_accuracy.result())

test accuracy is: 0.972600


tf.keras.applications有预设好的经典卷积神经网络

In [17]:
modelNetV2 = tf.keras.applications.MobileNetV2()

In [18]:
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

batch 0: loss 0.097656
batch 1: loss 0.024151
batch 2: loss 0.060104
batch 3: loss 0.061753
batch 4: loss 0.032853
batch 5: loss 0.035715
batch 6: loss 0.057404
batch 7: loss 0.015279
batch 8: loss 0.016809
batch 9: loss 0.030208
batch 10: loss 0.019344
batch 11: loss 0.022033
batch 12: loss 0.035137
batch 13: loss 0.063398
batch 14: loss 0.006787
batch 15: loss 0.019287
batch 16: loss 0.016559
batch 17: loss 0.046116
batch 18: loss 0.082925
batch 19: loss 0.069273
batch 20: loss 0.054044
batch 21: loss 0.028652
batch 22: loss 0.011470
batch 23: loss 0.083892
batch 24: loss 0.018970
batch 25: loss 0.174090
batch 26: loss 0.066221
batch 27: loss 0.014929
batch 28: loss 0.020009
batch 29: loss 0.018535
batch 30: loss 0.153087
batch 31: loss 0.038780
batch 32: loss 0.036895
batch 33: loss 0.018429
batch 34: loss 0.105709
batch 35: loss 0.047600
batch 36: loss 0.080029
batch 37: loss 0.007364
batch 38: loss 0.007831
batch 39: loss 0.012204
batch 40: loss 0.013503
batch 41: loss 0.009332
ba

In [19]:
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)
for batch_index in range(num_batches):
    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    y_pred = model.predict(data_loader.test_data[start_index: end_index])
    sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index],y_pred=y_pred)
print("test accuracy is: %f" % sparse_categorical_accuracy.result())

test accuracy is: 0.972500


#### RNN

In [77]:
class DataLoader():
    def __init__(self):
        path = tf.keras.utils.get_file('nietzsche.txt',
            origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
        with open(path, encoding='utf-8') as f:
            self.raw_text = f.read().lower()
        self.chars = sorted(list(set(self.raw_text)))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
        self.text = [self.char_indices[c] for c in self.raw_text]

    def get_batch(self, seq_length, batch_size):
        seq = []
        next_char = []
        for i in range(batch_size):
            index = np.random.randint(0, len(self.text) - seq_length)
            seq.append(self.text[index:index+seq_length])
            next_char.append(self.text[index+seq_length])
        return np.array(seq), np.array(next_char)       # [batch_size, seq_length], [num_batch]

In [79]:
class RNN(tf.keras.Model):
    def __init__(self, num_chars, batch_size, seq_length):
        super().__init__()
        self.num_chars = num_chars
        self.seq_length = seq_length
        self.batch_size = batch_size
        self.cell = tf.keras.layers.LSTMCell(units=256)
        self.dense = tf.keras.layers.Dense(units=self.num_chars)

    def call(self, inputs, from_logits=False):
        inputs = tf.one_hot(inputs, depth=self.num_chars)       # [batch_size, seq_length, num_chars]
        state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32)
        for t in range(self.seq_length):
            output, state = self.cell(inputs[:, t, :], state)
        logits = self.dense(output)
        if from_logits:
            return logits
        else:
            return tf.nn.softmax(logits)

In [84]:
num_batches = 100
seq_length = 40
batch_size = 50
learning_rate = 1e-3

steps:
* randomly select data from `DataLoader`
* calculate the predict value
* calculate the loss function
* calculate the dirivertive of the function. `tape.gradient`
* `optimizer.apply_gradients` to update the parameters

In [85]:
data_loader = DataLoader()

In [86]:
model = RNN(num_chars=len(data_loader.chars), batch_size=batch_size, seq_length=seq_length)

In [87]:
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(seq_length, batch_size)
    with tf.GradientTape() as tape:
        y_pred = modelRNN(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_pred=y_pred, y_true=y)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, modelRNN.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, modelRNN.variables))

batch 0: loss 2.300979
batch 1: loss 2.425753
batch 2: loss 2.269445
batch 3: loss 2.435592
batch 4: loss 2.547582
batch 5: loss 2.410161
batch 6: loss 2.166223
batch 7: loss 2.303907
batch 8: loss 2.346618
batch 9: loss 2.537049
batch 10: loss 2.499943
batch 11: loss 2.077685
batch 12: loss 2.248216
batch 13: loss 2.012670
batch 14: loss 2.319809
batch 15: loss 2.266413
batch 16: loss 2.519976
batch 17: loss 2.334644
batch 18: loss 2.093357
batch 19: loss 2.661996
batch 20: loss 1.888312
batch 21: loss 2.135050
batch 22: loss 2.407631
batch 23: loss 2.454175
batch 24: loss 2.599524
batch 25: loss 2.216748
batch 26: loss 2.301447
batch 27: loss 2.187305
batch 28: loss 2.659596
batch 29: loss 2.703160
batch 30: loss 2.532776
batch 31: loss 2.870480
batch 32: loss 2.358199
batch 33: loss 2.418718
batch 34: loss 2.341374
batch 35: loss 1.977418
batch 36: loss 2.228302
batch 37: loss 2.138907
batch 38: loss 2.348046
batch 39: loss 2.205188
batch 40: loss 2.309159
batch 41: loss 2.554825
ba

关于文本生成的过程有一点需要特别注意。之前，我们一直使用 tf.argmax() 函数，将对应概率最大的值作为预测值。然而对于文本生成而言，这样的预测方式过于绝对，会使得生成的文本失去丰富性。于是，我们使用 np.random.choice() 函数按照生成的概率分布取样。这样，即使是对应概率较小的字符，也有机会被取样到。同时，我们加入一个 temperature 参数控制分布的形状，参数值越大则分布越平缓（最大值和最小值的差值越小），生成文本的丰富度越高；参数值越小则分布越陡峭，生成文本的丰富度越低。

In [88]:
def predict(self, inputs, temperature=1.):
    batch_size, _ = tf.shape(inputs)
    logits = self(inputs, from_logits=True)
    prob = tf.nn.softmax(logits / temperature).numpy()
    return np.array([np.random.choice(self.num_chars, p=prob[i,:]) 
                     for i in range(batch_size.numpy())])

In [91]:
# X_, _ = data_loader.get_batch(seq_length, 1)
# for diversity in [0.2, 0.5, 1.0, 1.2]:
#     X = X_
#     print("diversity %f:" % diversity)
#     for t in range(400):
#         y_pred = modelRNN.predict(X, diversity)
#         print(data_loader.indices_char[y_pred[0]], end='', flush=True)
#         X = np.concatenate([X[:, 1:], np.expand_dims(y_pred, axis=1)], axis=-1)
#     print("\n")

#### Reinforcement Learning RL