Skip to content

Commit df36145

Browse files
change NN
1 parent 68802d8 commit df36145

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

NN/networks.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ def Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize):
9090

9191
def Step2LatentModel(latentSize, embeddingsSize):
9292
latents = L.Input((None, latentSize))
93-
embeddings = L.Input((None, embeddingsSize))
93+
embeddingsInput = L.Input((None, embeddingsSize))
9494
T = L.Input((None, 1))
95+
embeddings = embeddingsInput[..., :1] * 0.0
9596

9697
stepsData = latents
9798
intermediate = {}
@@ -115,14 +116,14 @@ def Step2LatentModel(latentSize, embeddingsSize):
115116
continue
116117
# # # # # # # # # # # # # # # # # # # # # # # # # # # # #
117118
latent = sMLP(sizes=[latentSize] * 1, activation='relu')(
118-
L.Concatenate(-1)([stepsData, temporal, encodedT, encodedT])
119+
L.Concatenate(-1)([stepsData, temporal, encodedT, embeddings])
119120
)
120121
latent = CFusingBlock()([stepsData, latent])
121122
return tf.keras.Model(
122123
inputs={
123124
'latent': latents,
124125
'time': T,
125-
'embeddings': embeddings,
126+
'embeddings': embeddingsInput,
126127
},
127128
outputs={
128129
'latent': latent,
@@ -195,9 +196,7 @@ def Face2LatentModel(
195196
}
196197
res['result'] = IntermediatePredictor(
197198
shift=0.0 if diffusion else 0.5 # shift points to the center, if not using diffusion
198-
)(
199-
L.Concatenate(-1)([res['latent'], emb])
200-
)
199+
)(res['latent'])
201200

202201
if diffusion:
203202
inputs['diffusionT'] = diffusionT

0 commit comments

Comments
 (0)