Skip to content

Commit 6f3979c

Browse files
fix misc
1 parent e9550bd commit 6f3979c

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

fit_stage.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import numpy as np
33

44
def train(model, memory, params):
5-
if len(memory) < params['batchSize']: return np.Inf
6-
75
modelClone = tf.keras.models.clone_model(model)
86
modelClone.set_weights(model.get_weights()) # use clone model for stability
97

learn_environment.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,18 @@ def testModel(EXPLORE_RATE):
8787
)
8888
print('Avg. train loss: %.4f' % trainLoss)
8989

90-
trainLoss = fit_stage.train(
91-
model, doomMemory,
92-
{
93-
'gamma': GAMMA,
94-
'batchSize': BATCH_SIZE,
95-
'steps': BOOTSTRAPPED_STEPS,
96-
'episodes': params['train doom episodes'](epoch),
97-
'alpha': params.get('doom alpha', lambda _: alpha)(epoch)
98-
}
99-
)
100-
print('Avg. train doom loss: %.4f' % trainLoss)
90+
if params['batchSize'] < len(doomMemory):
91+
trainLoss = fit_stage.train(
92+
model, doomMemory,
93+
{
94+
'gamma': GAMMA,
95+
'batchSize': BATCH_SIZE,
96+
'steps': BOOTSTRAPPED_STEPS,
97+
'episodes': params['train doom episodes'](epoch),
98+
'alpha': params.get('doom alpha', lambda _: alpha)(epoch)
99+
}
100+
)
101+
print('Avg. train doom loss: %.4f' % trainLoss)
101102
##################
102103
# test
103104
print('Testing...')

0 commit comments

Comments
 (0)