Skip to content

Commit 039902b

Browse files
more advanced loss function for the model
1 parent 130d1e5 commit 039902b

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

Core/CModelTrainer.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,52 @@ def _pointLoss(self, ytrue, ypred):
4040
tf.assert_equal(tf.shape(loss), tf.shape(ytrue))
4141
return tf.reduce_mean(loss, axis=-1)
4242

43-
def _trainStep(self, Data):
44-
print('Instantiate _trainStep')
45-
###############
46-
x, (y, ) = Data
47-
y = y[..., 0, :]
48-
losses = {}
49-
with tf.GradientTape() as tape:
50-
data = x['augmented']
43+
def _trainOn(self, data, y_list):
44+
def calculate_loss(predictions):
45+
# select the smallest loss from the list of suggested points
46+
losses = []
47+
for y in y_list:
48+
loss = self._pointLoss(y, predictions)[..., None]
49+
losses.append(loss)
50+
continue
51+
losses = tf.concat(losses, axis=-1)
52+
shp = tf.shape(y_list[0])
53+
tf.assert_equal(tf.shape(losses), tf.concat([shp[:-1], [len(y_list)]], axis=0))
54+
losses = tf.reduce_min(losses, axis=-1)
55+
tf.assert_equal(tf.shape(losses), shp[:-1])
56+
return tf.reduce_mean(losses)
57+
5158
data = self._replaceByEmbeddings(data)
5259
predictions = self._model(data, training=True)
5360
intermediate = predictions['intermediate']
54-
losses['final'] = tf.reduce_mean(self._pointLoss(y, predictions['result']))
61+
finalPredictions = predictions['result']
62+
losses = {}
63+
losses['final'] = calculate_loss(finalPredictions)
5564
for name, encoder in self._intermediateEncoders.items():
5665
latent = intermediate[name]
5766
pts = encoder(latent, training=True)
58-
loss = self._pointLoss(y, pts)
67+
loss = calculate_loss(pts)
5968
losses['loss-%s' % name] = tf.reduce_mean(loss)
6069
continue
61-
loss = sum(losses.values())
62-
losses['loss'] = loss
70+
return losses, tf.stop_gradient(finalPredictions)
71+
72+
def _trainStep(self, Data):
73+
print('Instantiate _trainStep')
74+
###############
75+
x, (y, ) = Data
76+
y = y[..., 0, :]
77+
losses = {}
78+
with tf.GradientTape() as tape:
79+
lossesClean, y_clean = self._trainOn(x['clean'], [y])
80+
# ensure that the augmentations are not affect predictions
81+
lossesAugmented, _ = self._trainOn(x['augmented'], [y, y_clean])
82+
assert lossesClean.keys() == lossesAugmented.keys(), 'Losses keys mismatch'
83+
# combine losses
84+
losses = {k: lossesClean[k] + lossesAugmented[k] for k in lossesClean.keys()}
85+
# calculate total loss and final loss
86+
losses['total-clean'] = sum(lossesClean.values())
87+
losses['total-augmented'] = sum(lossesAugmented.values())
88+
losses['loss'] = loss = sum([losses['total-clean'], losses['total-augmented']])
6389

6490
self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
6591
###############

0 commit comments

Comments
 (0)