In [2]:
import tensorflow as tf
import numpy as np

import os

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
print(tf.__version__)

2.8.0


In [3]:
DATA_DIR = './data'

# Read Data

In [4]:
train_list = filter(lambda f: f.startswith('data_') and f.endswith('.bin'), os.listdir(DATA_DIR))
train_path_list = list(map(lambda f: os.path.join(DATA_DIR, f), train_list))

test_list = filter(lambda f: f.startswith('test_') and f.endswith('.bin'), os.listdir(DATA_DIR))
test_path_list = list(map(lambda f: os.path.join(DATA_DIR, f), test_list))

print(train_path_list)
print(test_path_list)

['./data\\data_batch_1.bin', './data\\data_batch_2.bin', './data\\data_batch_3.bin', './data\\data_batch_4.bin', './data\\data_batch_5.bin']
['./data\\test_batch.bin']


In [5]:
LABEL_BYTE = 1
IMAGE_BYTES = 32 * 32 * 3
RECORD_BYTES = LABEL_BYTE + IMAGE_BYTES

In [6]:
def _load_dataset(data_path_list):
    def _process_record(record):
        value = tf.io.decode_raw(record, tf.uint8)
        label = value[0]
        image = value[1:]
        image = tf.reshape(image, (3, 32, 32))
        image = tf.transpose(image, (1, 2, 0))
        image = tf.cast(image, tf.float32)
        image = image / 255
        return image, label

    dataset = tf.data.FixedLengthRecordDataset(
        data_path_list,
        RECORD_BYTES)
    return dataset.map(_process_record)

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

In [7]:
# Create an instance of the model
model = MyModel()

In [8]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

In [9]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

@tf.function
def test_step(images, labels):
    predictions = model(images)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)

In [10]:
EPOCHS = 5
SUMMARY_DIR = './summary'

TRAIN_BATCH_SIZE = 32
TEST_BATCH_SIZE = 32

In [11]:
import time

In [12]:
train_dataset = _load_dataset(train_path_list).batch(TRAIN_BATCH_SIZE)
test_dataset = _load_dataset(test_path_list).batch(TEST_BATCH_SIZE)

summary_writer = tf.summary.create_file_writer(SUMMARY_DIR)

In [13]:
for epoch in range(EPOCHS):
    start = time.time()
    for images, labels in train_dataset:
        train_step(images, labels)

    for test_images, test_labels in test_dataset:
        test_step(test_images, test_labels)

    elapsed = time.time() - start
    print('elapsed: %f' % elapsed)

    template = 'Epoch {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1,
                        test_loss.result(),
                        test_accuracy.result()*100))

    # Reset the metrics for the next epoch
    test_loss.reset_states()
    test_accuracy.reset_states()

print('Training Finished.')

elapsed: 23.928967
Epoch 1, Test Loss: 1.25106680393219, Test Accuracy: 55.47999572753906
elapsed: 23.623842
Epoch 2, Test Loss: 1.2054194211959839, Test Accuracy: 57.59000015258789
elapsed: 24.076629
Epoch 3, Test Loss: 1.2196234464645386, Test Accuracy: 57.970001220703125
elapsed: 23.672711
Epoch 4, Test Loss: 1.2204262018203735, Test Accuracy: 58.82999801635742
elapsed: 23.587936
Epoch 5, Test Loss: 1.3552522659301758, Test Accuracy: 57.480003356933594
Training Finished.


In [20]:
model.save_weights('./checkpoints/cifa10_checkpoint')