In [None]:
!pip install tf-agents

Collecting tf-agents
  Downloading tf_agents-0.12.0-py3-none-any.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 5.2 MB/s 
Installing collected packages: tf-agents
Successfully installed tf-agents-0.12.0


In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import tensorflow as tf
import numpy as np
import unittest
import os
import io

from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.trajectories import time_step as ts
from tf_agents.networks import q_network
from tf_agents.agents.dqn import dqn_agent
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import py_driver
from tf_agents.drivers import dynamic_step_driver
from tf_agents.policies import random_tf_policy
from tf_agents.utils import common
from tf_agents.policies import policy_saver

# ENVIRONMENT

In [None]:
# ENVIRONMENT HYPERPARAMETERS

# Grid dimensions
WIDTH = 4
LEVELS = 8

# Number of members: beams, columns, and single diagonal brace per level
MAX_MEMBERS_PER_LEVEL = WIDTH * 3 + 1

# Total number of members
TOTAL_MEMBERS = MAX_MEMBERS_PER_LEVEL * LEVELS

# Brace limits
MAX_BRACE = 10

# Levels at and below requiring double braces
DOUBLE_BRACE_LEVEL = 3
DOUBLE_BRACE_ROW = LEVELS - DOUBLE_BRACE_LEVEL

In [None]:
class FrameEnv(py_environment.PyEnvironment):
    """
    Custom python environment
    game
    """
    def __init__(self):
        # action and environment specifications
        # action values: 0 - move no placement , 1 - place column, 2 - place beam, 3 - place brace
        self._action_spec = array_spec.BoundedArraySpec(
            shape=(), dtype=np.int32, minimum=0, maximum=3, name='action')
        self._observation_spec = array_spec.BoundedArraySpec(
            shape=(TOTAL_MEMBERS,), dtype=np.int32, minimum=0, maximum=4, name='observation')
        
        # environment space
        # state values: 0 - empty, 1 - column, 2 - beam, 3 - brace, 4 - agent
        self._state = np.zeros((TOTAL_MEMBERS,), dtype=np.int32)
        self._episode_ended=False
        self._state[0] = 4  # set agent at start
        self._level_num = 0  # counts backward with roof level = 0


    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def _reset(self):
        self._state = np.zeros((TOTAL_MEMBERS,), dtype=np.int32)
        self._state[0] = 4
        self._level_num = 0
        self._episode_ended = False
        return ts.restart(self._state)

    def _get_level(self, start_point):
        _end_point = (start_point + MAX_MEMBERS_PER_LEVEL)
        _level = self._state[start_point: _end_point]
        return _level

    def _get_level_lists(self, level_list):
        _column_array = level_list[0 : (WIDTH + 1)]
        _beam_array = level_list[(WIDTH + 1):(WIDTH * 2 + 1)]
        _brace_array = level_list[(WIDTH * 2 + 1):]
        return _column_array, _beam_array, _brace_array

    def _get_row_dims(self, position):
        # helper function to get 2D grid row from 1D array
        tens = position // MAX_MEMBERS_PER_LEVEL
        row_beg = tens * MAX_MEMBERS_PER_LEVEL
        row_end = row_beg + MAX_MEMBERS_PER_LEVEL - 1
        return row_beg, row_end

    def _check_column(self, position):
        _check = False
        # current level
        row_beg, _ = self._get_row_dims(position)
        if position < (row_beg + (WIDTH + 1)):
            _check = True
        return _check

    def _check_beam(self, position):
        _check = False
        # current level
        row_beg, _ = self._get_row_dims(position)
        if (row_beg + (WIDTH + 1)) <= position < (row_beg + (WIDTH * 2 + 1)):
            _check = True
        return _check

    def _check_brace(self, position):
        _check = False
        # current level
        row_beg, _ = self._get_row_dims(position)
        if position >= (row_beg + (WIDTH * 2 + 1)):
            _check = True
        return _check


    def _check_floor(self, position):
        """
        Floors are in this order
        column, column, ..., beam, beam, ..., brace, brace,...
        """
        _check = False

        # current level
        current_level = self._get_level(position)
        column_array, beam_array, brace_array = self._get_level_lists(current_level)

        # count beams, columns, and braces
        _column_count = np.count_nonzero(column_array == 1)
        _beam_count = np.count_nonzero(beam_array == 2)
        _brace_count = np.count_nonzero(brace_array == 3)

        # check member placement
        if (_column_count == (WIDTH + 1)) or (_beam_count == WIDTH):
            # check braces and alignment
            if self._level_num == 0:
                if _brace_count > 0:
                    _check = True
            else:
                # get info on level above
                _upper_position = position - MAX_MEMBERS_PER_LEVEL
                _upper_level = self._get_level(_upper_position)
                _, _, _upper_level_braces = self._get_level_lists(_upper_level)
                _upper_level_brace_count = np.count_nonzero(_upper_level_braces == 3)
                # get brace alignment info
                _aligned_braces = np.array_equal(_upper_level_braces, brace_array)
                _aligned_brace_count = np.where(np.logical_and(_upper_level_braces == 3, brace_array ==3))[0].size

                if self._level_num < DOUBLE_BRACE_ROW:
                    if (_brace_count > 0) and _aligned_braces:
                        _check = True
                elif self._level_num == DOUBLE_BRACE_ROW:
                    if (_brace_count > 1) and (_aligned_brace_count >= _upper_level_brace_count):
                        _check = True
                else:
                    if (_brace_count > 1) and _aligned_braces:
                        _check = True

        return _check

    def _step(self, action):
        if self._episode_ended:
            # The last action ended the episode. Ignore the current action and start
            # a new episode.
            return self.reset()

        _position = np.where(self._state == 4)[0].item()
        _next_position = _position
        _reward = 0

        # see what area of array position is in
        _column_col = self._check_column(_position)
        _beam_col = self._check_beam(_position)
        _brace_col = self._check_brace(_position)

        # actions
        if action == 0: # move to next member position
            if _column_col or _beam_col:
                self._episode_ended = True
                _reward += -1.0
            else:
                self._state[_position] = 0
                _next_position += 1
                _reward += 0.1
        elif action == 1: # place column
            if _column_col == False:
                self._episode_ended = True
                _reward += -1.0
            else:
                self._state[_position] = 1
                _next_position += 1
                _reward += 0.1
        elif action == 2: # place beam
            if _beam_col == False:
                self._episode_ended = True
                _reward += -1.0
            else:
                self._state[_position] = 2
                _next_position += 1
                _reward += 0.1
        elif action == 3: # place brace
            if _brace_col == False:
                self._episode_ended = True
                _reward += -1.0
            else:
                self._state[_position] = 3
                _next_position += 1
                _reward += -0.1
        else:
            raise ValueError('`action` should be 0 thru 3.')

        # check frame stability
        if self._episode_ended == False:
            if _next_position % MAX_MEMBERS_PER_LEVEL == 0:
                _row_start = _next_position - MAX_MEMBERS_PER_LEVEL
                if self._check_floor(_row_start) == True:
                    self._level_num += 1
                    _reward = 1.0

                    # check for win
                    if _next_position == TOTAL_MEMBERS:
                        self._episode_ended = True
                        _reward += 10.0
                else:
                    # level is unstable - LOOSE!!
                    self._episode_ended = True
                    _reward += -1.0


        if self._episode_ended:
            return ts.termination(self._state, reward=_reward)
        else:
            self._state[_next_position] = 4
            return ts.transition(self._state, reward=_reward)       

# Tests

In [None]:
class TestFramebot(unittest.TestCase):

    def test_reset(self):
        self.env = FrameEnv()
        # Add member and force episode to end
        self.env._state[0] = 0
        
        # check reset
        self.env.reset()
        result= self.env._state[0]
        result_sim_flag = self.env._episode_ended
        self.assertEqual(4, result)
        self.assertEqual(False, result_sim_flag)

    def test_get_level(self):
        self.env = FrameEnv()
        # fake an enviroment to test againts
        self.env._state = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3
                                    ])
        result = self.env._get_level(26)
        result2 = self.env._get_level(91)
        expected = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0])
        expected2 = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3])
        comparison = result == expected
        equal_arrays = comparison.all()
        self.assertEqual(True, equal_arrays)

    def test_get_level_lists(self):
        self.env = FrameEnv()
        self.env._state = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3
                                    ])
        level = self.env._get_level(26)
        result_column, result_beam, result_brace = self.env._get_level_lists(level)
        expect_column = [1, 1, 1, 1, 1]
        expect_beam   = [2, 2, 2, 2]
        expect_brace = [0, 0, 3, 0]
        compare_column = result_column == expect_column
        compare_beam = result_beam == expect_beam
        compare_brace = result_brace == expect_brace
        equal_column_arrays = compare_column.all()
        equal_beam_arrays = compare_beam.all()
        equal_brace_arrays = compare_brace.all()
        self.assertEqual(True, equal_column_arrays)
        self.assertEqual(True, equal_beam_arrays)
        self.assertEqual(True, equal_brace_arrays)

    def test__get_row_dims(self):
        self.env = FrameEnv()
        _position = 14
        start, end = self.env._get_row_dims(_position)
        self.assertEqual(13, start)
        self.assertEqual(25, end)

    def test_check_column(self):
        self.env = FrameEnv()
        _position = 14
        result_1 = self.env._check_column(_position)
        result_2 = self.env._check_beam(_position)
        result_3 = self.env._check_brace(_position)
        self.assertEqual(True, result_1)
        self.assertEqual(False, result_2)
        self.assertEqual(False, result_3)
    
    def test_check_beam(self):
        self.env = FrameEnv()
        _position = 18
        result_1 = self.env._check_column(_position)
        result_2 = self.env._check_beam(_position)
        result_3 = self.env._check_brace(_position)
        self.assertEqual(False, result_1)
        self.assertEqual(True, result_2)
        self.assertEqual(False, result_3)
    
    def test_check_brace(self):
        self.env = FrameEnv()
        _position = 22
        result_1 = self.env._check_column(_position)
        result_2 = self.env._check_beam(_position)
        result_3 = self.env._check_brace(_position)
        self.assertEqual(False, result_1)
        self.assertEqual(False, result_2)
        self.assertEqual(True, result_3)

    def test_check_floor_ok(self):
        self.env = FrameEnv()
        # fake an enviroment to test againts
        #                           0  1  2  3  4  5  6  7  8  9 10 11 12
        self.env._state = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3
                                    ])
        self.env._level_num = 1
        result_1 = self.env._check_floor(13)
        self.env._level_num = 2
        result_2 = self.env._check_floor(26)
        self.env._level_num = 3
        result_3 = self.env._check_floor(39)
        self.env._level_num = 4
        result_4 = self.env._check_floor(52)
        self.env._level_num = 5
        result_5 = self.env._check_floor(65)
        self.env._level_num = 6
        result_6 = self.env._check_floor(78)
        self.env._level_num = 7
        result_7 = self.env._check_floor(91)
        self.assertEqual(True, result_1)
        self.assertEqual(True, result_2)
        self.assertEqual(True, result_3)
        self.assertEqual(True, result_4)
        self.assertEqual(True, result_5)
        self.assertEqual(True, result_6)
        self.assertEqual(True, result_7)

    def test_check_floor_ng(self):
        self.env = FrameEnv()
        # fake an enviroment to test againts
        # fail tests - levels stuffed with errors
                                  # 0  1  2  3  4  5  6  7  8  9 10 11 12
        self.env._state = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 3, 0, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 0, 0,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 3,
                                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3
                                    ])
        self.env._level_num = 1
        result_1 = self.env._check_floor(13)
        self.env._level_num = 2
        result_2 = self.env._check_floor(26)
        self.env._level_num = 3
        result_3 = self.env._check_floor(39)
        self.env._level_num = 4
        result_4 = self.env._check_floor(52)
        self.env._level_num = 5
        result_5 = self.env._check_floor(65)
        self.env._level_num = 6
        result_6 = self.env._check_floor(78)
        self.env._level_num = 7
        result_7 = self.env._check_floor(91)
        self.assertEqual(False, result_1)
        self.assertEqual(False, result_2)
        self.assertEqual(False, result_3)
        self.assertEqual(False, result_4)
        self.assertEqual(False, result_5)
        self.assertEqual(False, result_6)
        self.assertEqual(False, result_7)

    
    def test_step(self):
        self.env = FrameEnv()
        actions = [ 1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
                    1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3
                    ]
        action_position = 0
        for action in actions:
            self.env._step(action)
            result = self.env._state[action_position]
            expected = actions[action_position]
            self.assertEqual(action, result)
            action_position += 1

        result2 = self.env._episode_ended
        self.assertEqual(True, result2)



# run code to execute unit tests on environment
unittest.main(argv=[''], verbosity=2, exit=False)  

test__get_row_dims (__main__.TestFramebot) ... ok
test_check_beam (__main__.TestFramebot) ... ok
test_check_brace (__main__.TestFramebot) ... ok
test_check_column (__main__.TestFramebot) ... ok
test_check_floor_ng (__main__.TestFramebot) ... ok
test_check_floor_ok (__main__.TestFramebot) ... ok
test_get_level (__main__.TestFramebot) ... ok
test_get_level_lists (__main__.TestFramebot) ... ok
test_reset (__main__.TestFramebot) ... ok
test_step (__main__.TestFramebot) ... ok

----------------------------------------------------------------------
Ran 10 tests in 0.032s

OK


<unittest.main.TestProgram at 0x7f909b5c6a50>

# TEST RUN

In [None]:
env = FrameEnv()
time_step = env.reset()
print(time_step)
cumulative_reward = time_step.reward
actions = [ 1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
            1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
            1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
            1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
            1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0,
            1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
            1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3,
            1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3
            ]
i = 0
for action in actions:
    time_step = env._step(action)
    cumulative_reward += time_step.reward
    if env._episode_ended == True:
        print("done")
        break
    i += 1

print(time_step)
cumulative_reward += time_step.reward
print('Final Reward = ', cumulative_reward)

TimeStep(
{'discount': array(1., dtype=float32),
 'observation': array([4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32),
 'reward': array(0., dtype=float32),
 'step_type': array(0, dtype=int32)})
done
TimeStep(
{'discount': array(0., dtype=float32),
 'observation': array([1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2,
       0, 0, 3, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1,
       2, 2, 2, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 0, 1,
       1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3, 1, 1, 1, 1, 1, 2, 2, 2, 2, 0,
       0, 3, 3, 1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 3, 3], dtype=int32),
 'reward': array(11., dtype=float32),
 'step_type': array(2, dtype

# MODEL

In [None]:
frame_env = FrameEnv()
# environment check
utils.validate_py_environment(frame_env, episodes=5)

# Tensorflow environments
train_env = tf_py_environment.TFPyEnvironment(frame_env)
eval_env = tf_py_environment.TFPyEnvironment(frame_env)

In [None]:
# Hyperparameters

fc_layer_params = [2056, 2056, 1028]

initial_collect_steps = 1000  
collect_steps_per_iteration = 1
target_update_period = 500
replay_buffer_capacity = 50000  # was 100,000 - too big?

batch_size = 64
learning_rate = 1e-4
# gradient_clipping = 0.9
gamma = 0.9    # need to really mess with this if things go sideways

In [None]:
# Agent

q_net = q_network.QNetwork(
    train_env.observation_spec(), 
    train_env.action_spec(),
    fc_layer_params=fc_layer_params,
    activation_fn = tf.keras.activations.relu)

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
global_step = tf.compat.v1.train.get_or_create_global_step()

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    target_update_period = target_update_period,
    td_errors_loss_fn=common.element_wise_huber_loss,
    gamma = gamma,
    train_step_counter = global_step)

agent.initialize()

In [None]:
# Metrics and evaluation

train_metrics = [
        tf_metrics.AverageReturnMetric(),
        tf_metrics.AverageEpisodeLengthMetric(),
    ]

In [None]:
# Data Collection

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_capacity)

collect_driver = dynamic_step_driver.DynamicStepDriver(
    train_env,
    agent.collect_policy,
    observers=[replay_buffer.add_batch] + train_metrics,
    num_steps=collect_steps_per_iteration)

# Initial data collection
collect_driver.run()

# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, sample_batch_size=batch_size,
    num_steps=2)

iterator = iter(dataset)

Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.


# Training Function

In [None]:
# Training
def train_agent(num_iterations):

    try:
        %%time
    except:
        pass

    # (Optional) Optimize by wrapping some of the code in a graph using TF function.
    agent.train = common.function(agent.train)

    all_train_loss = []
    all_metrics = []

    for _ in range(num_iterations):
        current_metrics = []

        # Collect a few steps using collect_policy and save to the replay buffer.
        collect_driver.run()

        # Sample a batch of data from the buffer and update the agent's network.
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience)
        all_train_loss.append(train_loss.loss.numpy())

        iteration  = agent.train_step_counter.numpy()
        
        for i in range(len(train_metrics)):
            current_metrics.append((train_metrics[i].result()).numpy())
                
        all_metrics.append(current_metrics)
        
        if iteration % 1000 == 0:
            print("\nIteration: {}, loss:{:.2f}".format(iteration, train_loss.loss.numpy()))
            
            for i in range(len(train_metrics)):
                print('{}: {}'.format(train_metrics[i].name, train_metrics[i].result().numpy()))

# Checkpoint and PolicySaver

In [None]:
tempdir = '/content/drive/MyDrive/framebot/data_chkpts_2/'

In [None]:
# Checkpointer

checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
    ckpt_dir = checkpoint_dir,
    max_to_keep = 1,
    agent = agent,
    policy = agent.policy,
    replay_buffer = replay_buffer,
    global_step = global_step
)

In [None]:
# Policy Saver
policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)

# TRAINING

In [None]:
def train_and_save(num_iterations):
    # train agent
    train_agent(num_iterations)
    # save checkpoint
    train_checkpointer.save(global_step)
    # save policy
    tf_policy_saver.save(policy_dir)

In [None]:
def reload_train_save(num_iterations):

    train_checkpointer.initialize_or_restore()
    global_step = tf.compat.v1.train.get_global_step()

    # train
    train_and_save(num_iterations)

In [None]:
# TODO make some presistant variable counters to keep track of
# total training epocs for all training 'sessions'

In [None]:
# train agent
num_iterations = 200000 # 200,000
train_and_save(num_iterations)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.44 µs
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))


Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.foldr(fn, elems, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))



Iteration: 1000, loss:0.01
AverageReturn: -0.20999988913536072
AverageEpisodeLength: 8.600000381469727

Iteration: 2000, loss:0.05
AverageReturn: -0.23999997973442078
AverageEpisodeLength: 7.199999809265137

Iteration: 3000, loss:0.01
AverageReturn: 0.6800000071525574
AverageEpisodeLength: 12.399999618530273

Iteration: 4000, loss:0.04
AverageReturn: -0.14999990165233612
AverageEpisodeLength: 7.699999809265137

Iteration: 5000, loss:0.07
AverageReturn: -0.19999992847442627
AverageEpisodeLength: 7.199999809265137

Iteration: 6000, loss:0.04
AverageReturn: 0.20000004768371582
AverageEpisodeLength: 9.600000381469727

Iteration: 7000, loss:0.05
AverageReturn: 0.29000014066696167
AverageEpisodeLength: 10.199999809265137

Iteration: 8000, loss:0.07
AverageReturn: 0.43000006675720215
AverageEpisodeLength: 11.600000381469727

Iteration: 9000, loss:0.07
AverageReturn: 0.25000011920928955
AverageEpisodeLength: 9.899999618530273

Iteration: 10000, loss:0.04
AverageReturn: 0.06999991834163666
Ave

In [None]:
# load and restore agent for training
#num_iterations = 10
#reload_train_save(num_iterations)