In [2]:
import tensorflow.compat.v2 as tf

In [3]:
import tensorflow_datasets as tfds

In [4]:
tfds.disable_progress_bar()
tf.enable_v2_behavior()

In [5]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
    try_gcs=True,
)

In [6]:
def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

In [7]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

In [8]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128,activation='relu'),
  tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer=tf.keras.optimizers.Adam(0.001),
    metrics=['accuracy'],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<tensorflow.python.keras.callbacks.History at 0x7ff50c0ab550>

In [11]:
!mkdir -p weights

In [22]:
model.save('./checkpoints/seq_model')

INFO:tensorflow:Assets written to: ./checkpoints/seq_model/assets


INFO:tensorflow:Assets written to: ./checkpoints/seq_model/assets


In [23]:
!tree

[01;34m.[00m
├── [01;34mcheckpoints[00m
│   └── [01;34mseq_model[00m
│       ├── [01;34massets[00m
│       ├── saved_model.pb
│       └── [01;34mvariables[00m
│           ├── variables.data-00000-of-00001
│           └── variables.index
├── CreateModel.ipynb
└── LoadModel.ipynb

4 directories, 5 files


In [24]:
new_model = tf.keras.models.load_model('checkpoints/seq_model')

In [25]:
type(ds_train)

tensorflow.python.data.ops.dataset_ops.PrefetchDataset

In [27]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


In [28]:
import numpy as np

In [30]:
randDataFloat = np.random.rand(1,28,28,1) * 255

In [32]:
randDataInt = randDataFloat.astype(int)

In [33]:
model.predict(randDataInt)

array([[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]], dtype=float32)