In [1]:
import tensorflow as tf
from tensorflow import keras

网络训练过程常需要统计准确率、召回率等指标，Keras提供了常用的测量工具，位于`keras.metrics`模块中

Keras测量工具使用步骤如下：
1. 新建测量器
2. 写入数据
3. 读取统计数据
4. 清零测量器

`keras.metrics`提供了较多常用测量器类，如统计均值的Mean类、准确率的Accuracy类、余弦相似度的CosineSimilarity类等。

## 1.均值测量器例子
以Mean测量器为例，我们统计每个step的平均误差：

In [2]:
def preprocess(x,y):
    # 调用此函数会自动传入x，y
    # 标准化到0~1
    x=tf.cast(x,dtype=tf.float32)/255.
    x=tf.reshape(x,[-1,28*28]) # 打平
    y=tf.cast(y,dtype=tf.int32) # 转换成整型张量
    y=tf.one_hot(y,depth=10) # 进行one-hot编码
    return x,y

def load_data():
    # 加载MNIST
    (x,y),(x_val,y_val)=keras.datasets.mnist.load_data()
    batchsz=512
    # 构建数据集对象
    train_dataset=tf.data.Dataset.from_tensor_slices((x,y))
    train_dataset=train_dataset.shuffle(1000)
    #批量训练
    train_dataset=train_dataset.batch(batchsz)
    train_dataset=train_dataset.map(preprocess)
    train_dataset=train_dataset.repeat(20)

    # 加载验证/测试集
    val_dataset=tf.data.Dataset.from_tensor_slices((x_val,y_val))
    val_dataset=val_dataset.shuffle(1000).batch(batchsz).map(preprocess)
    return train_dataset,val_dataset

class MyModel(keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1=keras.layers.Dense(256,activation='relu')
        self.fc2=keras.layers.Dense(128,activation='relu')
        self.fc3=keras.layers.Dense(10,activation='relu')

    def call(self, inputs, training=None, mask=None):
        x=self.fc1(inputs)
        x=self.fc2(x)
        x=self.fc3(x)
        return x

model=MyModel()
model.build(input_shape=(None,28*28))

optimizer=keras.optimizers.RMSprop(0.001)
train_dataset,val_dataset=load_data()
loss_meter=keras.metrics.Mean() # 1.新建平均测量器

for epoch in range(2):
    for step,(x,y) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            out=model(x)
            loss=tf.losses.MSE(y,out)
            mean_mse_loss=tf.reduce_mean(loss)
            loss_meter.update_state(loss) # 2.写入数据，记录采样的数据
        grads=tape.gradient(mean_mse_loss,model.trainable_variables)
        optimizer.apply_gradients(zip(grads,model.trainable_variables))

        if step%500==0:
            print(f'epoch[{epoch}]-step[{step}] loss:{loss_meter.result()}') # 3.读取统计信息
            loss_meter.reset_states() # 4.清零测量器


epoch[0]-step[0] loss:0.13442061841487885
epoch[0]-step[500] loss:0.032660964876413345
epoch[0]-step[1000] loss:0.02246745117008686
epoch[0]-step[1500] loss:0.021138837561011314
epoch[0]-step[2000] loss:0.020540548488497734
epoch[1]-step[0] loss:0.020303547382354736
epoch[1]-step[500] loss:0.020146848633885384
epoch[1]-step[1000] loss:0.020012257620692253
epoch[1]-step[1500] loss:0.019957510754466057
epoch[1]-step[2000] loss:0.019898351281881332


## 2.准确率统计例子

In [3]:
model2=MyModel()
model2.build(input_shape=(None,28*28))

optimizer2=keras.optimizers.RMSprop(0.001)
acc_meter=keras.metrics.Accuracy()
loss_meter2=keras.metrics.Mean()

for epoch in range(2):
    for step,(x,y) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            out=model2(x)
            loss=tf.losses.MSE(y,out)
            mean_mse_loss=tf.reduce_mean(loss)
            loss_meter2.update_state(mean_mse_loss)

        grads=tape.gradient(mean_mse_loss,model2.trainable_variables)
        optimizer2.apply_gradients(zip(grads,model2.trainable_variables))

        if step%500==0:
            print(f'TRAIN => epoch[{epoch}]-step[{step}] loss: {loss_meter2.result()}')

        if step%500==0:
            total,total_correct=0,0
            acc_meter.reset_states()

            for step,(x,y) in enumerate(val_dataset):
                x=tf.reshape(x,(-1,28*28))
                out=model2(x)
                # 预测
                pred=tf.argmax(out,axis=1)
                pred=tf.cast(pred,dtype=tf.int32)
                y=tf.argmax(y,axis=1) # one-hot转为数字标签
                correct=tf.equal(pred,0)
                total_correct+=tf.reduce_sum(tf.cast(correct,dtype=tf.int32)).numpy()
                total+=x.shape[0]
                acc_meter.update_state(y,pred)
            print(f'EVAL => epoch[{epoch}]-step[{step}] acc: {acc_meter.result().numpy()}')



TRAIN => epoch[0]-step[0] loss: 0.1015140637755394
EVAL => epoch[0]-step[19] acc: 0.19499999284744263
TRAIN => epoch[0]-step[500] loss: 0.015589861199259758
EVAL => epoch[0]-step[19] acc: 0.9779000282287598
TRAIN => epoch[0]-step[1000] loss: 0.009537935256958008
EVAL => epoch[0]-step[19] acc: 0.9804999828338623
TRAIN => epoch[0]-step[1500] loss: 0.007007712032645941
EVAL => epoch[0]-step[19] acc: 0.9830999970436096
TRAIN => epoch[0]-step[2000] loss: 0.005563453771173954
EVAL => epoch[0]-step[19] acc: 0.984000027179718
TRAIN => epoch[1]-step[0] loss: 0.004849955905228853
EVAL => epoch[1]-step[19] acc: 0.9825999736785889
TRAIN => epoch[1]-step[500] loss: 0.004123594146221876
EVAL => epoch[1]-step[19] acc: 0.9815999865531921
TRAIN => epoch[1]-step[1000] loss: 0.003591967048123479
EVAL => epoch[1]-step[19] acc: 0.9822999835014343
TRAIN => epoch[1]-step[1500] loss: 0.003186717862263322
EVAL => epoch[1]-step[19] acc: 0.9836000204086304
TRAIN => epoch[1]-step[2000] loss: 0.0028664427809417248

In [None]:
import os
pid=os.getpid()
!kill -9 $pid
