# Shallow Neural Network
- tf.keras로 구현

## 1. Import libraries

In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

## 2. Set hyperparameters

In [2]:
EPOCHS = 1000

## 3. Network Architecture

- Input layer : 2
- Hidden layer : 128 (Sigmoid)
- Output : 1 (Softmax)

In [3]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d1 = layers.Dense(128, input_dim=2, activation='sigmoid')
        self.d2 = layers.Dense(10, activation='softmax')
            
    def call(self, x, training=None, mask=None):
        x = self.d1(x)
        x = self.d2(x)
        return x

## 4. Train function

- 함수의 인자값 : model, input & label, loss_object, optimizer, train loss & train metric

In [16]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape:
        pred = model(inputs) # 예측값
        loss = loss_object(labels, pred) # 예측 손실
    grads = tape.gradient(loss, model.trainable_variables) # dy/dx
    
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    train_loss(loss)
    train_metric(labels, predictions)

## 5. Import Data

In [17]:
np.random.seed(0)

pts = list()
labels = list()
center_pts = np.random.uniform(-8.0, 8.0, (10, 2))
for label, center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)

pts = np.stack(pts, axis=0).astype(np.float32)
labels = np.stack(labels, axis=0)

train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(32)

## 6. Modeling

## (1) Model

In [18]:
model = MyModel()

### (2) Loss Object & Optimizer
- Loss Object : spare categorical cross entropy, binary cross entropy..
- Optimizer : Adam, RMSprop, ...

In [19]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

### (3) Metric ( loss & accuracy )

In [20]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

## 7. Train

In [22]:
for epoch in range(EPOCHS):
    for x, label in train_ds:
        train_step(model, x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'Epoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch + 1,
                          train_loss.result(),
                          train_accuracy.result() * 100))
    train_loss.reset_states()
    train_accuracy.reset_states()

Epoch 1, Loss: 0.5004048347473145, Accuracy: 85.82125091552734
Epoch 2, Loss: 0.4983041286468506, Accuracy: 85.854736328125
Epoch 3, Loss: 0.49625158309936523, Accuracy: 85.8885498046875
Epoch 4, Loss: 0.49433252215385437, Accuracy: 85.91976165771484
Epoch 5, Loss: 0.49232837557792664, Accuracy: 85.9474868774414
Epoch 6, Loss: 0.49037522077560425, Accuracy: 85.9784927368164
Epoch 7, Loss: 0.4885392487049103, Accuracy: 86.0060806274414
Epoch 8, Loss: 0.48661574721336365, Accuracy: 86.03408813476562
Epoch 9, Loss: 0.4847148656845093, Accuracy: 86.0625
Epoch 10, Loss: 0.48288586735725403, Accuracy: 86.09590148925781
Epoch 11, Loss: 0.481080561876297, Accuracy: 86.1223373413086
Epoch 12, Loss: 0.4793313145637512, Accuracy: 86.14647674560547
Epoch 13, Loss: 0.47759753465652466, Accuracy: 86.1728744506836
Epoch 14, Loss: 0.4758540689945221, Accuracy: 86.20056915283203
Epoch 15, Loss: 0.47417938709259033, Accuracy: 86.227783203125
Epoch 16, Loss: 0.4724782705307007, Accuracy: 86.2519149780273

Epoch 130, Loss: 0.37232863903045654, Accuracy: 87.70140075683594
Epoch 131, Loss: 0.3718596398830414, Accuracy: 87.70574188232422
Epoch 132, Loss: 0.37139323353767395, Accuracy: 87.7143783569336
Epoch 133, Loss: 0.3709069490432739, Accuracy: 87.7177734375
Epoch 134, Loss: 0.3704715073108673, Accuracy: 87.72456359863281
Epoch 135, Loss: 0.369998574256897, Accuracy: 87.73001098632812
Epoch 136, Loss: 0.36956536769866943, Accuracy: 87.74095916748047
Epoch 137, Loss: 0.36911532282829285, Accuracy: 87.74544525146484
Epoch 138, Loss: 0.36872997879981995, Accuracy: 87.75116729736328
Epoch 139, Loss: 0.36827442049980164, Accuracy: 87.75851440429688
Epoch 140, Loss: 0.36783117055892944, Accuracy: 87.76496887207031
Epoch 141, Loss: 0.3674217760562897, Accuracy: 87.77345275878906
Epoch 142, Loss: 0.36699292063713074, Accuracy: 87.77729797363281
Epoch 143, Loss: 0.3665727972984314, Accuracy: 87.784423828125
Epoch 144, Loss: 0.36614084243774414, Accuracy: 87.7885971069336
Epoch 145, Loss: 0.365706

Epoch 258, Loss: 0.3326058089733124, Accuracy: 88.26325225830078
Epoch 259, Loss: 0.3324032723903656, Accuracy: 88.26698303222656
Epoch 260, Loss: 0.3321828842163086, Accuracy: 88.2701416015625
Epoch 261, Loss: 0.3319704830646515, Accuracy: 88.27355194091797
Epoch 262, Loss: 0.3317573070526123, Accuracy: 88.276123046875
Epoch 263, Loss: 0.3315620422363281, Accuracy: 88.27867126464844
Epoch 264, Loss: 0.33136776089668274, Accuracy: 88.28121185302734
Epoch 265, Loss: 0.3311796188354492, Accuracy: 88.28319549560547
Epoch 266, Loss: 0.3309692442417145, Accuracy: 88.28488159179688
Epoch 267, Loss: 0.3307721018791199, Accuracy: 88.287109375
Epoch 268, Loss: 0.33055976033210754, Accuracy: 88.29068756103516
Epoch 269, Loss: 0.33036962151527405, Accuracy: 88.29288482666016
Epoch 270, Loss: 0.3301982581615448, Accuracy: 88.29589080810547
Epoch 271, Loss: 0.32998892664909363, Accuracy: 88.29940795898438
Epoch 272, Loss: 0.32981833815574646, Accuracy: 88.30210876464844
Epoch 273, Loss: 0.329615414

Epoch 385, Loss: 0.31270846724510193, Accuracy: 88.55911254882812
Epoch 386, Loss: 0.31258729100227356, Accuracy: 88.56084442138672
Epoch 387, Loss: 0.3124847114086151, Accuracy: 88.56237030029297
Epoch 388, Loss: 0.31237372756004333, Accuracy: 88.5640869140625
Epoch 389, Loss: 0.3122500777244568, Accuracy: 88.56539154052734
Epoch 390, Loss: 0.31213298439979553, Accuracy: 88.56709289550781
Epoch 391, Loss: 0.31205540895462036, Accuracy: 88.5677719116211
Epoch 392, Loss: 0.3119337260723114, Accuracy: 88.5708999633789
Epoch 393, Loss: 0.31183382868766785, Accuracy: 88.57238006591797
Epoch 394, Loss: 0.3117145895957947, Accuracy: 88.57425689697266
Epoch 395, Loss: 0.31159400939941406, Accuracy: 88.57572937011719
Epoch 396, Loss: 0.3114878237247467, Accuracy: 88.57840728759766
Epoch 397, Loss: 0.31137320399284363, Accuracy: 88.57905578613281
Epoch 398, Loss: 0.3112492561340332, Accuracy: 88.58211517333984
Epoch 399, Loss: 0.3111279308795929, Accuracy: 88.58295440673828
Epoch 400, Loss: 0.3

Epoch 511, Loss: 0.30073294043540955, Accuracy: 88.76021575927734
Epoch 512, Loss: 0.30065664649009705, Accuracy: 88.76224517822266
Epoch 513, Loss: 0.30057811737060547, Accuracy: 88.7634506225586
Epoch 514, Loss: 0.30051082372665405, Accuracy: 88.76513671875
Epoch 515, Loss: 0.30045366287231445, Accuracy: 88.7665023803711
Epoch 516, Loss: 0.300375759601593, Accuracy: 88.76818084716797
Epoch 517, Loss: 0.3003001809120178, Accuracy: 88.76953125
Epoch 518, Loss: 0.3002300560474396, Accuracy: 88.7703857421875
Epoch 519, Loss: 0.30014660954475403, Accuracy: 88.77254486083984
Epoch 520, Loss: 0.3000752925872803, Accuracy: 88.77388000488281
Epoch 521, Loss: 0.30000582337379456, Accuracy: 88.77537536621094
Epoch 522, Loss: 0.2999351918697357, Accuracy: 88.77750396728516
Epoch 523, Loss: 0.2998611629009247, Accuracy: 88.77899169921875
Epoch 524, Loss: 0.299784392118454, Accuracy: 88.77983093261719
Epoch 525, Loss: 0.29970645904541016, Accuracy: 88.78002166748047
Epoch 526, Loss: 0.299626916646

Epoch 640, Loss: 0.29225558042526245, Accuracy: 88.92239379882812
Epoch 641, Loss: 0.2921989858150482, Accuracy: 88.92371368408203
Epoch 642, Loss: 0.2921343743801117, Accuracy: 88.92475891113281
Epoch 643, Loss: 0.2920728325843811, Accuracy: 88.92607879638672
Epoch 644, Loss: 0.29201674461364746, Accuracy: 88.92752075195312
Epoch 645, Loss: 0.291969358921051, Accuracy: 88.92922973632812
Epoch 646, Loss: 0.2919166088104248, Accuracy: 88.93026733398438
Epoch 647, Loss: 0.29185494780540466, Accuracy: 88.93196868896484
Epoch 648, Loss: 0.2918075919151306, Accuracy: 88.93326568603516
Epoch 649, Loss: 0.291748583316803, Accuracy: 88.93402099609375
Epoch 650, Loss: 0.2916886508464813, Accuracy: 88.9351806640625
Epoch 651, Loss: 0.29164209961891174, Accuracy: 88.93659973144531
Epoch 652, Loss: 0.29158464074134827, Accuracy: 88.93788146972656
Epoch 653, Loss: 0.29153409600257874, Accuracy: 88.93875885009766
Epoch 654, Loss: 0.29148659110069275, Accuracy: 88.93924713134766
Epoch 655, Loss: 0.29

Epoch 770, Loss: 0.28589850664138794, Accuracy: 89.0618667602539
Epoch 771, Loss: 0.2858579754829407, Accuracy: 89.062255859375
Epoch 772, Loss: 0.2858116626739502, Accuracy: 89.06287384033203
Epoch 773, Loss: 0.28576600551605225, Accuracy: 89.06394958496094
Epoch 774, Loss: 0.28572091460227966, Accuracy: 89.06558990478516
Epoch 775, Loss: 0.28568539023399353, Accuracy: 89.06665802001953
Epoch 776, Loss: 0.28563717007637024, Accuracy: 89.0677261352539
Epoch 777, Loss: 0.2855888307094574, Accuracy: 89.06879425048828
Epoch 778, Loss: 0.28554823994636536, Accuracy: 89.06974029541016
Epoch 779, Loss: 0.2855014503002167, Accuracy: 89.07022857666016
Epoch 780, Loss: 0.2854548394680023, Accuracy: 89.0712890625
Epoch 781, Loss: 0.2854131758213043, Accuracy: 89.07245635986328
Epoch 782, Loss: 0.28536832332611084, Accuracy: 89.07396697998047
Epoch 783, Loss: 0.2853361666202545, Accuracy: 89.07490539550781
Epoch 784, Loss: 0.2853028178215027, Accuracy: 89.07640838623047
Epoch 785, Loss: 0.2852629

Epoch 898, Loss: 0.28092774748802185, Accuracy: 89.19125366210938
Epoch 899, Loss: 0.2808905839920044, Accuracy: 89.19245910644531
Epoch 900, Loss: 0.2808522880077362, Accuracy: 89.19326782226562
Epoch 901, Loss: 0.28082263469696045, Accuracy: 89.19417572021484
Epoch 902, Loss: 0.28078731894493103, Accuracy: 89.19498443603516
Epoch 903, Loss: 0.28075477480888367, Accuracy: 89.19598388671875
Epoch 904, Loss: 0.2807270288467407, Accuracy: 89.19658660888672
Epoch 905, Loss: 0.2806956470012665, Accuracy: 89.19758605957031
Epoch 906, Loss: 0.2806625962257385, Accuracy: 89.19868469238281
Epoch 907, Loss: 0.2806266248226166, Accuracy: 89.19967651367188
Epoch 908, Loss: 0.28058695793151855, Accuracy: 89.20066833496094
Epoch 909, Loss: 0.28055328130722046, Accuracy: 89.20185852050781
Epoch 910, Loss: 0.28052350878715515, Accuracy: 89.20255279541016
Epoch 911, Loss: 0.2804909646511078, Accuracy: 89.20354461669922
Epoch 912, Loss: 0.280452698469162, Accuracy: 89.20521545410156
Epoch 913, Loss: 0.

## 8. Save / Load parameters

In [23]:
np.savez_compressed('ch2_dataset.npz', inputs=pts, labels=labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()
W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('ch2_parameters.npz',
                    W_h=W_h,
                    b_h=b_h,
                    W_o=W_o,
                    b_o=b_o)