# Neural Network Learns to Play Games
In this example notebook, I will train a simple feedforward neural network to play LunarLanderContinuous using policy optimisation. I used this [notebook_here](https://github.com/JannesKlaas/sometimes_deep_sometimes_learning/blob/master/reinforcement.ipynb) as a base and develop according to it.

In [1]:
# importing the dependencies
import gym # RL environment
import numpy as np # matrix math
import matplotlib.pyplot as plt # plotting graphs
%matplotlib inline

# deep learning dependencies
from keras.models import Sequential
from keras.layers import Dense, Dropout

Using TensorFlow backend.
  return f(*args, **kwds)


In [2]:
env = gym.make('LunarLanderContinuous-v2')

In [3]:
action = env.action_space.sample()
print('[*]Sample action:',action)

[*]Sample action: [ 0.09762701  0.43037873]


In [4]:
print('[*]For action space we can have following values:')
print(env.action_space)
print(env.action_space.high)
print(env.action_space.low)
print('[*]For observation space we can have following values:')
print(env.observation_space)
print(env.observation_space.high)
print(env.observation_space.low)
print('[*]Sample Action')
print(env.action_space.sample())

[*]For action space we can have following values:
Box(2,)
[ 1.  1.]
[-1. -1.]
[*]For observation space we can have following values:
Box(8,)
[ inf  inf  inf  inf  inf  inf  inf  inf]
[-inf -inf -inf -inf -inf -inf -inf -inf]
[*]Sample Action
[ 0.20552675  0.08976637]


In [5]:
'''
Example of running a model, where
env.render() --> gives us a visual representation of the model
env.step(action) --> makes an action towards that step, returs a tuple of
    observation: the state after making the step
    reward: the reward that is obtained on taking that step
    done: whether the environment is over or not
    info: ??
'''
for i in range(20):
    observation = env.reset()
    for t in range(1000):
        env.render()
        print(observation)
        action = env.action_space.sample()
        observation, reward, done, info = env.step(action)
        if done:
            print(); print()
            print("Episode finished after {} timesteps".format(t+1))
            break

[ 0.0078722   0.94330836  0.79735179  0.17964249 -0.00911511 -0.180612    0.
  0.        ]
[ 0.01574459  0.94561883  0.79627414  0.15394576 -0.01803981 -0.17850987
  0.          0.        ]
[ 0.02367172  0.94753552  0.80310783  0.12762527 -0.02832155 -0.20565419
  0.          0.        ]
[ 0.03154516  0.94941569  0.79804859  0.12510719 -0.03890908 -0.21177022
  0.          0.        ]
[ 0.0394537   0.95153561  0.80151072  0.14101578 -0.04944677 -0.21077337
  0.          0.        ]
[ 0.04747105  0.95362292  0.81405563  0.13870319 -0.06165408 -0.24416862
  0.          0.        ]
[ 0.05542908  0.95531798  0.80658674  0.11252596 -0.07234517 -0.21384125
  0.          0.        ]
[ 0.06343536  0.95661269  0.8126256   0.08569936 -0.08424151 -0.23794873
  0.          0.        ]
[ 0.07146244  0.95809603  0.81634607  0.0980638  -0.09777476 -0.2706898   0.
  0.        ]
[ 0.07947998  0.96003447  0.81697798  0.12815834 -0.11289814 -0.30249584
  0.          0.        ]
[ 0.08753901  0.96157284  

   5.92798926e-04   5.80022931e-03   0.00000000e+00   0.00000000e+00]
[ -8.13102722e-04   9.61195094e-01  -3.15382749e-02   4.56177998e-01
   1.90868217e-03   2.63201803e-02   0.00000000e+00   0.00000000e+00]
[-0.00099239  0.96834908 -0.01986035  0.47693408  0.00381965  0.03822283
  0.          0.        ]
[-0.00118036  0.97550269 -0.02069705  0.47690179  0.00569596  0.03752981
  0.          0.        ]
[-0.00148029  0.98270035 -0.03137015  0.47983608  0.00704907  0.02706484
  0.          0.        ]
[-0.00182266  0.98949814 -0.0366764   0.45317383  0.00946546  0.04833255
  0.          0.        ]
[-0.00210142  0.99590054 -0.028707    0.42682317  0.0102805   0.01630205
  0.          0.        ]
[-0.00231915  1.00221234 -0.02290579  0.42077675  0.01140224  0.02243693
  0.          0.        ]
[-0.00253687  1.00812398 -0.02290925  0.39410448  0.01252323  0.02242196
  0.          0.        ]
[-0.00275469  1.01363574 -0.02291249  0.36743711  0.01364417  0.02242087
  0.          0.        ]

  0.          0.        ]
[-0.0807023   0.92411738 -0.28297474 -0.5120774   0.40119785  0.24679024
  0.          0.        ]
[-0.08374252  0.91664363 -0.31514764 -0.50144234  0.41316026  0.23924823
  0.          0.        ]
[-0.08709517  0.90959302 -0.34652584 -0.47337726  0.42532763  0.24334757
  0.          0.        ]
[-0.09055386  0.90240079 -0.35709994 -0.48292723  0.4375416   0.24427931
  0.          0.        ]
[-0.09436598  0.89526768 -0.39186826 -0.47890244  0.44917369  0.23264196
  0.          0.        ]
[-0.09821854  0.88772097 -0.39685569 -0.50689745  0.46192276  0.25498123
  0.          0.        ]
[-0.10225286  0.88046342 -0.41329432 -0.48714809  0.47282687  0.218082    0.
  0.        ]
[-0.10628767  0.87280684 -0.41328759 -0.51382367  0.48373088  0.21808021
  0.          0.        ]
[-0.11062393  0.86551028 -0.44373879 -0.49001525  0.49503553  0.22609282
  0.          0.        ]
[-0.11522827  0.85809975 -0.47007813 -0.49754492  0.50587702  0.21682992
  0.          0.  

  0.          0.        ]
[ 0.07632837  0.95141173  0.41473937 -0.05273872 -0.08563972 -0.05896095
  0.          0.        ]
[ 0.08040152  0.95065804  0.41180758 -0.05051274 -0.09010015 -0.08921675
  0.          0.        ]
[ 0.08452244  0.94979697  0.41647549 -0.05767161 -0.09446052 -0.0872152   0.
  0.        ]
[ 0.08851109  0.94889756  0.40182328 -0.06015371 -0.09738882 -0.05857108
  0.          0.        ]
[ 0.09249992  0.94759808  0.40183096 -0.08682623 -0.10031698 -0.05856885
  0.          0.        ]
[ 0.09655924  0.94639064  0.40712399 -0.0805799  -0.10150663 -0.02379258
  0.          0.        ]
[ 0.10061846  0.94478312  0.40712399 -0.10724669 -0.10269625 -0.02379258
  0.          0.        ]
[ 0.10488815  0.94339924  0.42962885 -0.09244308 -0.10533424 -0.05275966
  0.          0.        ]
[ 0.10922279  0.94190627  0.43452997 -0.09960763 -0.10639925 -0.02130001
  0.          0.        ]
[ 0.11383228  0.9406251   0.46315808 -0.08557126 -0.10859467 -0.04390851
  0.          0.  

  0.          0.        ]
[ 0.76827011  0.30608936  1.21374331 -1.06309662 -0.58535111 -0.07094678
  0.          0.        ]
[ 0.78087883  0.29032741  1.26344347 -1.05193621 -0.58840877 -0.0611548   0.
  0.        ]
[ 0.79364681  0.27456709  1.27973003 -1.05199413 -0.59188384 -0.06950166
  0.          0.        ]
[ 0.80641499  0.25840684  1.27972927 -1.0786616  -0.59535897 -0.06950213
  0.          0.        ]
[ 0.81918335  0.24184671  1.27972851 -1.10532913 -0.5988341  -0.06950202
  0.          0.        ]
[ 0.83195152  0.22488671  1.27972765 -1.13199654 -0.60230917 -0.06950194
  0.          0.        ]
[ 0.8446785   0.20755435  1.2744338  -1.15629044 -0.60440284 -0.04187359
  0.          0.        ]
[ 0.85747013  0.18977123  1.28281879 -1.18723323 -0.60879642 -0.08787194
  0.          0.        ]
[ 0.87022877  0.17160898  1.27860508 -1.21209844 -0.61210871 -0.06624548
  0.          0.        ]
[ 0.88314686  0.15351167  1.29308462 -1.2070844  -0.61364275 -0.03068082
  1.          0.  

  0.          0.        ]
[ 0.1439889   0.45350018  0.23196688 -0.65494582 -0.05841904 -0.05594952
  0.          0.        ]
[ 0.14621334  0.44328623  0.22355657 -0.68097115 -0.05952093 -0.0220373   0.
  0.        ]
[ 0.14830389  0.43328934  0.2085633  -0.66644204 -0.05903405  0.0097374   0.
  0.        ]
[ 0.15039597  0.42362157  0.20704577 -0.64443296 -0.05689123  0.04285631
  0.          0.        ]
[ 0.1525445   0.41441231  0.21265378 -0.6138656  -0.05471317  0.04356083
  0.          0.        ]
[ 0.15469303  0.40480314  0.21265376 -0.64053268 -0.05253512  0.04356082
  0.          0.        ]
[ 0.15694571  0.39541398  0.22273252 -0.62585532 -0.05002635  0.0501752   0.
  0.        ]
[ 0.1591404   0.38613902  0.21732497 -0.61825937 -0.0479042   0.04244277
  0.          0.        ]
[ 0.16139278  0.37645956  0.2245615  -0.64527607 -0.04723598  0.01336451
  0.          0.        ]
[ 0.16381054  0.36673999  0.24228587 -0.64798927 -0.04774801 -0.01024041
  0.          0.        ]
[ 0.1662

[-0.33069344  0.68245917 -0.80129089 -0.62494984  0.33931869  0.08327419
  0.          0.        ]
[-0.33872862  0.67267533 -0.80914555 -0.65358938  0.34521386  0.11790396
  0.          0.        ]
[-0.34697752  0.66290427 -0.82881842 -0.65235354  0.34934035  0.08252981
  0.          0.        ]
[-0.35518246  0.65275231 -0.82323189 -0.67746773  0.35221466  0.05748653
  0.          0.        ]
[-0.36331992  0.64223128 -0.81462698 -0.70161743  0.35314322  0.01857073
  0.          0.        ]
[-0.37155957  0.63186312 -0.82526712 -0.69152864  0.3545121   0.02737762
  0.          0.        ]
[ -3.79749537e-01   6.21114833e-01  -8.18967152e-01  -7.16545995e-01
   3.54481369e-01  -6.14280207e-04   0.00000000e+00   0.00000000e+00]
[-0.38789921  0.60998154 -0.81387243 -0.74195143  0.35332844 -0.02305876
  0.          0.        ]
[-0.39629254  0.59884101 -0.83676109 -0.7420742   0.35062921 -0.05398428
  0.          0.        ]
[-0.40484195  0.58767911 -0.85406094 -0.74391117  0.34970278 -0.01852

  0.          0.        ]
[ 0.28953238  0.51932517  0.89573498 -0.83122826 -0.41144854 -0.2070447   0.
  0.        ]
[ 0.29859934  0.50719796  0.9166749  -0.81142279 -0.42222136 -0.21545582
  0.          0.        ]
[ 0.30791416  0.4950326   0.941084   -0.81393153 -0.43263364 -0.20824554
  0.          0.        ]
[ 0.31733952  0.48300324  0.95261173 -0.80510108 -0.44362065 -0.21974015
  0.          0.        ]
[ 0.32698774  0.47108603  0.97492208 -0.79772034 -0.45469278 -0.22144198
  0.          0.        ]
[ 0.33663635  0.4587698   0.97491531 -0.82439626 -0.46576479 -0.2214401   0.
  0.        ]
[ 0.34622707  0.4460909   0.96740894 -0.84808159 -0.47498706 -0.18444555
  0.          0.        ]
[ 0.35587711  0.43297587  0.97497263 -0.87779287 -0.4860824  -0.22190738
  0.          0.        ]
[ 0.36552773  0.41946183  0.97496538 -0.90446866 -0.49717772 -0.22190549
  0.          0.        ]
[ 0.37517881  0.40554872  0.97495804 -0.9311444  -0.50827289 -0.22190359
  0.          0.        ]


  0.          0.        ]
[-0.25701046  0.70726624 -0.56820273 -0.56731319  0.17932987  0.02056181
  0.          0.        ]
[-0.26268225  0.69835845 -0.56820278 -0.5939799   0.18035798  0.0205618   0.
  0.        ]
[-0.26836438  0.68938676 -0.56946926 -0.59826552  0.18161762  0.02519294
  0.          0.        ]
[-0.2740212   0.68080474 -0.56770897 -0.57238344  0.18365172  0.04068181
  0.          0.        ]
[-0.27959905  0.67230321 -0.55843048 -0.56685212  0.18429837  0.01293332
  0.          0.        ]
[-0.28511672  0.66341658 -0.55084057 -0.59232527  0.18336806 -0.01860605
  0.          0.        ]
[-0.29070735  0.6548524  -0.55840807 -0.57086023  0.18271501 -0.01306101
  0.          0.        ]
[-0.29633713  0.64588175 -0.56332922 -0.59808693  0.18306954  0.00709055
  0.          0.        ]
[-0.30187306  0.63688412 -0.55275226 -0.5997413   0.18222366 -0.01691772
  0.          0.        ]
[-0.30740905  0.62748643 -0.5527523  -0.626408    0.18137778 -0.0169177   0.
  0.        ]


  0.          0.        ]
[ 0.18900328  0.71990882  0.56375332 -0.50668707 -0.20695646 -0.12539698
  0.          0.        ]
[ 0.19463215  0.71213551  0.56745005 -0.51886037 -0.21155705 -0.09201179
  0.          0.        ]
[ 0.20041494  0.70462809  0.5841063  -0.50133934 -0.21743903 -0.11763967
  0.          0.        ]
[ 0.20619793  0.69672099  0.5841054  -0.52800884 -0.22332101 -0.11763936
  0.          0.        ]
[ 0.21213884  0.68887845  0.59967499 -0.52368981 -0.22896518 -0.11288366
  0.          0.        ]
[ 0.21828661  0.68087264  0.62207761 -0.53488038 -0.23637827 -0.14826149
  0.          0.        ]
[ 0.22443581  0.67287265  0.62058272 -0.53425592 -0.24213427 -0.11512029
  0.          0.        ]
[ 0.2306529   0.6649747   0.62562919 -0.52717838 -0.2461212  -0.07973849
  0.          0.        ]
[ 0.23681746  0.65668793  0.61901741 -0.55288461 -0.24873555 -0.05228682
  0.          0.        ]
[ 0.24298201  0.6480012   0.61901727 -0.57955189 -0.2513499  -0.05228681
  0.      

[ 0.01387749  0.72839356  0.04764999 -0.57202295 -0.05539244 -0.041654    0.
  0.        ]
[ 0.01435957  0.719831    0.05230009 -0.57098916 -0.05943571 -0.08086559
  0.          0.        ]
[ 0.0150136   0.71148739  0.06886216 -0.55638034 -0.06285498 -0.06838544
  0.          0.        ]
[ 0.01571274  0.70274163  0.07450686 -0.58325144 -0.0674073  -0.09104632
  0.          0.        ]
[ 0.01653404  0.69384585  0.08857908 -0.59335448 -0.07378905 -0.12763506
  0.          0.        ]
[ 0.01749954  0.68501073  0.10248598 -0.58931084 -0.0796624  -0.11746716
  0.          0.        ]
[ 0.01846504  0.67577592  0.10248566 -0.61598028 -0.08553573 -0.11746672
  0.          0.        ]
[ 0.01943054  0.66614142  0.1024853  -0.64264978 -0.09140904 -0.11746634
  0.          0.        ]
[ 0.02046309  0.65609885  0.11089374 -0.66999327 -0.09898257 -0.15147097
  0.          0.        ]
[ 0.0214632   0.64634419  0.10811129 -0.65086403 -0.10701981 -0.16074501
  0.          0.        ]
[ 0.02253723  0.63

  0.          0.        ]
[-0.20021191  1.0786974  -0.6969409   0.11293381  0.25718075  0.13281078
  0.          0.        ]
[-0.20717053  1.08045931 -0.70264664  0.11625173  0.26410872  0.13855879
  0.          0.        ]
[-0.21440492  1.08256608 -0.72988272  0.13927514  0.27070561  0.13193778
  0.          0.        ]
[-0.22162294  1.08447858 -0.72855597  0.12623286  0.27762961  0.13847995
  0.          0.        ]
[-0.22885265  1.08633071 -0.7300817   0.12210256  0.2849319   0.14604598
  0.          0.        ]
[-0.23620172  1.08800107 -0.74330368  0.1096945   0.29358023  0.17296671
  0.          0.        ]
[-0.24355102  1.08927203 -0.74330096  0.08302179  0.30222851  0.17296579
  0.          0.        ]
[-0.25090046  1.09014378 -0.74329815  0.0563491   0.31087676  0.17296489
  0.          0.        ]
[-0.25831761  1.09058905 -0.7518609   0.02748163  0.32141194  0.21070361
  0.          0.        ]
[ -2.65735054e-01   1.09063531e+00  -7.51856565e-01   8.06083654e-04
   3.31947029e

  0.          0.        ]
[ 0.00903788  0.91965371  0.08693482 -0.17927628 -0.02793108 -0.09647474
  0.          0.        ]
[ 0.00990181  0.91728373  0.08918248 -0.15805796 -0.03068233 -0.05503001
  0.          0.        ]
[ 0.01088037  0.91506138  0.10017906 -0.14819864 -0.03298186 -0.04599458
  0.          0.        ]
[ 0.01199818  0.91303569  0.11353999 -0.13509275 -0.03472856 -0.03493371
  0.          0.        ]
[ 0.0131011   0.9110777   0.11049628 -0.13053153 -0.03492253 -0.00387952
  0.          0.        ]
[ 0.01417923  0.90922919  0.10820904 -0.12324683 -0.03531369 -0.007823    0.
  0.        ]
[ 0.01524448  0.90744153  0.10705428 -0.11919212 -0.03583976 -0.01052158
  0.          0.        ]
[ 0.0164979   0.90582896  0.12755901 -0.10756264 -0.03802959 -0.0437964   0.
  0.        ]
[ 0.017799    0.90419369  0.13217255 -0.10906781 -0.04006523 -0.0407131   0.
  0.        ]
[ 0.019168    0.90215817  0.14068841 -0.13581123 -0.0438061  -0.07481722
  0.          0.        ]
[ 0.0204

[ 0.2974864   0.28050209  0.67523603 -0.93689512 -0.56910902 -0.13062651
  0.          0.        ]
[ 0.30418301  0.26608439  0.67523317 -0.96356475 -0.57564032 -0.13062615
  0.          0.        ]
[ 0.31087999  0.25126701  0.67523022 -0.99023444 -0.58217162 -0.13062575
  0.          0.        ]
[ 0.31803331  0.23685961  0.71993017 -0.96252079 -0.58761019 -0.10877191
  0.          0.        ]
[ 0.32518682  0.22205244  0.71992812 -0.98918953 -0.59304875 -0.10877172
  0.          0.        ]
[ 0.33228941  0.206874    0.7134779  -1.01334966 -0.59688723 -0.07676939
  0.          0.        ]
[ 0.33942871  0.19126568  0.71823921 -1.04252141 -0.60204643 -0.10318402
  0.          0.        ]
[ 0.34656811  0.17525755  0.7182373  -1.06918996 -0.60720563 -0.10318387
  0.          0.        ]
[ 0.35418873  0.15965319  0.76659503 -1.04239718 -0.61265785 -0.10904454
  0.          0.        ]
[ 0.36193008  0.1438336   0.77969594 -1.05726115 -0.61938298 -0.13450284
  0.          0.        ]
[ 0.369671

  0.          0.        ]
[ 0.05120525  0.96426507  0.59060016  0.16979591 -0.05224394 -0.13248049
  0.          0.        ]
[ 0.05704451  0.9664157   0.59062066  0.14312771 -0.0588656  -0.1324455   0.
  0.        ]
[ 0.06278944  0.96898823  0.58185334  0.171199   -0.06616332 -0.14596781
  0.          0.        ]
[ 0.06857672  0.97115555  0.58717275  0.14409169 -0.07452983 -0.16734507
  0.          0.        ]
[ 0.07444954  0.97292042  0.59787517  0.11709966 -0.08503844 -0.21019156
  0.          0.        ]
[ 0.0803956   0.97460261  0.60662212  0.1114236  -0.09697495 -0.23875194
  0.          0.        ]
[ 0.08638191  0.97588149  0.61167188  0.08436479 -0.10992378 -0.25900009
  0.          0.        ]
[ 0.09260321  0.97741833  0.63437433  0.10151212 -0.12211336 -0.2438138   0.
  0.        ]
[ 0.09875679  0.97855855  0.6258604   0.07512418 -0.13258845 -0.20952058
  0.          0.        ]
[ 0.10489645  0.9796443   0.62610021  0.07125826 -0.14469185 -0.24208999
  0.          0.        ]


  0.          0.        ]
[-0.06381197  0.88913956 -0.74143782 -0.42729832  0.07884362  0.20647678
  0.          0.        ]
[-0.07107162  0.88277846 -0.73511987 -0.42457708  0.08792102  0.18156468
  0.          0.        ]
[-0.07833166  0.87601805 -0.73514524 -0.45125605  0.09699757  0.1815477   0.
  0.        ]
[-0.08563328  0.86925536 -0.74063587 -0.45156129  0.10740627  0.20819314
  0.          0.        ]
[-0.09312944  0.8623376  -0.76160684 -0.46208582  0.11933845  0.23866549
  0.          0.        ]
[-0.10069542  0.85576697 -0.76866207 -0.43905144  0.13136224  0.2404979   0.
  0.        ]
[-0.10831242  0.84879065 -0.77504606 -0.46631686  0.14467514  0.26628239
  0.          0.        ]
[-0.11593008  0.84141598 -0.77508311 -0.49299765  0.15798536  0.26622875
  0.          0.        ]
[-0.12354851  0.83364268 -0.77511778 -0.5196825   0.17129371  0.26619115
  0.          0.        ]
[-0.13133945  0.82594195 -0.79193377 -0.51491327  0.18418139  0.25777705
  0.          0.        ]


  0.          0.        ]
[ 0.13853216  0.98332816  0.81461334  0.17534374 -0.15700267 -0.15887495
  0.          0.        ]
[ 0.14659929  0.98557091  0.81463451  0.14866419 -0.16494358 -0.15883199
  0.          0.        ]
[ 0.15463495  0.98742276  0.81066351  0.12265357 -0.17205295 -0.14220046
  0.          0.        ]
[ 0.16273241  0.98923312  0.81485052  0.12008777 -0.17718109 -0.10257218
  0.          0.        ]
[ 0.17077084  0.99066029  0.80740404  0.09471842 -0.18075114 -0.0714075   0.
  0.        ]
[ 0.17900047  0.99228573  0.8261364   0.10796645 -0.18394653 -0.06391361
  0.          0.        ]
[ 0.18726282  0.9935031   0.83025494  0.08065417 -0.18799509 -0.08097814
  0.          0.        ]
[ 0.1956337   0.99504471  0.8430316   0.1020093  -0.19397637 -0.11963632
  0.          0.        ]
[ 0.20417805  0.99699164  0.85865955  0.12923595 -0.19822659 -0.0850053   0.
  0.        ]
[ 0.21288729  0.9988575   0.87432022  0.12393179 -0.20164846 -0.06843731
  0.          0.        ]


  0.          0.        ]
[ 0.08489714  0.84197283  0.56432047 -0.50253741 -0.08486468 -0.06798248
  0.          0.        ]
[ 0.09062777  0.83485728  0.57764635 -0.47463614 -0.08941465 -0.09100755
  0.          0.        ]
[ 0.09656096  0.8278739   0.59975791 -0.46595599 -0.09581093 -0.12793734
  0.          0.        ]
[ 0.10263596  0.82113848  0.6135633  -0.44942636 -0.10183493 -0.12049081
  0.          0.        ]
[ 0.10875731  0.81467343  0.61829119 -0.43143657 -0.10796228 -0.12255809
  0.          0.        ]
[ 0.11482582  0.80782099  0.61163702 -0.45718212 -0.11272174 -0.0951977   0.
  0.        ]
[ 0.1209403   0.80056105  0.61740918 -0.48446035 -0.1186537  -0.11864969
  0.          0.        ]
[ 0.1271184   0.79288645  0.6254148  -0.51225821 -0.12622391 -0.15141759
  0.          0.        ]
[ 0.1334136   0.78536158  0.63689013 -0.50229642 -0.13357504 -0.14703587
  0.          0.        ]
[ 0.13970766  0.77776595  0.63698473 -0.50706641 -0.14113478 -0.15120825
  0.          0.  

  0.          0.        ]
[-0.07865324  0.90211201 -0.35003726 -0.09740444  0.10490309  0.07107284
  0.          0.        ]
[-0.08206024  0.90026208 -0.34279642 -0.123473    0.10698733  0.04168856
  0.          0.        ]
[-0.08551798  0.8984335  -0.34667137 -0.12196618  0.10787655  0.01778601
  0.          0.        ]
[-0.08897419  0.89619065 -0.34651761 -0.14958584  0.10876582  0.01778521
  0.          0.        ]
[-0.09249353  0.89406977 -0.35280464 -0.14145532  0.10962645  0.01721256
  0.          0.        ]
[-0.09607391  0.89153863 -0.36047754 -0.16892784  0.11204784  0.04842779
  0.          0.        ]
[-0.09972382  0.88860436 -0.36918316 -0.19593396  0.116216    0.08336294
  0.          0.        ]
[-0.10337381  0.88527041 -0.36918292 -0.22260205  0.12038414  0.08336285
  0.          0.        ]
[-0.10701933  0.88220206 -0.36913576 -0.20493108  0.12494182  0.09115323
  0.          0.        ]
[-0.11066933  0.87905713 -0.36799827 -0.20991675  0.127919    0.05954368
  0.      

  0.          0.        ]
[-0.62276323  0.1295127  -0.80661459 -1.0399285  -0.03232239 -0.09991388
  0.          0.        ]
[-0.63082399  0.11419097 -0.80247812 -1.02153168 -0.03587518 -0.07105578
  0.          0.        ]
[-0.63888474  0.09846936 -0.80247812 -1.0481994  -0.03942795 -0.07105563
  0.          0.        ]
[-0.64688165  0.08274732 -0.7963203  -1.04822928 -0.04275877 -0.06661625
  0.          0.        ]
[-0.65491004  0.06707631 -0.79922986 -1.04484037 -0.04632443 -0.07131315
  0.          0.        ]
[-0.66285541  0.05100465 -0.78881903 -1.07163162 -0.05197493 -0.11301007
  0.          0.        ]
[-0.67085488  0.03522564 -0.79282732 -1.0521972  -0.05901262 -0.14075387
  0.          0.        ]
[-0.67888937  0.01925183 -0.79491515 -1.06528365 -0.06745394 -0.1688265   0.
  0.        ]
[-0.68692384  0.00287858 -0.79491572 -1.09195646 -0.07589523 -0.16882569
  0.          0.        ]
[-0.69500639 -0.0138946  -0.80094719 -1.11859728 -0.08313131 -0.14472157
  0.          0.  

[ 0.2472578   0.48737078  0.26804047 -0.8012104   0.19867818  0.34660254
  0.          0.        ]
[ 0.25018139  0.47500029  0.27698126 -0.82684326  0.21417512  0.30993912
  0.          0.        ]
[ 0.2531045   0.46223174  0.27698741 -0.85352955  0.22967181  0.30993397
  0.          0.        ]
[ 0.25595369  0.44904404  0.2677434  -0.88195788  0.24712215  0.34900703
  0.          0.        ]
[ 0.25884447  0.43547011  0.27307589 -0.90770887  0.26345149  0.32658641
  0.          0.        ]
[ 0.26152935  0.42198878  0.2541919  -0.90138919  0.27807844  0.29253938
  0.          0.        ]
[ 0.26421356  0.40810939  0.25419903 -0.92807312  0.29270515  0.29253511
  0.          0.        ]
[ 0.26671448  0.39429131  0.23612115 -0.92407875  0.30713606  0.28861828
  0.          0.        ]
[ 0.2690836   0.38061472  0.22410905 -0.91453609  0.32042167  0.26571262
  0.          0.        ]
[ 0.27145224  0.36653963  0.22411587 -0.94121679  0.33370715  0.26570938
  0.          0.        ]
[ 0.273895

  0.          0.        ]
[ 0.2568387   0.68368655  0.49559364 -0.56672618 -0.18977934 -0.10377719
  0.          0.        ]
[ 0.26174307  0.67479563  0.49559302 -0.59339517 -0.19496819 -0.10377696
  0.          0.        ]
[ 0.26664743  0.66550503  0.49559236 -0.62006404 -0.20015703 -0.10377676
  0.          0.        ]
[ 0.27168388  0.65612031  0.50847979 -0.62630666 -0.20502825 -0.0974246   0.
  0.        ]
[ 0.27672043  0.6463358   0.50847921 -0.65297534 -0.20989949 -0.09742448
  0.          0.        ]
[ 0.28175697  0.63615151  0.50847859 -0.67964395 -0.2147707  -0.09742432
  0.          0.        ]
[ 0.28703613  0.62592993  0.53443918 -0.68240617 -0.2213653  -0.1318921   0.
  0.        ]
[ 0.29244165  0.61562577  0.54681921 -0.68789743 -0.22771193 -0.12693249
  0.          0.        ]
[ 0.29777279  0.60493384  0.53749208 -0.71347586 -0.23214613 -0.08868414
  0.          0.        ]
[ 0.30306282  0.59384862  0.53232465 -0.73954379 -0.23552126 -0.06750249
  0.          0.        ]


  0.          0.        ]
[ 0.31582022  0.38203035  0.74643126 -1.26732801 -0.32800886 -0.10689957
  0.          0.        ]
[ 0.32323332  0.362638    0.74643011 -1.29399681 -0.33335382 -0.1068993   0.
  0.        ]
[ 0.33064651  0.34284592  0.74642901 -1.32066574 -0.33869877 -0.10689912
  0.          0.        ]
[ 0.33805981  0.32265402  0.74642782 -1.34733467 -0.34404373 -0.10689889
  0.          0.        ]
[ 0.3454731   0.30206237  0.74642663 -1.3740036  -0.34938866 -0.10689868
  0.          0.        ]
[ 0.35317307  0.28132811  0.77695842 -1.38398895 -0.35670602 -0.1463473   0.
  0.        ]
[ 0.3608222   0.26021376  0.7704957  -1.40902201 -0.36259493 -0.1177788   0.
  0.        ]
[ 0.36881027  0.23955851  0.80588074 -1.37882474 -0.37007684 -0.1496381   0.
  0.        ]
[ 0.37694921  0.21905237  0.82242031 -1.36930491 -0.37910631 -0.1805895   0.
  0.        ]
[ 0.38513288  0.19812513  0.82807999 -1.39776281 -0.38943911 -0.20665591
  0.          0.        ]
[ 0.39358177  0.1772307 

### Experience Replay and Train
We use something called the experience replay which nothing but a glorified word of memory. We store the events that have happened in form of tuples of [s, a, r, s']. From this we draw batches of experiences for input and target for training the neural network.

In [6]:
# experience replay class
class ExperieceReplay(object):
    def __init__(self, max_memory = 100, discount = 0.9):
        # max_memory: the number of memories in past we want to remember
        # discount: discount factor
        self.mem_len = max_memory
        self.discount = discount
        self.memory = list()
        
    def remember(self, states, game_over):
        # states: a tuple of information in following form <s,a,r,s'>
        # game_over: whether the game is complete or not
        self.memory.append([states, game_over])
        
        # to check for lenght
        if self.mem_len < len(self.memory):
            del self.memory[0]
            
    def get_batch(self, model, batch_size = 30):
        # model: the neural network model
        # batch_size: batch size to return
        
        # number of actions that can be taken
        num_actions = model.output_shape[-1]
        
        # dimensions of the game field
        env_dim = self.memory[0][0][0].shape[1]
        
        # we want to return an input and target vector with inputs from an obsereved state
        inputs = np.zeros((min(len(self.memory), batch_size), env_dim))
        
        # we are using q-learning so we need to get the output from the model
        # we make a target vector
        targets = np.zeros((min(len(self.memory), batch_size), num_actions))
        
        # we draw states to learn from randomly
        for i,idx in enumerate(np.random.randint(0, len(self.memory), size = inputs.shape[0])):
            
            state_t, action_t, reward_t, state_tp1 = self.memory[idx][0] # <s,a,r,s'>
            game_over = self.memory[idx][1] # whether the game ended in this state or not
            
            # adding state to inputs
            inputs[i:i+1] = state_t
            # adding targets, target values will not be affected by training
            targets[i] = model.predict(state_t)[0]
            
            '''
            Q-learning formula:
            if the game ended: target_value = reward
            if the game not ended: target_value = reward + (gamma*max_over_a'(q(s',a')))
            '''
            if game_over:
                targets[i, np.argmax(action_t)] = reward_t
            else:
                # q_sa is max_a'(q(s',a'))
                # state_tp1 = np.reshape(state_tp1, (1, state_tp1.shape[0]))
                targets[i, np.argmax(action_t)] = reward + self.discount*np.max(model.predict(state_tp1)[0])
                
        # returning the input and target matrices
        return inputs, targets

# function to normalise the actions, sometimes the predicted actions may produce out of bound numbers
# this may happen even after using tanh activation function in the last layer
def normalise(act, env):
    high = env.action_space.high
    low = env.action_space.low
    for i in range(len(act)):
        if act[i]>high[i]:
            act[i] = high[i]
        elif act[i]<low[i]:
            act[i] = low[i]
    return act
    
# function to train the model
def train(model, epochs, env, verbose = 1, disp_step = 50):
    win_ctr = 0 # winning counter
    win_hist = [] # history of winning
    # xxx = 0
    for e in range(epochs):
        loss = 0 # loss
        env.reset() # reseting the game
        game_over = False # initially the game is not over
        input_t = np.reshape(env.reset(), (1, env.reset().shape[0])) # initial observation
        while not game_over:
            input_tm1 = input_t
            # exploration vs. exploitation
            if np.random.rand() <= epsilon:
                # random for exploration
                # xxx = 1
                action = np.random.randint(env.action_space.low[0], env.action_space.high[0], size = (2,))
            else:
                # exploitation
                # xxx = -1
                action = model.predict(input_tm1)[0]
            
            # normalising actions
            action = normalise(action, env)
            
            # apply action and get new states
            input_t, reward, game_over, _ = env.step(action)
            input_t = np.reshape(input_t, (1, input_t.shape[0]))
            if reward >= 100:
                win_ctr += 1
            
            # to see the results 
            # env.render()
            
            # store experience
            exp_replay.remember([input_tm1, action, reward, input_t], game_over)
            
            # load new batch
            inputs, targets = exp_replay.get_batch(model, batch_size)
            
            # train on the new batch
            batch_loss = model.train_on_batch(inputs, targets)
            loss += batch_loss
        
        # print results
        if verbose > 0 and (e+1)%disp_step == 0:
            print("Epoch {:03d}/{:03d} | Loss {:.4f} | Win count {}".format(e+1, epochs, loss, win_ctr))
        
        win_hist.append(win_ctr)
        
    return win_hist

In [7]:
# This was junk
'''
action = np.random.randint(env.action_space.low[0], env.action_space.high[0], size = (2,))
print(action)
input_t = np.reshape(env.reset(), (1, env.reset().shape[0]))
input_tm1 = input_t
action = model.predict(input_tm1)[0]
print(action)
'''

'\naction = np.random.randint(env.action_space.low[0], env.action_space.high[0], size = (2,))\nprint(action)\ninput_t = np.reshape(env.reset(), (1, env.reset().shape[0]))\ninput_tm1 = input_t\naction = model.predict(input_tm1)[0]\nprint(action)\n'

In [8]:
x = env.action_space.sample()
env.reset()
observation, reward, done, info = env.step(x)
print(type(x))
print(type(x[0]))
print(x)
print(type(observation))
print(type(observation[0]))
print(type(reward))
print(type(done))

<class 'numpy.ndarray'>
<class 'numpy.float64'>
[-0.67130632 -0.49347708]
<class 'numpy.ndarray'>
<class 'numpy.float64'>
<class 'numpy.float64'>
<class 'bool'>


### Deep Learning Model
We now define our model which is a simple feed forward neural network.

In [9]:
def baseline_model(env):
    model = Sequential()
    model.add(Dense(100, input_shape = (8,)))
    model.add(Dropout(0.2))
    model.add(Dense(256))
    model.add(Dropout(0.4))
    model.add(Dense(env.action_space.shape[0], activation = 'tanh'))
    model.compile(loss = 'mse', optimizer = 'Adam')
    return model

In [10]:
# hyper-parameters
epsilon = 0.1 # exploration factor
max_memory = 500 # maximum history we want to keep
batch_size = 20 # batch size used for training

In [11]:
# define model
model = baseline_model(env)
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 100)               900       
_________________________________________________________________
dropout_1 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               25856     
_________________________________________________________________
dropout_2 (Dropout)          (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 2)                 514       
Total params: 27,270
Trainable params: 27,270
Non-trainable params: 0
_________________________________________________________________


In [12]:
# initialize replay memory
exp_replay = ExperieceReplay(max_memory = max_memory)

#### Junk
Next few boxes deal with some testing, read the code and understant

In [67]:
y = ExperieceReplay()
y.remember([np.reshape(env.reset(), (1, env.reset().shape[0])),
            np.random.random([2,1]),
            reward,
            np.reshape(env.reset(),(1, env.reset().shape[0]))],
          done)
inputs, targets = y.get_batch(model)
print(inputs)
print(targets)

[[ 0.00439825  0.93361178  0.44548283 -0.46678724 -0.0050897  -0.10090834
   0.          0.        ]]
[[ 0.07666291 -0.2351837 ]]


In [68]:
input_t = np.reshape(env.reset(), (1, env.reset().shape[0]))
input_tm1 = input_t
print(input_tm1)
action = model.predict(input_tm1)[0]
print(action)
input_t, reward, game_over, _ = env.step(action)
print(input_t)
print(input_tm1.shape)
print(input_t.shape)

[[ -7.55691528e-04   9.38725859e-01  -7.65572667e-02  -1.25842905e-01
    8.82449152e-04   1.73413813e-02   0.00000000e+00   0.00000000e+00]]
[ 0.07220774 -0.06286184]
[ 0.00798979  0.95599776  0.40077605  0.51796977 -0.00960459 -0.09879674
  0.          0.        ]
(1, 8)
(8,)


## Training the model
The results are pure shit :(

In [13]:
epochs = 500 # number of games played
print('Training now ...')
hist = train(model, epochs, env, 1, 10) # train
print('... Training done')

Training now ...
Epoch 010/500 | Loss 3728.8628 | Win count 1
Epoch 020/500 | Loss 3721.0410 | Win count 1
Epoch 030/500 | Loss 6411.6606 | Win count 1
Epoch 040/500 | Loss 4694.6015 | Win count 2
Epoch 050/500 | Loss 5928.7886 | Win count 2
Epoch 060/500 | Loss 2979.6749 | Win count 3
Epoch 070/500 | Loss 2980.3857 | Win count 3
Epoch 080/500 | Loss 6902.3497 | Win count 4
Epoch 090/500 | Loss 3467.5217 | Win count 4
Epoch 100/500 | Loss 7893.6792 | Win count 4
Epoch 110/500 | Loss 4453.8760 | Win count 4
Epoch 120/500 | Loss 5936.6062 | Win count 5
Epoch 130/500 | Loss 5201.6378 | Win count 5
Epoch 140/500 | Loss 3465.3894 | Win count 5
Epoch 150/500 | Loss 3960.3436 | Win count 5
Epoch 160/500 | Loss 8635.0444 | Win count 5
Epoch 170/500 | Loss 2243.9956 | Win count 7
Epoch 180/500 | Loss 6428.7172 | Win count 7
Epoch 190/500 | Loss 4447.4796 | Win count 7
Epoch 200/500 | Loss 4943.1445 | Win count 7
Epoch 210/500 | Loss 4693.1799 | Win count 8
Epoch 220/500 | Loss 7658.6400 | Win c