Skip to content

Commit 3787ab3

Browse files
misc
1 parent be83836 commit 3787ab3

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

Core/CModelTrainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
class CModelTrainer(CModelWrapper):
77
def __init__(self, timesteps, model='simple', **kwargs):
88
super().__init__(timesteps, model=model, **kwargs)
9-
self._compile()
9+
self.compile()
1010
# add signatures to help tensorflow optimize the graph
1111
specification = self._modelRaw['inputs specification']
1212
self._trainStep = tf.function(
@@ -27,7 +27,7 @@ def __init__(self, timesteps, model='simple', **kwargs):
2727
)
2828
return
2929

30-
def _compile(self):
30+
def compile(self):
3131
self._model.compile(optimizer=NNU.createOptimizer())
3232
return
3333

scripts/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,12 @@ def evaluate(onlyImproved=False):
7777
totalLoss = totalDist = 0.0
7878
for i, dataset in enumerate(datasets):
7979
loss, dist, T = _eval(dataset, model, os.path.join(folder, 'pred-%d.png' % i), args)
80-
if not onlyImproved:
80+
isImproved = loss < losses[i]
81+
if (not onlyImproved) or isImproved:
8182
print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f' % (
8283
i + 1, len(datasets), T, loss, losses[i], dist
8384
))
84-
if loss < losses[i]:
85+
if isImproved:
8586
print('Test %d / %d | Improved %.5f => %.5f' % (i + 1, len(datasets), losses[i], loss))
8687
model.save(folder, postfix='best-%d' % i) # save the model separately
8788
losses[i] = loss
@@ -197,6 +198,7 @@ def averageModels(folder, model, noiseStd=0.0):
197198
# average the weights
198199
TV = [(x / N) + np.random.normal(0.0, noiseStd, x.shape) for x in TV]
199200
model._model.set_weights(TV)
201+
model.compile() # recompile the model with the new weights
200202
return
201203

202204
def main(args):

0 commit comments

Comments
 (0)