In [1]:
from keras.layers import Dense, Input
from keras.models import Model
from keras.optimizers import Adam
import keras

from collections import deque

import random
import numpy as np

import gym

Using TensorFlow backend.


In [2]:
class DQN:
    def __init__(self, gamma, epsilon, epsilon_min, epsilon_decay, learning_rate, tau, env):
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.learning_rate = learning_rate
        self.tau = tau
        self.memory = deque(maxlen=2000)
        
        self.env = env
        
        self.model = self.createModel()
        self.targetModel = self.createModel()
        
        
    def createModel(self):
        input = Input(shape=self.env.observation_space.shape)
        x = Dense(24, activation='relu')(input)
        x = Dense(48, activation='relu')(x)
        x = Dense(24, activation='relu')(x)
        output = Dense(self.env.action_space.n)(x)
        
        model = Model(inputs=input, outputs=output)
        model.compile(loss='mse', optimizer=Adam(lr=self.learning_rate))
        
        return model
        
    def remember(self, state, action, reward, next_state, done):  # done表示是否需要重置环境
        self.memory.append((state, action, reward, next_state, done))
        
        
    def takeAction(self, state):
        self.epsilon *= self.epsilon_decay
        self.epsilon = max(self.epsilon, self.epsilon_min)
        
        if np.random.random() <= self.epsilon:
            return self.env.action_space.sample()
        
        return np.argmax(self.model.predict(state)[0])

    def replay(self, batch_size):  # 用Q-target-model的知识来更新Q-model
        if len(self.memory) < batch_size:
            return
        samples = random.sample(self.memory, batch_size)
        for sample in samples:
            state, action, reward, next_state, done = sample
            target = self.targetModel.predict(state)
            
            if done:
                target[0][action] = reward
            else:
                q_future = max(self.targetModel.predict(state)[0])
                target[0][action] = reward + q_future * self.gamma
            self.model.fit(state, target, epochs=1, verbose=0)
    
    
    def targetTrain(self):  # 用Q-model的网络参数来更新Q-target-model
        weights = self.model.get_weights()
        target_weights = self.targetModel.get_weights()
        for i in range(len(target_weights)):
            target_weights[i] = weights[i] * self.tau + target_weights[i] * (1 - self.tau)
        self.targetModel.set_weights(target_weights)
        
        
    def saveModel(self, filename):
        self.model.save(filename)

In [3]:
def main():
    env = gym.make('CartPole-v0')
    dqn = DQN(0.95, 1, 0.001, 0.95, 0.001, 0.125, env)
    
    for episode in range(1000):
        state = env.reset().reshape(1, 4)
        
        for step in range(10000):
            env.render()
            action = dqn.takeAction(state)
            print(action)
            next_state, reward, done, _ = env.step(action)
            
            next_state = next_state.reshape(1, 4)
            
            dqn.remember(state, action, reward, next_state, done)
            
            dqn.replay(32)
            
            dqn.targetTrain()
            
            if done:
                break
                
    sa
                
if __name__ =="__main__":
    main()
        
        

  result = entry_point.load(False)


0
1
1
1
0
1
1
1
1
1
1
1
1
1
0
0
1
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
1
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
0
1
1
0
0
0
0
1
0
1
0
0
1
0
0
0
0
1
0
1
0
0
0
0
0
1
0
0
0
0
0
1
0
0
0
0
1
1
1
1
1
1
1
1
1
1
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
0
0
0
1
0
1
0
1
1
0
1
0
1
0
0
0


KeyboardInterrupt: 