# 경사 하강법을 이용한 얕은 신경망 학습


In [14]:
import tensorflow as tf
import numpy as np

## 하이퍼 파라미터 설정

In [15]:
EPOCHS= 1000

## 네트워크 구조 정의
### 얕은 신경망
#### 입력 계층 : 2, 은닉 계층 : 128 (Sigmoid activation), 출력 계층 : 10 (Softmax activation)

In [16]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d1 = tf.keras.layers.Dense(128, input_dim=2, activation='sigmoid')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')
        
    def call(self, x, training=None, mask=None):
        x = self.d1(x)
        return self.d2(x)

## 학습 루프 정의

In [17]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables) #df(x)/dx
    
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_metric(labels, predictions)

## 데이터셋 생성, 전처리

In [18]:
np.random.seed(0)

pts = list()
labels = list()
center_pts = np.random.uniform(-8.0, 8.0, (10, 2))
for label, center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)
        
pts = np.stack(pts, axis=0).astype(np.float32)
labels = np.stack(labels, axis=0)

train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(32)


## 모델 생성

In [19]:
model = MyModel()

## 손실 함수 및 최적화 알고리즘 설정
### CrossEntropy, Adam Optimizer

In [20]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

## 평가 지표 설정
### Accuracy

In [21]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

## 학습 루프

In [23]:
for epoch in range(EPOCHS):
    for x, label in train_ds:
        train_step(model, x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'Epoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch + 1,
                         train_loss.result(),
                         train_accuracy.result() * 100))

Epoch 1, Loss: 0.37851738929748535, Accuracy: 87.62191772460938
Epoch 2, Loss: 0.3779905140399933, Accuracy: 87.62772369384766
Epoch 3, Loss: 0.3774819076061249, Accuracy: 87.63710021972656
Epoch 4, Loss: 0.37699422240257263, Accuracy: 87.64234161376953
Epoch 5, Loss: 0.3765145242214203, Accuracy: 87.65022277832031
Epoch 6, Loss: 0.37603121995925903, Accuracy: 87.65803527832031
Epoch 7, Loss: 0.3755563795566559, Accuracy: 87.66355895996094
Epoch 8, Loss: 0.3751174211502075, Accuracy: 87.67035675048828
Epoch 9, Loss: 0.3746238648891449, Accuracy: 87.67841339111328
Epoch 10, Loss: 0.37412312626838684, Accuracy: 87.68508911132812
Epoch 11, Loss: 0.3736244738101959, Accuracy: 87.692138671875
Epoch 12, Loss: 0.3731236755847931, Accuracy: 87.69913482666016
Epoch 13, Loss: 0.37263673543930054, Accuracy: 87.70475769042969
Epoch 14, Loss: 0.37220054864883423, Accuracy: 87.71292877197266
Epoch 15, Loss: 0.37173596024513245, Accuracy: 87.71717071533203
Epoch 16, Loss: 0.37125375866889954, Accurac

Epoch 132, Loss: 0.33472153544425964, Accuracy: 88.23970794677734
Epoch 133, Loss: 0.3345075249671936, Accuracy: 88.2407455444336
Epoch 134, Loss: 0.3342846632003784, Accuracy: 88.2454605102539
Epoch 135, Loss: 0.334054172039032, Accuracy: 88.24957275390625
Epoch 136, Loss: 0.3338248133659363, Accuracy: 88.25254821777344
Epoch 137, Loss: 0.333598792552948, Accuracy: 88.2571792602539
Epoch 138, Loss: 0.3333730697631836, Accuracy: 88.26067352294922
Epoch 139, Loss: 0.33314648270606995, Accuracy: 88.2644271850586
Epoch 140, Loss: 0.3329368531703949, Accuracy: 88.26787567138672
Epoch 141, Loss: 0.33271703124046326, Accuracy: 88.270751953125
Epoch 142, Loss: 0.3325214087963104, Accuracy: 88.27472686767578
Epoch 143, Loss: 0.3323056697845459, Accuracy: 88.27894592285156
Epoch 144, Loss: 0.3321245312690735, Accuracy: 88.28121185302734
Epoch 145, Loss: 0.33191147446632385, Accuracy: 88.28401947021484
Epoch 146, Loss: 0.3317216634750366, Accuracy: 88.28791046142578
Epoch 147, Loss: 0.3315386176

Epoch 259, Loss: 0.3143754303455353, Accuracy: 88.55744171142578
Epoch 260, Loss: 0.31425148248672485, Accuracy: 88.56004333496094
Epoch 261, Loss: 0.3141235113143921, Accuracy: 88.56159210205078
Epoch 262, Loss: 0.3139972388744354, Accuracy: 88.56312561035156
Epoch 263, Loss: 0.31386643648147583, Accuracy: 88.56548309326172
Epoch 264, Loss: 0.3137337565422058, Accuracy: 88.56700897216797
Epoch 265, Loss: 0.3136039078235626, Accuracy: 88.5693588256836
Epoch 266, Loss: 0.31347399950027466, Accuracy: 88.57148742675781
Epoch 267, Loss: 0.3133600056171417, Accuracy: 88.5736083984375
Epoch 268, Loss: 0.3132394254207611, Accuracy: 88.57530975341797
Epoch 269, Loss: 0.3131161034107208, Accuracy: 88.57720184326172
Epoch 270, Loss: 0.313008576631546, Accuracy: 88.5791015625
Epoch 271, Loss: 0.3128911256790161, Accuracy: 88.58098602294922
Epoch 272, Loss: 0.3127661645412445, Accuracy: 88.5840835571289
Epoch 273, Loss: 0.3126470148563385, Accuracy: 88.58655548095703
Epoch 274, Loss: 0.31251704692

Epoch 388, Loss: 0.30173230171203613, Accuracy: 88.75049591064453
Epoch 389, Loss: 0.301648885011673, Accuracy: 88.7522201538086
Epoch 390, Loss: 0.3015706241130829, Accuracy: 88.75312042236328
Epoch 391, Loss: 0.3014789819717407, Accuracy: 88.75467681884766
Epoch 392, Loss: 0.30140841007232666, Accuracy: 88.7562255859375
Epoch 393, Loss: 0.30132848024368286, Accuracy: 88.75728607177734
Epoch 394, Loss: 0.30124688148498535, Accuracy: 88.75882720947266
Epoch 395, Loss: 0.30118250846862793, Accuracy: 88.76020050048828
Epoch 396, Loss: 0.3010971248149872, Accuracy: 88.76107025146484
Epoch 397, Loss: 0.30101433396339417, Accuracy: 88.76227569580078
Epoch 398, Loss: 0.3009392321109772, Accuracy: 88.76396179199219
Epoch 399, Loss: 0.30086952447891235, Accuracy: 88.7649917602539
Epoch 400, Loss: 0.30078235268592834, Accuracy: 88.76683044433594
Epoch 401, Loss: 0.3006960153579712, Accuracy: 88.76817321777344
Epoch 402, Loss: 0.3006136119365692, Accuracy: 88.77032470703125
Epoch 403, Loss: 0.30

Epoch 517, Loss: 0.29308587312698364, Accuracy: 88.9047622680664
Epoch 518, Loss: 0.29302939772605896, Accuracy: 88.90543365478516
Epoch 519, Loss: 0.2929680347442627, Accuracy: 88.90664672851562
Epoch 520, Loss: 0.2929143011569977, Accuracy: 88.90731811523438
Epoch 521, Loss: 0.2928623557090759, Accuracy: 88.9079818725586
Epoch 522, Loss: 0.29280540347099304, Accuracy: 88.90892028808594
Epoch 523, Loss: 0.2927437126636505, Accuracy: 88.9105224609375
Epoch 524, Loss: 0.29268670082092285, Accuracy: 88.91131591796875
Epoch 525, Loss: 0.2926359176635742, Accuracy: 88.91278839111328
Epoch 526, Loss: 0.2925873398780823, Accuracy: 88.91411590576172
Epoch 527, Loss: 0.2925264239311218, Accuracy: 88.91570281982422
Epoch 528, Loss: 0.2924828827381134, Accuracy: 88.91742706298828
Epoch 529, Loss: 0.29243069887161255, Accuracy: 88.91793823242188
Epoch 530, Loss: 0.29236888885498047, Accuracy: 88.91911315917969
Epoch 531, Loss: 0.2923077344894409, Accuracy: 88.92056274414062
Epoch 532, Loss: 0.292

Epoch 643, Loss: 0.28672850131988525, Accuracy: 89.03623962402344
Epoch 644, Loss: 0.2867085933685303, Accuracy: 89.03711700439453
Epoch 645, Loss: 0.28665897250175476, Accuracy: 89.03812408447266
Epoch 646, Loss: 0.2866109609603882, Accuracy: 89.03935241699219
Epoch 647, Loss: 0.286570280790329, Accuracy: 89.04034423828125
Epoch 648, Loss: 0.28652098774909973, Accuracy: 89.04122161865234
Epoch 649, Loss: 0.2864700257778168, Accuracy: 89.04198455810547
Epoch 650, Loss: 0.2864273488521576, Accuracy: 89.04331970214844
Epoch 651, Loss: 0.2863822281360626, Accuracy: 89.04373168945312
Epoch 652, Loss: 0.28633618354797363, Accuracy: 89.04447937011719
Epoch 653, Loss: 0.2862938940525055, Accuracy: 89.04523468017578
Epoch 654, Loss: 0.286249041557312, Accuracy: 89.04621887207031
Epoch 655, Loss: 0.2862046957015991, Accuracy: 89.04730987548828
Epoch 656, Loss: 0.28615856170654297, Accuracy: 89.04793548583984
Epoch 657, Loss: 0.28611600399017334, Accuracy: 89.04925537109375
Epoch 658, Loss: 0.28

Epoch 773, Loss: 0.28159087896347046, Accuracy: 89.16477966308594
Epoch 774, Loss: 0.2815519869327545, Accuracy: 89.16572570800781
Epoch 775, Loss: 0.28151533007621765, Accuracy: 89.1669692993164
Epoch 776, Loss: 0.2814878523349762, Accuracy: 89.16800689697266
Epoch 777, Loss: 0.2814594507217407, Accuracy: 89.16864013671875
Epoch 778, Loss: 0.28142249584198, Accuracy: 89.1697769165039
Epoch 779, Loss: 0.2813868224620819, Accuracy: 89.17071533203125
Epoch 780, Loss: 0.28135502338409424, Accuracy: 89.17153930664062
Epoch 781, Loss: 0.2813252806663513, Accuracy: 89.17217254638672
Epoch 782, Loss: 0.28129297494888306, Accuracy: 89.17310333251953
Epoch 783, Loss: 0.28125423192977905, Accuracy: 89.17442321777344
Epoch 784, Loss: 0.2812170684337616, Accuracy: 89.17545318603516
Epoch 785, Loss: 0.2811761498451233, Accuracy: 89.1763687133789
Epoch 786, Loss: 0.2811441421508789, Accuracy: 89.17668914794922
Epoch 787, Loss: 0.2811135947704315, Accuracy: 89.1776123046875
Epoch 788, Loss: 0.2810821

Epoch 904, Loss: 0.2773009240627289, Accuracy: 89.2849349975586
Epoch 905, Loss: 0.2772690951824188, Accuracy: 89.28575134277344
Epoch 906, Loss: 0.2772384583950043, Accuracy: 89.28683471679688
Epoch 907, Loss: 0.2772122323513031, Accuracy: 89.28800201416016
Epoch 908, Loss: 0.27717795968055725, Accuracy: 89.2891616821289
Epoch 909, Loss: 0.27715709805488586, Accuracy: 89.2900619506836
Epoch 910, Loss: 0.2771303653717041, Accuracy: 89.2904281616211
Epoch 911, Loss: 0.2771042585372925, Accuracy: 89.2914047241211
Epoch 912, Loss: 0.27707937359809875, Accuracy: 89.29248046875
Epoch 913, Loss: 0.2770465314388275, Accuracy: 89.29336547851562
Epoch 914, Loss: 0.27701395750045776, Accuracy: 89.29425811767578
Epoch 915, Loss: 0.2769824266433716, Accuracy: 89.29558563232422
Epoch 916, Loss: 0.27695614099502563, Accuracy: 89.29630279541016
Epoch 917, Loss: 0.27692338824272156, Accuracy: 89.29673767089844
Epoch 918, Loss: 0.2768929898738861, Accuracy: 89.29771423339844
Epoch 919, Loss: 0.27685993

## 데이터셋 및 학습 파라미터 저장

In [24]:
np.savez_compressed('ch2_dataset.npz', inputs=pts, labels=labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()
W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('ch2_parameters.npz', 
                    W_h=W_h,
                    b_h=b_h,
                    W_o=W_o,
                    b_o=b_o)