@@ -40,26 +40,52 @@ def _pointLoss(self, ytrue, ypred):
40
40
tf .assert_equal (tf .shape (loss ), tf .shape (ytrue ))
41
41
return tf .reduce_mean (loss , axis = - 1 )
42
42
43
- def _trainStep (self , Data ):
44
- print ('Instantiate _trainStep' )
45
- ###############
46
- x , (y , ) = Data
47
- y = y [..., 0 , :]
48
- losses = {}
49
- with tf .GradientTape () as tape :
50
- data = x ['augmented' ]
43
+ def _trainOn (self , data , y_list ):
44
+ def calculate_loss (predictions ):
45
+ # select the smallest loss from the list of suggested points
46
+ losses = []
47
+ for y in y_list :
48
+ loss = self ._pointLoss (y , predictions )[..., None ]
49
+ losses .append (loss )
50
+ continue
51
+ losses = tf .concat (losses , axis = - 1 )
52
+ shp = tf .shape (y_list [0 ])
53
+ tf .assert_equal (tf .shape (losses ), tf .concat ([shp [:- 1 ], [len (y_list )]], axis = 0 ))
54
+ losses = tf .reduce_min (losses , axis = - 1 )
55
+ tf .assert_equal (tf .shape (losses ), shp [:- 1 ])
56
+ return tf .reduce_mean (losses )
57
+
51
58
data = self ._replaceByEmbeddings (data )
52
59
predictions = self ._model (data , training = True )
53
60
intermediate = predictions ['intermediate' ]
54
- losses ['final' ] = tf .reduce_mean (self ._pointLoss (y , predictions ['result' ]))
61
+ finalPredictions = predictions ['result' ]
62
+ losses = {}
63
+ losses ['final' ] = calculate_loss (finalPredictions )
55
64
for name , encoder in self ._intermediateEncoders .items ():
56
65
latent = intermediate [name ]
57
66
pts = encoder (latent , training = True )
58
- loss = self . _pointLoss ( y , pts )
67
+ loss = calculate_loss ( pts )
59
68
losses ['loss-%s' % name ] = tf .reduce_mean (loss )
60
69
continue
61
- loss = sum (losses .values ())
62
- losses ['loss' ] = loss
70
+ return losses , tf .stop_gradient (finalPredictions )
71
+
72
+ def _trainStep (self , Data ):
73
+ print ('Instantiate _trainStep' )
74
+ ###############
75
+ x , (y , ) = Data
76
+ y = y [..., 0 , :]
77
+ losses = {}
78
+ with tf .GradientTape () as tape :
79
+ lossesClean , y_clean = self ._trainOn (x ['clean' ], [y ])
80
+ # ensure that the augmentations are not affect predictions
81
+ lossesAugmented , _ = self ._trainOn (x ['augmented' ], [y , y_clean ])
82
+ assert lossesClean .keys () == lossesAugmented .keys (), 'Losses keys mismatch'
83
+ # combine losses
84
+ losses = {k : lossesClean [k ] + lossesAugmented [k ] for k in lossesClean .keys ()}
85
+ # calculate total loss and final loss
86
+ losses ['total-clean' ] = sum (lossesClean .values ())
87
+ losses ['total-augmented' ] = sum (lossesAugmented .values ())
88
+ losses ['loss' ] = loss = sum ([losses ['total-clean' ], losses ['total-augmented' ]])
63
89
64
90
self ._optimizer .minimize (loss , tape .watched_variables (), tape = tape )
65
91
###############
0 commit comments