In [5]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from stable_baselines.common.atari_wrappers import make_atari, wrap_deepmind
import time

In [6]:
# Configuration paramaters for the whole setup
seed = 42
gamma = 0.99  # Discount factor for past rewards
epsilon = 1.0  # Epsilon greedy parameter
epsilon_min = 0.1  # Minimum epsilon greedy parameter
epsilon_max = 1.0  # Maximum epsilon greedy parameter
epsilon_interval = (
    epsilon_max - epsilon_min
)  # Rate at which to reduce chance of random action being taken
batch_size = 32  # Size of batch taken from replay buffer
max_steps_per_episode = 10000

# Use the Baseline Atari environment because of Deepmind helper functions
env = make_atari("BreakoutNoFrameskip-v4")
# Warp the frames, grey scale, stake four frame and scale to smaller ratio
env = wrap_deepmind(env, frame_stack=True, scale=True)
env.seed(seed)

num_actions = 4




In [None]:

def create_q_model():
    # Network defined by the Deepmind paper
    inputs = layers.Input(shape=(84, 84, 4,))

    # Convolutions on the frames on the screen
    layer1 = layers.Conv2D(32, 8, strides=4, activation="relu")(inputs)
    layer2 = layers.Conv2D(64, 4, strides=2, activation="relu")(layer1)
    layer3 = layers.Conv2D(64, 3, strides=1, activation="relu")(layer2)

    layer4 = layers.Flatten()(layer3)

    layer5 = layers.Dense(512, activation="relu")(layer4)
    action = layers.Dense(num_actions, activation="linear")(layer5)

    return keras.Model(inputs=inputs, outputs=action)


# The first model makes the predictions for Q-values which are used to
# make a action.
model = create_q_model()
# Build a target model for the prediction of future rewards.
# The weights of a target model get updated every 10000 steps thus when the
# loss between the Q-values is calculated the target Q-value is stable.
model_target = create_q_model()

# In the Deepmind paper they use RMSProp however then Adam optimizer
# improves training time
optimizer = keras.optimizers.Adam(learning_rate=0.00025, clipnorm=1.0)

# Experience replay buffers
action_history = []
state_history = []
state_next_history = []
rewards_history = []
done_history = []
episode_reward_history = []
running_reward = 0
episode_count = 0
frame_count = 0
# Number of frames to take random action and observe output
epsilon_random_frames = 50000
# Number of frames for exploration
epsilon_greedy_frames = 1000000.0
# Maximum replay length
# Note: The Deepmind paper suggests 1000000 however this causes memory issues
# max_memory_length = 100000
max_memory_length = 10000
# Train the model after 4 actions
update_after_actions = 4
# How often to update the target network
update_target_network = 10000
# Using huber loss for stability
loss_function = keras.losses.Huber()

while True:  # Run until solved
    state = np.array(env.reset())
    episode_reward = 0

    for timestep in range(1, max_steps_per_episode):
        env.render(); # Adding this line would show the attempts
        # time.sleep(0.01)
        # of the agent in a pop up window.
        frame_count += 1

        # Use epsilon-greedy for exploration
        if frame_count < epsilon_random_frames or epsilon > np.random.rand(1)[0]:
            # Take random action
            action = np.random.choice(num_actions)
        else:
            # Predict action Q-values
            # From environment state
            state_tensor = tf.convert_to_tensor(state)
            state_tensor = tf.expand_dims(state_tensor, 0)
            action_probs = model(state_tensor, training=False)
            # Take best action
            action = tf.argmax(action_probs[0]).numpy()

        # Decay probability of taking random action
        epsilon -= epsilon_interval / epsilon_greedy_frames
        epsilon = max(epsilon, epsilon_min)

        # Apply the sampled action in our environment
        state_next, reward, done, _ = env.step(action)
        state_next = np.array(state_next)

        episode_reward += reward

        # Save actions and states in replay buffer
        action_history.append(action)
        state_history.append(state)
        state_next_history.append(state_next)
        done_history.append(done)
        rewards_history.append(reward)
        state = state_next

        # Update every fourth frame and once batch size is over 32
        if frame_count % update_after_actions == 0 and len(done_history) > batch_size:

            # Get indices of samples for replay buffers
            indices = np.random.choice(range(len(done_history)), size=batch_size)

            # Using list comprehension to sample from replay buffer
            state_sample = np.array([state_history[i] for i in indices])
            state_next_sample = np.array([state_next_history[i] for i in indices])
            rewards_sample = [rewards_history[i] for i in indices]
            action_sample = [action_history[i] for i in indices]
            done_sample = tf.convert_to_tensor(
                [float(done_history[i]) for i in indices]
            )

            # Build the updated Q-values for the sampled future states
            # Use the target model for stability
            future_rewards = model_target.predict(state_next_sample)
            # Q value = reward + discount factor * expected future reward
            updated_q_values = rewards_sample + gamma * tf.reduce_max(
                future_rewards, axis=1
            )

            # If final frame set the last value to -1
            updated_q_values = updated_q_values * (1 - done_sample) - done_sample

            # Create a mask so we only calculate loss on the updated Q-values
            masks = tf.one_hot(action_sample, num_actions)

            with tf.GradientTape() as tape:
                # Train the model on the states and updated Q-values
                q_values = model(state_sample)

                # Apply the masks to the Q-values to get the Q-value for action taken
                q_action = tf.reduce_sum(tf.multiply(q_values, masks), axis=1)
                # Calculate loss between new Q-value and old Q-value
                loss = loss_function(updated_q_values, q_action)

            # Backpropagation
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if frame_count % update_target_network == 0:
            # update the the target network with new weights
            model_target.set_weights(model.get_weights())
            # Log details
            template = "running reward: {:.2f} at episode {}, frame count {}"
            print(template.format(running_reward, episode_count, frame_count))

        # Limit the state and reward history
        if len(rewards_history) > max_memory_length:
            del rewards_history[:1]
            del state_history[:1]
            del state_next_history[:1]
            del action_history[:1]
            del done_history[:1]

        if done:
            break

    # Update running reward to check condition for solving
    episode_reward_history.append(episode_reward)
    if len(episode_reward_history) > 100:
        del episode_reward_history[:1]
    running_reward = np.mean(episode_reward_history)

    episode_count += 1
    print(str(episode_count) + ": " + str(running_reward))
    if running_reward > 40:  # Condition to consider the task solved
        print("Solved at episode {}!".format(episode_count))
        break



1: 2.0
2: 1.5
3: 1.0
4: 0.75
5: 0.6
6: 0.5
7: 0.42857142857142855
8: 0.375
9: 0.3333333333333333
10: 0.3
11: 0.2727272727272727
12: 0.25
13: 0.23076923076923078
14: 0.21428571428571427
15: 0.2
16: 0.1875
17: 0.17647058823529413
18: 0.16666666666666666
19: 0.15789473684210525
20: 0.15
21: 0.14285714285714285
22: 0.13636363636363635
23: 0.21739130434782608
24: 0.20833333333333334
25: 0.2
26: 0.19230769230769232
27: 0.18518518518518517
28: 0.17857142857142858
29: 0.1724137931034483
30: 0.16666666666666666
31: 0.16129032258064516
32: 0.15625
33: 0.21212121212121213
34: 0.20588235294117646
35: 0.22857142857142856
36: 0.2222222222222222
37: 0.21621621621621623
38: 0.2631578947368421
39: 0.2564102564102564
40: 0.25
41: 0.24390243902439024
42: 0.23809523809523808
43: 0.23255813953488372
44: 0.22727272727272727
45: 0.2222222222222222
46: 0.21739130434782608
47: 0.2127659574468085
48: 0.20833333333333334
49: 0.20408163265306123
50: 0.24
51: 0.27450980392156865
52: 0.28846153846153844
53: 0.30188

710: 0.39
711: 0.39
712: 0.39
713: 0.41
714: 0.4
715: 0.41
716: 0.41
717: 0.41
718: 0.41
719: 0.41
720: 0.41
721: 0.41
722: 0.43
723: 0.43
724: 0.44
725: 0.44
726: 0.44
727: 0.44
728: 0.46
729: 0.46
730: 0.46
731: 0.48
732: 0.48
733: 0.46
734: 0.47
735: 0.48
736: 0.46
737: 0.46
738: 0.46
739: 0.46
740: 0.47
741: 0.47
742: 0.47
743: 0.45
744: 0.45
745: 0.45
746: 0.47
747: 0.47
748: 0.48
749: 0.49
750: 0.49
751: 0.49
752: 0.51
753: 0.51
754: 0.51
755: 0.52
756: 0.52
757: 0.52
running reward: 0.52 at episode 757, frame count 30000
758: 0.52
759: 0.52
760: 0.54
761: 0.52
762: 0.52
763: 0.52
764: 0.51
765: 0.5
766: 0.5
767: 0.48
768: 0.48
769: 0.5
770: 0.5
771: 0.52
772: 0.53
773: 0.53
774: 0.53
775: 0.54
776: 0.54
777: 0.54
778: 0.54
779: 0.54
780: 0.53
781: 0.53
782: 0.53
783: 0.53
784: 0.53
785: 0.53
786: 0.51
787: 0.49
788: 0.48
789: 0.48
790: 0.47
791: 0.45
792: 0.45
793: 0.47
794: 0.47
795: 0.47
796: 0.47
797: 0.45
798: 0.44
799: 0.44
800: 0.44
801: 0.42
802: 0.41
803: 0.4
804: 0.41
8

1474: 0.4
1475: 0.4
1476: 0.42
1477: 0.42
1478: 0.4
1479: 0.39
1480: 0.4
1481: 0.4
1482: 0.4
1483: 0.4
1484: 0.4
1485: 0.4
1486: 0.4
1487: 0.42
1488: 0.42
1489: 0.42
1490: 0.42
1491: 0.42
1492: 0.42
1493: 0.41
1494: 0.42
1495: 0.41
1496: 0.39
1497: 0.39
1498: 0.39
1499: 0.41
1500: 0.42
1501: 0.42
1502: 0.42
1503: 0.42
1504: 0.4
1505: 0.39
1506: 0.37
1507: 0.36
1508: 0.36
1509: 0.36
1510: 0.35
1511: 0.35
1512: 0.35
1513: 0.35
1514: 0.35
1515: 0.35
1516: 0.35
1517: 0.37
1518: 0.37
1519: 0.37
1520: 0.37
1521: 0.39
1522: 0.39
1523: 0.39
1524: 0.39
1525: 0.39
1526: 0.39
1527: 0.37
1528: 0.38
1529: 0.37
1530: 0.37
1531: 0.36
1532: 0.36
1533: 0.34
1534: 0.34
1535: 0.34
1536: 0.34
1537: 0.34
1538: 0.34
1539: 0.34
1540: 0.34
1541: 0.32
1542: 0.32
1543: 0.34
1544: 0.33
1545: 0.34
1546: 0.32
1547: 0.33
1548: 0.32
1549: 0.32
1550: 0.32
1551: 0.32
1552: 0.32
1553: 0.32
1554: 0.32
1555: 0.32
1556: 0.32
1557: 0.32
1558: 0.3
1559: 0.3
1560: 0.3
1561: 0.3
1562: 0.3
1563: 0.3
1564: 0.3
1565: 0.3
1566: 0

2213: 0.46
2214: 0.47
2215: 0.47
2216: 0.45
2217: 0.44
2218: 0.43
2219: 0.42
2220: 0.42
2221: 0.42
2222: 0.44
2223: 0.43
2224: 0.42
2225: 0.42
2226: 0.42
2227: 0.42
2228: 0.42
2229: 0.42
2230: 0.42
2231: 0.44
2232: 0.44
2233: 0.45
2234: 0.46
2235: 0.46
2236: 0.48
2237: 0.48
2238: 0.49
2239: 0.5
2240: 0.49
2241: 0.51
2242: 0.51
2243: 0.51
2244: 0.5
2245: 0.51
2246: 0.49
2247: 0.49
2248: 0.48
2249: 0.48
2250: 0.47
2251: 0.45
2252: 0.45
2253: 0.44
2254: 0.44
2255: 0.43
2256: 0.41
2257: 0.41
2258: 0.41
2259: 0.43
2260: 0.43
2261: 0.43
2262: 0.43
2263: 0.43
2264: 0.43
2265: 0.43
2266: 0.43
2267: 0.43
2268: 0.43
2269: 0.43
2270: 0.43
2271: 0.43
2272: 0.41
2273: 0.41
2274: 0.4
2275: 0.39
2276: 0.39
2277: 0.37
2278: 0.37
2279: 0.39
2280: 0.39
2281: 0.37
2282: 0.37
2283: 0.37
2284: 0.36
2285: 0.35
2286: 0.35
2287: 0.35
2288: 0.37
2289: 0.37
2290: 0.38
2291: 0.38
2292: 0.38
2293: 0.38
2294: 0.38
2295: 0.38
2296: 0.4
2297: 0.4
2298: 0.38
2299: 0.37
2300: 0.37
2301: 0.37
2302: 0.39
2303: 0.39
2304

2951: 0.29
2952: 0.31
2953: 0.29
2954: 0.28
2955: 0.28
2956: 0.28
2957: 0.28
2958: 0.3
2959: 0.3
2960: 0.3
2961: 0.3
2962: 0.3
2963: 0.3
2964: 0.3
2965: 0.3
2966: 0.32
2967: 0.34
2968: 0.35
2969: 0.36
2970: 0.37
2971: 0.37
2972: 0.37
2973: 0.37
2974: 0.39
2975: 0.39
2976: 0.38
2977: 0.38
2978: 0.38
2979: 0.37
2980: 0.36
2981: 0.36
2982: 0.36
2983: 0.34
2984: 0.33
2985: 0.33
2986: 0.33
2987: 0.33
2988: 0.33
2989: 0.33
2990: 0.33
2991: 0.33
2992: 0.33
2993: 0.33
2994: 0.34
2995: 0.35
2996: 0.35
2997: 0.35
2998: 0.35
2999: 0.35
3000: 0.35
3001: 0.37
3002: 0.38
3003: 0.39
3004: 0.39
3005: 0.4
3006: 0.4
3007: 0.4
3008: 0.4
3009: 0.39
3010: 0.4
3011: 0.4
3012: 0.39
3013: 0.38
3014: 0.38
3015: 0.38
3016: 0.38
3017: 0.38
3018: 0.4
3019: 0.4
3020: 0.4
3021: 0.4
3022: 0.4
3023: 0.4
3024: 0.38
3025: 0.38
3026: 0.4
3027: 0.41
3028: 0.42
3029: 0.4
3030: 0.41
3031: 0.41
3032: 0.41
3033: 0.41
3034: 0.43
3035: 0.41
3036: 0.41
3037: 0.43
3038: 0.43
3039: 0.43
3040: 0.44
3041: 0.44
3042: 0.45
3043: 0.46