Skip to content

Commit 130d1e5

Browse files
adjust optimizer, increase model size, add LayerNormalization somewhere
1 parent df36145 commit 130d1e5

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

NN/Utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def __init__(self, mlp=None, **kwargs):
237237
if mlp is None: mlp = lambda x: x
238238
self._mlp = mlp
239239
self._norm = L.LayerNormalization()
240+
self._norm2 = L.LayerNormalization()
240241
return
241242

242243
def build(self, input_shapes):
@@ -253,14 +254,15 @@ def call(self, x):
253254
xhat = self._lastDense(xhat)
254255
x0 = x[0]
255256
x = tf.concat([x0, xhat], axis=-1)
256-
return self._combiner(x)
257+
res = self._combiner(x)
258+
return self._norm2(res)
257259
####################################
258260
# Hacky way to provide same optimizer for all models
259261
def createOptimizer(config=None):
260262
if config is None:
261263
config = {
262264
'learning_rate': 1e-4,
263-
'weight_decay': 1e-1,
265+
'weight_decay': 1e-4,
264266
'exclude_from_weight_decay': [
265267
'batch_normalization', 'bias',
266268
'CEL_', # exclude CCoordsEncodingLayer from weight decay

NN/networks.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,10 @@ def Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize):
6464
for i, EFeat in enumerate(encodedEFList):
6565
combined = CFusingBlock(name='F2S/ResMul-%d' % i)([
6666
combined,
67-
sMLP(sizes=[latentSize] * 1, activation='relu', name='F2S/MLP-%d' % i)(
68-
L.Concatenate(-1)([combined, encodedP, EFeat, embeddings])
67+
sMLP(sizes=[latentSize] * 3, activation='relu', name='F2S/MLP-%d' % i)(
68+
L.LayerNormalization()(
69+
L.Concatenate(-1)([combined, encodedP, EFeat, embeddings])
70+
)
6971
)
7072
])
7173
# save intermediate output
@@ -74,6 +76,7 @@ def Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize):
7476
continue
7577

7678
combined = L.Dense(latentSize, name='F2S/Combine')(combined)
79+
combined = L.LayerNormalization()(combined)
7780
# combined = CQuantizeLayer()(combined)
7881
return tf.keras.Model(
7982
inputs={
@@ -92,7 +95,7 @@ def Step2LatentModel(latentSize, embeddingsSize):
9295
latents = L.Input((None, latentSize))
9396
embeddingsInput = L.Input((None, embeddingsSize))
9497
T = L.Input((None, 1))
95-
embeddings = embeddingsInput[..., :1] * 0.0
98+
embeddings = embeddingsInput
9699

97100
stepsData = latents
98101
intermediate = {}
@@ -105,11 +108,11 @@ def Step2LatentModel(latentSize, embeddingsSize):
105108
intermediate['S2L/enc0'] = temporal
106109
# # # # # # # # # # # # # # # # # # # # # # # # # # # # #
107110
for blockId in range(3):
108-
temp = L.Concatenate(-1)([temporal, encodedT])
109-
for _ in range(1):
111+
temp = L.Concatenate(-1)([temporal, encodedT, embeddings])
112+
for _ in range(3):
110113
temp = L.LSTM(latentSize, return_sequences=True)(temp)
111-
temp = sMLP(sizes=[latentSize] * 1, activation='relu')(
112-
L.Concatenate(-1)([temporal, temp])
114+
temp = sMLP(sizes=[latentSize] * 3, activation='relu')(
115+
L.Concatenate(-1)([temporal, temp, encodedT, embeddings])
113116
)
114117
temporal = CFusingBlock()([temporal, temp])
115118
intermediate['S2L/ResLSTM-%d' % blockId] = temporal
@@ -165,6 +168,7 @@ def Face2LatentModel(
165168
# add diffusion features to the embeddings
166169
emb = L.Concatenate(-1)([emb, encodedDT, encodedDP])
167170

171+
emb = L.LayerNormalization()(emb)
168172
Face2Step = Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize=emb.shape[-1])
169173
Step2Latent = Step2LatentModel(latentSize, embeddingsSize=emb.shape[-1])
170174

@@ -196,7 +200,9 @@ def Face2LatentModel(
196200
}
197201
res['result'] = IntermediatePredictor(
198202
shift=0.0 if diffusion else 0.5 # shift points to the center, if not using diffusion
199-
)(res['latent'])
203+
)(
204+
L.Concatenate(-1)([res['latent'], T, emb])
205+
)
200206

201207
if diffusion:
202208
inputs['diffusionT'] = diffusionT

0 commit comments

Comments
 (0)