In [1]:
import os
import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras import layers
from datetime import datetime
%matplotlib inline
%load_ext tensorboard

### Hyperparameter

In [2]:
num_epochs = 10
batch_size = 32
learning_rate = 0.001
dropout_rate = 0.5
input_shape = (32, 32, 3)
num_classes = 10

### Build Model

In [3]:
inputs = layers.Input(input_shape)
l = layers.Conv2D(32, (3, 3), padding='SAME')(inputs)
l = layers.Activation('relu')(l)
l = layers.Conv2D(32, (3, 3), padding='SAME')(l)
l = layers.Activation('relu')(l)
l = layers.MaxPool2D(pool_size=(2, 2))(l)
l = layers.Dropout(dropout_rate)(l)

l = layers.Conv2D(64, (3, 3), padding='SAME')(l)
l = layers.Activation('relu')(l)
l = layers.Conv2D(64, (3, 3), padding='SAME')(l)
l = layers.Activation('relu')(l)
l = layers.MaxPool2D(pool_size=(2, 2))(l)
l = layers.Dropout(dropout_rate)(l)

l = layers.Flatten()(l)
l = layers.Dense(512)(l)
l = layers.Activation('relu')(l)
l = layers.Dense(128)(l)
l = layers.Activation('relu')(l)
l = layers.Dropout(dropout_rate)(l)
l = layers.Dense(num_classes)(l)
outputs = layers.Activation('softmax')(l)

model = tf.keras.Model(inputs=inputs, outputs=outputs, name='cnnv1')

In [6]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

### Data Preprocessing

In [7]:
os.listdir('dataset/cifar/')

['.DS_Store', 'test', 'labels.txt', 'train']

In [8]:
train_dataset_path = glob.glob('dataset/cifar/train/*.png')[:1000]
test_dataset_path = glob.glob('dataset/cifar/test/*.png')[:1000]
len(train_dataset_path), len(test_dataset_path)

(1000, 1000)

In [9]:
def get_class_name(path):
    return path.split('_')[-1].replace('.png', '')

In [10]:
class_category = [get_class_name(path) for path in train_dataset_path]
class_category = np.unique(class_category)
class_category

array(['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
       'horse', 'ship', 'truck'], dtype='<U10')

In [11]:
def get_sparse_label(path):
    global class_category
    fname = tf.strings.split(path, '_')[-1]
    cls_name = tf.strings.regex_replace(fname, '.png', '')
    onehot = tf.cast(class_category == cls_name, tf.uint8)
    return tf.argmax(onehot)

def read_dataset(path):
    gfile = tf.io.read_file(path)
    image = tf.io.decode_image(gfile)
    image = tf.cast(image, tf.float32) / 255.
    label = get_sparse_label(path)
    return image, label

def image_preprocess(image, label):
    image = tf.image.random_flip_left_right(image)
    return image, label


In [12]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [13]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_dataset_path)
train_dataset = train_dataset.map(read_dataset, num_parallel_calls=AUTOTUNE)
train_dataset = train_dataset.map(image_preprocess, num_parallel_calls=AUTOTUNE)
train_dataset = train_dataset.batch(batch_size)
train_dataset = train_dataset.shuffle(len(train_dataset_path))
train_dataset = train_dataset.repeat()

In [14]:
test_dataset = tf.data.Dataset.from_tensor_slices(train_dataset_path)
test_dataset = test_dataset.map(read_dataset, num_parallel_calls=AUTOTUNE)
test_dataset = test_dataset.batch(batch_size)
test_dataset = test_dataset.repeat()

# Callbacks

## TensorBoard

In [59]:
logdir = os.path.join(os.getcwd()+'/logs', datetime.now().strftime('%Y%m%d-%H%M%S'))
logdir

'/Users/slidemorning/Slideworkspace/Github/tensorflow-2.0/logs/20201106-184441'

In [60]:
tensorboard = tf.keras.callbacks.TensorBoard(
    log_dir=logdir,
    write_graph=True,
    write_images=True,
    histogram_freq=1,
)

In [61]:
%tensorboard --logdir=/Users/slidemorning/Slideworkspace/Github/tensorflow-2.0/logs/

### Training

In [63]:
model.fit_generator(
    train_dataset,
    steps_per_epoch=len(train_dataset_path) // batch_size,
    validation_data=test_dataset,
    validation_steps=len(test_dataset_path) // batch_size,
    epochs=num_epochs,
    callbacks=[tensorboard]
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fcaa9345950>

## Lambda