# 8. 가치 기반 에이전트
바닥부터 배우는 강화 학습 8장에 있는 코드를 참고 했습니다.

- colab에서 동작하기 위한 환경 설정

In [55]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!pip install gym[classic_control] > /dev/null 2>&1

In [56]:
import base64
import collections
import glob
import io
import random

import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from IPython import display as ipythondisplay
from IPython.display import HTML

# 8.2 딥 Q러닝

- Hyperparameters

In [57]:
learning_rate = 0.0005
gamma         = 0.98
buffer_limit  = 50000
batch_size    = 32

- Replay Buffer 클래스

In [58]:
class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
               torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
               torch.tensor(done_mask_lst)

    def size(self):
        return len(self.buffer)

- Qnet 클래스: 액션-가치함수 딥러닝 모델

In [59]:
class Qnet(nn.Module):
    def __init__(self):
        super(Qnet, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
      
    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0,1)
        else : 
            return out.argmax().item()

- 학습 함수

In [60]:
def train(q, q_target, memory, optimizer):
    for i in range(10):
        s,a,r,s_prime,done_mask = memory.sample(batch_size)

        q_out = q(s)
        q_a = q_out.gather(1,a)
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

- 메인 함수

In [71]:
def main():
    env = gym.make('CartPole-v1')
    q = Qnet()
    q_target = Qnet()
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer()

    print_interval = 20
    score = 0.0  
    optimizer = optim.Adam(q.parameters(), lr=learning_rate)

    max_score = 0.0

    for n_epi in range(10000):
        epsilon = max(0.01, 0.08 - 0.01*(n_epi/200)) #Linear annealing from 8% to 1%
        s = env.reset()
        done = False

        while not done:
            a = q.sample_action(torch.from_numpy(s).float(), epsilon)      
            s_prime, r, done, info = env.step(a)
            done_mask = 0.0 if done else 1.0
            memory.put((s,a,r/100.0,s_prime, done_mask))
            s = s_prime

            score += r
            if done:
                break
            
        if memory.size()>2000:
            train(q, q_target, memory, optimizer)

        if n_epi%print_interval==0 and n_epi!=0:
            q_target.load_state_dict(q.state_dict())
            print("n_episode :{}, score : {:.1f}, n_buffer : {}, eps : {:.1f}%".format(
                                                            n_epi, score/print_interval, memory.size(), epsilon*100))
            if epsilon == 0.01 and score > max_score:
                print(f'>>>> save q_target.pth: {score:.1f}')
                torch.save(q_target.state_dict(), 'q_target.pth')
                max_score = score
            score = 0.0
    env.close()

- 학습

In [72]:
main() 

  deprecation(
  deprecation(


n_episode :20, score : 9.8, n_buffer : 197, eps : 7.9%
n_episode :40, score : 9.2, n_buffer : 381, eps : 7.8%
n_episode :60, score : 9.7, n_buffer : 575, eps : 7.7%
n_episode :80, score : 9.8, n_buffer : 770, eps : 7.6%
n_episode :100, score : 9.4, n_buffer : 959, eps : 7.5%
n_episode :120, score : 9.8, n_buffer : 1155, eps : 7.4%
n_episode :140, score : 9.8, n_buffer : 1350, eps : 7.3%
n_episode :160, score : 9.8, n_buffer : 1545, eps : 7.2%
n_episode :180, score : 9.2, n_buffer : 1730, eps : 7.1%
n_episode :200, score : 9.8, n_buffer : 1925, eps : 7.0%
n_episode :220, score : 9.8, n_buffer : 2122, eps : 6.9%
n_episode :240, score : 10.1, n_buffer : 2323, eps : 6.8%
n_episode :260, score : 9.6, n_buffer : 2514, eps : 6.7%
n_episode :280, score : 9.9, n_buffer : 2713, eps : 6.6%
n_episode :300, score : 12.1, n_buffer : 2955, eps : 6.5%
n_episode :320, score : 18.1, n_buffer : 3316, eps : 6.4%
n_episode :340, score : 72.0, n_buffer : 4756, eps : 6.3%
n_episode :360, score : 193.7, n_buf

- 결과확인 (준비)

In [83]:
env = gym.make('CartPole-v1')
q_target = Qnet()
q_target.load_state_dict(torch.load('q_target.pth'))

  deprecation(
  deprecation(


<All keys matched successfully>

- 결과확인 (env record)

In [84]:
env = gym.wrappers.RecordVideo(env, './video')

  logger.warn(


- 실행

In [85]:
s, done = env.reset(), False

while not done:
    action = q_target.sample_action(torch.from_numpy(s).float(), 0.0)
    s_prime, r, done, info = env.step(action)
    s = s_prime
    print(action, r)

  logger.deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(


0 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.

In [86]:
# play recorded video
def show_video():
    mp4list = glob.glob('video/*.mp4')
    if len(mp4list) > 0:
        mp4 = mp4list[0]
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        ipythondisplay.display(HTML(data='''
            <video alt="test" autoplay loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
            </video>'''.format(encoded.decode('ascii'))))
    else: 
        print("Could not find video")

In [87]:
show_video()