-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.lua
70 lines (62 loc) · 1.83 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
--
-- Copyright (c) 2016, Horizon Robotics, Inc.
-- All rights reserved.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Author: Yao Zhou, yao.zhou@hobot.cc
--
if not deeprl then
require 'init'
end
local cmd = torch.CmdLine()
cmd:text('params setting')
cmd:option('-seed', os.time(), 'initial random seed')
cmd:option('-win_height', 18, 'environment window height')
cmd:option('-win_width', 16, 'environment window width')
cmd:option('-max_men', 1000, 'max memory size')
cmd:option('-bsize', 16, 'training batch size')
cmd:option('-n_actions', 3, 'number of actions')
cmd:option('-discount', 0.9, 'discount factor gamma ')
cmd:option('-hid_dim', 128, 'dimension of hidden states')
cmd:option('-epoch', 1000, 'training epoch')
cmd:option('-epsilon', 1, 'epsilon, random sampling rate')
cmd:option('-duel', true, 'using dueling network')
cmd:option('-task', 'pong', 'game option')
cmd:text()
-- parse arguments
local opt = cmd:parse(arg or {})
math.randomseed(opt.seed)
local env_config = {
win_height = opt.win_height,
win_width = opt.win_width,
}
local agent_config = {
max_men = opt.max_men,
bsize = opt.bsize,
n_actions = opt.n_actions,
n_states = opt.win_height * opt.win_width,
discount = opt.discount,
hid_dim = opt.hid_dim,
duel = opt.duel,
optim_config = {
learningRate = 0.1,
}
}
local learner_config = {
env_config = env_config,
agent_config = agent_config,
task = opt.task,
epsilon = opt.epsilon,
epoch = opt.epoch,
}
local learner = deeprl.learner(learner_config)
if opt.task == 'car' then
learner:run_car()
learner:test_car(100)
else
learner:run()
learner:test(1000)
end