Skip to content

Commit 4f13e0a

Browse files
refactor: Adjust image augmentation parameters in training script
1 parent 28745d4 commit 4f13e0a

File tree

5 files changed

+60
-24
lines changed

5 files changed

+60
-24
lines changed

Core/CModelTrainer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, timesteps, model='simple', **kwargs):
2828
return
2929

3030
def compile(self):
31-
self._model.compile(optimizer=NNU.createOptimizer())
31+
self._optimizer = NNU.createOptimizer()
3232
return
3333

3434
def _pointLoss(self, ytrue, ypred):
@@ -46,22 +46,22 @@ def _trainStep(self, Data):
4646
x, (y, ) = Data
4747
y = y[..., 0, :]
4848
losses = {}
49-
parts = list(self._embeddings.values()) + [self._model]
50-
TV = sum([p.trainable_variables for p in parts], [])
51-
with tf.GradientTape(watch_accessed_variables=False) as tape:
52-
tape.watch(TV)
49+
with tf.GradientTape() as tape:
5350
data = x['augmented']
5451
data = self._replaceByEmbeddings(data)
5552
predictions = self._model(data, training=True)
56-
predictions = dict(**predictions['intermediate'], final=predictions['result'])
57-
for name, pts in predictions.items():
53+
intermediate = predictions['intermediate']
54+
losses['final'] = tf.reduce_mean(self._pointLoss(y, predictions['result']))
55+
for name, encoder in self._intermediateEncoders.items():
56+
latent = intermediate[name]
57+
pts = encoder(latent, training=True)
5858
loss = self._pointLoss(y, pts)
5959
losses['loss-%s' % name] = tf.reduce_mean(loss)
6060
continue
6161
loss = sum(losses.values())
6262
losses['loss'] = loss
6363

64-
self._model.optimizer.minimize(loss, TV, tape=tape)
64+
self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
6565
###############
6666
return losses
6767

Core/CModelWrapper.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from tensorflow.keras import layers as L
66

77
class CModelWrapper:
8-
def __init__(self, timesteps, model='simple', user=None, stats=None, **kwargs):
8+
def __init__(self, timesteps, model='simple', user=None, stats=None, use_encoders=True, **kwargs):
99
if user is None:
1010
user = {
1111
'userId': 0,
@@ -37,6 +37,15 @@ def __init__(self, timesteps, model='simple', user=None, stats=None, **kwargs):
3737
'placeId': L.Embedding(len(stats['placeId']), embeddings['size']),
3838
'screenId': L.Embedding(len(stats['screenId']), embeddings['size']),
3939
}
40+
self._intermediateEncoders = {}
41+
if use_encoders:
42+
shapes = self._modelRaw['intermediate shapes']
43+
for name, shape in shapes.items():
44+
enc = networks.IntermediatePredictor(name='%s-encoder' % name)
45+
enc.build(shape)
46+
self._intermediateEncoders[name] = enc
47+
continue
48+
4049
if 'weights' in kwargs:
4150
self.load(**kwargs['weights'])
4251
return
@@ -80,6 +89,15 @@ def save(self, folder=None, postfix=''):
8089
embeddings[nm] = weights
8190
continue
8291
np.savez_compressed(path.replace('.h5', '-embeddings.npz'), **embeddings)
92+
# save intermediate encoders
93+
if self._intermediateEncoders:
94+
encoders = {}
95+
for nm, encoder in self._intermediateEncoders.items():
96+
# save each variable separately
97+
for ww in encoder.trainable_variables:
98+
encoders['%s-%s' % (nm, ww.name)] = ww.numpy()
99+
continue
100+
np.savez_compressed(path.replace('.h5', '-intermediate-encoders.npz'), **encoders)
83101
return
84102

85103
def load(self, folder=None, postfix='', embeddings=False):
@@ -92,6 +110,16 @@ def load(self, folder=None, postfix='', embeddings=False):
92110
if not emb.built: emb.build((None, w.shape[0]))
93111
emb.set_weights([w]) # replace embeddings
94112
continue
113+
114+
if self._intermediateEncoders:
115+
encodersName = path.replace('.h5', '-intermediate-encoders.npz')
116+
if os.path.isfile(encodersName):
117+
encoders = np.load(encodersName)
118+
for nm, encoder in self._intermediateEncoders.items():
119+
for ww in encoder.trainable_variables:
120+
w = encoders['%s-%s' % (nm, ww.name)]
121+
ww.assign(w)
122+
continue
95123
return
96124

97125
def lock(self, isLocked):
@@ -101,4 +129,7 @@ def lock(self, isLocked):
101129
@property
102130
def timesteps(self):
103131
return self._timesteps
104-
132+
133+
def trainable_variables(self):
134+
parts = list(self._embeddings.values()) + [self._model] + list(self._intermediateEncoders.values())
135+
return sum([p.trainable_variables for p in parts], [])

NN/Utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def call(self, x):
2727
coefs = tf.pow(self._base, powers)
2828
return tf.reduce_sum(x * coefs, axis=-1)
2929
############################################
30-
SMLP_GLOBAL_DROPOUT = 0.01
30+
SMLP_GLOBAL_DROPOUT = 0.0
3131
class sMLP(tf.keras.layers.Layer):
3232
def __init__(self, sizes, activation='linear', dropout=None, **kwargs):
3333
super().__init__(**kwargs)
@@ -40,7 +40,11 @@ def __init__(self, sizes, activation='linear', dropout=None, **kwargs):
4040
continue
4141
self._F = tf.keras.Sequential(layers, name=self.name + '/F')
4242
return
43-
43+
44+
def build(self, input_shape):
45+
self._F.build(input_shape)
46+
return super().build(input_shape)
47+
4448
def call(self, x, **kwargs):
4549
return self._F(x, **kwargs)
4650
############################################

NN/networks.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def build(self, input_shape):
2525
sizes=[128, 64, 32], activation='relu',
2626
name='%s/MLP' % self.name
2727
)
28+
self._mlp.build(input_shape)
2829
self._decodePoints = L.Dense(2, name='%s/DecodePoints' % self.name)
2930
return super().build(input_shape)
3031

@@ -168,8 +169,6 @@ def Face2LatentModel(
168169
**stepsData['intermediate'],
169170
**res['intermediate'],
170171
}
171-
# drop all intermediate outputs
172-
res['intermediate'] = {}
173172

174173
inputs = {
175174
'points': points,
@@ -181,13 +180,10 @@ def Face2LatentModel(
181180
'screenId': screenIdEmb,
182181
}
183182

184-
intermediate = res['intermediate']
185-
IP = lambda x: IntermediatePredictor()(x) # own IntermediatePredictor for each output
186-
res['intermediate'] = {k: IP(x) for k, x in intermediate.items()}
187-
res['result'] = IP(res['latent'])
188-
183+
res['result'] = IntermediatePredictor()(res['latent'])
189184
main = tf.keras.Model(inputs=inputs, outputs=res)
190185
return {
186+
'intermediate shapes': {k: v.shape for k, v in res['intermediate'].items()},
191187
'main': main,
192188
'Face2Step': Face2Step,
193189
'Step2Latent': Step2Latent,

scripts/train.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,22 +182,24 @@ def _trainer_from(args):
182182
raise Exception('Unknown trainer: %s' % (args.trainer, ))
183183

184184
def averageModels(folder, model, noiseStd=0.0):
185-
TV = [np.zeros_like(x) for x in model._model.get_weights()]
185+
TV = [np.zeros_like(x) for x in model.trainable_variables()]
186186
N = 0
187187
for nm in glob.glob(os.path.join(folder, '*.h5')):
188188
if not('best' in nm): continue # only the best models
189189
model.load(nm, embeddings=True)
190190
# add the weights to the total
191-
weights = model._model.get_weights()
191+
weights = model.trainable_variables()
192192
for i in range(len(TV)):
193-
TV[i] += weights[i]
193+
TV[i] += weights[i].numpy()
194194
continue
195195
N += 1
196196
continue
197197

198198
# average the weights
199199
TV = [(x / N) + np.random.normal(0.0, noiseStd, x.shape) for x in TV]
200-
model._model.set_weights(TV)
200+
for v, new in zip(model.trainable_variables(), TV):
201+
v.assign(new)
202+
continue
201203
model.compile() # recompile the model with the new weights
202204
return
203205

@@ -230,7 +232,7 @@ def main(args):
230232
),
231233
)
232234
)
233-
model = dict(timesteps=timesteps, stats=stats)
235+
model = dict(timesteps=timesteps, stats=stats, use_encoders=args.with_enconders)
234236
if args.model is not None:
235237
model['weights'] = dict(folder=folder, postfix=args.model, embeddings=args.embeddings)
236238
if args.modelId is not None:
@@ -327,6 +329,9 @@ def performRandomSearch(epoch=0):
327329
'--restarts', type=int, default=1,
328330
help='Number of times to restart the model reinitializing the weights'
329331
)
332+
parser.add_argument(
333+
'--with-enconders', default=False, action='store_true',
334+
)
330335

331336
main(parser.parse_args())
332337
pass

0 commit comments

Comments
 (0)