In [5]:
import gym
import torch 
import collections
import os
import numpy as np
import random
from utils import *
from exp_replay_memory import ReplayMemory




In [6]:
width=0.2

In [7]:
class PassAction(object):
    def __init__(self):
        self.actions={0:0,1:0,2:0,3:0}
        self.reward={0:-2.,1:-1.,2:1.,3:2.}
    def store(self,action):
        if action in self.actions:
            self.actions[action]+=1
        else:
            self.actions[action]=1
        
    def additional_reward(self,action):
        l = sorted(self.actions.items(), key=lambda x: x[1], reverse=True)
        for i in range(len(l)):
            
            if l[i][0]==action:
                
                return self.reward[i]
        
        return 

# SARSA

In [8]:
def epsilon_greedy(q_func, observation, eps, env_actions):
    prob = np.random.random()

    if prob < eps:
        return random.choice(range(env_actions))
    elif isinstance(q_func, CNN) or isinstance(q_func, LinearMapNet):
        with torch.no_grad():
            return q_func(observation).max(1)[1].item()
    else:
        qvals = [q_func[observation + (action, )] for action in range(env_actions)]
        return np.argmax(qvals)
    
def greedy(qstates_dict, observation, env_actions):
    qvals = [qstates_dict[observation + (action, )] for action in range(env_actions)]
    return max(qvals)

In [25]:
def sarsa_lander(env, n_episodes, gamma, lr, min_eps,width, print_freq=500, render_freq=500):
    q_observations = collections.defaultdict(float)   # note that the first insertion of a key initializes its value to 0.0
    return_per_ep = [0.0]
    epsilon = 1.0
    num_actions = env.action_space.n
    
    for i in range(n_episodes):
        t = 0

        # Initial episode state: S
        state_continuous=env.reset()
        curr_state = discretize_state(state_continuous)

        if state_continuous[1]>=1-width/2  and state_continuous[1]<=1+width/2 :
            curr_observation = (1,)
        else:
            curr_observation = discretize_state(state_continuous)
        # Choose A from S using policy π
        action = epsilon_greedy(q_observations, curr_observation, epsilon, num_actions)
        frame=0
        while True:
            
            # Create (S, A) pair
            qstate = curr_state + (action, )
            qobservation = curr_observation + (action, )

            # Take action A, earn immediate reward R and land into next state S'
            # S --> A --> R --> S'
            state_continuous, reward, done, _ = env.step(action)
            
            next_state = discretize_state(state_continuous)
            if state_continuous[1]>=1-width/2  and state_continuous[1]<=1+width/2 :
                next_observation = (1,)
            else:
                next_observation = discretize_state(state_continuous)
            
            
            # Next State: S'
            # Choose A' from S' using policy π
            next_action = epsilon_greedy(q_observations, next_observation, epsilon, num_actions)

            # create (S', A') pair
            new_qstate = next_state + (next_action, )
            new_qobservation = next_observation + (next_action, )

            ###################################################################
            # Policy evaluation step
            if not done:
                q_observations[qobservation] += lr * (reward + gamma * q_observations[new_qobservation] - q_observations[qobservation]) # (S', A') non terminal state
            else:
                q_observations[qobservation] += lr * (reward - q_observations[qobservation])    # (S', A') terminal state
            ###################################################################

            return_per_ep[-1] += reward

            if done:
                if (i + 1) % print_freq == 0:
                    print("\nEpisode finished after {} timesteps".format(t + 1))
                    print("Episode {}: Total Return = {}".format(i + 1, return_per_ep[-1]))
                    print("Total keys in q_observations dictionary = {}".format(len(q_observations)))

                if (i + 1) % 100 == 0:
                    mean_100ep_reward = round(np.mean(return_per_ep[-101:-1]), 1)
                    print("Last 100 episodes mean reward: {}".format(mean_100ep_reward))

                epsilon = decay_epsilon(epsilon, min_eps)
                return_per_ep.append(0.0)

                break

            curr_state = next_state
            curr_observation = next_observation
            action = next_action
            t += 1

    return return_per_ep,q_observations

In [113]:
n_episodes=10000
lr=0.1
gamma=0.99
final_eps=0.01
environment = gym.make("LunarLander-v2")

In [27]:
print("\nTraining Sarsa lander with arguments num_episodes={}, step-size={}, gamma={}, final_epsilon={} ..."\
                            .format(n_episodes, lr, gamma, final_eps))
sarsa_total_rewards,q_observations  = sarsa_lander(environment, n_episodes, gamma, lr, final_eps,width)
print("Done!")




Training Sarsa lander with arguments num_episodes=100, step-size=0.1, gamma=0.99, final_epsilon=0.01 ...
1.133715702327238
2.2912162280311477
1.211470288683388
0.5117465848840379
0.08350535302824483
1.5612628243585516
1.4286507442327274
-2.0388880937181737
1.5984739617936168
-1.29012125756111
-1.8067347820893918
0.9246776558653778
-0.056988643435288394
-1.5686124946942357
-2.9744081624803855
-2.2444012463649017
-1.2128418974776765
-2.19925152375626
-2.2917402741175295
-3.237688300859985
1.8893889985068995
2.0162691083599897
-3.4816605582737155
-1.3716700663988661
-0.058256752930265054
-1.8553161930636566
0.5245138339137838
-1.5297687922317482
-1.5133442280979625
0.5680200855459077
0.7243935519063995
-1.0861669228015376
-1.9672241629700409
-1.3867642517404295
-1.97029091516157
-2.7105209225272504
-2.137752324146163
1.0870746162320983
-1.6400331595129944
-1.2254431567442328
2.5000161531201117
-2.5083260672149366
-2.6307683590637496
-1.2568063027582912
-1.8553603793337174
-2.492854535928

-1.3498459885632883
-0.9665107957105181
-0.7197664358662894
3.552686570658989
-0.9223477871221917
-0.5029346224995084
-0.6072061516179588
-0.2007643814517894
4.481605494080026
-0.33719998048189836
2.4461971075202884
-0.3237875391493503
-0.7013725711900338
-0.7964599548221531
-0.1804007318716276
2.2357521029033594
0.0002691373857726387
2.216387620112653
-0.556899964248571
-0.5367336316877516
-0.7741791640374049
-1.008591816225844
-0.6473670088529389
1.4986285998748599
2.314879883727497
4.152035152872668
-0.7058021230403426
3.4941903128901233
-1.2440107750646223
2.4365849386396805
-1.2983388260966808
3.897551229486152
1.846326147347196
0.8024426852580404
-0.5334119342902273
-1.6853211908924368
-1.935186155387412
-1.6154945117469879
7.9405331887712975
6.416700185823545
-100
-2.8898832452357963
2.081154298006426
-2.786218046880964
1.9481905733223641
1.4227883189185786
1.1669424503841708
1.5682229116429767
2.083472237946411
2.442505461366268
1.3407169185104248
0.911560953485407
-4.566916336

-0.1027523548845852
-0.8577327895895224
0.5270074243352212
1.3999170435376527
-2.298955809160303
-2.3962565930048343
-1.6287794729873895
-2.5455765554873993
-1.8923118587714782
-0.6822944299230642
-2.052305431426362
-0.1828415951806221
-2.3221461767921596
-1.1779641804563596
-2.556473535444468
-3.8935782395940337
-4.108572608697698
7.43583424874812
-2.8060207141352294
25.891730683355092
-1.4136049309312568
-100
-1.0335622094895485
1.8824878423033067
1.9255421550818426
2.218069946864547
2.318901105872841
2.1881699945462585
2.075060518301741
2.137194271209438
1.9476252671418035
2.347219422324288
1.8937190600447604
2.559228320218723
0.3834354813844538
-3.1823319421409324
-3.617855948121472
-2.6221645718008504
1.1125155562969382
-2.8315885148697078
0.756966557929087
-3.0027225552671553
-3.4144769369169965
0.3075740153461595
-3.073846269563944
-2.62527927243235
-2.6875421810361773
-2.2427868146755245
-2.3488073813401797
-0.27880172825009025
3.033622763016706
-1.953887664328987
-1.7206247571

-1.4314644745031047
-2.8513073685352763
-2.7870845602940606
-4.695596880986204
-5.17968031661245
-2.7401380007708256
-3.256468541476947
-1.035865746948814
-2.0267340696182146
-2.0368273273857653
-3.0499303906689463
-3.2044425351863013
-3.781515979482424
-2.5397707353257033
-1.686316312738852
-5.823638658813468
-1.2398330751221056
-2.1497467078477825
-3.266734911978715
-7.1837368570349325
-1.131369503635226
-7.143787819260967
-1.2140002568217187
-0.9242869800875997
-1.7537150198721179
-0.653518268014152
-1.5881155948363812
-0.4158941829365108
-2.3279180413295237
-2.453610382396664
-1.7841915040963272
-4.542229868174604
-2.0008080984173944
-1.0687761959356112
-0.6233318563575085
-6.079573134862232
-5.8337848229947324
-1.10529020088964
-2.93797066388626
-6.763371199940923
-2.113569760328687
-3.331125919595563
-100
-1.022520716805559
-0.2689039029017113
-1.0035986121393137
-0.028780745307527694
-0.8113420043487736
-1.9175723724417526
-0.991444391848944
-1.9634875763258595
-2.25715779420735

-1.373262771256974
1.5998580369172088
1.4254455189975601
2.7188411689616716
-0.781468373113853
-0.7559893184627242
1.0321984831750342
-0.0993095547931648
-1.3005910604850601
5.2257513310283175
-0.05624724740064721
3.813677785506184
-0.6873470592023807
0.04511956498910877
0.30594115419566603
-0.9958255729593464
2.214258703201449
-1.1641760721202548
2.506499795225648
4.043577885356467
-1.3846804792700869
-0.07880656796072458
3.827702997294563
1.746312572717909
-0.7008944744942767
-0.7173109016672754
-0.7397354690006921
-0.061683800516617565
0.15291691399676324
-1.8741972479850506
-1.6821157822073178
0.820791940056506
-2.095824898959164
-2.4034077748300704
-3.13267753908829
-2.8390758999694867
-3.1903293484729702
-2.828576469959755
6.518067091757275
2.4743694738488387
-100
-1.1204446381478022
1.9296291368577954
-0.0989323418441927
0.1455482312586878
-2.232418008792449
2.5123145580235304
-1.3443945763291651
-1.4308525130800547
-1.5016039574893512
-0.5640997022185548
-0.439565099802137
1.74

-3.3968322596010965
-1.8342700893500978
-2.5314538721260647
-3.7744753277388257
-2.0795320514076137
-3.2251606155265904
6.705775630060555
15.155947792527183
-100
-0.3329010844747586
-0.1756095312528305
-0.027366234810328932
1.9601549295187055
-0.9503025265451015
0.9150746653348392
-1.3904619665233213
-2.118745537936063
2.0675770349433096
0.03663298390137129
0.28152242924733284
2.4101432185819194
-0.7635913030235315
-0.7787681577277681
0.10648817276146019
0.20638363908457905
0.5869674475199258
0.8518580381738456
-0.05395314151795105
2.705230127648787
0.9698787974197114
4.6308630999875335
0.9094816047878294
0.19895167038163208
-1.8708437631980235
-1.9584359771628772
-0.22703464566173465
-1.2941967728667407
-0.939657531632945
-1.4586660552068906
-1.4341038195327656
-1.7566031989850888
-1.5297595210409713
-1.3053750696731174
-0.7177919363998626
-0.4884329927689339
3.4026474006811727
2.1949809060084364
-1.1150288596003566
-0.6191108909783907
2.045278366589412
-1.004851625717463
-0.752924934

0.8091869383247012
-0.9120743138228977
1.920936616812878
-1.5486295090502142
-0.8215615581677014
-0.9104781745034245
-1.005211654620524
3.2116079156159687
1.3969711403020255
-2.079129704056099
-2.353567475948721
-0.9830507399049953
1.5297044080326259
-1.8901773975728133
-1.0594936735732847
-1.208468984514808
-2.657306017350861
-2.152291419934585
9.692025987479024
-100
-1.208203557519396
-1.7630377574354839
0.8989268472216849
-1.8842011765458142
-1.9899896408816733
-1.852689677486154
-1.4078810717421948
-1.559889204129945
3.0797244893990525
4.282126966744738
-2.0240728058514876
-2.2384219986139144
-2.4037642122698046
-2.5103010933243284
-1.531689923742788
2.9369145858799586
-1.8114163504164935
-1.7631758972314344
-1.7150298263931063
1.7861674532046947
1.3402037123692423
3.102499031089292
-1.9107221515682227
-2.476313569262827
1.1257477083192213
-2.76689911940869
-2.272584942772937
3.1170055942706485
2.7992990876184765
2.3650007253126946
-0.5983557678257683
0.8255881585543363
-3.21892322

-1.6553799329684296
-1.6187408722607586
-0.8155607652196
-0.6688641120718148
0.1514167039980066
-1.0270406003564005
-1.0191088407360382
2.4273862102295256
-1.1180076445489533
-0.5319272348559696
-0.4435417459128803
-1.8304078284279786
0.933854986907005
-1.0104448064728888
-0.9927181425574918
0.5818883508308772
-2.3034841502131655
7.195660565718384
1.7183895347158977
-100
0.3141720380559161
0.6406469622463373
1.0235453227385005
0.906784204864665
1.0174494977760957
1.0530948820566846
1.075322609528173
1.4066002704349512
1.5550727663294037
1.2427842594385379
-0.581051878108741
-1.1188197568475573
-1.7874695361773252
-0.5051736733118719
-0.5202689690480315
-0.9153107460311662
-0.1259056848940247
0.3390976186291084
1.2922142345683596
-1.0741915080611466
0.5178959660053237
-0.5187355322868712
-0.9696314673992379
-2.1164716344231067
-0.29448435498274533
-1.6025448384831396
-3.1688094357582473
-2.198121633288139
-2.2327137777145083
-2.8405127725683585
3.5060247588915674
-2.0455084682050426
1.5

-2.1262074242536912
-2.499034204353336
0.50926017184035
-1.4935283505552366
-1.4595597842414054
-1.9489136845840278
-1.9226832759198373
3.004515486293781
-1.9530660260597585
1.9365453511252155
-1.8520527279200962
-1.4015280842843356
3.552574376261089
-0.8849764940530054
-2.056328379475586
2.1746180071739163
-2.109311518037343
-0.8466389316798757
-1.4304790603154345
0.4244710586683141
1.898719369958269
1.4490327098622628
-2.0947675652871838
-0.6174104926580515
-0.4145977189630219
1.677381789707607
-0.836786112131847
-0.8096677392032063
3.5913501169639916
-0.18264677948047733
3.574130936243404
-0.5831195812130261
-0.5583230685723208
-0.5287192625951036
-0.4951413738057795
2.9032844517632954
2.237608971549048
-0.6727775567517256
3.6797218126348357
-0.7494959806818144
-0.7052696955562112
4.919520142442605
2.680651288747538
-0.342668456195496
1.8707713360476077
-2.1230294634205777
-2.522666576933203
1.565857879869111
0.5652532967007005
-2.2079112539649928
-1.9944023278301006
-1.789287152324

-0.7481560692910694
-1.2228400486555404
-0.7647321384644101
-0.8497178206529827
-0.330600115601128
-0.652068488272846
-0.6073615408852504
-0.6141702878898332
-0.2648312304924616
-0.24361377828517902
0.015107874515053937
-0.22342678010031136
3.2961528713106363
-0.2957234869706724
-0.25795077379336817
-0.2858352361487835
2.581422114080323
-0.032917921513557075
-0.221537298463943
-0.25593928211404093
4.294139681304171
-0.5366275732857264
-0.29021669473166756
5.825006780655127
-0.3292199835226757
1.523110053435812
-0.5980173364197583
2.714547251227043
-0.7753469793391901
-0.9738251670287912
-1.1930112328819007
5.016904467549762
1.9412218422966248
-2.065022619198969
2.178018710969792
1.0908577923913867
-2.994012640872596
11.248592624021892
2.8518618892116208
-100
2.033083110950089
2.742931365094195
-1.8758560455646205
2.5396656441671737
2.0521431116368376
-0.9677047870613649
-2.7885526146137054
-0.45530328317450425
-2.651637346151999
-2.1179653562604117
-2.5140865025027765
-2.81053254869033

1.5219555959783235
-2.3247138145004542
-1.6138326007146475
-2.6201322181802014
-0.8115451360270345
0.7102926734170694
-0.6046384045122568
-1.4173793209997996
1.1796816772198724
-2.43112935288019
-0.7115446847825797
-1.9109896454452724
-0.2792929303357983
-2.331973482850485
-2.497804221651164
-0.7352376061002122
0.8280673888099102
-1.3182567842531796
-1.30495255480389
-2.3555482248581527
-0.7797200683635868
-0.46457798587479604
-1.1123188482848718
1.5840365244982706
-0.39647966268858
-1.0563228303772974
-0.1514358865987606
-0.8377705855780562
0.2617555483276146
-0.5721984896668175
-1.4232687985372638
1.2601374342620602
1.0085360526496572
0.361312169395718
0.483254008771554
3.04449119778742
1.6554853004966332
0.4197181711706833
0.2172556291182002
-0.3151377686465082
0.7176686291568057
-0.08736015457353119
0.6820043839909726
2.9061830151107246
-0.8507125936050659
-0.1830559933922018
-0.9749068775676324
-0.3753131664147986
2.6962717909839116
0.12075691580898365
-1.5334755510564821
-0.70579

-0.962977710305239
3.3018372137635312
-0.080607785012279
1.7157974888689693
-0.7712821452571916
0.4230929261681922
-2.1713786196492877
-0.2917285973565231
2.3870247892010186
-0.7596591494981795
-0.33108860284798314
0.06745287762410726
-1.8243350200492319
0.5169403156286421
0.3866795219263793
2.017282994653482
-0.04023503694881356
-0.29765788798410425
-0.07861109923658091
-1.7845941008235104
-0.17188479764527528
1.690251070192528
2.0497192206794184
-2.35196769138406
1.2893450756929041
-2.4368192524902086
-2.483371815272362
-1.9142494815895361
-1.9849138600185938
-2.059762794686577
-3.1783712412834277
-2.431328821120502
-1.5498892040233432
-3.1985414599814406
-3.5766797162781914
-2.958248183375133
10.13204711686306
11.010108818121667
-100
-1.0782119433850312
-0.9446553669381217
0.1228595674797191
0.04286999475718403
-2.051973205065933
-1.1183454573580889
-0.2842145924096815
-0.18299149990761407
1.5890341646796855
2.2662976195578155
-1.0199933134715025
0.21918749130571882
-0.8712739218865

-1.0881602647189084
-0.011368234084318363
-1.9578642014639047
1.2047592786884536
1.6718195945870888
-2.561866389499271
2.576281335308022
-2.763052587746215
-1.0543117514403082
-0.21936262194195705
1.449090616198913
-2.889104141224293
-0.9649058799962222
0.25218033952326097
-1.697033447224726
-1.4262219685729634
-0.6141038585821537
-1.438109239501017
-1.4266642824198357
-0.9342593026350869
0.15751673945062522
-0.39801262029088774
-1.1727190458335486
-0.6996731530651801
0.13971645875552668
-1.2190364410242296
-2.1440679426969766
-0.42070223838406945
-1.3668742226603285
-0.581423008887241
-1.2167895837922629
-1.5244795652058258
-1.1584055372901787
-2.3760603155266424
-1.3669195386929403
-1.3620933811625662
-2.546283631654147
-1.3000639920888546
-2.52009230461985
-1.724861308870004
-0.45164013359084154
-0.2453054669513392
-2.87333449350898
-2.7729458286411157
-0.5995986128952768
-1.8066426880542281
-0.1444903528041823
-1.0719588699876113
0.2249509117621369
-0.8541800941283668
-2.0258727668

-0.1792421015956893
-0.6559514929927275
3.888928536083381
4.153716633217766
-0.891688173501052
1.5741163444630046
3.921965574112517
-0.47390647137402087
0.19500099446291302
-0.2454557863667037
-0.9806920389681295
-0.43671386016651925
-0.9868348136576515
1.818201526448928
-3.4886692598374225
-3.603553972360146
-3.7251229230672607
-1.1713080815149113
-3.9387354695441843
-4.38663115844713
-4.233876436231639
-3.705986800242384
-4.4954976986625175
10.138753450426437
-100
3.05936005665036
-1.8240974485100878
2.0342380282262527
-1.4069095555407773
-1.9679589065887046
1.4954363458363502
-1.720003542103285
-1.4209045349602423
-1.100546704044972
4.763791599291506
-1.5794453464288825
-1.3253732827106148
-1.6194132792036828
1.1464935544645527
-1.4695526488751796
4.403282888457949
-1.104410200025767
2.4709251870035702
4.59523491136643
1.1047134344761786
-1.2175135425818258
3.1691685633072213
4.28849197397376
-1.5348232840514686
1.2670347516517892
2.501153382521278
3.329577413157023
-1.9691209055241

1.02770892597469
-1.6409824884986108
0.9032167034840495
0.8337650901137124
1.0045322021643426
0.1007824192186888
0.067915998981249
1.0086110381680282
0.029020641201072978
-0.9718194845341657
-0.19621522257871787
0.02423121830817082
-0.9935238410472209
-1.0195174182059645
-1.5558317932488694
-0.9469174874078305
-1.687314651375998
-0.02638861781763352
1.4030598403145234
1.8579711611643972
-0.6579394963261109
-0.501778009671159
2.999353906209916
-1.6367346422710216
0.31828153956588495
-1.4143078504258415
-0.33900431681692567
0.11270169007434674
-0.5380826782777604
-0.9572496690588423
1.6906274773737209
1.848385404828872
1.7198777720476925
0.17292812407541305
-0.7820307645162359
-1.5012406154399958
1.200600185460371
-0.5640419284898144
-1.3251132114227449
3.966075743153408
1.8735955480502013
0.7263320440468817
1.1787148874274294
1.8453275735600527
-1.027319631397545
-0.32952154776054954
-0.3490258161472468
1.3664253309664047
-0.41706436360450994
-0.9465175571098985
0.41498226426986773
-0.5

-2.7485854746047664
12.142626695291323
1.351461267755327
-100
1.8670592471121494
-0.04550001829999245
-1.1575576008509643
-1.1644493336402206
-0.052637641014341624
-1.00540902667409
0.2749917157832169
1.695154926272562
-1.7507109138985857
-2.1878232240759687
0.8965259342350123
-0.08693444017778801
-1.002301160313749
-2.0027121221835658
-1.4517187076320284
-1.0858851360103188
-1.0812957812893274
-0.3198194194902999
0.2036432527506247
0.34824386784674854
-0.4835446660511309
1.4269685410583122
0.5602449756062924
-1.549116503091965
0.49597285710484695
2.74569355584124
-0.43440951010973095
1.0643495051511536
1.675709228633093
-1.5109609762821197
-0.7048273225561275
0.48587595501976577
-0.448810400528572
2.5542000444093107
1.9536132998380082
0.6362179415378659
0.679630604624408
1.7075402251405023
-0.9312554002709408
-0.9630471391555193
2.161759728336915
2.697354240557684
-1.355996428332246
0.05338934070644655
-0.5098108843532145
2.5373828553603177
0.6006416326180044
-0.6157889415164561
-0.62

In [115]:
# #save result
a=np.array(sarsa_total_rewards)
np.save(f'sarsa_total_rewards_width={width}.npy',a) 
with open(f"sarsa_qtable_width={width}.pkl", "wb") as pkl_handle:
    pickle.dump(q_observations, pkl_handle)

# sarsa_total_rewards = np.load(f'sarsa_total_rewards_width={width}.npy').tolist()

# DQN

In [108]:
def dqn_lander(env, n_episodes, gamma, lr, min_eps, width, \
                batch_size=32, memory_capacity=50000, \
                network='linear', learning_starts=1000, \
                train_freq=1, target_network_update_freq=1000, \
                print_freq=500, render_freq=500, save_freq=1000):

    # set device to run on
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    loss_function = torch.nn.MSELoss()

    # path to save checkpoints
    PATH = "./models"
    if not os.path.isdir(PATH):
        os.mkdir(PATH)

    num_actions = env.action_space.n
    
    input_shape = env.observation_space.shape[-1]+1
    
    qnet, qnet_optim = build_qnetwork(num_actions, lr, input_shape, network, device)
    qtarget_net, _ = build_qnetwork(num_actions, lr, input_shape, network, device)
    qtarget_net.load_state_dict(qnet.state_dict())
    qnet.train()
    qtarget_net.eval()
    replay_memory = ReplayMemory(memory_capacity)
    

    epsilon = 1.0 
    return_per_ep = [0.0] 
    saved_mean_reward = None
    t = 0

    for i in range(n_episodes):
        action_capacity=1000
        
        
        curr_state = lmn_input(env.reset())
        
        #add one dimension for blind or not, 1 for blind, 0 for not blind 
        curr_observation=torch.cat( (curr_state.squeeze(0),torch.tensor([0])) ).unsqueeze(0)
        curr_observation = torch.tensor([[0.,0.,0.,0.,0.,0.,0.,0.,1.]]) if curr_state[0][1]>=1-width/2 and curr_state[0][1]<=1+width/2 else curr_observation    

        

        while True:

            # choose action A using behaviour policy -> ε-greedy; use q-network
            action = epsilon_greedy(qnet, curr_observation.to(device), epsilon, num_actions)
            
            # take action A, earn immediate reward R and land into next state S'
            next_state, reward, done, _ = env.step(action)
            
            
            
            next_state = lmn_input(next_state)
            
            #add one dimension for blind or not, 1 for blind, 0 for not blind 
            next_observation =torch.cat( (next_state.squeeze(0),torch.tensor([0])) ).unsqueeze(0)
            next_observation =torch.tensor([[0.,0.,0.,0.,0.,0.,0.,0.,1.]]) if next_state[0][1]>=1-width/2 and  next_state[0][1]<=1+width/2 else next_observation

            # store transition (S, A, R, S', Done) in replay memory
            replay_memory.store(curr_observation, action, float(reward), next_observation, float(done))
            
            
            # if replay memory currently stores > 'learning_starts' transitions,
            # sample a random mini-batch and update q_network's parameters
            if t > learning_starts and t % train_freq == 0:
                curr_observations, actions, rewards, next_observations, dones = replay_memory.sample_minibatch(batch_size)
                #loss = 
                fit(qnet, \
                    qnet_optim, \
                    qtarget_net, \
                    loss_function, \
                    curr_observations, \
                    actions, \
                    rewards, \
                    next_observations, \
                    dones, \
                    gamma, \
                    num_actions, 
                    device)

            # periodically update q-target network's parameters
            if t > learning_starts and t % target_network_update_freq == 0:
                update_target_network(qnet, qtarget_net)

            t += 1
            return_per_ep[-1] += reward

            if done:
                if (i + 1) % print_freq == 0:
                    print("\nEpisode: {}".format(i + 1))
                    print("Episode return : {}".format(return_per_ep[-1]))
                    print("Total time-steps: {}".format(t))

                if (i + 1) % 100 == 0:
                    mean_100ep_reward = round(np.mean(return_per_ep[-101:-1]), 1)
                    print("\nLast 100 episodes mean reward: {}".format(mean_100ep_reward))

                if t > learning_starts and (i + 1) % save_freq == 0:
                    if saved_mean_reward is None or mean_100ep_reward > saved_mean_reward:
                        print("\nSaving model due to mean reward increase: {} -> {}".format(saved_mean_reward, mean_100ep_reward))
                        save_model(qnet, f"width={width}_episode={i+1}_rw={mean_100ep_reward}", PATH)
                        saved_mean_reward = mean_100ep_reward

                return_per_ep.append(0.0)
                epsilon = decay_epsilon(epsilon, min_eps)

                break
            current_state = next_state
            curr_observation = next_observation

    return return_per_ep

In [109]:
n_episodes= 10000
lr = 0.0005
gamma = 0.99
final_eps = 0.01

environment = gym.make("LunarLander-v2")

In [110]:
print("\nTraining DQN lander with arguments num_episodes={}, learning rate={}, gamma={}, final_epsilon={} ..."\
                            .format(n_episodes,lr, gamma, final_eps))
dqn_total_rewards = dqn_lander(environment, n_episodes, gamma, lr, final_eps,width)
print("Done!")


Training DQN lander with arguments num_episodes=100, learning rate=0.0005, gamma=0.99, final_epsilon=0.01 ...

Last 100 episodes mean reward: -153.8
Done!


In [116]:
# #save result
# a=np.array(dqn_total_rewards)
np.save(f'dqn_total_rewards_width={width}_lr={lr}_epochs={n_episodes}_dim=9.npy',a) 

dqn_total_rewards = np.load(f'dqn_total_rewards_width={width}.npy').tolist()

# DQN with reward

In [62]:
def dqn_lander_reward(env, n_episodes, gamma, lr, min_eps, width, \
                batch_size=32, memory_capacity=50000, \
                network='linear', learning_starts=1000, \
                train_freq=1, target_network_update_freq=1000, \
                print_freq=500, render_freq=500, save_freq=1000):

    # set device to run on
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    loss_function = torch.nn.MSELoss()

    # path to save checkpoints
    PATH = "./models"
    if not os.path.isdir(PATH):
        os.mkdir(PATH)

    num_actions = env.action_space.n   
    input_shape = env.observation_space.shape[-1]+1
    
    qnet, qnet_optim = build_qnetwork(num_actions, lr, input_shape, network, device)
    qtarget_net, _ = build_qnetwork(num_actions, lr, input_shape, network, device)
    qtarget_net.load_state_dict(qnet.state_dict())
    qnet.train()
    qtarget_net.eval()
    replay_memory = ReplayMemory(memory_capacity)
    

    epsilon = 1.0 
    return_per_ep = [0.0] 
    additional_rewards =[]
    saved_mean_reward = None
    t = 0

    for i in range(n_episodes):
        
        curr_state = lmn_input(env.reset())
        
        #add one dimension for blind or not, 1 for blind, 0 for not blind 
        curr_observation=torch.cat( (curr_state.squeeze(0),torch.tensor([0])) ).unsqueeze(0)
        curr_observation = torch.tensor([[0.,0.,0.,0.,0.,0.,0.,0.,1.]]) if curr_state[0][1]>=1-width/2 and curr_state[0][1]<=1+width/2 else curr_observation    
        
        # a decay factor to encourage exploration at early stage
        decay=np.exp(-0.001*n_episodes)
        past_actions=np.array([0,0,0,0])

        one_episode_reward=[]
        
        while True:

            # choose action A using behaviour policy -> ε-greedy; use q-network
            action = epsilon_greedy(qnet, curr_observation.to(device), epsilon, num_actions)
            

            # take action A, earn immediate reward R and land into next state S'
            next_state, reward, done, _ = env.step(action)

            past_actions[action]+=1
            c=np.array([0,0,0,0])
            c[action]=1
            
            
            add_rw=np.linalg.norm(c-past_actions/np.linalg.norm(past_actions)) *decay
#             print(add_rw)
            
            one_episode_reward.append(add_rw)
            reward+= add_rw

            
            next_state = lmn_input(next_state)
            
            #add one dimension for blind or not, 1 for blind, 0 for not blind 
            next_observation =torch.cat( (next_state.squeeze(0),torch.tensor([0])) ).unsqueeze(0)
            next_observation =torch.tensor([[0.,0.,0.,0.,0.,0.,0.,0.,1.]]) if next_state[0][1]>=1-width/2 and  next_state[0][1]<=1+width/2 else next_observation

            # store transition (S, A, R, S', Done) in replay memory
            replay_memory.store(curr_observation, action, float(reward), next_observation, float(done))
            
            # if replay memory currently stores > 'learning_starts' transitions,
            # sample a random mini-batch and update q_network's parameters
            if t > learning_starts and t % train_freq == 0:
                curr_observations, actions, rewards, next_observations, dones = replay_memory.sample_minibatch(batch_size)
                #loss = 
                fit(qnet, \
                    qnet_optim, \
                    qtarget_net, \
                    loss_function, \
                    curr_observations, \
                    actions, \
                    rewards, \
                    next_observations, \
                    dones, \
                    gamma, \
                    num_actions, 
                    device)

            # periodically update q-target network's parameters
            if t > learning_starts and t % target_network_update_freq == 0:
                update_target_network(qnet, qtarget_net)

            t += 1
            return_per_ep[-1] += reward

            if done:
                if (i + 1) % print_freq == 0:
                    print("\nEpisode: {}".format(i + 1))
                    print("Episode return : {}".format(return_per_ep[-1]))
                    print("Total time-steps: {}".format(t))

                if (i + 1) % 100 == 0:
                    mean_100ep_reward = round(np.mean(return_per_ep[-101:-1]), 1)
                    print("\nLast 100 episodes mean reward: {}".format(mean_100ep_reward))

                if t > learning_starts and (i + 1) % save_freq == 0:
                    if saved_mean_reward is None or mean_100ep_reward > saved_mean_reward:
                        print("\nSaving model due to mean reward increase: {} -> {}".format(saved_mean_reward, mean_100ep_reward))
                        save_model(qnet, f"width={width}_episode={i+1}_rw={mean_100ep_reward}_reward", PATH)
                        saved_mean_reward = mean_100ep_reward

                return_per_ep.append(0.0)
                epsilon = decay_epsilon(epsilon, min_eps)

                break
            current_state = next_state
            curr_observation = next_observation
        
        additional_rewards.append(np.sum(one_episode_reward))

    return return_per_ep,additional_rewards

In [63]:
n_episodes= 10000
lr = 0.0005
gamma = 0.99
final_eps = 0.01

environment = gym.make("LunarLander-v2")

In [64]:
print("\nTraining DQN lander with arguments num_episodes={}, learning rate={}, gamma={}, final_epsilon={} ..."\
                            .format(n_episodes,lr, gamma, final_eps))
dqn_total_rewards,additional_rewards = dqn_lander_reward(environment, n_episodes, gamma, lr, final_eps,width)
print("Done!")


Training DQN lander with arguments num_episodes=10000, learning rate=0.0005, gamma=0.99, final_epsilon=0.01 ...

Last 100 episodes mean reward: -148.9

Last 100 episodes mean reward: -75.7

Last 100 episodes mean reward: -118.9

Last 100 episodes mean reward: 3.0

Episode: 500
Episode return : 168.4824922625362
Total time-steps: 194830

Last 100 episodes mean reward: 71.6

Last 100 episodes mean reward: 165.6

Last 100 episodes mean reward: 209.4

Last 100 episodes mean reward: 205.5

Last 100 episodes mean reward: 219.5

Episode: 1000
Episode return : 257.0917478959576
Total time-steps: 389397

Last 100 episodes mean reward: 207.1

Saving model due to mean reward increase: None -> 207.1

Last 100 episodes mean reward: 211.1

Last 100 episodes mean reward: 232.9

Last 100 episodes mean reward: 216.8

Last 100 episodes mean reward: 175.5

Episode: 1500
Episode return : 246.03165751894392
Total time-steps: 553589

Last 100 episodes mean reward: 206.5

Last 100 episodes mean reward: 231.

KeyboardInterrupt: 

In [None]:
a=np.array(dqn_total_rewards)
b=np.array(additional_rewards)
np.save(f'dqn_total_rewards_width={width}_lr={lr}_epochs={n_episodes}_dim=9_reward.npy',a) 
np.save(f'dqn_total_rewards_width={width}_lr={lr}_epochs={n_episodes}_dim=9_additional_reward.npy',b) 


In [None]:
# 模型不变，跑结果
# 画图