##### Copyright 2019 The TensorFlow Authors.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TensorFlow Addons Callbacks: TQDM

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/addons/blob/master/examples/template.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/addons/blob/master/examples/template.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Overview
This notebook will demonstrate how to use TQDMCallback in TensorFlow Addons.

## Setup

In [2]:
!pip install -q tqdm

!pip install -q ipywidgets
!jupyter nbextension enable --py widgetsnbextension --sys-prefix

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [3]:
import tensorflow as tf
import tensorflow_addons as tfa

import tensorflow.keras as keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten

## Import and Normalize Data

In [4]:
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize data
x_train, x_test = x_train / 255.0, x_test / 255.0

## Build Simple MNIST CNN Model

In [5]:
# build the model using the Sequential API
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))

model.compile(optimizer='adam',
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

## TQDMCallback example usage 1

In [6]:
# initialize tqdm callback with default parameters
tqdm_callback = tfa.callbacks.TQDMCallback()

# train the model with tqdm_callback
# make sure to set verbose = 0 to disable
# the default progress bar.
model.fit(x_train, y_train,
          batch_size=64,
          epochs=10,
          verbose=0,
          callbacks=[tqdm_callback],
          validation_data=(x_test, y_test))

HBox(children=(IntProgress(value=0, description='Training', layout=Layout(flex='2'), max=10, style=ProgressSty…

HBox(children=(IntProgress(value=0, description='Epoch: 0', layout=Layout(flex='2'), max=60000, style=Progress…




HBox(children=(IntProgress(value=0, description='Epoch: 1', layout=Layout(flex='2'), max=60000, style=Progress…




HBox(children=(IntProgress(value=0, description='Epoch: 2', layout=Layout(flex='2'), max=60000, style=Progress…




HBox(children=(IntProgress(value=0, description='Epoch: 3', layout=Layout(flex='2'), max=60000, style=Progress…




HBox(children=(IntProgress(value=0, description='Epoch: 4', layout=Layout(flex='2'), max=60000, style=Progress…




HBox(children=(IntProgress(value=0, description='Epoch: 5', layout=Layout(flex='2'), max=60000, style=Progress…




HBox(children=(IntProgress(value=0, description='Epoch: 6', layout=Layout(flex='2'), max=60000, style=Progress…




HBox(children=(IntProgress(value=0, description='Epoch: 7', layout=Layout(flex='2'), max=60000, style=Progress…




HBox(children=(IntProgress(value=0, description='Epoch: 8', layout=Layout(flex='2'), max=60000, style=Progress…




HBox(children=(IntProgress(value=0, description='Epoch: 9', layout=Layout(flex='2'), max=60000, style=Progress…





<tensorflow.python.keras.callbacks.History at 0x109abbb38>

In [7]:
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Test loss: 0.06480617881095968
Test accuracy: 0.9803
