# Reinforcement Learning with Keras interface: DQN

The goal of this notebook is to demonstrate how easy it is to do reinforcement learning with OpenMined and PySyft using the Keras interface. We will apply DQN to the game MountainCar-V0 from gym. The original code was written by Yash Patel. The original code can be found [here](https://towardsdatascience.com/reinforcement-learning-w-keras-openai-dqns-1eed3a5338c).

In [2]:
import syft
from syft import FloatTensor



In [3]:
import gym
import numpy as np
import random

from syft.interfaces.keras.models import Sequential
from syft.interfaces.keras.layers import Dense, Dropout
from syft.interfaces.keras.optimizers import SGD

from collections import deque
from syft import FloatTensor

Using TensorFlow backend.
lol... Just Kidding... Using OpenMined Backend

In [10]:
class DQN:
    def __init__(self, env):
        self.env     = env
        self.memory  = deque(maxlen=2000)
        
        self.gamma = 0.85
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.005
        self.tau = .125

        self.model        = self.create_model()
        self.target_model = self.create_model()

    def create_model(self):
        model   = Sequential()
        state_shape  = self.env.observation_space.shape
        model.add(Dense(24, input_shape=state_shape[0], activation="relu"))
        model.add(Dense(48, activation="relu"))
        model.add(Dense(24, activation="relu"))
        model.add(Dense(self.env.action_space.n))
        model.compile(loss='mean_squared_error',
                      optimizer=SGD(lr=0.01), metrics =[])
        return model

    def act(self, state):
        self.epsilon *= self.epsilon_decay
        self.epsilon = max(self.epsilon_min, self.epsilon)
        if np.random.random() < self.epsilon:
            return self.env.action_space.sample()
        return np.argmax(self.model.predict(state).to_numpy())

    def remember(self, state, action, reward, new_state, done):
        self.memory.append([state, action, reward, new_state, done])

    def replay(self):
        batch_size = 5
        state_store = []
        target_store = []
        if len(self.memory) < batch_size: 
            return
        samples = random.sample(self.memory, batch_size)
        for sample in samples:
            state, action, reward, new_state, done = sample
            target = self.target_model.predict(state).to_numpy()
            if done:
                target.data[0][action] = reward
            else:
                Q_future = self.target_model.predict(new_state).max()
                target[0][action] = (Q_future * self.gamma + reward).to_numpy()[0]
       
            state_store.append(state)
            target_store.append(target)
            
        state_store = np.array(state_store).reshape(batch_size,self.env.observation_space.shape[0])
        target_store = np.array(target_store).reshape(batch_size,self.env.action_space.n)
        
        self.model.fit(state_store, target_store, batch_size=1,epochs=1,verbose=True,validation_data=None)


    def target_train(self):
        weights = self.model.get_weights()
        print("Model weights")
        print(weights[0])
        target_weights = self.target_model.get_weights()
        print("Model target weights")
        print(target_weights[0])
        for i in range(len(target_weights)):
            target_weights[i] *= 0
            target_weights[i] += weights[i] * self.tau + target_weights[i] * (1 - self.tau)

    #def save_model(self, fn):
    #    self.model.save(fn)

def main():
    env     = gym.make("MountainCar-v0")
    gamma   = 0.9
    epsilon = .95

    trials  = 3
    trial_len = 4

    # updateTargetNetwork = 1000
    dqn_agent = DQN(env=env)
    steps = []
    for trial in range(trials):
        cur_state = env.reset().reshape(1,2)
        for step in range(trial_len):
            action = dqn_agent.act(cur_state)
            new_state, reward, done, _ = env.step(action)

            # reward = reward if not done else -20
            # MOD - convert new_state to FloatTensor
            new_state = new_state.reshape(1,2)
            dqn_agent.remember(cur_state, action, reward, new_state, done)
            
            dqn_agent.replay()       # internally iterates default (prediction) model
            dqn_agent.target_train() # iterates target model

            cur_state = new_state
            if done:
                break
        if step >= 2:
            print("Failed to complete in trial {}".format(trial))
            #if step % 10 == 0:
            #    dqn_agent.save_model("trial-{}.model".format(trial))
        else:
            print("Completed in {} trials".format(trial))
            #dqn_agent.save_model("success.model")
            break

if __name__ == "__main__":
    main()


[2018-01-22 23:58:11,316] Making new env: MountainCar-v0


Model weights
  -0.4942018  -0.4012343  -0.2896211   0.5738681   0.5133142  -0.4022536
  -0.5736486   0.6226808   0.5327068  -0.3060617  -0.02539152 -0.5397618
   0.3067479   0.1865416   0.04672527  0.2267746  -0.3978108   0.03644294
   0.2903181  -0.6522111  -0.2795804  -0.4814984  -0.241456   -0.5379501  
   0.4330761   0.4987098  -0.610293   -0.2167451  -0.555723   -0.02354956
  -0.3329354   0.6442267   0.1655313   0.3726228  -0.3428683  -0.5801873
  -0.1038545  -0.2236702  -0.6786835  -0.1848043   0.2161993  -0.2124214
  -0.4091783  -0.4776817  -0.3147609   0.5861519  -0.5587538  -0.3296073   
Model target weights
   0.2587354  -0.372409    0.09992349 -0.5706312   0.02272362  0.6834459
   0.6130362   0.1743304  -0.1788663   0.3275356   0.2429037  -0.377506
  -0.2932415   0.6251568  -0.06242311 -0.6493004  -0.1835307  -0.3939249
   0.4632618   0.04675138  0.5575235   0.2318198  -0.6693673   0.1919654  
   0.451363    0.2576147   0.308138    0.5794345   0.6278042   0.3118587
   0.265

Model weights
  -0.4945915  -0.401518   -0.2895979   0.5738681   0.5133142  -0.4027838
  -0.5731899   0.6226808   0.5327068  -0.3064558  -0.02547651 -0.5396988
   0.3067479   0.1865416   0.04672527  0.2267746  -0.3976278   0.03644294
   0.2903181  -0.6520575  -0.279848   -0.4812601  -0.2407986  -0.5376998  
   0.4330743   0.4987086  -0.6102929  -0.2167451  -0.555723   -0.02355149
  -0.3329337   0.6442267   0.1655313   0.3726215  -0.3428688  -0.580187
  -0.1038545  -0.2236702  -0.6786835  -0.1848043   0.2162     -0.2124214
  -0.4091783  -0.4776813  -0.3147619   0.586153   -0.558751   -0.3296064   
Model target weights
  -0.06177523 -0.05015429 -0.03620263  0.07173351  0.06416427 -0.0502817
  -0.07170608  0.07783511  0.06658835 -0.03825771 -0.00317394 -0.06747023
   0.03834349  0.0233177   0.00584066  0.02834683 -0.04972634  0.00455537
   0.03628976 -0.08152639 -0.03494754 -0.0601873  -0.030182   -0.06724376 
   0.05413451  0.06233873 -0.07628662 -0.02709314 -0.06946538 -0.0029437
  -0.0

Model weights
  -0.4948372  -0.4019829  -0.2897011   0.5738681   0.5133142  -0.4050698
  -0.5711624   0.6226808   0.5327068  -0.3087344  -0.02441805 -0.5399088
   0.3067479   0.1865416   0.04672527  0.2267746  -0.3969516   0.03644294
   0.2903181  -0.6506675  -0.2808182  -0.4814278  -0.2395779  -0.5368765  
   0.4330717   0.4987063  -0.6102927  -0.2167451  -0.555723   -0.02355777
  -0.3329285   0.6442267   0.1655313   0.3726161  -0.342868   -0.5801869
  -0.1038545  -0.2236702  -0.6786835  -0.1848043   0.2162019  -0.2124214
  -0.4091783  -0.4776784  -0.3147644   0.5861543  -0.5587456  -0.3296037   
Model target weights
  -0.06182394 -0.05018975 -0.03619973  0.07173351  0.06416427 -0.05034798
  -0.07164874  0.07783511  0.06658835 -0.03830697 -0.00318456 -0.06746235
   0.03834349  0.0233177   0.00584066  0.02834683 -0.04970347  0.00455537
   0.03628976 -0.08150718 -0.034981   -0.06015751 -0.03009983 -0.06721247 
   0.05413429  0.06233858 -0.07628661 -0.02709314 -0.06946538 -0.00294394
  -

Model weights
  -0.4955036  -0.4035053  -0.290013    0.5738681   0.5133142  -0.4092647
  -0.5677883   0.6226808   0.5327068  -0.3127538  -0.02252327 -0.5408115
   0.3067479   0.1865416   0.04672527  0.2267746  -0.3955      0.03644294
   0.2903181  -0.6484755  -0.2825652  -0.4817371  -0.2369915  -0.5355542  
   0.4330618   0.4986971  -0.6102924  -0.2167451  -0.555723   -0.02356916
  -0.33292     0.6442267   0.1655313   0.372609   -0.3428706  -0.580187
  -0.1038545  -0.2236702  -0.6786835  -0.1848043   0.2162065  -0.2124214
  -0.4091783  -0.477676   -0.3147694   0.5861606  -0.558729   -0.3295981   
Model target weights
  -0.06185465 -0.05024787 -0.03621264  0.07173351  0.06416427 -0.05063372
  -0.0713953   0.07783511  0.06658835 -0.0385918  -0.00305226 -0.0674886
   0.03834349  0.0233177   0.00584066  0.02834683 -0.04961894  0.00455537
   0.03628976 -0.08133344 -0.03510228 -0.06017847 -0.02994724 -0.06710956 
   0.05413396  0.06233829 -0.07628659 -0.02709314 -0.06946538 -0.00294472
  -0.

Model weights
  -0.4950638  -0.4055549  -0.2908484   0.5738681   0.5133142  -0.4140087
  -0.5641875   0.6226808   0.5327068  -0.318128   -0.01942068 -0.5432188
   0.3067479   0.1865416   0.04672527  0.2267746  -0.3931931   0.03644294
   0.2903181  -0.6462914  -0.2847232  -0.4834707  -0.2350039  -0.5347359  
   0.4330486   0.4986847  -0.6102924  -0.2167451  -0.555723   -0.02358327
  -0.3329107   0.6442267   0.1655313   0.3726006  -0.3428762  -0.5801877
  -0.1038545  -0.2236702  -0.6786835  -0.1848043   0.216211   -0.2124214
  -0.4091783  -0.4776727  -0.3147751   0.5861695  -0.5587062  -0.3295923   
Model target weights
  -0.06193795 -0.05043816 -0.03625163  0.07173351  0.06416427 -0.05115809
  -0.07097354  0.07783511  0.06658835 -0.03909422 -0.00281541 -0.06760143
   0.03834349  0.0233177   0.00584066  0.02834683 -0.0494375   0.00455537
   0.03628976 -0.08105943 -0.03532065 -0.06021713 -0.02962394 -0.06694428 
   0.05413273  0.06233713 -0.07628655 -0.02709314 -0.06946538 -0.00294614
  -

Model weights
  -0.4945005  -0.4079719  -0.2923104   0.5738681   0.5133142  -0.4193385
  -0.5603916   0.6226808   0.5327068  -0.3242715  -0.01603146 -0.5469432
   0.3067479   0.1865416   0.04672527  0.2267746  -0.3906312   0.03644294
   0.2903181  -0.6444327  -0.2872293  -0.4868779  -0.2331651  -0.5344478  
   0.4330309   0.49867    -0.6102975  -0.2167451  -0.555723   -0.02360367
  -0.332898    0.6442267   0.1655313   0.3725891  -0.3428859  -0.580193
  -0.1038545  -0.2236702  -0.6786835  -0.1848043   0.2162146  -0.2124214
  -0.4091783  -0.4776712  -0.3147832   0.5861743  -0.5586789  -0.3295878   
Model target weights
  -0.06188298 -0.05069437 -0.03635605  0.07173351  0.06416427 -0.05175108
  -0.07052343  0.07783511  0.06658835 -0.03976601 -0.00242758 -0.06790235
   0.03834349  0.0233177   0.00584066  0.02834683 -0.04914914  0.00455537
   0.03628976 -0.08078643 -0.0355904  -0.06043384 -0.02937549 -0.06684198 
   0.05413107  0.06233559 -0.07628655 -0.02709314 -0.06946538 -0.00294791
  -0

Model weights
  -0.4950635  -0.4127043  -0.2957468   0.5738681   0.5133142  -0.4270729
  -0.5567082   0.6226808   0.5327068  -0.3316379  -0.01292867 -0.5535142
   0.3067479   0.1865416   0.04672527  0.2267746  -0.3880762   0.03644294
   0.2903181  -0.6433718  -0.2904723  -0.4932896  -0.2317428  -0.5357793  
   0.4330114   0.4986511  -0.610306   -0.2167451  -0.555723   -0.02362927
  -0.3328857   0.6442267   0.1655313   0.3725749  -0.3428962  -0.580203
  -0.1038545  -0.2236702  -0.6786835  -0.1848043   0.216218   -0.2124214
  -0.4091783  -0.4776712  -0.314793    0.5861731  -0.5586523  -0.3295864   
Model target weights
  -0.06181257 -0.05099649 -0.0365388   0.07173351  0.06416427 -0.05241731
  -0.07004895  0.07783511  0.06658835 -0.04053394 -0.00200393 -0.06836791
   0.03834349  0.0233177   0.00584066  0.02834683 -0.0488289   0.00455537
   0.03628976 -0.08055409 -0.03590366 -0.06085973 -0.02914564 -0.06680598 
   0.05412887  0.06233374 -0.07628719 -0.02709314 -0.06946538 -0.00295046
  -0

Model weights
  -0.4965588  -0.4190992  -0.3007854   0.5738681   0.5133142  -0.436594
  -0.5531054   0.6226808   0.5327068  -0.339825   -0.01066683 -0.5623993
   0.3067479   0.1865416   0.04672527  0.2267746  -0.3859871   0.03644294
   0.2903181  -0.6427026  -0.294186   -0.5021192  -0.2302747  -0.5386384  
   0.4329872   0.4986242  -0.6103221  -0.2167451  -0.555723   -0.02366341
  -0.3328725   0.6442267   0.1655313   0.3725573  -0.3429117  -0.5802233
  -0.1038545  -0.2236702  -0.6786835  -0.1848043   0.2162186  -0.2124214
  -0.4091783  -0.4776721  -0.3148049   0.5861614  -0.5586247  -0.3295914   
Model target weights
  -0.06188293 -0.05158803 -0.03696835  0.07173351  0.06416427 -0.05338411
  -0.06958852  0.07783511  0.06658835 -0.04145473 -0.00161608 -0.06918927
   0.03834349  0.0233177   0.00584066  0.02834683 -0.04850953  0.00455537
   0.03628976 -0.08042147 -0.03630903 -0.06166121 -0.02896785 -0.06697241 
   0.05412643  0.06233139 -0.07628825 -0.02709314 -0.06946538 -0.00295366
  -0

Model weights
  -0.4982862  -0.4278694  -0.3073461   0.5738681   0.5133142  -0.449271
  -0.5518052   0.6226808   0.5327068  -0.3498059  -0.0091124  -0.574014
   0.3067479   0.1865416   0.04672527  0.2267746  -0.385453    0.03644294
   0.2903181  -0.6425768  -0.2982545  -0.5161409  -0.2296616  -0.5452388  
   0.4329616   0.4985921  -0.6103415  -0.2167451  -0.555723   -0.02370184
  -0.3328588   0.6442267   0.1655313   0.3725383  -0.3429314  -0.5802476
  -0.1038545  -0.2236702  -0.6786835  -0.1848043   0.2162159  -0.2124214
  -0.4091783  -0.4776712  -0.3148164   0.5861441  -0.5585962  -0.3296008   
Model target weights
  -0.06206985 -0.05238741 -0.03759817  0.07173351  0.06416427 -0.05457425
  -0.06913818  0.07783511  0.06658835 -0.04247812 -0.00133335 -0.07029992
   0.03834349  0.0233177   0.00584066  0.02834683 -0.04824839  0.00455537
   0.03628976 -0.08033783 -0.03677325 -0.0627649  -0.02878433 -0.0673298  
   0.05412341  0.06232803 -0.07629026 -0.02709314 -0.06946538 -0.00295793
  -0.