# 경사 하강법을 이용한 얕은 신경망 학습


In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## 하이퍼 파라미터 설정

In [2]:
EPOCHS = 1000

## 네트워크 구조 정의
### 얕은 신경망
#### 입력 계층 : 2, 은닉 계층 : 128 (Sigmoid activation), 출력 계층 : 10 (Softmax activation)

In [26]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d1 = tf.keras.layers.Dense(128, input_dim=2, activation='sigmoid')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')
        
    def call(self, x, training=None, mask=None):
        x = self.d1(x)
        return self.d2(x)

## 학습 루프 정의

In [33]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables) # grad(loss)  df(x)/dx
    
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_metric(labels, predictions)

## 데이터셋 생성, 전처리

In [34]:
np.random.seed(0)

pts = list()
labels = list()
center_pts = np.random.uniform(-8.0, 8.0, (10, 2))
for label, center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)
        
pts = np.stack(pts, axis=0).astype(np.float32)
labels = np.stack(labels, axis=0)

train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(32)

## 모델 생성

In [35]:
model = MyModel()

## 손실 함수 및 최적화 알고리즘 설정
### CrossEntropy, Adam Optimizer

In [36]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

## 평가 지표 설정
### Accuracy

In [37]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

## 학습 루프

In [40]:
for epoch in range(EPOCHS):
    for x, label in train_ds:
        train_step(model, x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'Epoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch +1,
                             train_loss.result(),
                             train_accuracy.result() * 100))

Epoch 1, Loss: 0.30435529351234436, Accuracy: 88.69740295410156
Epoch 2, Loss: 0.30426865816116333, Accuracy: 88.69792938232422
Epoch 3, Loss: 0.3041715919971466, Accuracy: 88.69932556152344
Epoch 4, Loss: 0.3040834665298462, Accuracy: 88.70123291015625
Epoch 5, Loss: 0.30398958921432495, Accuracy: 88.70347595214844
Epoch 6, Loss: 0.3039257228374481, Accuracy: 88.7052001953125
Epoch 7, Loss: 0.30385035276412964, Accuracy: 88.70674133300781
Epoch 8, Loss: 0.3037576973438263, Accuracy: 88.7082748413086
Epoch 9, Loss: 0.30367761850357056, Accuracy: 88.7098159790039
Epoch 10, Loss: 0.3035868704319, Accuracy: 88.71116638183594
Epoch 11, Loss: 0.30349498987197876, Accuracy: 88.71166229248047
Epoch 12, Loss: 0.30340492725372314, Accuracy: 88.7126693725586
Epoch 13, Loss: 0.3033137917518616, Accuracy: 88.71417999267578
Epoch 14, Loss: 0.3032410740852356, Accuracy: 88.71569061279297
Epoch 15, Loss: 0.3031456470489502, Accuracy: 88.717529296875
Epoch 16, Loss: 0.30305832624435425, Accuracy: 88.7

Epoch 131, Loss: 0.2948480546474457, Accuracy: 88.86976623535156
Epoch 132, Loss: 0.2947787940502167, Accuracy: 88.87065887451172
Epoch 133, Loss: 0.2947140634059906, Accuracy: 88.87168884277344
Epoch 134, Loss: 0.29466187953948975, Accuracy: 88.87300109863281
Epoch 135, Loss: 0.2945965528488159, Accuracy: 88.87445068359375
Epoch 136, Loss: 0.29454612731933594, Accuracy: 88.87533569335938
Epoch 137, Loss: 0.29448485374450684, Accuracy: 88.8755111694336
Epoch 138, Loss: 0.29441899061203003, Accuracy: 88.87596893310547
Epoch 139, Loss: 0.29435136914253235, Accuracy: 88.87740325927734
Epoch 140, Loss: 0.2942884862422943, Accuracy: 88.87841796875
Epoch 141, Loss: 0.2942271828651428, Accuracy: 88.87970733642578
Epoch 142, Loss: 0.29416200518608093, Accuracy: 88.88127136230469
Epoch 143, Loss: 0.2941035032272339, Accuracy: 88.88298034667969
Epoch 144, Loss: 0.2940468192100525, Accuracy: 88.88341522216797
Epoch 145, Loss: 0.2939951419830322, Accuracy: 88.88469696044922
Epoch 146, Loss: 0.2939

Epoch 260, Loss: 0.28791454434394836, Accuracy: 89.01561737060547
Epoch 261, Loss: 0.28786417841911316, Accuracy: 89.01668548583984
Epoch 262, Loss: 0.28781363368034363, Accuracy: 89.01809692382812
Epoch 263, Loss: 0.2877650260925293, Accuracy: 89.01879119873047
Epoch 264, Loss: 0.28771254420280457, Accuracy: 89.01972961425781
Epoch 265, Loss: 0.28766462206840515, Accuracy: 89.0207748413086
Epoch 266, Loss: 0.28761789202690125, Accuracy: 89.02159118652344
Epoch 267, Loss: 0.28756463527679443, Accuracy: 89.02251434326172
Epoch 268, Loss: 0.28751909732818604, Accuracy: 89.0235595703125
Epoch 269, Loss: 0.2874683439731598, Accuracy: 89.02423858642578
Epoch 270, Loss: 0.287422776222229, Accuracy: 89.02540588378906
Epoch 271, Loss: 0.28738129138946533, Accuracy: 89.02679443359375
Epoch 272, Loss: 0.28733476996421814, Accuracy: 89.02818298339844
Epoch 273, Loss: 0.28729522228240967, Accuracy: 89.02909088134766
Epoch 274, Loss: 0.28724485635757446, Accuracy: 89.03035736083984
Epoch 275, Loss:

Epoch 389, Loss: 0.2824917137622833, Accuracy: 89.15105438232422
Epoch 390, Loss: 0.2824551463127136, Accuracy: 89.15235137939453
Epoch 391, Loss: 0.28241395950317383, Accuracy: 89.15364074707031
Epoch 392, Loss: 0.282382071018219, Accuracy: 89.1544189453125
Epoch 393, Loss: 0.282349169254303, Accuracy: 89.15550231933594
Epoch 394, Loss: 0.2823067903518677, Accuracy: 89.15657806396484
Epoch 395, Loss: 0.28226393461227417, Accuracy: 89.15766143798828
Epoch 396, Loss: 0.28221964836120605, Accuracy: 89.15894317626953
Epoch 397, Loss: 0.2821780741214752, Accuracy: 89.16011810302734
Epoch 398, Loss: 0.2821462154388428, Accuracy: 89.16067504882812
Epoch 399, Loss: 0.28212615847587585, Accuracy: 89.16153717041016
Epoch 400, Loss: 0.282087504863739, Accuracy: 89.16260528564453
Epoch 401, Loss: 0.2820552587509155, Accuracy: 89.16397857666016
Epoch 402, Loss: 0.2820222079753876, Accuracy: 89.16524505615234
Epoch 403, Loss: 0.2819862961769104, Accuracy: 89.1658935546875
Epoch 404, Loss: 0.2819442

Epoch 516, Loss: 0.27827709913253784, Accuracy: 89.27200317382812
Epoch 517, Loss: 0.2782503068447113, Accuracy: 89.27294921875
Epoch 518, Loss: 0.2782178223133087, Accuracy: 89.27379608154297
Epoch 519, Loss: 0.2781880795955658, Accuracy: 89.27474212646484
Epoch 520, Loss: 0.2781613767147064, Accuracy: 89.27567291259766
Epoch 521, Loss: 0.27813294529914856, Accuracy: 89.27670288085938
Epoch 522, Loss: 0.2780970633029938, Accuracy: 89.27735900878906
Epoch 523, Loss: 0.27806615829467773, Accuracy: 89.27857208251953
Epoch 524, Loss: 0.2780330181121826, Accuracy: 89.2793197631836
Epoch 525, Loss: 0.27800193428993225, Accuracy: 89.2798843383789
Epoch 526, Loss: 0.27797043323516846, Accuracy: 89.2809066772461
Epoch 527, Loss: 0.2779357135295868, Accuracy: 89.28201293945312
Epoch 528, Loss: 0.2779049873352051, Accuracy: 89.28276062011719
Epoch 529, Loss: 0.2778737545013428, Accuracy: 89.28358459472656
Epoch 530, Loss: 0.2778422236442566, Accuracy: 89.28496551513672
Epoch 531, Loss: 0.2778115

Epoch 645, Loss: 0.2746112048625946, Accuracy: 89.37899780273438
Epoch 646, Loss: 0.27458423376083374, Accuracy: 89.37975311279297
Epoch 647, Loss: 0.2745620608329773, Accuracy: 89.38066864013672
Epoch 648, Loss: 0.274534672498703, Accuracy: 89.38142395019531
Epoch 649, Loss: 0.2745113670825958, Accuracy: 89.38225555419922
Epoch 650, Loss: 0.27448269724845886, Accuracy: 89.38300323486328
Epoch 651, Loss: 0.27445968985557556, Accuracy: 89.38367462158203
Epoch 652, Loss: 0.27443403005599976, Accuracy: 89.3842544555664
Epoch 653, Loss: 0.27440765500068665, Accuracy: 89.38484191894531
Epoch 654, Loss: 0.2743796408176422, Accuracy: 89.38574981689453
Epoch 655, Loss: 0.27435892820358276, Accuracy: 89.38624572753906
Epoch 656, Loss: 0.2743324339389801, Accuracy: 89.38666534423828
Epoch 657, Loss: 0.27430248260498047, Accuracy: 89.38741302490234
Epoch 658, Loss: 0.2742719054222107, Accuracy: 89.38806915283203
Epoch 659, Loss: 0.2742467522621155, Accuracy: 89.38872528076172
Epoch 660, Loss: 0.2

Epoch 772, Loss: 0.271564245223999, Accuracy: 89.46855926513672
Epoch 773, Loss: 0.2715420424938202, Accuracy: 89.46918487548828
Epoch 774, Loss: 0.2715189456939697, Accuracy: 89.46979522705078
Epoch 775, Loss: 0.2714928388595581, Accuracy: 89.4705581665039
Epoch 776, Loss: 0.27146583795547485, Accuracy: 89.47117614746094
Epoch 777, Loss: 0.2714409828186035, Accuracy: 89.47193908691406
Epoch 778, Loss: 0.27141711115837097, Accuracy: 89.47254943847656
Epoch 779, Loss: 0.27139413356781006, Accuracy: 89.47338104248047
Epoch 780, Loss: 0.27137139439582825, Accuracy: 89.47384643554688
Epoch 781, Loss: 0.271347314119339, Accuracy: 89.47445678710938
Epoch 782, Loss: 0.27132490277290344, Accuracy: 89.47514343261719
Epoch 783, Loss: 0.2713008522987366, Accuracy: 89.47582244873047
Epoch 784, Loss: 0.2712755799293518, Accuracy: 89.47672271728516
Epoch 785, Loss: 0.2712548077106476, Accuracy: 89.47695922851562
Epoch 786, Loss: 0.2712315022945404, Accuracy: 89.47756958007812
Epoch 787, Loss: 0.2712

Epoch 901, Loss: 0.2688314616680145, Accuracy: 89.54776763916016
Epoch 902, Loss: 0.2688116431236267, Accuracy: 89.54834747314453
Epoch 903, Loss: 0.2687913477420807, Accuracy: 89.54913330078125
Epoch 904, Loss: 0.2687746286392212, Accuracy: 89.54950714111328
Epoch 905, Loss: 0.26875948905944824, Accuracy: 89.54994201660156
Epoch 906, Loss: 0.26874127984046936, Accuracy: 89.55072021484375
Epoch 907, Loss: 0.26872164011001587, Accuracy: 89.55149841308594
Epoch 908, Loss: 0.2687003016471863, Accuracy: 89.55220794677734
Epoch 909, Loss: 0.2686786651611328, Accuracy: 89.55284118652344
Epoch 910, Loss: 0.2686573565006256, Accuracy: 89.55321502685547
Epoch 911, Loss: 0.26863807439804077, Accuracy: 89.55371856689453
Epoch 912, Loss: 0.2686171531677246, Accuracy: 89.55428314208984
Epoch 913, Loss: 0.2685982584953308, Accuracy: 89.55485534667969
Epoch 914, Loss: 0.26858431100845337, Accuracy: 89.55548858642578
Epoch 915, Loss: 0.2685617506504059, Accuracy: 89.55599212646484
Epoch 916, Loss: 0.2

## 데이터셋 및 학습 파라미터 저장

In [43]:
np.savez_compressed('ch2_dataset.npz', inputs=pts, labels=labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()
W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('ch2_parameters.npz',
                    W_h=W_h,
                    b_h=b_h,
                    W_o=W_o,
                    b_o=b_o,)
