File tree Expand file tree Collapse file tree 2 files changed +12
-13
lines changed Expand file tree Collapse file tree 2 files changed +12
-13
lines changed Original file line number Diff line number Diff line change 2
2
import numpy as np
3
3
4
4
def train (model , memory , params ):
5
- if len (memory ) < params ['batchSize' ]: return np .Inf
6
-
7
5
modelClone = tf .keras .models .clone_model (model )
8
6
modelClone .set_weights (model .get_weights ()) # use clone model for stability
9
7
Original file line number Diff line number Diff line change @@ -87,17 +87,18 @@ def testModel(EXPLORE_RATE):
87
87
)
88
88
print ('Avg. train loss: %.4f' % trainLoss )
89
89
90
- trainLoss = fit_stage .train (
91
- model , doomMemory ,
92
- {
93
- 'gamma' : GAMMA ,
94
- 'batchSize' : BATCH_SIZE ,
95
- 'steps' : BOOTSTRAPPED_STEPS ,
96
- 'episodes' : params ['train doom episodes' ](epoch ),
97
- 'alpha' : params .get ('doom alpha' , lambda _ : alpha )(epoch )
98
- }
99
- )
100
- print ('Avg. train doom loss: %.4f' % trainLoss )
90
+ if params ['batchSize' ] < len (doomMemory ):
91
+ trainLoss = fit_stage .train (
92
+ model , doomMemory ,
93
+ {
94
+ 'gamma' : GAMMA ,
95
+ 'batchSize' : BATCH_SIZE ,
96
+ 'steps' : BOOTSTRAPPED_STEPS ,
97
+ 'episodes' : params ['train doom episodes' ](epoch ),
98
+ 'alpha' : params .get ('doom alpha' , lambda _ : alpha )(epoch )
99
+ }
100
+ )
101
+ print ('Avg. train doom loss: %.4f' % trainLoss )
101
102
##################
102
103
# test
103
104
print ('Testing...' )
You can’t perform that action at this time.
0 commit comments