# Automatic mixed precision training

In [1]:
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.mixed_precision import experimental as mixed_precision

In [2]:
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

In [3]:
print('Compute dtype: %s' % policy.compute_dtype)
print('Variable dtype: %s' % policy.variable_dtype)

Compute dtype: float16
Variable dtype: float32


In [12]:
loss_scale = policy.loss_scale
print('Loss scale: %s' % loss_scale)

Loss scale: DynamicLossScale(current_loss_scale=32768.0, num_good_steps=60, initial_loss_scale=32768.0, increment_period=2000, multiplier=2.0)


In [4]:
inputs = keras.Input(shape=(784,), name='digits')
if tf.config.list_physical_devices('GPU'):
    print('The model will run with 4096 units on a GPU')
    num_units = 4096
else:
    # Use fewer units on CPUs so the model finishes in a reasonable amount of time
    print('The model will run with 64 units on a CPU')
    num_units = 64
    
dense1 = layers.Dense(num_units, activation='relu', name='dense_1')
x = dense1(inputs)
dense2 = layers.Dense(num_units, activation='relu', name='dense_2')
x = dense2(x)

The model will run with 4096 units on a GPU


In [5]:
print('x.dtype: %s' % x.dtype.name)
# 'kernel' is dense1's variable
print('dense1.kernel.dtype: %s' % dense1.kernel.dtype.name)

x.dtype: float16
dense1.kernel.dtype: float32


In [6]:
# INCORRECT: softmax and model output will be float16, when it should be float32
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)
print('Outputs dtype: %s' % outputs.dtype.name)

Outputs dtype: float16


In [7]:
# CORRECT: softmax and model output are float32
x = layers.Dense(10, name='dense_logits')(x)
outputs = layers.Activation('softmax', dtype='float32', name='predictions')(x)
print('Outputs dtype: %s' % outputs.dtype.name)

Outputs dtype: float32


In [8]:
# The linear activation is an identity function. So this simply casts 'outputs'
# to float32. In this particular case, 'outputs' is already float32 so this is a
# no-op.
outputs = layers.Activation('linear', dtype='float32')(outputs)

In [9]:
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop(),
              metrics=['accuracy'])

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

In [10]:
initial_weights = model.get_weights()

In [11]:
history = model.fit(x_train, y_train,
                    batch_size=4096,
                    epochs=5,
                    validation_split=0.2)
test_scores = model.evaluate(x_test, y_test, verbose=2)
print('Test loss:', test_scores[0])
print('Test accuracy:', test_scores[1])

Train on 48000 samples, validate on 12000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
10000/10000 - 1s - loss: 0.1166 - accuracy: 0.9634
Test loss: 0.11663010916002095
Test accuracy: 0.9634
