In [None]:
# simple implementation of DQN 

In [1]:
import numpy as np
import gym
import torch
import os

In [2]:
env_name = 'Acrobot-v1'
env = gym.make(env_name)
obs = env.reset()
env.render()
observation_space = env.observation_space
action_space = env.action_space
print('obs:', observation_space)
print('act:', action_space)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
obs: Box(6,)
act: Discrete(3)


In [24]:
from simple_net import PolNet

In [28]:
qfunc = PolNet(observation_space, action_space)

In [30]:
from machina.pols import CategoricalPol
pol = CategoricalPol(observation_space, action_space, qfunc)

In [41]:
pol.deterministic_ac_real(torch.tensor([obs,obs,obs],dtype=torch.float))

(array([2, 2, 2]), tensor([2, 2, 2]), {'pi': tensor([[0.3319, 0.3279, 0.3402],
          [0.3319, 0.3279, 0.3402],
          [0.3319, 0.3279, 0.3402]], grad_fn=<SoftmaxBackward>), 'hs': None})

In [None]:
# https://github.com/DeepX-inc/machina/blob/master/machina/algos/ppo_clip.py を参考にかく
def dqn(off_traj, pol, num_iteration,b_size):
    pass

## source code below

In [4]:
# 
from simple_net import PolNet,VNet

from machina.vfuncs import DeterministicSVfunc
# policy
pol_net = PolNet(observation_space, action_space)
pol = CategoricalPol(observation_space, action_space, pol_net)
# value function
vf_net = VNet(observation_space)
vf = DeterministicSVfunc(observation_space, vf_net)

# set optimizer to both models
pol_lr = 1e-4
optim_pol = torch.optim.Adam(pol_net.parameters(), pol_lr)

vf_lr = 3e-4
optim_vf = torch.optim.Adam(vf_net.parameters(), vf_lr)

#  arguments of PPO
kl_beta = 1
gamma = 0.995
lam = 1 
clip_param = 0.2
epoch_per_iter = 50
batch_size = 64
max_grad_norm = 10

In [5]:
# registrate your environment and policy to sampler
from machina.samplers import EpiSampler
sampler = EpiSampler(env, pol, num_parallel=2, seed=42)

## 2.Visualize behavior before trainning
You can check initial policy's behavior.

In [6]:
import time
done = False
o = env.reset() 
for _ in range(150): # show 150 frames (=10 sec)
    if done:
        time.sleep(1) # when the boundary　of eposode
        o = env.reset()
    ac_real, ac, a_i = pol.deterministic_ac_real(torch.tensor(o, dtype=torch.float))
    ac_real = ac_real.reshape(pol.action_space.shape)
    next_o, r, done, e_i = env.step(np.array(ac_real))
    o = next_o
    time.sleep(1/15) # 15fps
    env.render()

## 3.Trainning
You can edit train setting and train your policy. It takes several menutes.

In [7]:
# train your policy
from machina.traj import epi_functional as ef
from machina import logger
from machina.utils import measure
from machina.traj import Traj
from machina.algos import ppo_clip

# machina automatically write log (model ,scores, etc..)
log_dir_name = 'garbage'
if not os.path.exists(log_dir_name):
    os.mkdir(log_dir_name)
    os.mkdir(log_dir_name+'/models')
score_file = os.path.join(log_dir_name, 'progress.csv')
logger.add_tabular_output(score_file)

# counter and record for loop
total_epi = 0
total_step = 0
max_rew = -500

# how long will you train
max_episodes = 100 # for100 eposode

# max timesteps per eposode
max_steps_per_iter = 150 # 150 frames (= 10 sec)

In [8]:
# train loop
while max_episodes > total_epi:
    # sample trajectories
    with measure('sample'):
        epis = sampler.sample(pol, max_steps=max_steps_per_iter)

    # train from trajectories
    with measure('train'):
        traj = Traj()
        traj.add_epis(epis)
        
        # calulate advantage
        traj = ef.compute_vs(traj, vf)
        traj = ef.compute_rets(traj, gamma)
        traj = ef.compute_advs(traj, gamma, lam)
        traj = ef.centerize_advs(traj)
        traj = ef.compute_h_masks(traj)
        traj.register_epis()

        result_dict = ppo_clip.train(traj=traj, pol=pol, vf=vf, clip_param=clip_param,
                                     optim_pol=optim_pol, optim_vf=optim_vf, 
                                     epoch=epoch_per_iter, batch_size=batch_size,
                                     max_grad_norm=max_grad_norm)
    # update counter and record
    total_epi += traj.num_epi
    step = traj.num_step
    total_step += step
    rewards = [np.sum(epi['rews']) for epi in epis]
    mean_rew = np.mean(rewards)
    logger.record_results(log_dir_name, result_dict, score_file,
                          total_epi, step, total_step,
                          rewards,
                          plot_title=env_name)
    if mean_rew > max_rew:
        torch.save(pol.state_dict(), os.path.join(
            log_dir_name, 'models', 'pol_max.pkl'))
        torch.save(vf.state_dict(), os.path.join(
            log_dir_name, 'models', 'vf_max.pkl'))
        torch.save(optim_pol.state_dict(), os.path.join(
            log_dir_name, 'models', 'optim_pol_max.pkl'))
        torch.save(optim_vf.state_dict(), os.path.join(
            log_dir_name, 'models', 'optim_vf_max.pkl'))
        max_rew = mean_rew

    torch.save(pol.state_dict(), os.path.join(
        log_dir_name, 'models', 'pol_last.pkl'))
    torch.save(vf.state_dict(), os.path.join(
        log_dir_name, 'models', 'vf_last.pkl'))
    torch.save(optim_pol.state_dict(), os.path.join(
        log_dir_name, 'models', 'optim_pol_last.pkl'))
    torch.save(optim_vf.state_dict(), os.path.join(
        log_dir_name, 'models', 'optim_vf_last.pkl'))
    del traj
del sampler

2019-02-13 11:55:59.556658 JST | sample: 0.4890sec
2019-02-13 11:55:59.589596 JST | Optimizing...
2019-02-13 11:56:02.653787 JST | Optimization finished!
2019-02-13 11:56:02.654872 JST | train: 3.0972sec
2019-02-13 11:56:02.656035 JST | outdir /Users/kyoshiro/Downloads/machina/example/quickstart/garbage
2019-02-13 11:56:02.670284 JST | --------------  -------------
2019-02-13 11:56:02.671367 JST | PolLossAverage      0.0722895
2019-02-13 11:56:02.672435 JST | PolLossStd          0.861815
2019-02-13 11:56:02.673484 JST | PolLossMedian       0.393907
2019-02-13 11:56:02.674440 JST | PolLossMin         -1.5943
2019-02-13 11:56:02.675541 JST | PolLossMax          1.06612
2019-02-13 11:56:02.676926 JST | VfLossAverage    3218.36
2019-02-13 11:56:02.677964 JST | VfLossStd        3598.97
2019-02-13 11:56:02.678887 JST | VfLossMedian     1819.75
2019-02-13 11:56:02.679765 JST | VfLossMin          83.0641
2019-02-13 11:56:02.681362 JST | VfLossMax       16472.4
2019-02-13 11:56:02.682710 JST | 

2019-02-13 11:56:21.599135 JST | VfLossAverage    740.062
2019-02-13 11:56:21.600250 JST | VfLossStd       1656.02
2019-02-13 11:56:21.601337 JST | VfLossMedian      62.8886
2019-02-13 11:56:21.602599 JST | VfLossMin          0.767179
2019-02-13 11:56:21.603434 JST | VfLossMax       6542.1
2019-02-13 11:56:21.604175 JST | RewardAverage   -500
2019-02-13 11:56:21.604940 JST | RewardStd          0
2019-02-13 11:56:21.606093 JST | RewardMedian    -500
2019-02-13 11:56:21.607233 JST | RewardMin       -500
2019-02-13 11:56:21.608366 JST | RewardMax       -500
2019-02-13 11:56:21.609298 JST | EpisodePerIter     2
2019-02-13 11:56:21.610181 JST | TotalEpisode      12
2019-02-13 11:56:21.611100 JST | StepPerIter     1000
2019-02-13 11:56:21.611934 JST | TotalStep       5964
2019-02-13 11:56:21.613288 JST | --------------  ------------
2019-02-13 11:56:22.102690 JST | sample: 0.4727sec
2019-02-13 11:56:22.118885 JST | Optimizing...
2019-02-13 11:56:25.401233 JST | Optimization finished!
2019-02

2019-02-13 11:56:40.403626 JST | StepPerIter       976
2019-02-13 11:56:40.404439 JST | TotalStep       10923
2019-02-13 11:56:40.405391 JST | --------------  -------------
2019-02-13 11:56:40.931771 JST | sample: 0.5110sec
2019-02-13 11:56:40.949946 JST | Optimizing...
2019-02-13 11:56:44.102027 JST | Optimization finished!
2019-02-13 11:56:44.105742 JST | train: 3.1726sec
2019-02-13 11:56:44.107109 JST | outdir /Users/kyoshiro/Downloads/machina/example/quickstart/garbage
2019-02-13 11:56:44.112833 JST | --------------  -------------
2019-02-13 11:56:44.113724 JST | PolLossAverage      0.0854223
2019-02-13 11:56:44.115353 JST | PolLossStd          0.774902
2019-02-13 11:56:44.117130 JST | PolLossMedian       0.316919
2019-02-13 11:56:44.118391 JST | PolLossMin         -1.5071
2019-02-13 11:56:44.119681 JST | PolLossMax          1.16125
2019-02-13 11:56:44.120518 JST | VfLossAverage     567.286
2019-02-13 11:56:44.121361 JST | VfLossStd         868.347
2019-02-13 11:56:44.122190 JST | 

2019-02-13 11:57:02.845023 JST | PolLossAverage      0.121287
2019-02-13 11:57:02.846230 JST | PolLossStd          0.635785
2019-02-13 11:57:02.847420 JST | PolLossMedian       0.40246
2019-02-13 11:57:02.849304 JST | PolLossMin         -1.47822
2019-02-13 11:57:02.851299 JST | PolLossMax          0.985034
2019-02-13 11:57:02.852852 JST | VfLossAverage     127.705
2019-02-13 11:57:02.854872 JST | VfLossStd         187.869
2019-02-13 11:57:02.856728 JST | VfLossMedian       58.6139
2019-02-13 11:57:02.857950 JST | VfLossMin           1.20563
2019-02-13 11:57:02.859110 JST | VfLossMax         917.574
2019-02-13 11:57:02.860603 JST | RewardAverage    -500
2019-02-13 11:57:02.861674 JST | RewardStd           0
2019-02-13 11:57:02.862905 JST | RewardMedian     -500
2019-02-13 11:57:02.865530 JST | RewardMin        -500
2019-02-13 11:57:02.872399 JST | RewardMax        -500
2019-02-13 11:57:02.874027 JST | EpisodePerIter      2
2019-02-13 11:57:02.875369 JST | TotalEpisode       34
2019-02-1

2019-02-13 11:57:20.549802 JST | VfLossMax        1734.49
2019-02-13 11:57:20.550795 JST | RewardAverage    -362.5
2019-02-13 11:57:20.551889 JST | RewardStd          46.5
2019-02-13 11:57:20.552912 JST | RewardMedian     -362.5
2019-02-13 11:57:20.553942 JST | RewardMin        -409
2019-02-13 11:57:20.555123 JST | RewardMax        -316
2019-02-13 11:57:20.556427 JST | EpisodePerIter      2
2019-02-13 11:57:20.557368 JST | TotalEpisode       44
2019-02-13 11:57:20.558519 JST | StepPerIter       727
2019-02-13 11:57:20.559545 JST | TotalStep       21298
2019-02-13 11:57:20.560622 JST | --------------  ------------
2019-02-13 11:57:21.164099 JST | sample: 0.5795sec
2019-02-13 11:57:21.187040 JST | Optimizing...
2019-02-13 11:57:24.885204 JST | Optimization finished!
2019-02-13 11:57:24.887849 JST | train: 3.7225sec
2019-02-13 11:57:24.889141 JST | outdir /Users/kyoshiro/Downloads/machina/example/quickstart/garbage
2019-02-13 11:57:24.894370 JST | --------------  -------------
2019-02-13 

2019-02-13 11:57:39.714670 JST | --------------  ------------
2019-02-13 11:57:40.244439 JST | sample: 0.5122sec
2019-02-13 11:57:40.262308 JST | Optimizing...
2019-02-13 11:57:42.980513 JST | Optimization finished!
2019-02-13 11:57:42.983508 JST | train: 2.7382sec
2019-02-13 11:57:42.984694 JST | outdir /Users/kyoshiro/Downloads/machina/example/quickstart/garbage
2019-02-13 11:57:42.989950 JST | --------------  -------------
2019-02-13 11:57:42.990881 JST | PolLossAverage     -0.0050852
2019-02-13 11:57:42.991728 JST | PolLossStd          0.800565
2019-02-13 11:57:42.992960 JST | PolLossMedian       0.0123125
2019-02-13 11:57:42.994292 JST | PolLossMin         -0.98274
2019-02-13 11:57:42.995713 JST | PolLossMax          1.33365
2019-02-13 11:57:42.997102 JST | VfLossAverage     350.885
2019-02-13 11:57:42.998627 JST | VfLossStd         298.817
2019-02-13 11:57:43.000016 JST | VfLossMedian      184.981
2019-02-13 11:57:43.002274 JST | VfLossMin           6.1471
2019-02-13 11:57:43.003

2019-02-13 11:57:59.116728 JST | PolLossMedian       0.0469332
2019-02-13 11:57:59.118734 JST | PolLossMin         -2.51908
2019-02-13 11:57:59.119695 JST | PolLossMax          1.51558
2019-02-13 11:57:59.121055 JST | VfLossAverage     596.644
2019-02-13 11:57:59.122021 JST | VfLossStd         940.645
2019-02-13 11:57:59.123158 JST | VfLossMedian      192.325
2019-02-13 11:57:59.124501 JST | VfLossMin           1.10091
2019-02-13 11:57:59.125975 JST | VfLossMax        3798.45
2019-02-13 11:57:59.127478 JST | RewardAverage    -448.5
2019-02-13 11:57:59.130764 JST | RewardStd          51.5
2019-02-13 11:57:59.132440 JST | RewardMedian     -448.5
2019-02-13 11:57:59.133888 JST | RewardMin        -500
2019-02-13 11:57:59.135610 JST | RewardMax        -397
2019-02-13 11:57:59.136846 JST | EpisodePerIter      2
2019-02-13 11:57:59.138278 JST | TotalEpisode       66
2019-02-13 11:57:59.139134 JST | StepPerIter       898
2019-02-13 11:57:59.140189 JST | TotalStep       30868
2019-02-13 11:57:5

2019-02-13 11:58:16.892909 JST | RewardStd           5
2019-02-13 11:58:16.893932 JST | RewardMedian     -229
2019-02-13 11:58:16.894939 JST | RewardMin        -234
2019-02-13 11:58:16.896209 JST | RewardMax        -224
2019-02-13 11:58:16.897243 JST | EpisodePerIter      2
2019-02-13 11:58:16.898759 JST | TotalEpisode       76
2019-02-13 11:58:16.900296 JST | StepPerIter       460
2019-02-13 11:58:16.902506 JST | TotalStep       35128
2019-02-13 11:58:16.904307 JST | --------------  -------------
2019-02-13 11:58:17.224061 JST | sample: 0.3007sec
2019-02-13 11:58:17.234498 JST | Optimizing...
2019-02-13 11:58:19.057275 JST | Optimization finished!
2019-02-13 11:58:19.059358 JST | train: 1.8340sec
2019-02-13 11:58:19.060638 JST | outdir /Users/kyoshiro/Downloads/machina/example/quickstart/garbage
2019-02-13 11:58:19.064725 JST | --------------  -------------
2019-02-13 11:58:19.066697 JST | PolLossAverage     -0.0339037
2019-02-13 11:58:19.067972 JST | PolLossStd          0.795466
2019

2019-02-13 11:58:27.365847 JST | Optimizing...
2019-02-13 11:58:28.844088 JST | Optimization finished!
2019-02-13 11:58:28.845867 JST | train: 1.4872sec
2019-02-13 11:58:28.847088 JST | outdir /Users/kyoshiro/Downloads/machina/example/quickstart/garbage
2019-02-13 11:58:28.850346 JST | --------------  -------------
2019-02-13 11:58:28.851364 JST | PolLossAverage      0.0355306
2019-02-13 11:58:28.852384 JST | PolLossStd          0.957871
2019-02-13 11:58:28.853539 JST | PolLossMedian       0.422882
2019-02-13 11:58:28.854671 JST | PolLossMin         -1.21926
2019-02-13 11:58:28.855720 JST | PolLossMax          1.26208
2019-02-13 11:58:28.856653 JST | VfLossAverage     609.075
2019-02-13 11:58:28.857885 JST | VfLossStd         616.759
2019-02-13 11:58:28.859035 JST | VfLossMedian      289.739
2019-02-13 11:58:28.860489 JST | VfLossMin          35.0093
2019-02-13 11:58:28.861829 JST | VfLossMax        2502.92
2019-02-13 11:58:28.862967 JST | RewardAverage    -250.5
2019-02-13 11:58:28.86

2019-02-13 11:58:39.368995 JST | PolLossMax          1.78736
2019-02-13 11:58:39.370712 JST | VfLossAverage     485.349
2019-02-13 11:58:39.371813 JST | VfLossStd         771.761
2019-02-13 11:58:39.372729 JST | VfLossMedian       87.0832
2019-02-13 11:58:39.373829 JST | VfLossMin           2.07871
2019-02-13 11:58:39.374720 JST | VfLossMax        5557.67
2019-02-13 11:58:39.375687 JST | RewardAverage    -345
2019-02-13 11:58:39.376916 JST | RewardStd          77
2019-02-13 11:58:39.377832 JST | RewardMedian     -345
2019-02-13 11:58:39.378732 JST | RewardMin        -422
2019-02-13 11:58:39.379562 JST | RewardMax        -268
2019-02-13 11:58:39.380458 JST | EpisodePerIter      2
2019-02-13 11:58:39.381656 JST | TotalEpisode       98
2019-02-13 11:58:39.382846 JST | StepPerIter       692
2019-02-13 11:58:39.384387 JST | TotalStep       40290
2019-02-13 11:58:39.385793 JST | --------------  -------------
2019-02-13 11:58:39.833790 JST | sample: 0.4348sec
2019-02-13 11:58:39.849904 JST | 

- you can check progress by **garbage/Reward.png**

## 4.Visualize behavior after trainning
You can check best policy's behavior.

In [9]:
# load best policy
best_path = 'garbage/models/pol_max.pkl'
best_pol = CategoricalPol(observation_space, action_space, pol_net)
best_pol.load_state_dict(torch.load(best_path))

In [10]:
# show your trained policy's behavior
done = False
o = env.reset()
for _ in range(300): # show 300 frames (=20 sec)
    if done:
        time.sleep(1) # when the boundary　of eposode
        o = env.reset()
    ac_real, ac, a_i = pol.deterministic_ac_real(torch.tensor(o, dtype=torch.float))
    ac_real = ac_real.reshape(pol.action_space.shape)
    next_o, r, done, e_i = env.step(np.array(ac_real))
    o = next_o
    time.sleep(1/15) # 15fps
    env.render()

In [10]:
# close your environment
env.close()