In [33]:
import tensorflow as tf
import datetime

In [34]:
(train_image,train_labels),(test_image,test_labels)=tf.keras.datasets.mnist.load_data()

train_image=tf.expand_dims(train_image,-1)
test_image=tf.expand_dims(test_image,-1)

train_image = tf.cast(train_image/255,tf.float32)
test_image = tf.cast(test_image/255,tf.float32)

train_labels = tf.cast(train_labels,tf.int64)
test_labels = tf.cast(test_labels,tf.int64)


In [35]:
dataset = tf.data.Dataset.from_tensor_slices((train_image,train_labels))
dataset = dataset.repeat().shuffle(60000).batch(128)

test_dataset = tf.data.Dataset.from_tensor_slices((test_image,test_labels))
test_dataset=test_dataset.repeat().batch(128)

In [36]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16,[3,3],activation='relu',input_shape=(None,None,1)),
    tf.keras.layers.Conv2D(32,[3,3],activation='relu'),
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(10,activation='softmax')
])

In [37]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [38]:
import os
log_dir = os.path.join('logs',datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [39]:
tensorboard_callback=tf.keras.callbacks.TensorBoard(log_dir,histogram_freq=1)

In [None]:
file_writer = tf.summary.create_file_writer(log_dir + '/lr')
file_writer.set_as_default()

In [None]:
def lr_sche(epoch):
    learning_rate = 0.2
    if epoch>5:
        learning_rate = 0.02
    if epoch>10:
        learning_rate = 0.01
    if epoch>20:
        learning_rate = 0.005
    tf.summary.scalar('learning_rate',data=learning_rate,step=epoch)
    return learning_rate

In [None]:
lr_callback=tf.keras.callbacks.LearningRateScheduler(lr_sche())

In [40]:
model.fit(dataset,
          epochs=25,
          steps_per_epoch=60000//128,
          validation_data=test_dataset,
          validation_steps=10000//128,
          callbacks=[tensorboard_callback,lr_callback]
)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x258d3d82d30>

In [41]:
# %load_ext tensorboard
# %matplotlib inline

In [42]:
# %tensorboard --logdir logs

Launching TensorBoard...

自定义训练中使用tensorboard




In [None]:
optimizer=tf.keras.optimizers.Adam()
loss_fun=tf.keras.losses.SparseCategoricalCrossentropy()

In [None]:
def loss(model,x,y):
    y_ = model(x)
    return loss_fun(y,y_)

In [None]:
train_loss=tf.keras.metrics.Mean('train_loss')
train_accuracy=tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')
test_loss=tf.keras.metrics.Mean('test_loss')
test_accuracy=tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')


In [None]:
def train_step(model,images,labels):
    with tf.GradientTape() as t:
        pred = model(images)
        loss_step=loss_fun(labels,pred)
        #loss_step = loss(model,images,labels)

    grads = t.gradient(loss_step,model.trainable_variables)
    optimizer.apply_gradients(zip(grads,model.trainable_variables))
    train_loss(loss_step)
    train_accuracy(labels,pred)

In [None]:
def test_step(model,images,labels):
    pred = model(images)
    loss_step=loss_fun(labels,pred)
    test_loss(loss_step)
    test_accuracy(labels,pred)

In [None]:
current_time=datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir='logs/gradient_tape'+current_time+'/train'
test_log_dir='logs/gradient_tape'+current_time+'/test'

train_writer=tf.summary.create_file_writer(train_log_dir)
test_writer=tf.summary.create_file_writer(test_log_dir)


In [None]:
def train():
    for epoch in range(10):
        for (batch,(images,labels)) in enumerate(dataset):
            train_step(model,images,labels)
        with train_writer.as_default():
            tf.summary.scalar('loss',train_loss.result(),step=epoch)
            tf.summary.scalar('acc',train_accuracy.result(),step=epoch)

        for (batch,(images,labels)) in enumerate(test_dataset):
            test_step(model,images,labels)
            print("*",end='')
        with test_writer.as_default():
            tf.summary.scalar('loss',test_loss.result(),step=epoch)
            tf.summary.scalar('acc',test_accuracy.result(),step=epoch)
        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()

In [None]:
train()
