<a href="https://colab.research.google.com/github/wjdgoruds2/Machine_Learning/blob/main/CartPole.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!apt-get install -y python-numpy python-dev cmake zlib1g-dev libjpeg-dev xvfb xorg-dev python-opengl libboost-all-dev libsdl2-dev swig
!pip install pyvirtualdisplay
!pip install piglet
from pyvirtualdisplay import Display
display=Display(visible=0,size=(1400,900))
display.start()
!pip install gym
!pip install gym[atari]
from base64 import b64encode
from glob import glob
from IPython.display import HTML
from IPython import display as ipy_display
from gym import logger as gym_logger
from gym.wrappers import Monitor
gym_logger.set_level(40)

def show_video():
  mp4_list=glob('video/*.mp4')
  if mp4_list:
    mp4=mp4_list[0]
    video=open(mp4,'r+b').read()
    encoded=b64encode(video)
    ipy_display.display(HTML(data='''
            <video alt="gameplay" autoplay controls style="height: 400px;">
                <source src="data:video/mp4;base64,%s" type="video/mp4" />
            </video>
        ''' % (encoded.decode('ascii'))))
  else:
    print('No video found')

def wrap_env(env):
  env=Monitor(env,'./video',force=True)
  return env
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

#Hyperparameters
learning_rate = 0.0002
gamma         = 0.98

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=0)
        return x
      
    def put_data(self, item):
        self.data.append(item)
        
    def train_net(self):
        R = 0
        self.optimizer.zero_grad()
        for r, prob in self.data[::-1]:
            R = r + gamma * R
            loss = -torch.log(prob) * R
            loss.backward()
        self.optimizer.step()
        self.data = []

def main():
    env = gym.make('CartPole-v1')
    pi = Policy()
    score = 0.0
    print_interval = 100
    
    
    for n_epi in range(3000):
        s = env.reset()
        done = False
        
        while not done: # CartPole-v1 forced to terminates at 500 step.
            prob = pi(torch.from_numpy(s).float())
            m = Categorical(prob)
            a = m.sample()
            s_prime, r, done, info = env.step(a.item())
            pi.put_data((r,prob[a]))
            s = s_prime
            score += r
            
        pi.train_net()
        
        if n_epi%print_interval==0 and n_epi!=0:
            print("# of episode :{}, avg score : {}".format(n_epi, score/print_interval))
            score = 0.0
    env=wrap_env(gym.make('CartPole-v1'))
    print(env.action_space)
    s=env.reset()
    print(s)
    for i in range(1000):
      env.render()
      prob=pi(torch.from_numpy(s).float())
      m=Categorical(prob)
      a=m.sample()
      s_prime,r,done,info=env.step(a.item())
      s=s_prime
      if done:
        break
    env.close()
    show_video()
    
if __name__ == '__main__':
    main()

Reading package lists... Done
Building dependency tree       
Reading state information... Done
libjpeg-dev is already the newest version (8c-2ubuntu8).
python-dev is already the newest version (2.7.15~rc1-1).
python-numpy is already the newest version (1:1.13.3-2ubuntu1).
zlib1g-dev is already the newest version (1:1.2.11.dfsg-0ubuntu2).
libboost-all-dev is already the newest version (1.65.1.0ubuntu1).
python-opengl is already the newest version (3.1.0+dfsg-1).
swig is already the newest version (3.0.12-1).
cmake is already the newest version (3.10.2-1ubuntu2.18.04.1).
xorg-dev is already the newest version (1:7.7+19ubuntu7.1).
libsdl2-dev is already the newest version (2.0.8+dfsg1-1ubuntu1.18.04.4).
xvfb is already the newest version (2:1.19.6-1ubuntu4.8).
0 upgraded, 0 newly installed, 0 to remove and 31 not upgraded.
# of episode :100, avg score : 24.94
# of episode :200, avg score : 25.91
# of episode :300, avg score : 32.03
# of episode :400, avg score : 39.42
# of episode :500, 