In [1]:
import numpy as np 
import pandas as pd 
import tensorflow as tf

2024-04-23 10:24:35.568470: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
MAX_LEN = 300
BATCH_SIZE = 32
(x_train,y_train),(x_test,y_test) = tf.keras.datasets.reuters.load_data()


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/reuters.npz


[1m2110848/2110848[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 1us/step


In [23]:
x_train = tf.keras.utils.pad_sequences(x_train, maxlen=MAX_LEN, padding='post')
x_test = tf.keras.utils.pad_sequences(x_test, maxlen=MAX_LEN, padding='post')

MAX_WORDS = x_train.max()+1
CAT_NUM = y_train.max()+1

ds_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)) \
          .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
          .prefetch(tf.data.experimental.AUTOTUNE).cache()
   
ds_test = tf.data.Dataset.from_tensor_slices((x_test,y_test)) \
          .shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
          .prefetch(tf.data.experimental.AUTOTUNE).cache()

In [30]:
tf.keras.backend.clear_session()
def create_model():
    
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Embedding(MAX_WORDS,7))
    model.add(tf.keras.layers.Conv1D(filters = 64,kernel_size = 5,activation = "relu"))
    model.add(tf.keras.layers.MaxPool1D(2))
    model.add(tf.keras.layers.Conv1D(filters = 32,kernel_size = 3,activation = "relu"))
    model.add(tf.keras.layers.MaxPool1D(2))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(CAT_NUM,activation = "softmax"))
    return(model)

def compile_model(model):
    model.compile(optimizer=tf.keras.optimizers.Nadam(),
                loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                metrics=[tf.keras.metrics.SparseCategoricalAccuracy(),tf.keras.metrics.SparseTopKCategoricalAccuracy(5)]) 
    return(model)
 
model = create_model()
model.summary()
model = compile_model(model)

## 一 内置fit方法

In [31]:
history = model.fit(ds_train,validation_data = ds_test,epochs = 10)

Epoch 1/10


[1m281/281[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 45ms/step - loss: 2.3573 - sparse_categorical_accuracy: 0.4047 - sparse_top_k_categorical_accuracy: 0.7162 - val_loss: 1.6515 - val_sparse_categorical_accuracy: 0.5726 - val_sparse_top_k_categorical_accuracy: 0.7622
Epoch 2/10
[1m281/281[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 42ms/step - loss: 1.5373 - sparse_categorical_accuracy: 0.6102 - sparse_top_k_categorical_accuracy: 0.7833 - val_loss: 1.5249 - val_sparse_categorical_accuracy: 0.6144 - val_sparse_top_k_categorical_accuracy: 0.7925
Epoch 3/10
[1m281/281[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 43ms/step - loss: 1.2061 - sparse_categorical_accuracy: 0.6865 - sparse_top_k_categorical_accuracy: 0.8550 - val_loss: 1.5744 - val_sparse_categorical_accuracy: 0.6394 - val_sparse_top_k_categorical_accuracy: 0.8170
Epoch 4/10
[1m281/281[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 40ms/step - loss: 0.8920 - sparse_categorical_accur

## 二，内置train_on_batch方法
该内置方法相比较fit方法更加灵活，可以不通过回调函数而直接在批次层次上更加精细地控制训练的过程。

In [35]:
def train_model(model,ds_train,ds_valid,epoches):

    for epoch in tf.range(1,epoches+1):
        model.reset_metrics()
        
        # 在后期降低学习率
        if epoch < 5:
            model.optimizer.learning_rate.assign(model.optimizer.learning_rate)
            tf.print("testing optimizer Learning Rate...\n\n")
        else:
            model.optimizer.learning_rate.assign(model.optimizer.learning_rate/2.0)
            tf.print("Lowering optimizer Learning Rate...\n\n")
        for x, y in ds_train:
            train_result = model.train_on_batch(x, y,return_dict=True)

        for x, y in ds_valid:
            valid_result = model.test_on_batch(x, y,return_dict=True)
            
        if epoch%1 ==0:
            tf.print("epoch = ",epoch)
            #print("train:",dict(zip(model.metrics_names,train_result)))
            #print("valid:",dict(zip(model.metrics_names,valid_result)))
            print("train:", train_result)
            print("valid:", valid_result)
train_model(model,ds_train,ds_test,10)

testing optimizer Learning Rate...




2024-04-23 11:39:33.342761: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


epoch =  1
train: {'loss': array(0.12660909, dtype=float32), 'sparse_categorical_accuracy': array(0.95479846, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.9997773, dtype=float32)}
valid: {'loss': array(0.829089, dtype=float32), 'sparse_categorical_accuracy': array(0.87887424, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.96099037, dtype=float32)}
testing optimizer Learning Rate...




2024-04-23 11:39:52.626016: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 11:43:33.242365: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


epoch =  2
train: {'loss': array(0.12136879, dtype=float32), 'sparse_categorical_accuracy': array(0.9540191, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.9997773, dtype=float32)}
valid: {'loss': array(0.856765, dtype=float32), 'sparse_categorical_accuracy': array(0.8768258, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.9609013, dtype=float32)}
testing optimizer Learning Rate...




2024-04-23 11:43:52.768618: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 11:48:57.172990: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


epoch =  3
train: {'loss': array(0.11697056, dtype=float32), 'sparse_categorical_accuracy': array(0.9539078, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.9997773, dtype=float32)}
valid: {'loss': array(0.87555087, dtype=float32), 'sparse_categorical_accuracy': array(0.87468827, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.96099037, dtype=float32)}
testing optimizer Learning Rate...




2024-04-23 11:49:28.020061: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 11:55:05.783566: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


epoch =  4
train: {'loss': array(0.1108659, dtype=float32), 'sparse_categorical_accuracy': array(0.9558005, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.999666, dtype=float32)}
valid: {'loss': array(0.8986177, dtype=float32), 'sparse_categorical_accuracy': array(0.8762914, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.96099037, dtype=float32)}
Lowering optimizer Learning Rate...




2024-04-23 11:56:26.988796: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 12:01:17.241393: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 12:01:53.663343: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


epoch =  5
train: {'loss': array(0.10228459, dtype=float32), 'sparse_categorical_accuracy': array(0.9613672, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.99988866, dtype=float32)}
valid: {'loss': array(0.9096785, dtype=float32), 'sparse_categorical_accuracy': array(0.88822585, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.96205914, dtype=float32)}
Lowering optimizer Learning Rate...




2024-04-23 12:06:23.593160: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


epoch =  6
train: {'loss': array(0.09321329, dtype=float32), 'sparse_categorical_accuracy': array(0.9618125, dtype=float32), 'sparse_top_k_categorical_accuracy': array(1., dtype=float32)}
valid: {'loss': array(0.87403274, dtype=float32), 'sparse_categorical_accuracy': array(0.8924118, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.96250445, dtype=float32)}
Lowering optimizer Learning Rate...




2024-04-23 12:06:45.973081: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 12:10:29.730996: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


epoch =  7
train: {'loss': array(0.08481575, dtype=float32), 'sparse_categorical_accuracy': array(0.9638165, dtype=float32), 'sparse_top_k_categorical_accuracy': array(1., dtype=float32)}
valid: {'loss': array(0.86022115, dtype=float32), 'sparse_categorical_accuracy': array(0.89330244, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.9625935, dtype=float32)}
Lowering optimizer Learning Rate...




2024-04-23 12:10:46.698953: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 12:14:48.342268: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


epoch =  8
train: {'loss': array(0.07914669, dtype=float32), 'sparse_categorical_accuracy': array(0.96526384, dtype=float32), 'sparse_top_k_categorical_accuracy': array(1., dtype=float32)}
valid: {'loss': array(0.8632489, dtype=float32), 'sparse_categorical_accuracy': array(0.8948165, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.9625935, dtype=float32)}
Lowering optimizer Learning Rate...




2024-04-23 12:22:09.934030: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 12:27:23.129998: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


epoch =  9
train: {'loss': array(0.07376754, dtype=float32), 'sparse_categorical_accuracy': array(0.96559787, dtype=float32), 'sparse_top_k_categorical_accuracy': array(1., dtype=float32)}
valid: {'loss': array(0.863851, dtype=float32), 'sparse_categorical_accuracy': array(0.89535093, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.96223724, dtype=float32)}
Lowering optimizer Learning Rate...




2024-04-23 12:27:41.761622: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 12:33:49.262331: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


epoch =  10
train: {'loss': array(0.07008261, dtype=float32), 'sparse_categorical_accuracy': array(0.9661545, dtype=float32), 'sparse_top_k_categorical_accuracy': array(1., dtype=float32)}
valid: {'loss': array(0.8640161, dtype=float32), 'sparse_categorical_accuracy': array(0.8957072, dtype=float32), 'sparse_top_k_categorical_accuracy': array(0.9621482, dtype=float32)}


2024-04-23 12:34:26.910381: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


## 三，自定义训练循环
自定义训练循环无需编译模型，直接利用优化器根据损失函数反向传播迭代参数，拥有最高的灵活性。

In [37]:
optimizer = tf.keras.optimizers.Nadam()
loss_func = tf.keras.losses.SparseCategoricalCrossentropy()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')

@tf.function
def train_step(model, features, labels):
    with tf.GradientTape() as tape:
        predictions = model(features,training = True)
        loss = loss_func(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss.update_state(loss)
    train_metric.update_state(labels, predictions)
    

@tf.function
def valid_step(model, features, labels):
    predictions = model(features)
    batch_loss = loss_func(labels, predictions)
    valid_loss.update_state(batch_loss)
    valid_metric.update_state(labels, predictions)
    

def train_model(model,ds_train,ds_valid,epochs):
    for epoch in tf.range(1,epochs+1):
        
        for features, labels in ds_train:
            train_step(model,features,labels)

        for features, labels in ds_valid:
            valid_step(model,features,labels)

        logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
        
        if epoch%1 ==0:
            tf.print(tf.strings.format(
                logs,
                (epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())
            ))
            tf.print("")
            
        train_loss.reset_state()
        valid_loss.reset_state()
        train_metric.reset_state()
        valid_metric.reset_state()

train_model(model,ds_train,ds_test,10)



2024-04-23 13:05:07.247657: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch=1,Loss:0.106880799,Accuracy:0.959251821,Valid Loss:4.28723955,Valid Accuracy:0.562778294



2024-04-23 13:05:08.182491: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 13:05:16.623141: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch=2,Loss:0.0981276706,Accuracy:0.962146521,Valid Loss:4.33151579,Valid Accuracy:0.57435441



2024-04-23 13:05:18.793666: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 13:05:26.847513: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch=3,Loss:0.0937262177,Accuracy:0.962146521,Valid Loss:4.46681833,Valid Accuracy:0.573463917



2024-04-23 13:05:27.236776: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 13:05:35.346675: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch=4,Loss:0.0959016,Accuracy:0.962480545,Valid Loss:4.64248037,Valid Accuracy:0.567675889



2024-04-23 13:05:35.733429: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 13:05:43.538831: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch=5,Loss:0.0910790786,Accuracy:0.963927865,Valid Loss:4.90510798,Valid Accuracy:0.575244904



2024-04-23 13:05:43.941590: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 13:05:52.328652: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch=6,Loss:0.0933086872,Accuracy:0.962925851,Valid Loss:5.01254797,Valid Accuracy:0.568121076



2024-04-23 13:05:52.744401: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 13:06:01.219411: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch=7,Loss:0.0960104,Accuracy:0.961923838,Valid Loss:4.93034363,Valid Accuracy:0.569902062



2024-04-23 13:06:01.604873: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 13:06:09.843056: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch=8,Loss:0.0885622576,Accuracy:0.963259876,Valid Loss:4.97358465,Valid Accuracy:0.570792496



2024-04-23 13:06:10.238906: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 13:06:18.209178: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch=9,Loss:0.0884829164,Accuracy:0.962369204,Valid Loss:5.22685385,Valid Accuracy:0.570792496



2024-04-23 13:06:18.610751: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-23 13:06:26.674296: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch=10,Loss:0.0899516,Accuracy:0.960921824,Valid Loss:5.42199,Valid Accuracy:0.561442554



2024-04-23 13:06:29.304567: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
