In [1]:
import tensorflow as tf

# Dataset Load

In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [3]:
x_train.shape

(60000, 28, 28)

In [4]:
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

In [5]:
x_train.shape

(60000, 28, 28, 1)

In [6]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

# Define Model

In [7]:
# pytorch: forward function == tensorflow: call function

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.d1 = tf.keras.layers.Dense(64, activation='relu')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        x = self.d2(x)

        return x

In [8]:
model = MyModel()

In [9]:
model.build(input_shape=(64, 28, 28, 1))
model.summary()

Model: "my_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             multiple                  320       
                                                                 
 flatten (Flatten)           multiple                  0         
                                                                 
 dense (Dense)               multiple                  1384512   
                                                                 
 dense_1 (Dense)             multiple                  650       
                                                                 
Total params: 1,385,482
Trainable params: 1,385,482
Non-trainable params: 0
_________________________________________________________________


# Loss, Optimizer

In [10]:
criterion = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

### 모델의 성능과 loss를 측정할 지표를 선택한다. (epoch 마다 해당 지표를 바탕으로 결과 출력)

In [11]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

# Model Training
- ```tf.GradientTape```를 사용하여 모델을 훈련시킨다.


In [12]:
# training

@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        preds = model(images, training=True)
        loss = criterion(labels, preds)
    
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))  # == pytorch: optimizer.step()

    train_loss(loss)
    train_accuracy(labels, preds)

In [13]:
# test

@tf.function
def test_step(images, labels):
    preds = model(images, training=False)
    loss = criterion(labels, preds)

    test_loss(loss)
    test_accuracy(labels, preds)

In [14]:
NUM_EPOCHS = 5

for epoch in range(NUM_EPOCHS):
    train_loss.reset_states()
    train_accuracy.reset_states()
    
    test_loss.reset_states()
    test_accuracy.reset_states()

    for images, labels in train_ds:
        train_step(images, labels)

    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels)

    print(
        f"Epoch {epoch + 1}, "
        f"Loss: {train_loss.result()}, "
        f"Accuracy: {train_accuracy.result() * 100}, "
        f"Test Loss: {test_loss.result()}, "
        f"Test Accuracy: {test_accuracy.result() * 100}"
    )

  return dispatch_target(*args, **kwargs)


Epoch 1, Loss: 0.1542178988456726, Accuracy: 95.37833404541016, Test Loss: 0.07374430447816849, Test Accuracy: 97.7199935913086
Epoch 2, Loss: 0.05163539573550224, Accuracy: 98.43333435058594, Test Loss: 0.05618084594607353, Test Accuracy: 98.1500015258789
Epoch 3, Loss: 0.030548909679055214, Accuracy: 99.06999969482422, Test Loss: 0.056958213448524475, Test Accuracy: 98.22999572753906
Epoch 4, Loss: 0.01769600436091423, Accuracy: 99.42666625976562, Test Loss: 0.05454782396554947, Test Accuracy: 98.47000122070312
Epoch 5, Loss: 0.01296925451606512, Accuracy: 99.5616683959961, Test Loss: 0.060278791934251785, Test Accuracy: 98.29999542236328
