In [None]:
import copy
import pylab
import random
import numpy as np
from environment import Env
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam


# 딥살사 인공신경망
class DeepSARSA(tf.keras.Model):
    def __init__(self, action_size):
        super(DeepSARSA, self).__init__()
        self.fc1 = Dense(30, activation='relu')
        self.fc2 = Dense(30, activation='relu')
        self.fc_out = Dense(action_size)

    def call(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        q = self.fc_out(x)
        return q


# 그리드월드 예제에서의 딥살사 에이전트
class DeepSARSAgent:
    def __init__(self, state_size, action_size):
        # 상태의 크기와 행동의 크기 정의
        self.state_size = state_size
        self.action_size = action_size
        
        # 딥살사 하이퍼 파라메터
        self.discount_factor = 0.99
        self.learning_rate = 0.001
        self.epsilon = 1.  
        self.epsilon_decay = .9999
        self.epsilon_min = 0.01
        self.model = DeepSARSA(self.action_size)
        self.optimizer = Adam(lr=self.learning_rate)

    # 입실론 탐욕 정책으로 행동 선택
    def get_action(self, state):
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        else:
            q_values = self.model(state)
            return np.argmax(q_values[0])

    # <s, a, r, s', a'>의 샘플로부터 모델 업데이트
    def train_model(self, state, action, reward, next_state, next_action, done):
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

        # 학습 파라메터
        model_params = self.model.trainable_variables
        with tf.GradientTape() as tape:
            tape.watch(model_params)
            predict = self.model(state)[0]
            one_hot_action = tf.one_hot([action], self.action_size)
            predict = tf.reduce_sum(one_hot_action * predict, axis=1)

            # done = True 일 경우 에피소드가 끝나서 다음 상태가 없음
            next_q = self.model(next_state)[0][next_action]
            target = reward + (1 - done) * self.discount_factor * next_q

            # MSE 오류 함수 계산
            loss = tf.reduce_mean(tf.square(target - predict))
        
        # 오류함수를 줄이는 방향으로 모델 업데이트
        grads = tape.gradient(loss, model_params)
        self.optimizer.apply_gradients(zip(grads, model_params))


if __name__ == "__main__":
    # 환경과 에이전트 생성
    env = Env(render_speed=0.01)
    state_size = 15
    action_space = [0, 1, 2, 3, 4]
    action_size = len(action_space)
    agent = DeepSARSAgent(state_size, action_size)
    
    scores, episodes = [], []

    EPISODES = 1000
    for e in range(EPISODES):
        done = False
        score = 0
        # env 초기화
        state = env.reset()
        state = np.reshape(state, [1, state_size])

        while not done:
            # 현재 상태에 대한 행동 선택
            action = agent.get_action(state)

            # 선택한 행동으로 환경에서 한 타임스텝 진행 후 샘플 수집
            next_state, reward, done = env.step(action)
            next_state = np.reshape(next_state, [1, state_size])
            next_action = agent.get_action(next_state)

            # 샘플로 모델 학습
            agent.train_model(state, action, reward, next_state, 
                                next_action, done)
            score += reward
            state = next_state

            if done:
                # 에피소드마다 학습 결과 출력
                print("episode: {:3d} | score: {:3d} | epsilon: {:.3f}".format(
                      e, score, agent.epsilon))

                scores.append(score)
                episodes.append(e)
                pylab.plot(episodes, scores, 'b')
                pylab.xlabel("episode")
                pylab.ylabel("score")
                pylab.savefig("./save_graph/graph.png")


        # 100 에피소드마다 모델 저장
        if e % 100 == 0:
            agent.model.save_weights('save_model/model', save_format='tf')

2022-11-23 15:04:21.375394: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-11-23 15:04:21.375735: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
  super(Adam, self).__init__(name, **kwargs)


Metal device set to: Apple M1 Pro
episode:   0 | score:  -5 | epsilon: 0.993
episode:   1 | score: -34 | epsilon: 0.971
episode:   2 | score: -14 | epsilon: 0.965
episode:   3 | score:   1 | epsilon: 0.963
episode:   4 | score: -30 | epsilon: 0.940
episode:   5 | score:  -2 | epsilon: 0.937
episode:   6 | score: -12 | epsilon: 0.921
episode:   7 | score:  -9 | epsilon: 0.912
episode:   8 | score:   1 | epsilon: 0.910
episode:   9 | score:  -8 | epsilon: 0.902
episode:  10 | score:  -9 | epsilon: 0.899
episode:  11 | score:  -2 | epsilon: 0.893
episode:  12 | score:   0 | epsilon: 0.891
episode:  13 | score:  -2 | epsilon: 0.886
episode:  14 | score: -25 | epsilon: 0.867
episode:  15 | score:  -7 | epsilon: 0.859
episode:  16 | score:  -4 | epsilon: 0.851
episode:  17 | score:  -2 | epsilon: 0.847
episode:  18 | score: -38 | epsilon: 0.825
episode:  19 | score: -11 | epsilon: 0.816
episode:  20 | score:   1 | epsilon: 0.815
episode:  21 | score:   0 | epsilon: 0.813
episode:  22 | score

episode: 190 | score:  -1 | epsilon: 0.396
episode: 191 | score:  -1 | epsilon: 0.395
episode: 192 | score:  -3 | epsilon: 0.394
episode: 193 | score:  -1 | epsilon: 0.393
episode: 194 | score:   0 | epsilon: 0.392
episode: 195 | score:  -2 | epsilon: 0.390
episode: 196 | score:  -2 | epsilon: 0.389
episode: 197 | score:   0 | epsilon: 0.387
episode: 198 | score:   1 | epsilon: 0.387
episode: 199 | score:   0 | epsilon: 0.386
episode: 200 | score:   1 | epsilon: 0.385
episode: 201 | score:  -1 | epsilon: 0.384
episode: 202 | score:   0 | epsilon: 0.384
episode: 203 | score:   0 | epsilon: 0.383
episode: 204 | score:   1 | epsilon: 0.383
episode: 205 | score:   0 | epsilon: 0.382
episode: 206 | score:   1 | epsilon: 0.382
episode: 207 | score:   1 | epsilon: 0.381
episode: 208 | score:   1 | epsilon: 0.380
episode: 209 | score:   0 | epsilon: 0.379
episode: 210 | score:   1 | epsilon: 0.379
episode: 211 | score:  -1 | epsilon: 0.378
episode: 212 | score:   0 | epsilon: 0.378
episode: 21

episode: 381 | score:   1 | epsilon: 0.267
episode: 382 | score:   0 | epsilon: 0.267
episode: 383 | score:   1 | epsilon: 0.266
episode: 384 | score:   1 | epsilon: 0.266
episode: 385 | score:   0 | epsilon: 0.265
episode: 386 | score:   1 | epsilon: 0.264
episode: 387 | score:   1 | epsilon: 0.263
episode: 388 | score:   1 | epsilon: 0.263
episode: 389 | score:   0 | epsilon: 0.262
episode: 390 | score:   0 | epsilon: 0.262
episode: 391 | score:  -3 | epsilon: 0.261
episode: 392 | score:   0 | epsilon: 0.261
episode: 393 | score:  -1 | epsilon: 0.260
episode: 394 | score:   1 | epsilon: 0.259
episode: 395 | score:  -1 | epsilon: 0.257
episode: 396 | score:   1 | epsilon: 0.257
episode: 397 | score:   0 | epsilon: 0.257
episode: 398 | score:   1 | epsilon: 0.256
episode: 399 | score:   1 | epsilon: 0.256
episode: 400 | score:   1 | epsilon: 0.255
episode: 401 | score:   1 | epsilon: 0.255
episode: 402 | score:   0 | epsilon: 0.255
episode: 403 | score:   0 | epsilon: 0.254
episode: 40

episode: 572 | score:   0 | epsilon: 0.183
episode: 573 | score:   1 | epsilon: 0.183
episode: 574 | score:  -1 | epsilon: 0.183
episode: 575 | score:   1 | epsilon: 0.183
episode: 576 | score:   0 | epsilon: 0.182
episode: 577 | score:   1 | epsilon: 0.182
episode: 578 | score:   0 | epsilon: 0.181
episode: 579 | score:   1 | epsilon: 0.181
episode: 580 | score:   1 | epsilon: 0.181
episode: 581 | score:   1 | epsilon: 0.181
episode: 582 | score:   1 | epsilon: 0.180
episode: 583 | score:   0 | epsilon: 0.180
episode: 584 | score:   0 | epsilon: 0.180
episode: 585 | score:  -1 | epsilon: 0.179
episode: 586 | score:  -1 | epsilon: 0.179
episode: 587 | score:  -1 | epsilon: 0.178
episode: 588 | score:  -1 | epsilon: 0.177
episode: 589 | score:   1 | epsilon: 0.177
episode: 590 | score:   0 | epsilon: 0.176
episode: 591 | score:   0 | epsilon: 0.175
episode: 592 | score:   0 | epsilon: 0.174
episode: 593 | score:   1 | epsilon: 0.174
episode: 594 | score:   1 | epsilon: 0.174
episode: 59

episode: 763 | score:   1 | epsilon: 0.122
episode: 764 | score:   1 | epsilon: 0.122
episode: 765 | score:  -1 | epsilon: 0.122
episode: 766 | score:   1 | epsilon: 0.122
episode: 767 | score:   1 | epsilon: 0.122
episode: 768 | score:   1 | epsilon: 0.121
episode: 769 | score:   1 | epsilon: 0.121
episode: 770 | score:   1 | epsilon: 0.121
episode: 771 | score:   1 | epsilon: 0.121
episode: 772 | score:   0 | epsilon: 0.121
episode: 773 | score:   1 | epsilon: 0.120
episode: 774 | score:   1 | epsilon: 0.120
episode: 775 | score:   1 | epsilon: 0.120
episode: 776 | score:   1 | epsilon: 0.120
episode: 777 | score:   1 | epsilon: 0.120
episode: 778 | score:   1 | epsilon: 0.119
episode: 779 | score:   1 | epsilon: 0.119
episode: 780 | score:   1 | epsilon: 0.119
episode: 781 | score:   0 | epsilon: 0.118
episode: 782 | score:   1 | epsilon: 0.118
episode: 783 | score:   1 | epsilon: 0.117
episode: 784 | score:   1 | epsilon: 0.117
episode: 785 | score:   1 | epsilon: 0.117
episode: 78

episode: 954 | score:   0 | epsilon: 0.056
episode: 955 | score:   1 | epsilon: 0.056
episode: 956 | score:   1 | epsilon: 0.056
episode: 957 | score:   1 | epsilon: 0.055
episode: 958 | score:   0 | epsilon: 0.055
episode: 959 | score:   0 | epsilon: 0.055
episode: 960 | score:   0 | epsilon: 0.055
episode: 961 | score:   0 | epsilon: 0.055
episode: 962 | score:   0 | epsilon: 0.055
episode: 963 | score:   1 | epsilon: 0.055
episode: 964 | score:   1 | epsilon: 0.055
episode: 965 | score:   1 | epsilon: 0.055
episode: 966 | score:   1 | epsilon: 0.055
episode: 967 | score:   0 | epsilon: 0.055
episode: 968 | score:   0 | epsilon: 0.055
episode: 969 | score:   0 | epsilon: 0.054
episode: 970 | score:   1 | epsilon: 0.054
episode: 971 | score:   1 | epsilon: 0.053
episode: 972 | score:   0 | epsilon: 0.053
episode: 973 | score:  -1 | epsilon: 0.053
episode: 974 | score:   0 | epsilon: 0.053
