In [None]:
from torch import randint
from torch import nn, optim
import torch 
import gym
import numpy as np

from collections import deque
import random

from scores.score_logger import ScoreLogger



In [None]:
ENV_NAME = "CartPole-v1"
GAMMA = 0.95
MEMORY = 1000000
BATCH_SIZE = 20
LEARNING_RATE = 0.01
EXPLORATION_DECAY = 0.995
EXPLORATION_MAX = 1.0
EXPLORATION_MIN = 0.01


In [228]:
class DQN:
    def __init__(self, observation_space, action_space):
        self.model = nn.Sequential(
            nn.Linear(observation_space, 24),
            nn.ReLU(),
            nn.Linear(24, 24),
            nn.ReLU(),
            nn.Linear(24, action_space)
        )
        self.observation_space = observation_space
        self.action_space = action_space
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.loss_fn = nn.MSELoss()
        self.exploration_rate = EXPLORATION_MAX
        self.discount = GAMMA
        self.memory = deque(maxlen=MEMORY)

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    # Sometimes act randomly. Do so less and less as the exploration rate decays.
    def act(self, state):
        if (np.random.rand() < self.exploration_rate):
            return random.randrange(self.action_space)
        # print(self.model(torch.from_numpy(state)).argmax().item())
        return self.model(torch.from_numpy(state)).argmax().item()
    
    def get_q_next(self, next_state):
        return self.discount * self.model(torch.from_numpy(next_state)).max()

    def experience_replay(self):
        # Don't replay if we don't have enough memory
        if len(self.memory) < BATCH_SIZE:
            return
            
        batch = random.sample(self.memory, BATCH_SIZE)
        # self.optimizer.zero_grad()
        for state, action, reward, next_state, terminal in batch:    
            q_update = reward
            # Update the q value for the action we took
            # Bellman inspired update
            # Current state rewards plus next state rewards discounted by gamma
            if not terminal:
                q_update = reward + self.get_q_next(next_state)
            else: 
                # create long tensor
                q_update = torch.tensor(q_update, dtype=torch.float32)
            
            ## Get the q_values for the current state
            q_values = self.model(torch.from_numpy(state))
            prediction, _ = torch.max(q_values, axis=1)  

            loss = self.loss_fn(prediction, q_update.reshape(1))

            # We reset the optimizer each time because we are training in batches of one
            self.optimizer.zero_grad()

            # Back propagate the loss
            loss.backward(retain_graph=True)

            # Update the weights
            self.optimizer.step()   
            
        # Decay the exploration rate
        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)     

    

In [229]:
# Create environment and a way to track the score
env = gym.make(ENV_NAME)
score_logger = ScoreLogger(ENV_NAME)

# Get the action and state space sizes from the environment
action_space = env.action_space.n
observation_space = env.observation_space.shape[0]

# Reset the environment and get the first state
state, info = env.reset(seed=42, return_info=True)

# Create the agent
dqn = DQN(observation_space, action_space)

run = 0 # run is the number of episodes
while run < 100:
    run += 1
    state = env.reset()
    state = np.reshape(state, [1, observation_space])
    step = 0
    while(True): 
        step += 1

        # Predict action then take action in environment
        action = dqn.act(state)
        state_next, reward, terminal, info = env.step(action)

        # Get set reward negative if game over
        reward = reward if not terminal else -reward
        state_next = np.reshape(state_next, [1, observation_space])

        # Store experience in memory
        dqn.remember(state, action, reward, state_next, terminal)
        state = state_next

        if terminal:
            print("Run: " + str(run) + ", exploration: " + str(dqn.exploration_rate) + ", score: " + str(step))
            score_logger.add_score(step, run)
            break
        
        # Experience replay - train model
        dqn.experience_replay()
        


  deprecation(
  deprecation(


Run: 1, exploration: 0.9229311239742362, score: 36
Scores: (min: 36, avg: 36, max: 36)

1
Run: 2, exploration: 0.8690529955452602, score: 13
Scores: (min: 13, avg: 24.5, max: 36)

0


  self._save_png(input_path=SCORES_CSV_PATH,


0
0
Run: 3, exploration: 0.8142285204175609, score: 14
Scores: (min: 13, avg: 21, max: 36)

0
0
0
0
0
0
0
0
Run: 4, exploration: 0.6935613678313175, score: 33
Scores: (min: 13, avg: 24, max: 36)

0
0
0
0
0
0
Run: 5, exploration: 0.6629680834613705, score: 10
Scores: (min: 10, avg: 21.2, max: 36)

0
0
0
0
0
Run: 6, exploration: 0.6305556603555866, score: 11
Scores: (min: 10, avg: 19.5, max: 36)

0
0
0
0
0
Run: 7, exploration: 0.6057704364907278, score: 9
Scores: (min: 9, avg: 18, max: 36)

0
0
0
0
0
0
0
Run: 8, exploration: 0.5590843898207511, score: 17
Scores: (min: 9, avg: 17.875, max: 36)

0
0
0
0
0
0
0
0
Run: 9, exploration: 0.5185893309484582, score: 16
Scores: (min: 9, avg: 17.666666666666668, max: 36)

0
0
0
0
0
0
0
Run: 10, exploration: 0.4932355662165453, score: 11
Scores: (min: 9, avg: 17, max: 36)

0
0
0
0
0
Run: 11, exploration: 0.46912134373457726, score: 11
Scores: (min: 9, avg: 16.454545454545453, max: 36)

0
0
0
0
0
0
0
0
Run: 12, exploration: 0.446186062443672, score: 1

In [None]:
prediction

In [None]:
v, i = torch.max(dqn.model(torch.from_numpy(state).float()), axis=1)

In [None]:
v

In [None]:
q_values = dqn.model(torch.from_numpy(state))

In [None]:
q_values[0][0] = 1

In [None]:
action_space

In [None]:
state = np.reshape(state, [1, observation_space])


In [None]:
np.random.rand()

In [None]:
dqn.model(state).argmax()

In [None]:
dqn.model(torch.from_numpy(state).float()).max().item()

In [None]:
state