In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from sklearn import preprocessing
import RK4

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
r = RK4.Rocket()

dAngle = 0.05*np.pi
dThrottle = 0.05
dt = 0.1

#R = []
#RX = []
#RY = []
#V = []
#Fuel = []

def step(r, action):
    pt_vars = [r.rx, r.ry, r.vx, r.vy, r.fuelm]
    
    angle, throttle, staged = r.input_vars
    
    if action == 0:
        angle += dAngle
        throttle += dThrottle
    elif action == 1:
        angle -= dAngle
        throttle += dThrottle
    elif action == 2:
        angle += dAngle
        throttle -= dThrottle
    elif action == 3:
        angle -= dAngle
        throttle -= dThrottle
    elif action == 4:
        angle = angle
        throttle = throttle
    else:
        print("NO ACTION SELECTED")
        
    if throttle > 1:
        throttle = 1
    elif throttle < 0:
        throttle = 0

    
    #OUT OF FUEL
    if r.fuelm <= 0:
        print("Out Of Fuel")
        done = True
    #CRASHES
    elif np.sqrt(r.rx**2+r.ry**2) < r.sea:
        print("Crashed")
        done = True
    #TIME CUTOFF
    elif r.t > 100000:
        print("Time Limit")
        done = True
    #ACHIEVE ORBIT HERE DON"T FORGET TO INCLUDE THIS HERE 
    else:
        done = False
        
    r.input_vars = angle, throttle, staged 
    r.RK4_step(pt_vars, dt, r.input_vars)
    
    return r, done

'''
actions = {
(+throttle, +angle),
(+throttle, -angle),
(-throttle, +angle),
(-throttle, -angle),
(do nothing)
}
'''

def reward(state):
    rx, ry, vx, vy, m, has_staged = state[0]
    
    x = np.sqrt(rx**2+ry**2)
    v = np.sqrt(vx**2+vy**2)
    
    rw_x = 1/(1+np.exp(-0.00004*(x-r.karman)))
    rw_v = 1/(1+np.exp(-0.001*(v-7790/2)))
    rw_m = 0  #0.1*m/r.mi
    
    punish = 0
    if np.sqrt(r.rx**2+r.ry**2) < r.sea:
        punish = -2
    
    return rw_x + rw_v + rw_m + punish

In [3]:
inputs = keras.Input(shape=(6))
x = keras.layers.Dense(16, activation="relu")(inputs)
x = keras.layers.Dense(32, activation="relu")(x)
outputs = keras.layers.Dense(5, activation=None)(x)

model = keras.Model(inputs=inputs, outputs=outputs)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
              loss=tf.keras.losses.MeanSquaredError())

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


In [None]:
epochs = 2000
greed = 0.5
greed_decay = 0.99
discount_factor = 0.9

all_RX = []
all_RY = []
all_rewards =[]

for i in range(epochs):
    print("EPOCH: ", i)
    r = RK4.Rocket()
    state = np.array([[r.rx, r.ry, r.vx, r.vy, r.m, int(r.has_staged)]])
    greed *= greed_decay
    
    RX = []
    RY = []
    rewards = []
    
    done = False
    while not done:
        if np.random.random() < greed:
            action = np.random.randint(0, 5)
        else:
            action = np.argmax(model.predict(state))
            
        RX.append(r.rx)
        RY.append(r.ry)
            
        r, done = step(r, action)
        new_state = np.array([[r.rx, r.ry, r.vx, r.vy, r.m, int(r.has_staged)]])
                
        target = reward(state) + discount_factor * np.max(model.predict(new_state))
        rewards.append(reward(state))
        
        target_vector = np.zeros(5)
        target_vector[action] = target
        model.fit(state, np.array([target_vector]), epochs=1, verbose=0)
        state = new_state
        if done:
            print(reward(state))
        
    all_rewards.append(max(rewards))
    if i % 100 == 0:
        all_RX.append(RX)
        all_RY.append(RY)
        

EPOCH:  0
Crashed
-1.9576421020622474
EPOCH:  1
Crashed
-1.9613765567863861
EPOCH:  2
Crashed
-1.9612042657546214
EPOCH:  3
Crashed
-1.9606430090263989
EPOCH:  4
Crashed
-1.961285878048273
EPOCH:  5
Crashed
-1.9591385895673525
EPOCH:  6
Crashed
-1.9584220078290424
EPOCH:  7
Crashed
-1.9604480918449447
EPOCH:  8
Crashed
-1.961359199774587
EPOCH:  9
Crashed
-1.961692644636267
EPOCH:  10
Crashed
-1.960817078784599
EPOCH:  11
Crashed
-1.9609321070028676
EPOCH:  12
Crashed
-1.9612203115973799
EPOCH:  13
Crashed
-1.9610937986195645
EPOCH:  14
Crashed
-1.9611093290893347
EPOCH:  15
Crashed
-1.958389936145275
EPOCH:  16
Crashed
-1.9599701787632227
EPOCH:  17
Crashed
-1.9609689420258576
EPOCH:  18
Crashed
-1.9612423538863286
EPOCH:  19
Crashed
-1.9603903702159613
EPOCH:  20
Crashed
-1.9601780367701405
EPOCH:  21
Crashed
-1.9611777552383736
EPOCH:  22
Crashed
-1.9612259952196298
EPOCH:  23
Crashed
-1.961148479759094
EPOCH:  24
Crashed
-1.959165190299767
EPOCH:  25
Crashed
-1.960501146323413
EPOC

Crashed
-1.9601777637250937
EPOCH:  211
Crashed
-1.9611376398533549
EPOCH:  212
Crashed
-1.9611973095555468
EPOCH:  213
Crashed
-1.9609300540703647
EPOCH:  214
Crashed
-1.9588836697165521
EPOCH:  215
Crashed
-1.961029784434266
EPOCH:  216
Crashed
-1.9613577267035345
EPOCH:  217
Crashed
-1.960762676622143
EPOCH:  218
Crashed
-1.9608080866280462
EPOCH:  219
Crashed
-1.9610822891466106
EPOCH:  220
Crashed
-1.9609111112411013
EPOCH:  221
Crashed
-1.9613766820816974
EPOCH:  222
Crashed
-1.9613884564162551
EPOCH:  223
Crashed
-1.9611626793561132
EPOCH:  224
Crashed
-1.961040524901168
EPOCH:  225
Crashed
-1.960775175259834
EPOCH:  226
Crashed
-1.9611375890605112
EPOCH:  227
Crashed
-1.9592008154305076
EPOCH:  228
Crashed
-1.961409492772489
EPOCH:  229
Crashed
-1.9614563517049002
EPOCH:  230
Crashed
-1.9614462168038715
EPOCH:  231
Crashed
-1.9606495573586051
EPOCH:  232
Crashed
-1.9613873866180915
EPOCH:  233
Crashed
-1.961412152172423
EPOCH:  234
Crashed
-1.9613770643369333
EPOCH:  235
Crashe

Crashed
-1.9605207578706318
EPOCH:  417
Crashed
-1.959899839158017
EPOCH:  418
Crashed
-1.9609768279126591
EPOCH:  419
Crashed
-1.9608689106315276
EPOCH:  420
Crashed
-1.9610426228705593
EPOCH:  421
Crashed
-1.9603688980646874
EPOCH:  422
Crashed
-1.960557545224468
EPOCH:  423
Crashed
-1.9613044934693935
EPOCH:  424
Crashed
-1.960514665942163
EPOCH:  425
Crashed
-1.9601847951840383
EPOCH:  426
Crashed
-1.961343736525194
EPOCH:  427
Crashed
-1.9600713118652662
EPOCH:  428
Crashed
-1.9606715205266352
EPOCH:  429
Crashed
-1.9610680018336375
EPOCH:  430
Crashed
-1.9602373904908386
EPOCH:  431
Crashed
-1.960773046573763
EPOCH:  432
Crashed
-1.9606782452257763
EPOCH:  433
Crashed
-1.9614299701637186
EPOCH:  434
Crashed
-1.9611519475892827
EPOCH:  435
Crashed
-1.960901286856805
EPOCH:  436
Crashed
-1.9602265454756649
EPOCH:  437
Crashed
-1.9611200472874297
EPOCH:  438
Crashed
-1.9605739007516392
EPOCH:  439
Crashed
-1.9613141984009836
EPOCH:  440
Crashed
-1.9599925405055232
EPOCH:  441
Crashe

Crashed
-1.9606638692975635
EPOCH:  625
Crashed
-1.9608342247247157
EPOCH:  626
Crashed
-1.9608994333386982
EPOCH:  627
Crashed
-1.9615536652378989
EPOCH:  628
Crashed
-1.960880688184415
EPOCH:  629
Crashed
-1.9602496636437194
EPOCH:  630
Crashed
-1.9615263173455648
EPOCH:  631
Crashed
-1.9615033729919613
EPOCH:  632
Crashed
-1.9611941763229368
EPOCH:  633
Crashed
-1.9611939643289358
EPOCH:  634
Crashed
-1.9611939643289358
EPOCH:  635
Crashed
-1.9611939643289358
EPOCH:  636
Crashed
-1.9611939643289358
EPOCH:  637
Crashed
-1.9606938385038297
EPOCH:  638
Crashed
-1.9615536652378989
EPOCH:  639
Crashed
-1.9608804049785187
EPOCH:  640
Crashed
-1.9608787243293098
EPOCH:  641
Crashed
-1.9610716655199267
EPOCH:  642
Crashed
-1.960985882204125
EPOCH:  643
Crashed
-1.9608049848349105
EPOCH:  644
Crashed
-1.9610754338941403
EPOCH:  645
Crashed
-1.9611444808182896
EPOCH:  646
Crashed
-1.9611444808182896
EPOCH:  647
Crashed
-1.9614137518876424
EPOCH:  648
Crashed
-1.9613726567451113
EPOCH:  649
Cr

Crashed
-1.9614218307021145
EPOCH:  832
Crashed
-1.961415849553369
EPOCH:  833
Crashed
-1.9614285265277915
EPOCH:  834
Crashed
-1.961314631927072
EPOCH:  835
Crashed
-1.9614393111875894
EPOCH:  836
Crashed
-1.9614558948131373
EPOCH:  837
Crashed
-1.9615832884281084
EPOCH:  838
Crashed
-1.9614176590980872
EPOCH:  839
Crashed
-1.9614556113648045
EPOCH:  840
Crashed
-1.9614309711210036
EPOCH:  841
Crashed
-1.9613410553900608
EPOCH:  842
Crashed
-1.9614339335459356
EPOCH:  843
Crashed
-1.9614398643160655
EPOCH:  844
Crashed
-1.96138811191241
EPOCH:  845
Crashed
-1.9616580071608958
EPOCH:  846
Crashed
-1.961443579453802
EPOCH:  847
Crashed
-1.961420987329989
EPOCH:  848
Crashed
-1.961418540554124
EPOCH:  849
Crashed
-1.9613809171039822
EPOCH:  850
Crashed
-1.9613760506235909
EPOCH:  851
Crashed
-1.9611952785735072
EPOCH:  852
Crashed
-1.9612935355770584
EPOCH:  853
Crashed
-1.9614556584393197
EPOCH:  854
Crashed
-1.9614254275226242
EPOCH:  855
Crashed
-1.9613172301480561
EPOCH:  856
Crashed

Out Of Fuel
1.7478260844970537
EPOCH:  1036
Out Of Fuel
1.7478260844970537
EPOCH:  1037
Out Of Fuel
1.7478260844970537
EPOCH:  1038
Out Of Fuel
1.7478260844970537
EPOCH:  1039
Out Of Fuel
1.7478260844970537
EPOCH:  1040
Out Of Fuel
1.7478260844970537
EPOCH:  1041
Out Of Fuel
1.7478260844970537
EPOCH:  1042
Out Of Fuel
1.7478260844970537
EPOCH:  1043
Out Of Fuel
1.7478260844970537
EPOCH:  1044
Out Of Fuel
1.7478260844970537
EPOCH:  1045


In [None]:
#NEED TO PLOT SOMETHING TO SHOW PROGRESS
all_RX = np.array(all_RX)
all_RY = np.array(all_RY)


rad = 6378000
surface = plt.Circle((0, 0), rad, color='b', fill=False)
#karman = rad + 100000

ax = plt.gca()
for i in range(len(all_RX)):
    plt.plot(all_RX[i], all_RY[i], label=f"{i}")
    
ax.add_patch(surface)
plt.xlim(min([min(i) for i in all_RX]), max([max(i) for i in all_RX]))
plt.ylim(min([min(i) for i in all_RY]), max([max(i) for i in all_RY]))
plt.title("X vs. Y Position")
plt.axes().set_aspect('equal','datalim')
plt.legend()
plt.show()

In [None]:
plt.plot(all_rewards)