In [1]:
import numpy as np

from catch_ball import CatchBall
from dqn_agent import DQNAgent


if __name__ == "__main__":
    # parameters
    n_epochs = 1000

    # environment, agent
    env = CatchBall()
    agent = DQNAgent(env.enable_actions, env.name)

    # variables
    win = 0

    for e in range(n_epochs):
        # reset
        frame = 0
        loss = 0.0
        Q_max = 0.0
        env.reset()
        state_t_1, reward_t, terminal = env.observe()

        while not terminal:
            state_t = state_t_1

            # execute action in environment
            action_t = agent.select_action(state_t, agent.exploration)
            env.execute_action(action_t)

            # observe environment
            state_t_1, reward_t, terminal = env.observe()

            # store experience
            agent.store_experience(state_t, action_t, reward_t, state_t_1, terminal)

            # experience replay
            agent.experience_replay()

            # for log
            frame += 1
            loss += agent.current_loss
            Q_max += np.max(agent.Q_values(state_t))
            if reward_t == 1:
                win += 1

        print("EPOCH: {:03d}/{:03d} | WIN: {:03d} | LOSS: {:.4f} | Q_MAX: {:.4f}".format(
            e, n_epochs - 1, win, loss / frame, Q_max / frame))

    # save model
    agent.save_model()

EPOCH: 000/999 | WIN: 000 | LOSS: 0.0068 | Q_MAX: 0.0005
EPOCH: 001/999 | WIN: 000 | LOSS: 0.0477 | Q_MAX: 0.0002
EPOCH: 002/999 | WIN: 001 | LOSS: 0.0505 | Q_MAX: -0.0004
EPOCH: 003/999 | WIN: 001 | LOSS: 0.0419 | Q_MAX: -0.0001
EPOCH: 004/999 | WIN: 002 | LOSS: 0.0475 | Q_MAX: -0.0006
EPOCH: 005/999 | WIN: 002 | LOSS: 0.0490 | Q_MAX: -0.0004
EPOCH: 006/999 | WIN: 003 | LOSS: 0.0296 | Q_MAX: -0.0011
EPOCH: 007/999 | WIN: 004 | LOSS: 0.0475 | Q_MAX: -0.0005
EPOCH: 008/999 | WIN: 004 | LOSS: 0.0517 | Q_MAX: 0.0032
EPOCH: 009/999 | WIN: 005 | LOSS: 0.0430 | Q_MAX: 0.0072
EPOCH: 010/999 | WIN: 006 | LOSS: 0.0516 | Q_MAX: 0.0109
EPOCH: 011/999 | WIN: 006 | LOSS: 0.0528 | Q_MAX: 0.0154
EPOCH: 012/999 | WIN: 006 | LOSS: 0.0372 | Q_MAX: 0.0167
EPOCH: 013/999 | WIN: 006 | LOSS: 0.0533 | Q_MAX: 0.0178
EPOCH: 014/999 | WIN: 006 | LOSS: 0.0403 | Q_MAX: 0.0183
EPOCH: 015/999 | WIN: 006 | LOSS: 0.0357 | Q_MAX: 0.0145
EPOCH: 016/999 | WIN: 007 | LOSS: 0.0382 | Q_MAX: 0.0147
EPOCH: 017/999 | WIN: 008