# MNIST with Single Layer NN

## Import MNIST dataset

In [1]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
import tensorflow as tf

In [3]:
tf.convert_to_tensor(mnist.train.images).get_shape()

TensorShape([Dimension(55000), Dimension(784)])

## MNIST training dataset
* 총 55,000장
* 가로, 세로 길이 = 28
* 데이터는 28 * 28 = 784 길이의 배열로 담겨있음

## Build flow of tensors

In [4]:
import tensorflow as tf

W = tf.Variable(tf.zeros([784, 10]))
x = tf.placeholder('float', [None, 784])
b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x, W) + b)

TensorFlow는 NumPy의 `shape broadcasting rule`을 따릅니다.

여기서 x*W의 shape는 `[n, 10]` b는 `[10]`인데, b의 한 dimension 크기가 `1`이므로 b가 복제되어 shape가 `[n, 10]`으로 동일해집니다.

## Print shapes

In [6]:
print('W', W.get_shape())
print('x', x.get_shape())
print('x*W', tf.matmul(x, W).get_shape())
print('b', b.get_shape())
print('y', y.get_shape())

W (784, 10)
x (?, 784)
x*W (?, 10)
b (10,)
y (?, 10)


## Cost function: Cross Entropy

In [7]:
y_label = tf.placeholder('float', [None, 10])
cross_entropy = -tf.reduce_sum(y_label * tf.log(y))

## Optimizer: Gradient Descent

In [8]:
optimizer = tf.train.GradientDescentOptimizer(0.01)

## Define `train` operation

In [None]:
train = optimizer.minimize(cross_entropy)

## Create TensorFlow session & initialize all variables

In [25]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())

## Train

In [28]:
for i in range(1000):
    # Mini-Batch
    batch_xs, batch_ys = mnist.train.next_batch(100)
    
    # Train batch
    sess.run(train, feed_dict={x: batch_xs, y_label: batch_ys})
    
    # Calcurate accuracy
    correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_label, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    
    print(sess.run(accuracy, feed_dict={
                x: mnist.test.images, 
                y_label: mnist.test.labels
            }))

0.3824
0.3761
0.4591
0.5659
0.5737
0.6713
0.6828
0.7065
0.6812
0.6902
0.6198
0.6732
0.6742
0.7754
0.8117
0.7971
0.803
0.7916
0.6658
0.7178
0.7927
0.775
0.7929
0.8471
0.8283
0.8206
0.8109
0.7707
0.8242
0.835
0.797
0.8311
0.8164
0.8055
0.8535
0.8468
0.7966
0.7604
0.8626
0.8711
0.8676
0.8588
0.8349
0.8406
0.8477
0.85
0.8633
0.866
0.8486
0.8582
0.8748
0.7996
0.8495
0.8455
0.8599
0.8659
0.8716
0.8861
0.8836
0.8817
0.8737
0.8452
0.8729
0.8741
0.8605
0.8748
0.8598
0.8579
0.8234
0.8446
0.7972
0.8685
0.8363
0.8769
0.8234
0.85
0.8799
0.8836
0.8887
0.8627
0.8748
0.8768
0.8847
0.8883
0.8899
0.8825
0.8878
0.8815
0.8709
0.8036
0.8271
0.8428
0.8556
0.8611
0.8524
0.7729
0.8128
0.8862
0.8917
0.8944
0.8692
0.8741
0.8672
0.8818
0.8963
0.8915
0.8956
0.8925
0.8887
0.8947
0.8847
0.8948
0.8932
0.8683
0.8849
0.8959
0.8971
0.8994
0.891
0.897
0.8976
0.8934
0.8935
0.8707
0.8804
0.8958
0.8853
0.8983
0.8912
0.8963
0.9
0.8763
0.9003
0.8613
0.9008
0.8965
0.902
0.9023
0.903
0.9022
0.8998
0.8799
0.8836
0.8821
0.8527
0