Skip to content

Commit

Permalink
[a3c] specify dir to save train logs
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 17, 2020
1 parent 963e510 commit 21a6984
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions examples/A3C-Gym/train-atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,6 @@ def get_training_dataflow(self):


def train():
assert tf.test.is_gpu_available(), "Training requires GPUs!"
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)

# assign GPUs for training & inference
num_gpu = get_num_gpu()
global PREDICTOR_THREAD
Expand Down Expand Up @@ -275,9 +271,11 @@ def train():
parser.add_argument('--env', help='env', required=True)
parser.add_argument('--task', help='task to perform',
choices=['play', 'eval', 'train', 'dump_video'], default='train')
parser.add_argument('--output', help='output directory for submission', default='output_dir')
parser.add_argument('--output', help='output directory for logs and videos')
parser.add_argument('--episode', help='number of episode to eval', default=100, type=int)
args = parser.parse_args()
if args.output is None:
args.output = os.path.join('train_log', 'train-atari-{}'.format(args.env))

ENV_NAME = args.env
NUM_ACTIONS = get_player().action_space.n
Expand All @@ -303,4 +301,6 @@ def train():
get_player(train=False, dumpdir=args.output),
pred, args.episode)
else:
assert tf.test.is_gpu_available(), "Training requires GPUs!"
logger.set_logger_dir(args.output)
train()

0 comments on commit 21a6984

Please sign in to comment.