# 9. 정책 기반 에이전트
바닥부터 배우는 강화 학습 9장에 있는 코드를 참고 했습니다.

- colab에서 동작하기 위한 환경 설정

In [1]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!pip install gym[classic_control] > /dev/null 2>&1

In [2]:
import base64
import collections
import glob
import io
import random

import gym

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

from IPython import display as ipythondisplay
from IPython.display import HTML

# 9.2 REINFORCE 알고리즘

- Hyperparameters

In [3]:
learning_rate = 0.0002
gamma         = 0.98

- 정책 네트워크 클래스

In [4]:
class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=0)
        return x
      
    def put_data(self, item):
        self.data.append(item)
        
    def train_net(self):
        R = 0
        self.optimizer.zero_grad()
        for r, prob in self.data[::-1]:
            R = r + gamma * R
            loss = -torch.log(prob) * R
            loss.backward()
        self.optimizer.step()
        self.data = []

- 메인 함수

In [5]:
def main():
    env = gym.make('CartPole-v1')
    pi = Policy()
    score = 0.0
    print_interval = 20

    max_score = 0.0
    
    for n_epi in range(10000):
        s = env.reset()
        done = False
        
        while not done: # CartPole-v1 forced to terminates at 500 step.
            prob = pi(torch.from_numpy(s).float())
            m = Categorical(prob)
            a = m.sample()
            s_prime, r, done, info = env.step(a.item())
            pi.put_data((r,prob[a]))
            s = s_prime
            score += r
            
        pi.train_net()
        
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {}".format(n_epi, score/print_interval))

            if score > max_score:
                print(f'>>>> save reinforce.pth: {score:.1f}')
                torch.save(pi.state_dict(), 'reinforce.pth')
                max_score = score

            score = 0.0
    env.close()

- 학습

In [6]:
main() 

  deprecation(
  deprecation(


# of episode :20, avg score : 22.4
>>>> save reinforce.pth: 448.0
# of episode :40, avg score : 24.45
>>>> save reinforce.pth: 489.0
# of episode :60, avg score : 25.4
>>>> save reinforce.pth: 508.0
# of episode :80, avg score : 26.95
>>>> save reinforce.pth: 539.0
# of episode :100, avg score : 25.65
# of episode :120, avg score : 31.4
>>>> save reinforce.pth: 628.0
# of episode :140, avg score : 26.3
# of episode :160, avg score : 30.3
# of episode :180, avg score : 39.45
>>>> save reinforce.pth: 789.0
# of episode :200, avg score : 34.4
# of episode :220, avg score : 37.8
# of episode :240, avg score : 41.05
>>>> save reinforce.pth: 821.0
# of episode :260, avg score : 36.1
# of episode :280, avg score : 37.8
# of episode :300, avg score : 33.8
# of episode :320, avg score : 46.85
>>>> save reinforce.pth: 937.0
# of episode :340, avg score : 46.85
# of episode :360, avg score : 42.7
# of episode :380, avg score : 51.55
>>>> save reinforce.pth: 1031.0
# of episode :400, avg score : 4

- 결과확인 (준비)

In [7]:
env = gym.make('CartPole-v1')
pi = Policy()
pi.load_state_dict(torch.load('reinforce.pth'))

<All keys matched successfully>

- 결과확인 (env record)

In [8]:
env = gym.wrappers.RecordVideo(env, './video_reinforce')

  logger.warn(


- 실행

In [9]:
s, done = env.reset(), False

while not done:
    prob = pi(torch.from_numpy(s).float())
    m = Categorical(prob)
    action = m.sample()
    s_prime, r, done, info = env.step(action.item())
    s = s_prime
    print(action.item(), r)

  logger.deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(


0 1.0
1 1.0
1 1.0
1 1.0
0 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
1 1.0
0 1.0
0 1.0
0 1.0
0 1.0
1 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
1 1.0
0 1.0
0 1.0
0 1.

In [10]:
# play recorded video
def show_video():
    mp4list = glob.glob('video_reinforce/*.mp4')
    if len(mp4list) > 0:
        mp4 = mp4list[0]
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        ipythondisplay.display(HTML(data='''
            <video alt="test" autoplay loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
            </video>'''.format(encoded.decode('ascii'))))
    else: 
        print("Could not find video")

In [11]:
show_video()

# 9.3 액터-크리틱

- Hyperparameters

In [3]:
learning_rate = 0.0002
gamma         = 0.98
n_rollout     = 10

- 액터 크리틱 클래스

In [4]:
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4,256)
        self.fc_pi = nn.Linear(256,2)
        self.fc_v = nn.Linear(256,1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def pi(self, x, softmax_dim = 0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob
    
    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v
    
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], []
        for transition in self.data:
            s,a,r,s_prime,done = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r/100.0])
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0
            done_lst.append([done_mask])
        
        s_batch, a_batch, r_batch, s_prime_batch, done_batch = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                                                               torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), \
                                                               torch.tensor(done_lst, dtype=torch.float)
        self.data = []
        return s_batch, a_batch, r_batch, s_prime_batch, done_batch
  
    def train_net(self):
        s, a, r, s_prime, done = self.make_batch()
        td_target = r + gamma * self.v(s_prime) * done
        delta = td_target - self.v(s)
        
        pi = self.pi(s, softmax_dim=1)
        pi_a = pi.gather(1,a)
        loss = -torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.v(s), td_target.detach())

        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step() 

- 메인 함수

In [7]:
def main():
    env = gym.make('CartPole-v1')
    model = ActorCritic()    
    print_interval = 20
    score = 0.0

    max_score = 0.0

    for n_epi in range(10000):
        done = False
        s = env.reset()
        while not done:
            for t in range(n_rollout):
                prob = model.pi(torch.from_numpy(s).float())
                m = Categorical(prob)
                a = m.sample().item()
                s_prime, r, done, info = env.step(a)
                model.put_data((s,a,r,s_prime,done))
                
                s = s_prime
                score += r
                
                if done:
                    break                     
            
            model.train_net()
            
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))

            if score > max_score:
                print(f'>>>> save actor-critic.pth: {score:.1f}')
                torch.save(model.state_dict(), 'actor-critic.pth')
                max_score = score

            score = 0.0
    env.close()

- 학습

In [8]:
main() 

  deprecation(
  deprecation(


# of episode :20, avg score : 24.0
>>>> save actor-critic.pth: 480.0
# of episode :40, avg score : 19.4
# of episode :60, avg score : 24.1
>>>> save actor-critic.pth: 483.0
# of episode :80, avg score : 32.2
>>>> save actor-critic.pth: 644.0
# of episode :100, avg score : 36.8
>>>> save actor-critic.pth: 736.0
# of episode :120, avg score : 35.0
# of episode :140, avg score : 35.7
# of episode :160, avg score : 53.5
>>>> save actor-critic.pth: 1070.0
# of episode :180, avg score : 56.3
>>>> save actor-critic.pth: 1126.0
# of episode :200, avg score : 73.7
>>>> save actor-critic.pth: 1474.0
# of episode :220, avg score : 76.2
>>>> save actor-critic.pth: 1524.0
# of episode :240, avg score : 75.4
# of episode :260, avg score : 89.0
>>>> save actor-critic.pth: 1780.0
# of episode :280, avg score : 95.5
>>>> save actor-critic.pth: 1909.0
# of episode :300, avg score : 127.8
>>>> save actor-critic.pth: 2556.0
# of episode :320, avg score : 165.2
>>>> save actor-critic.pth: 3305.0
# of episo

- 결과확인 (준비)

In [9]:
env = gym.make('CartPole-v1')
model = ActorCritic()
model.load_state_dict(torch.load('actor-critic.pth'))

  deprecation(
  deprecation(


<All keys matched successfully>

- 결과확인 (env record)

In [10]:
env = gym.wrappers.RecordVideo(env, './video_actor_critic')

- 실행

In [11]:
s, done = env.reset(), False

while not done:
    prob = model.pi(torch.from_numpy(s).float())
    m = Categorical(prob)
    action = m.sample()
    s_prime, r, done, info = env.step(action.item())
    s = s_prime
    print(action.item(), r)

  logger.deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(


1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
0 1.0
0 1.0
1 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
1 1.0
0 1.0
0 1.0
1 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
0 1.0
1 1.0
1 1.

In [12]:
# play recorded video
def show_video():
    mp4list = glob.glob('video_actor_critic/*.mp4')
    if len(mp4list) > 0:
        mp4 = mp4list[0]
        video = io.open(mp4, 'r+b').read()
        encoded = base64.b64encode(video)
        ipythondisplay.display(HTML(data='''
            <video alt="test" autoplay loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
            </video>'''.format(encoded.decode('ascii'))))
    else: 
        print("Could not find video")

In [13]:
show_video()