In [None]:
import gym
import ptan
import argparse

import torch
import torch.optim as optim

from tensorboardX import SummaryWriter

from lib import dqn_model, common

REWARD_STEPS_DEFAULT = 2


if __name__ == "__main__":
    params = common.HYPERPARAMS['pong']
    #parser = argparse.ArgumentParser()
    #parser.add_argument("--cuda", default=False, action="store_true", help="Enable cuda")
    #parser.add_argument("-n", default=REWARD_STEPS_DEFAULT, type=int, help="Count of steps to unroll Bellman")
    #args = parser.parse_args()
    device = torch.device("cuda")

    env = gym.make(params['env_name'])
    env = ptan.common.wrappers.wrap_dqn(env)

    writer = SummaryWriter(comment="-" + params['run_name'] + "-%d-step" % REWARD_STEPS_DEFAULT)
    net = dqn_model.DQN(env.observation_space.shape, env.action_space.n).to(device)

    tgt_net = ptan.agent.TargetNet(net)
    selector = ptan.actions.EpsilonGreedyActionSelector(epsilon=params['epsilon_start'])
    epsilon_tracker = common.EpsilonTracker(selector, params)
    agent = ptan.agent.DQNAgent(net, selector, device=device)

    exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=params['gamma'], steps_count=REWARD_STEPS_DEFAULT)
    buffer = ptan.experience.ExperienceReplayBuffer(exp_source, buffer_size=params['replay_size'])
    optimizer = optim.Adam(net.parameters(), lr=params['learning_rate'])

    frame_idx = 0

    with common.RewardTracker(writer, params['stop_reward']) as reward_tracker:
        while True:
            env.render()
            frame_idx += 1
            buffer.populate(1)
            epsilon_tracker.frame(frame_idx)

            new_rewards = exp_source.pop_total_rewards()
            if new_rewards:
                if reward_tracker.reward(new_rewards[0], frame_idx, selector.epsilon):
                    break

            if len(buffer) < params['replay_initial']:
                continue

            optimizer.zero_grad()
            batch = buffer.sample(params['batch_size'])
            loss_v = common.calc_loss_dqn(batch, net, tgt_net.target_model,
                                          gamma=params['gamma']**REWARD_STEPS_DEFAULT, device=device)
            loss_v.backward()
            optimizer.step()

            if frame_idx % params['target_net_sync'] == 0:
                tgt_net.sync()

912: done 1 games, mean reward -20.000, speed 89.47 f/s, eps 0.99
1668: done 2 games, mean reward -20.500, speed 125.69 f/s, eps 0.98
2667: done 3 games, mean reward -20.333, speed 122.63 f/s, eps 0.97
3472: done 4 games, mean reward -20.500, speed 125.53 f/s, eps 0.97
4270: done 5 games, mean reward -20.600, speed 125.90 f/s, eps 0.96
5121: done 6 games, mean reward -20.667, speed 126.20 f/s, eps 0.95
6163: done 7 games, mean reward -20.571, speed 126.47 f/s, eps 0.94
6921: done 8 games, mean reward -20.625, speed 123.67 f/s, eps 0.93
7737: done 9 games, mean reward -20.667, speed 124.91 f/s, eps 0.92
8497: done 10 games, mean reward -20.700, speed 124.83 f/s, eps 0.92
9284: done 11 games, mean reward -20.727, speed 118.34 f/s, eps 0.91
10103: done 12 games, mean reward -20.750, speed 93.11 f/s, eps 0.90
10942: done 13 games, mean reward -20.769, speed 33.79 f/s, eps 0.89
11806: done 14 games, mean reward -20.786, speed 33.77 f/s, eps 0.88
12678: done 15 games, mean reward -20.733, sp